diff --git a/noxfile.py b/noxfile.py index f9cbbec..dc35c3e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,7 +15,15 @@ @nox.session def lint(session: nox.Session) -> None: - """Run the linter.""" + """ + Run the linter. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ session.install("pre-commit") session.run( "pre-commit", @@ -28,7 +36,15 @@ def lint(session: nox.Session) -> None: @nox.session def pylint(session: nox.Session) -> None: - """Run PyLint.""" + """ + Run PyLint. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ # This needs to be installed into the package environment, and is slower # than a pre-commit check session.install(".", "pylint") @@ -37,14 +53,30 @@ def pylint(session: nox.Session) -> None: @nox.session def tests(session: nox.Session) -> None: - """Run the unit and regular tests.""" + """ + Run the unit and regular tests. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ session.install(".[test]") session.run("pytest", *session.posargs) @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: - """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" + """ + Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ parser = argparse.ArgumentParser() parser.add_argument("--serve", action="store_true", help="Serve after building") parser.add_argument( @@ -91,7 +123,15 @@ def docs(session: nox.Session) -> None: @nox.session def build_api_docs(session: nox.Session) -> None: - """Build (regenerate) API docs.""" + """ + Build (regenerate) API docs. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ session.install("sphinx") session.chdir("docs") session.run( @@ -107,7 +147,15 @@ def build_api_docs(session: nox.Session) -> None: @nox.session def build(session: nox.Session) -> None: - """Build an SDist and wheel.""" + """ + Build an SDist and wheel. + + Parameters + ---------- + session : nox.Session + The Nox session object. + + """ build_path = DIR.joinpath("build") if build_path.exists(): shutil.rmtree(build_path) diff --git a/src/multinterp/backend/__init__.py b/src/multinterp/backend/__init__.py index 4f68d7b..1ca78fc 100644 --- a/src/multinterp/backend/__init__.py +++ b/src/multinterp/backend/__init__.py @@ -1,9 +1,10 @@ from __future__ import annotations import contextlib +from typing import Callable, Dict, List, Union -backend_functions = {} -backends = ["cupy", "jax", "numba", "scipy"] +backend_functions: Dict[str, Callable] = {} +backends: List[str] = ["cupy", "jax", "numba", "scipy"] for backend in backends: with contextlib.suppress(ImportError): @@ -13,7 +14,37 @@ ) -def multinterp(grids, values, args, backend="numba"): +def multinterp( + grids: List[Union[List[float], List[int]]], + values: Union[List[float], List[int]], + args: Union[List[float], List[int]], + backend: str = "numba", +) -> Union[List[float], List[int]]: + """ + Perform multivariate interpolation using the specified backend. + + Parameters + ---------- + grids : list of list of float or int + Grid points in the domain. + values : list of float or int + Functional values at the grid points. + args : list of float or int + Points at which to interpolate data. + backend : str, optional + Backend to use for interpolation. Default is "numba". + + Returns + ------- + list of float or int + Interpolated values of the function. + + Raises + ------ + ValueError + If the specified backend is not valid. + + """ if backend not in backends: msg = f"Invalid backend: {backend}" raise ValueError(msg) diff --git a/src/multinterp/backend/_cupy.py b/src/multinterp/backend/_cupy.py index e6b4603..51d1b2d 100644 --- a/src/multinterp/backend/_cupy.py +++ b/src/multinterp/backend/_cupy.py @@ -25,7 +25,21 @@ def cupy_multinterp(grids, values, args, options=None): array-like Interpolated values of the function. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, cp.ndarray): + raise ValueError("values should be a cupy array.") + if not isinstance(args, cp.ndarray): + raise ValueError("args should be a cupy array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + mc_kwargs = update_mc_kwargs(options) args = cp.asarray(args) @@ -57,7 +71,23 @@ def cupy_gradinterp(grids, values, args, axis=None, options=None): array-like Interpolated values of the gradient. + Raises + ------ + ValueError + If the input parameters are not of the expected types or if the axis parameter is not an integer. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, cp.ndarray): + raise ValueError("values should be a cupy array.") + if not isinstance(args, cp.ndarray): + raise ValueError("args should be a cupy array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + if axis is not None and not isinstance(axis, int): + raise ValueError("Axis should be an integer.") + mc_kwargs = update_mc_kwargs(options) eo = options.get("edge_order", 1) if options else 1 @@ -68,9 +98,6 @@ def cupy_gradinterp(grids, values, args, axis=None, options=None): coords = cupy_get_coordinates(grids, args) if axis is not None: - if not isinstance(axis, int): - msg = "Axis should be an integer." - raise ValueError(msg) gradient = cp.gradient(values, grids[axis], axis=axis, edge_order=eo) return cupy_map_coordinates(gradient, coords, **mc_kwargs) gradient = cp.gradient(values, *grids, edge_order=eo) @@ -94,7 +121,17 @@ def cupy_get_coordinates(grids, args): cp.array Coordinates with respect to the grid. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(args, cp.ndarray): + raise ValueError("args should be a cupy array.") + coords = cp.empty_like(args) for dim, grid in enumerate(grids): grid_size = cp.arange(grid.size) diff --git a/src/multinterp/backend/_jax.py b/src/multinterp/backend/_jax.py index a5b8254..f8bd88e 100644 --- a/src/multinterp/backend/_jax.py +++ b/src/multinterp/backend/_jax.py @@ -29,7 +29,21 @@ def jax_multinterp(grids, values, args, options=None): array-like Interpolated values. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, jnp.ndarray): + raise ValueError("values should be a jax array.") + if not isinstance(args, jnp.ndarray): + raise ValueError("args should be a jax array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + mc_kwargs = update_mc_kwargs(options, jax=True) args = jnp.asarray(args) @@ -61,7 +75,23 @@ def jax_gradinterp(grids, values, args, axis=None, options=None): array-like Interpolated values of the gradient. + Raises + ------ + ValueError + If the input parameters are not of the expected types or if the axis parameter is not an integer. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, jnp.ndarray): + raise ValueError("values should be a jax array.") + if not isinstance(args, jnp.ndarray): + raise ValueError("args should be a jax array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + if axis is not None and not isinstance(axis, int): + raise ValueError("Axis should be an integer.") + mc_kwargs = update_mc_kwargs(options, jax=True) eo = options.get("edge_order", 1) if options else 1 @@ -72,9 +102,6 @@ def jax_gradinterp(grids, values, args, axis=None, options=None): coords = jax_get_coordinates(grids, args) if axis is not None: - if not isinstance(axis, int): - msg = "Axis should be an integer." - raise ValueError(msg) gradient = jnp.gradient(values, grids[axis], axis=axis, edge_order=eo) return jax_map_coordinates(gradient, coords, **mc_kwargs) gradient = jnp.gradient(values, *grids, edge_order=eo) @@ -99,7 +126,17 @@ def jax_get_coordinates(grids, args): jnp.array Coordinates of the specified input points with respect to the grid. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(args, jnp.ndarray): + raise ValueError("args should be a jax array.") + grid_sizes = [jnp.arange(grid.size) for grid in grids] return jnp.array( [ @@ -132,6 +169,17 @@ def jax_map_coordinates(values, coords, order=None, mode=None, cval=None): Interpolated values at specified coordinates. """ + if not isinstance(values, jnp.ndarray): + raise ValueError("values should be a jax array.") + if not isinstance(coords, jnp.ndarray): + raise ValueError("coords should be a jax array.") + if order is not None and not isinstance(order, int): + raise ValueError("order should be an integer.") + if mode is not None and not isinstance(mode, str): + raise ValueError("mode should be a string.") + if cval is not None and not isinstance(cval, (int, float)): + raise ValueError("cval should be a number.") + original_shape = coords[0].shape coords = coords.reshape(len(values.shape), -1) output = map_coordinates(values, coords, order, mode, cval) diff --git a/src/multinterp/backend/_numba.py b/src/multinterp/backend/_numba.py index 7bfca76..c33780e 100644 --- a/src/multinterp/backend/_numba.py +++ b/src/multinterp/backend/_numba.py @@ -26,7 +26,21 @@ def numba_multinterp(grids, values, args, options=None): array-like Interpolated values of the function. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, (list, typed.List)): + raise ValueError("grids should be a list or typed.List of arrays.") + if not isinstance(values, np.ndarray): + raise ValueError("values should be a numpy array.") + if not isinstance(args, np.ndarray): + raise ValueError("args should be a numpy array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + mc_kwargs = update_mc_kwargs(options) args = np.asarray(args) @@ -91,7 +105,21 @@ def nb_interp_piecewise(args, grids, values, axis): np.ndarray Interpolated values on arguments. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(args, np.ndarray): + raise ValueError("args should be a numpy array.") + if not isinstance(grids, np.ndarray): + raise ValueError("grids should be a numpy array.") + if not isinstance(values, np.ndarray): + raise ValueError("values should be a numpy array.") + if not isinstance(axis, int): + raise ValueError("axis should be an integer.") + shape = args[0].shape # original shape of arguments size = args[0].size # number of points in arguments shape_axis = values.shape[axis] # number of points in axis diff --git a/src/multinterp/backend/_scipy.py b/src/multinterp/backend/_scipy.py index 3e46948..331d60c 100644 --- a/src/multinterp/backend/_scipy.py +++ b/src/multinterp/backend/_scipy.py @@ -25,7 +25,21 @@ def scipy_multinterp(grids, values, args, options=None): array-like Interpolated values of the function. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, np.ndarray): + raise ValueError("values should be a numpy array.") + if not isinstance(args, np.ndarray): + raise ValueError("args should be a numpy array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + mc_kwargs = update_mc_kwargs(options) args = np.asarray(args) @@ -57,7 +71,23 @@ def scipy_gradinterp(grids, values, args, axis=None, options=None): array-like Interpolated values of the gradient. + Raises + ------ + ValueError + If the input parameters are not of the expected types or if the axis parameter is not an integer. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(values, np.ndarray): + raise ValueError("values should be a numpy array.") + if not isinstance(args, np.ndarray): + raise ValueError("args should be a numpy array.") + if options is not None and not isinstance(options, dict): + raise ValueError("options should be a dictionary.") + if axis is not None and not isinstance(axis, int): + raise ValueError("Axis should be an integer.") + mc_kwargs = update_mc_kwargs(options) eo = options.get("edge_order", 1) if options else 1 @@ -68,9 +98,6 @@ def scipy_gradinterp(grids, values, args, axis=None, options=None): coords = scipy_get_coordinates(grids, args) if axis is not None: - if not isinstance(axis, int): - msg = "Axis should be an integer." - raise ValueError(msg) gradient = np.gradient(values, grids[axis], axis=axis, edge_order=eo) return scipy_map_coordinates(gradient, coords, **mc_kwargs) gradient = np.gradient(values, *grids, edge_order=eo) @@ -94,7 +121,17 @@ def scipy_get_coordinates(grids, args): np.array Coordinates with respect to the grid. + Raises + ------ + ValueError + If the input parameters are not of the expected types. + """ + if not isinstance(grids, list): + raise ValueError("grids should be a list of arrays.") + if not isinstance(args, np.ndarray): + raise ValueError("args should be a numpy array.") + coords = np.empty_like(args) for dim, grid in enumerate(grids): grid_size = np.arange(grid.size) diff --git a/src/multinterp/backend/_torch.py b/src/multinterp/backend/_torch.py index 7d2acc2..cf0841e 100644 --- a/src/multinterp/backend/_torch.py +++ b/src/multinterp/backend/_torch.py @@ -9,7 +9,28 @@ from multinterp.utilities import update_mc_kwargs -def as_tensor(arrs, device="cpu"): +def as_tensor(arrs: Sequence[np.ndarray | torch.Tensor | list], device: str = "cpu") -> torch.Tensor: + """ + Convert input arrays to a PyTorch tensor on the specified device. + + Parameters + ---------- + arrs : Sequence[np.ndarray | torch.Tensor | list] + Input arrays to be converted. + device : str, optional + Target device for the tensor, by default "cpu". + + Returns + ------- + torch.Tensor + Converted tensor on the specified device. + + Raises + ------ + TypeError + If the input arrays are not of the expected types. + + """ target_device = torch.device(device) if isinstance(arrs, (torch.Tensor, np.ndarray)): @@ -21,7 +42,27 @@ def as_tensor(arrs, device="cpu"): raise TypeError(msg) -def torch_multinterp(grids, values, args, options=None): +def torch_multinterp(grids: Sequence[np.ndarray], values: np.ndarray, args: np.ndarray, options: dict = None) -> torch.Tensor: + """ + Perform multivariate interpolation using PyTorch. + + Parameters + ---------- + grids : Sequence[np.ndarray] + Grid points in the domain. + values : np.ndarray + Functional values at the grid points. + args : np.ndarray + Points at which to interpolate data. + options : dict, optional + Additional options for interpolation. + + Returns + ------- + torch.Tensor + Interpolated values of the function. + + """ mc_kwargs = update_mc_kwargs(options) target_device = options.get("device", "cpu") if options else "cpu" @@ -33,7 +74,34 @@ def torch_multinterp(grids, values, args, options=None): return torch_map_coordinates(values, coords, **mc_kwargs) -def torch_gradinterp(grids, values, args, axis=None, options=None): +def torch_gradinterp(grids: Sequence[np.ndarray], values: np.ndarray, args: np.ndarray, axis: int = None, options: dict = None) -> torch.Tensor: + """ + Computes the interpolated value of the gradient evaluated at specified points using PyTorch. + + Parameters + ---------- + grids : Sequence[np.ndarray] + Grid points in the domain. + values : np.ndarray + Functional values at the grid points. + args : np.ndarray + Points at which to interpolate data. + axis : int, optional + Axis along which to compute the gradient. + options : dict, optional + Additional options for interpolation. + + Returns + ------- + torch.Tensor + Interpolated values of the gradient. + + Raises + ------ + ValueError + If the axis parameter is not an integer. + + """ mc_kwargs = update_mc_kwargs(options) eo = options.get("edge_order", 1) if options else 1 @@ -55,7 +123,23 @@ def torch_gradinterp(grids, values, args, axis=None, options=None): ) -def torch_get_coordinates(grids, args): +def torch_get_coordinates(grids: Sequence[torch.Tensor], args: torch.Tensor) -> torch.Tensor: + """ + Takes input values and converts them to coordinates with respect to the specified grid. + + Parameters + ---------- + grids : Sequence[torch.Tensor] + Grid points for each dimension. + args : torch.Tensor + Points at which to interpolate data. + + Returns + ------- + torch.Tensor + Coordinates with respect to the grid. + + """ coords = torch.empty_like(args) for dim, grid in enumerate(grids): grid_size = torch.arange(grid.numel(), device=grid.device) @@ -64,7 +148,23 @@ def torch_get_coordinates(grids, args): return coords -def torch_map_coordinates(values, coords, **kwargs): +def torch_map_coordinates(values: torch.Tensor, coords: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Run the map_coordinates function on the specified values. + + Parameters + ---------- + values : torch.Tensor + Functional values from which to interpolate. + coords : torch.Tensor + Coordinates at which to interpolate values. + + Returns + ------- + torch.Tensor + Interpolated values. + + """ original_shape = coords[0].shape coords = coords.reshape(len(values.shape), -1) output = map_coordinates(values, coords, **kwargs) @@ -178,22 +278,23 @@ def map_coordinates( return _map_coordinates(input, coordinates, order, mode, cval) -def torch_interp(x, xp, fp): - """One-dimensional linear interpolation in PyTorch. +def torch_interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: + """ + One-dimensional linear interpolation in PyTorch. Parameters ---------- - x: array_like - The x-coordinates of the interpolated values. - xp: 1-D sequence of floats - The x-coordinates of the data points. - fp: 1-D sequence of floats - The y-coordinates of the data points, same length as xp. + x : torch.Tensor + The x-coordinates of the interpolated values. + xp : torch.Tensor + The x-coordinates of the data points. + fp : torch.Tensor + The y-coordinates of the data points, same length as xp. Returns ------- - array_like - The interpolated values, same shape as x. + torch.Tensor + The interpolated values, same shape as x. """ # Sort and get sorted indices diff --git a/src/multinterp/rectilinear/_multi.py b/src/multinterp/rectilinear/_multi.py index daa4f98..1924724 100644 --- a/src/multinterp/rectilinear/_multi.py +++ b/src/multinterp/rectilinear/_multi.py @@ -13,30 +13,38 @@ class MultivariateInterp(_RegularGrid): to index coordinates and uses `map_coordinates` from scipy, cupy, or jax. """ - def __init__(self, values, grids, backend="scipy", options=None): + def __init__(self, values: np.ndarray, grids: list[np.ndarray], backend: str = "scipy", options: dict | None = None): """Initialize a multivariate interpolator. Parameters ---------- values : np.ndarray Functional values on a regular grid. - grids : _type_ + grids : list[np.ndarray] 1D grids for each dimension. backend : str, optional One of "scipy", "numba", "cupy", or "jax". Determines hardware to use for interpolation. + options : dict, optional + Additional options for interpolation. """ super().__init__(values, grids, backend=backend) self.mc_kwargs = update_mc_kwargs(options, jax=self.backend == "jax") self._gradient = {} - def compile(self): + def compile(self) -> None: + """Compile the interpolator by precomputing the coordinates.""" self(*self.grids) - def __call__(self, *args): + def __call__(self, *args: np.ndarray) -> np.ndarray: """Interpolates arguments on the regular grid. + Parameters + ---------- + args : np.ndarray + Arguments to be interpolated. + Returns ------- np.ndarray @@ -57,7 +65,7 @@ def __call__(self, *args): coords = self._get_coordinates(args) return self._map_coordinates(coords) - def _get_coordinates(self, args): + def _get_coordinates(self, args: np.ndarray) -> np.ndarray: """For each argument, finds the index coordinates for interpolation. Parameters @@ -73,13 +81,13 @@ def _get_coordinates(self, args): """ return get_coords(self.grids, args, backend=self.backend) - def _map_coordinates(self, coords): + def _map_coordinates(self, coords: np.ndarray) -> np.ndarray: """Uses coordinates to interpolate on the regular grid with `map_coordinates` from scipy or cupy, depending on backend. Parameters ---------- - coordinates : np.ndarray + coords : np.ndarray Index coordinates for interpolation. Returns @@ -90,7 +98,7 @@ def _map_coordinates(self, coords): """ return map_coords(self.values, coords, **self.mc_kwargs, backend=self.backend) - def diff(self, axis=None, edge_order=1): + def diff(self, axis: int | None = None, edge_order: int = 1) -> MultivaluedInterp | MultivariateInterp: """Differentiates the interpolator along the specified axis. If axis is None, then returns a MultivaluedInterp object that approximates the partial derivative of the function across all axes. Otherwise, returns a MultivariateInterp object that approximates the partial derivative of the function along the specified axis. Parameters @@ -98,18 +106,20 @@ def diff(self, axis=None, edge_order=1): axis : int, optional Axis along which to differentiate the function. edge_order : int, optional - TODO: Add description + Order of the finite difference approximation used to compute the gradient. Returns ------- MultivaluedInterp or MultivariateInterp Interpolator object that approximates the partial derivative(s) of the function. + Raises + ------ + ValueError + If the specified axis is not valid. + """ - # if axis is not an integer less than or equal to the number - # of dimensions of the input array, then a ValueError is raised. if axis is None: - # return MultivaluedInterp for ax in range(self.ndim): if ax not in self._gradient: self._gradient[ax] = get_grad( @@ -154,30 +164,38 @@ class MultivaluedInterp(_MultivaluedRegularGrid): to index coordinates and uses `map_coordinates` from scipy, cupy, or jax. """ - def __init__(self, values, grids, backend="scipy", options=None): + def __init__(self, values: np.ndarray, grids: list[np.ndarray], backend: str = "scipy", options: dict | None = None): """Initialize a multivariate interpolator. Parameters ---------- values : np.ndarray Functional values on a regular grid. - grids : _type_ + grids : list[np.ndarray] 1D grids for each dimension. backend : str, optional One of "scipy", "numba", "cupy", or "jax". Determines hardware to use for interpolation. + options : dict, optional + Additional options for interpolation. """ super().__init__(values, grids, backend=backend) self.mc_kwargs = update_mc_kwargs(options) self._gradient = {} - def compile(self): + def compile(self) -> None: + """Compile the interpolator by precomputing the coordinates.""" self(*self.grids) - def __call__(self, *args): + def __call__(self, *args: np.ndarray) -> np.ndarray: """Interpolates arguments on the regular grid. + Parameters + ---------- + args : np.ndarray + Arguments to be interpolated. + Returns ------- np.ndarray @@ -198,7 +216,7 @@ def __call__(self, *args): coords = self._get_coordinates(args) return self._map_coordinates(coords) - def _get_coordinates(self, args): + def _get_coordinates(self, args: np.ndarray) -> np.ndarray: """For each argument, finds the index coordinates for interpolation. Parameters @@ -214,13 +232,13 @@ def _get_coordinates(self, args): """ return get_coords(self.grids, args, self.backend) - def _map_coordinates(self, coords): + def _map_coordinates(self, coords: np.ndarray) -> np.ndarray: """Uses coordinates to interpolate on the regular grid with `map_coordinates` from scipy or cupy, depending on backend. Parameters ---------- - coordinates : np.ndarray + coords : np.ndarray Index coordinates for interpolation. Returns @@ -236,24 +254,29 @@ def _map_coordinates(self, coords): return asarray(fvals, backend=self.backend) - def diff(self, axis=None, argnum=None, edge_order=1): + def diff(self, axis: int | None = None, argnum: int | None = None, edge_order: int = 1) -> MultivaluedInterp | MultivariateInterp: """Differentiates the interpolator along the specified axis. If both axis and argnum are specified, then returns the partial derivative of the specified function argument along the specified axis. If axis is None, then returns a MultivaluedInterp object that approximates the partial derivatives of the specified function argument along each axis. If argnum is None, then returns a MultivaluedInterp object that approximates the partial derivatives of all arguments of the function along the specified axes. Parameters ---------- axis : int, optional Axis along which to differentiate the function. + argnum : int, optional + Argument number to differentiate. edge_order : int, optional - TODO: Add description + Order of the finite difference approximation used to compute the gradient. Returns ------- MultivaluedInterp or MultivariateInterp Interpolator object that approximates the partial derivative(s) of the function. + Raises + ------ + ValueError + If the specified axis is not valid. + """ - # if axis is not an integer less than or equal to the number - # of dimensions of the input array, then a ValueError is raised. if axis is None: msg = "Must specify axis (function) to differentiate." raise ValueError(msg) @@ -263,7 +286,6 @@ def diff(self, axis=None, argnum=None, edge_order=1): raise ValueError(msg) if argnum is None: - # return MultivaluedInterp for arg in range(self.ndim): if (axis, arg) not in self._gradient: self._gradient[(axis, arg)] = get_grad( diff --git a/src/multinterp/rectilinear/_utils.py b/src/multinterp/rectilinear/_utils.py index 20f96b5..cd678d1 100644 --- a/src/multinterp/rectilinear/_utils.py +++ b/src/multinterp/rectilinear/_utils.py @@ -1,36 +1,35 @@ from __future__ import annotations import contextlib - import numpy as np +from typing import Any, Callable from multinterp.backend._numba import numba_get_coordinates, numba_map_coordinates from multinterp.backend._scipy import scipy_get_coordinates, scipy_map_coordinates -GET_COORDS = { +GET_COORDS: dict[str, Callable[..., Any]] = { "scipy": scipy_get_coordinates, "numba": numba_get_coordinates, } -MAP_COORDS = { +MAP_COORDS: dict[str, Callable[..., Any]] = { "scipy": scipy_map_coordinates, "numba": numba_map_coordinates, } -GET_GRAD = { +GET_GRAD: dict[str, Callable[..., Any]] = { "scipy": np.gradient, "numba": np.gradient, } with contextlib.suppress(ImportError): import cupy as cp - from multinterp.backend._cupy import cupy_get_coordinates, cupy_map_coordinates GET_COORDS["cupy"] = cupy_get_coordinates MAP_COORDS["cupy"] = cupy_map_coordinates GET_GRAD["cupy"] = cp.gradient + with contextlib.suppress(ImportError): import jax.numpy as jnp - from multinterp.backend._jax import jax_get_coordinates, jax_map_coordinates GET_COORDS["jax"] = jax_get_coordinates @@ -38,19 +37,70 @@ GET_GRAD["jax"] = jnp.gradient -def get_coords(grids, args, backend="scipy"): - """Wrapper function for the get_coordinates function from the chosen backend.""" +def get_coords(grids: list[np.ndarray], args: np.ndarray, backend: str = "scipy") -> np.ndarray: + """Wrapper function for the get_coordinates function from the chosen backend. + + Parameters + ---------- + grids : list of np.ndarray + Grid points in the domain. + args : np.ndarray + Points at which to interpolate data. + backend : str, optional + Backend to use for interpolation. Default is "scipy". + + Returns + ------- + np.ndarray + Coordinates with respect to the grid. + + """ return GET_COORDS[backend](grids, args) -def map_coords(values, coords, backend="scipy", **kwargs): - """Wrapper function for the map_coordinates function from the chosen backend.""" +def map_coords(values: np.ndarray, coords: np.ndarray, backend: str = "scipy", **kwargs: Any) -> np.ndarray: + """Wrapper function for the map_coordinates function from the chosen backend. + + Parameters + ---------- + values : np.ndarray + Functional values from which to interpolate. + coords : np.ndarray + Coordinates at which to interpolate values. + backend : str, optional + Backend to use for interpolation. Default is "scipy". + **kwargs : dict + Additional keyword arguments for the map_coordinates function. + + Returns + ------- + np.ndarray + Interpolated values of the function. + + """ return MAP_COORDS[backend](values, coords, **kwargs) -def get_grad(values, grids, axis=None, edge_order=None, backend="scipy"): +def get_grad(values: np.ndarray, grids: np.ndarray, axis: int | None = None, edge_order: int | None = None, backend: str = "scipy") -> np.ndarray: """Wrapper function for the gradient function from the chosen backend. - TODO: use appropriate gradient functions from each backend. + Parameters + ---------- + values : np.ndarray + Functional values at the grid points. + grids : np.ndarray + Grid points in the domain. + axis : int, optional + Axis along which to compute the gradient. + edge_order : int, optional + Order of the finite difference approximation used to compute the gradient. + backend : str, optional + Backend to use for interpolation. Default is "scipy". + + Returns + ------- + np.ndarray + Gradient of the function. + """ return GET_GRAD[backend](values, grids, axis=axis, edge_order=edge_order) diff --git a/src/multinterp/utilities.py b/src/multinterp/utilities.py index f64539d..2b689aa 100644 --- a/src/multinterp/utilities.py +++ b/src/multinterp/utilities.py @@ -43,7 +43,23 @@ } -def update_mc_kwargs(options=None, jax=False): +def update_mc_kwargs(options: dict | None = None, jax: bool = False) -> dict: + """ + Update the keyword arguments for the map_coordinates function based on the provided options. + + Parameters + ---------- + options : dict, optional + Additional options for interpolation. + jax : bool, optional + Flag indicating whether to use JAX-specific options. + + Returns + ------- + dict + Updated keyword arguments for the map_coordinates function. + + """ mc_kwargs = SHORT_MC_KWARGS if jax else LONG_MC_KWARGS if options: mc_kwargs = SHORT_MC_KWARGS.copy() if jax else LONG_MC_KWARGS.copy() @@ -52,7 +68,28 @@ def update_mc_kwargs(options=None, jax=False): return mc_kwargs -def asarray(values, backend): +def asarray(values: np.ndarray, backend: str) -> np.ndarray: + """ + Convert the input values to an array using the specified backend. + + Parameters + ---------- + values : np.ndarray + Input values to be converted. + backend : str + Backend to use for conversion. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + Converted array. + + Raises + ------ + ValueError + If the specified backend is not valid. + + """ if backend not in BACKENDS: msg = f"Invalid backend. Must be one of: {BACKENDS}" raise ValueError(msg) @@ -60,7 +97,23 @@ def asarray(values, backend): return MODULES[backend].asarray(values) -def aslist(grids, backend): +def aslist(grids: list[np.ndarray], backend: str) -> list[np.ndarray]: + """ + Convert the input grids to a list of arrays using the specified backend. + + Parameters + ---------- + grids : list of np.ndarray + Input grids to be converted. + backend : str + Backend to use for conversion. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + list of np.ndarray + Converted list of arrays. + + """ if backend == "numba": grids = typed.List([np.asarray(grid) for grid in grids]) else: @@ -69,21 +122,109 @@ def aslist(grids, backend): return grids -def empty(shape, backend): +def empty(shape: tuple[int, ...], backend: str) -> np.ndarray: + """ + Create an empty array with the specified shape using the specified backend. + + Parameters + ---------- + shape : tuple of int + Shape of the empty array. + backend : str + Backend to use for creating the array. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + Empty array with the specified shape. + + """ return MODULES[backend].empty(shape) -def empty_like(values, backend): +def empty_like(values: np.ndarray, backend: str) -> np.ndarray: + """ + Create an empty array with the same shape and type as the input values using the specified backend. + + Parameters + ---------- + values : np.ndarray + Input values to determine the shape and type of the empty array. + backend : str + Backend to use for creating the array. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + Empty array with the same shape and type as the input values. + + """ return MODULES[backend].empty_like(values) -def interp(x, y, z, backend): +def interp(x: np.ndarray, y: np.ndarray, z: np.ndarray, backend: str) -> np.ndarray: + """ + Perform one-dimensional linear interpolation using the specified backend. + + Parameters + ---------- + x : np.ndarray + The x-coordinates of the interpolated values. + y : np.ndarray + The x-coordinates of the data points. + z : np.ndarray + The y-coordinates of the data points, same length as y. + backend : str + Backend to use for interpolation. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + The interpolated values, same shape as x. + + """ return MODULES[backend].interp(x, y, z) -def take(arr, indices, axis, backend): +def take(arr: np.ndarray, indices: int, axis: int, backend: str) -> np.ndarray: + """ + Take elements from an array along an axis using the specified backend. + + Parameters + ---------- + arr : np.ndarray + Input array from which to take elements. + indices : int + Indices of elements to take. + axis : int + Axis along which to take elements. + backend : str + Backend to use for taking elements. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + Array of taken elements. + + """ return MODULES[backend].take(arr, indices, axis=axis) -def mgrid(args, backend): +def mgrid(args: tuple[slice, ...], backend: str) -> np.ndarray: + """ + Return coordinate matrices from coordinate vectors using the specified backend. + + Parameters + ---------- + args : tuple of slice + Coordinate vectors. + backend : str + Backend to use for creating coordinate matrices. Must be one of "scipy", "numba", "cupy", "jax", or "torch". + + Returns + ------- + np.ndarray + Coordinate matrices. + + """ return MODULES[backend].mgrid[args] diff --git a/tests/test_backend.py b/tests/test_backend.py index 16a7035..e3388ae 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -137,3 +137,40 @@ def test_torch_3d(setup_data_3d): grids, values, args, true_values = setup_data_3d result_multinterp = torch_multinterp(grids, values, args) assert np.allclose(true_values, result_multinterp.cpu(), atol=1e-05) + + +def test_scipy_multinterp_invalid_grids(): + """Test scipy_multinterp with invalid grids parameter.""" + grids = "invalid_grids" + values = np.array([1, 2, 3]) + args = np.array([1, 2, 3]) + with pytest.raises(ValueError, match="grids should be a list of arrays."): + scipy_multinterp(grids, values, args) + + +def test_scipy_multinterp_invalid_values(): + """Test scipy_multinterp with invalid values parameter.""" + grids = [np.array([1, 2, 3])] + values = "invalid_values" + args = np.array([1, 2, 3]) + with pytest.raises(ValueError, match="values should be a numpy array."): + scipy_multinterp(grids, values, args) + + +def test_scipy_multinterp_invalid_args(): + """Test scipy_multinterp with invalid args parameter.""" + grids = [np.array([1, 2, 3])] + values = np.array([1, 2, 3]) + args = "invalid_args" + with pytest.raises(ValueError, match="args should be a numpy array."): + scipy_multinterp(grids, values, args) + + +def test_scipy_multinterp_invalid_options(): + """Test scipy_multinterp with invalid options parameter.""" + grids = [np.array([1, 2, 3])] + values = np.array([1, 2, 3]) + args = np.array([1, 2, 3]) + options = "invalid_options" + with pytest.raises(ValueError, match="options should be a dictionary."): + scipy_multinterp(grids, values, args, options)