Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@

@nox.session
def lint(session: nox.Session) -> None:
"""Run the linter."""
"""
Run the linter.

Parameters
----------
session : nox.Session
The Nox session object.

"""
session.install("pre-commit")
session.run(
"pre-commit",
Expand All @@ -28,7 +36,15 @@ def lint(session: nox.Session) -> None:

@nox.session
def pylint(session: nox.Session) -> None:
"""Run PyLint."""
"""
Run PyLint.

Parameters
----------
session : nox.Session
The Nox session object.

"""
# This needs to be installed into the package environment, and is slower
# than a pre-commit check
session.install(".", "pylint")
Expand All @@ -37,14 +53,30 @@ def pylint(session: nox.Session) -> None:

@nox.session
def tests(session: nox.Session) -> None:
"""Run the unit and regular tests."""
"""
Run the unit and regular tests.

Parameters
----------
session : nox.Session
The Nox session object.

"""
session.install(".[test]")
session.run("pytest", *session.posargs)


@nox.session(reuse_venv=True)
def docs(session: nox.Session) -> None:
"""Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links."""
"""
Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.

Parameters
----------
session : nox.Session
The Nox session object.

"""
parser = argparse.ArgumentParser()
parser.add_argument("--serve", action="store_true", help="Serve after building")
parser.add_argument(
Expand Down Expand Up @@ -91,7 +123,15 @@ def docs(session: nox.Session) -> None:

@nox.session
def build_api_docs(session: nox.Session) -> None:
"""Build (regenerate) API docs."""
"""
Build (regenerate) API docs.

Parameters
----------
session : nox.Session
The Nox session object.

"""
session.install("sphinx")
session.chdir("docs")
session.run(
Expand All @@ -107,7 +147,15 @@ def build_api_docs(session: nox.Session) -> None:

@nox.session
def build(session: nox.Session) -> None:
"""Build an SDist and wheel."""
"""
Build an SDist and wheel.

Parameters
----------
session : nox.Session
The Nox session object.

"""
build_path = DIR.joinpath("build")
if build_path.exists():
shutil.rmtree(build_path)
Expand Down
37 changes: 34 additions & 3 deletions src/multinterp/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import contextlib
from typing import Callable, Dict, List, Union

backend_functions = {}
backends = ["cupy", "jax", "numba", "scipy"]
backend_functions: Dict[str, Callable] = {}
backends: List[str] = ["cupy", "jax", "numba", "scipy"]

for backend in backends:
with contextlib.suppress(ImportError):
Expand All @@ -13,7 +14,37 @@
)


def multinterp(grids, values, args, backend="numba"):
def multinterp(
grids: List[Union[List[float], List[int]]],
values: Union[List[float], List[int]],
args: Union[List[float], List[int]],
backend: str = "numba",
) -> Union[List[float], List[int]]:
"""
Perform multivariate interpolation using the specified backend.

Parameters
----------
grids : list of list of float or int
Grid points in the domain.
values : list of float or int
Functional values at the grid points.
args : list of float or int
Points at which to interpolate data.
backend : str, optional
Backend to use for interpolation. Default is "numba".

Returns
-------
list of float or int
Interpolated values of the function.

Raises
------
ValueError
If the specified backend is not valid.

"""
if backend not in backends:
msg = f"Invalid backend: {backend}"
raise ValueError(msg)
Expand Down
43 changes: 40 additions & 3 deletions src/multinterp/backend/_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,21 @@ def cupy_multinterp(grids, values, args, options=None):
array-like
Interpolated values of the function.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(values, cp.ndarray):
raise ValueError("values should be a cupy array.")
if not isinstance(args, cp.ndarray):
raise ValueError("args should be a cupy array.")
if options is not None and not isinstance(options, dict):
raise ValueError("options should be a dictionary.")

mc_kwargs = update_mc_kwargs(options)

args = cp.asarray(args)
Expand Down Expand Up @@ -57,7 +71,23 @@ def cupy_gradinterp(grids, values, args, axis=None, options=None):
array-like
Interpolated values of the gradient.

Raises
------
ValueError
If the input parameters are not of the expected types or if the axis parameter is not an integer.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(values, cp.ndarray):
raise ValueError("values should be a cupy array.")
if not isinstance(args, cp.ndarray):
raise ValueError("args should be a cupy array.")
if options is not None and not isinstance(options, dict):
raise ValueError("options should be a dictionary.")
if axis is not None and not isinstance(axis, int):
raise ValueError("Axis should be an integer.")

mc_kwargs = update_mc_kwargs(options)
eo = options.get("edge_order", 1) if options else 1

Expand All @@ -68,9 +98,6 @@ def cupy_gradinterp(grids, values, args, axis=None, options=None):
coords = cupy_get_coordinates(grids, args)

if axis is not None:
if not isinstance(axis, int):
msg = "Axis should be an integer."
raise ValueError(msg)
gradient = cp.gradient(values, grids[axis], axis=axis, edge_order=eo)
return cupy_map_coordinates(gradient, coords, **mc_kwargs)
gradient = cp.gradient(values, *grids, edge_order=eo)
Expand All @@ -94,7 +121,17 @@ def cupy_get_coordinates(grids, args):
cp.array
Coordinates with respect to the grid.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(args, cp.ndarray):
raise ValueError("args should be a cupy array.")

coords = cp.empty_like(args)
for dim, grid in enumerate(grids):
grid_size = cp.arange(grid.size)
Expand Down
54 changes: 51 additions & 3 deletions src/multinterp/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,21 @@ def jax_multinterp(grids, values, args, options=None):
array-like
Interpolated values.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(values, jnp.ndarray):
raise ValueError("values should be a jax array.")
if not isinstance(args, jnp.ndarray):
raise ValueError("args should be a jax array.")
if options is not None and not isinstance(options, dict):
raise ValueError("options should be a dictionary.")

mc_kwargs = update_mc_kwargs(options, jax=True)

args = jnp.asarray(args)
Expand Down Expand Up @@ -61,7 +75,23 @@ def jax_gradinterp(grids, values, args, axis=None, options=None):
array-like
Interpolated values of the gradient.

Raises
------
ValueError
If the input parameters are not of the expected types or if the axis parameter is not an integer.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(values, jnp.ndarray):
raise ValueError("values should be a jax array.")
if not isinstance(args, jnp.ndarray):
raise ValueError("args should be a jax array.")
if options is not None and not isinstance(options, dict):
raise ValueError("options should be a dictionary.")
if axis is not None and not isinstance(axis, int):
raise ValueError("Axis should be an integer.")

mc_kwargs = update_mc_kwargs(options, jax=True)
eo = options.get("edge_order", 1) if options else 1

Expand All @@ -72,9 +102,6 @@ def jax_gradinterp(grids, values, args, axis=None, options=None):
coords = jax_get_coordinates(grids, args)

if axis is not None:
if not isinstance(axis, int):
msg = "Axis should be an integer."
raise ValueError(msg)
gradient = jnp.gradient(values, grids[axis], axis=axis, edge_order=eo)
return jax_map_coordinates(gradient, coords, **mc_kwargs)
gradient = jnp.gradient(values, *grids, edge_order=eo)
Expand All @@ -99,7 +126,17 @@ def jax_get_coordinates(grids, args):
jnp.array
Coordinates of the specified input points with respect to the grid.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(grids, list):
raise ValueError("grids should be a list of arrays.")
if not isinstance(args, jnp.ndarray):
raise ValueError("args should be a jax array.")

grid_sizes = [jnp.arange(grid.size) for grid in grids]
return jnp.array(
[
Expand Down Expand Up @@ -132,6 +169,17 @@ def jax_map_coordinates(values, coords, order=None, mode=None, cval=None):
Interpolated values at specified coordinates.

"""
if not isinstance(values, jnp.ndarray):
raise ValueError("values should be a jax array.")
if not isinstance(coords, jnp.ndarray):
raise ValueError("coords should be a jax array.")
if order is not None and not isinstance(order, int):
raise ValueError("order should be an integer.")
if mode is not None and not isinstance(mode, str):
raise ValueError("mode should be a string.")
if cval is not None and not isinstance(cval, (int, float)):
raise ValueError("cval should be a number.")

original_shape = coords[0].shape
coords = coords.reshape(len(values.shape), -1)
output = map_coordinates(values, coords, order, mode, cval)
Expand Down
28 changes: 28 additions & 0 deletions src/multinterp/backend/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def numba_multinterp(grids, values, args, options=None):
array-like
Interpolated values of the function.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(grids, (list, typed.List)):
raise ValueError("grids should be a list or typed.List of arrays.")
if not isinstance(values, np.ndarray):
raise ValueError("values should be a numpy array.")
if not isinstance(args, np.ndarray):
raise ValueError("args should be a numpy array.")
if options is not None and not isinstance(options, dict):
raise ValueError("options should be a dictionary.")

mc_kwargs = update_mc_kwargs(options)

args = np.asarray(args)
Expand Down Expand Up @@ -91,7 +105,21 @@ def nb_interp_piecewise(args, grids, values, axis):
np.ndarray
Interpolated values on arguments.

Raises
------
ValueError
If the input parameters are not of the expected types.

"""
if not isinstance(args, np.ndarray):
raise ValueError("args should be a numpy array.")
if not isinstance(grids, np.ndarray):
raise ValueError("grids should be a numpy array.")
if not isinstance(values, np.ndarray):
raise ValueError("values should be a numpy array.")
if not isinstance(axis, int):
raise ValueError("axis should be an integer.")

shape = args[0].shape # original shape of arguments
size = args[0].size # number of points in arguments
shape_axis = values.shape[axis] # number of points in axis
Expand Down
Loading