From 87f39024f407d03b2c1fc259b6476666d6c5348a Mon Sep 17 00:00:00 2001 From: Mathias Truel <28542694+mtruel@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:21:03 +0200 Subject: [PATCH 1/5] Add RBFInterpolator with neighbours supports and tests against Scipy (#4) Using scipy API and using jaxkd for kd-trees. --- README.rst | 1 - interpax/__init__.py | 1 + interpax/_rbf.py | 674 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_scipy.py | 281 ++++++++++++++++++ 5 files changed, 957 insertions(+), 1 deletion(-) create mode 100644 interpax/_rbf.py diff --git a/README.rst b/README.rst index adbdab5..bc8f4d8 100644 --- a/README.rst +++ b/README.rst @@ -14,7 +14,6 @@ in 1d, 2d, and 3d, as well as Fourier interpolation for periodic functions in Coming soon: - Spline interpolation for rectilinear grids in N-dimensions -- RBF interpolation for unstructured data in N-dimensions - Smoothing splines for noisy data diff --git a/interpax/__init__.py b/interpax/__init__.py index 45a1017..1cc5678 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -10,6 +10,7 @@ PchipInterpolator, PPoly, ) +from ._rbf import RBFInterpolator from ._spline import ( AbstractInterpolator, Interpolator1D, diff --git a/interpax/_rbf.py b/interpax/_rbf.py new file mode 100644 index 0000000..5b41f87 --- /dev/null +++ b/interpax/_rbf.py @@ -0,0 +1,674 @@ +"""Module for RBF interpolation using JAX. Based on scipy implementation.""" + +from itertools import combinations_with_replacement +from typing import Optional, Union + +import equinox as eqx +import jax +import jax.debug +import jax.lax +import jax.numpy as jnp +import jaxkd as jk +from jax.scipy.linalg import solve +from jaxtyping import Array, Float, Int, Shaped + +from .utils import asarray_inexact + +__all__ = ["RBFInterpolator"] + +# These RBFs are implemented +_AVAILABLE = { + "linear", + "thin_plate_spline", + "cubic", + "quintic", + "multiquadric", + "inverse_multiquadric", + "inverse_quadratic", + "gaussian", +} + +# The shape parameter does not need to be specified when using these RBFs +_SCALE_INVARIANT = {"linear", "thin_plate_spline", "cubic", "quintic"} + +# For RBFs that are conditionally positive definite of order m, the interpolant +# should include polynomial terms with degree >= m - 1 +_NAME_TO_MIN_DEGREE = { + "multiquadric": 0, + "linear": 0, + "thin_plate_spline": 1, + "cubic": 1, + "quintic": 2, +} + + +def _monomial_powers(ndim: int, degree: int) -> Int[Array, " nmonos ndim"]: + """Return the powers for each monomial in a polynomial. + + Parameters + ---------- + ndim : int + Number of variables in the polynomial. + degree : int + Degree of the polynomial. + + Returns + ------- + (nmonos, ndim) int ndarray + Array where each row contains the powers for each variable in a + monomial. + """ + nmonos = jnp.prod(jnp.arange(degree + 1, degree + ndim + 1)) // jnp.prod( + jnp.arange(1, ndim + 1) + ) + out = jnp.zeros((nmonos, ndim), dtype=jnp.int32) + count = 0 + for deg in range(degree + 1): + for mono in combinations_with_replacement(range(ndim), deg): + # `mono` is a tuple of variables in the current monomial with + # multiplicity indicating power (e.g., (0, 1, 1) represents x*y**2) + for var in mono: + out = out.at[count, var].add(1) + count += 1 + + return out + + +def _rbf_kernel( + r: Float[Array, "..."], kernel_index: int, epsilon: float +) -> Float[Array, "..."]: + """Evaluate the RBF kernel function. + + Parameters + ---------- + r : ndarray + Distance between points + kernel_index : int + Index of the RBF kernel in _KERNEL_FUNCTIONS + epsilon : float + Shape parameter + + Returns + ------- + ndarray + Value of the RBF kernel + """ + r = epsilon * r + return _rbf_kernel_direct(r, kernel_index) + + +def _rbf_kernel_direct( + r: Float[Array, "..."], kernel_index: int +) -> Float[Array, "..."]: + """Evaluate the RBF kernel function with pre-scaled distances. + + Parameters + ---------- + r : ndarray + Distance between points (already scaled by epsilon) + kernel_index : int + Index of the RBF kernel in _KERNEL_FUNCTIONS + + Returns + ------- + ndarray + Value of the RBF kernel + """ + return jax.lax.switch(kernel_index, _KERNEL_FUNCTIONS, r) + + +def _build_system( + y: Float[Array, " P N"], + d: Shaped[Array, " P *d_shape"], + smoothing: Float[Array, " P"], + kernel_index: int, + epsilon: float, + powers: Int[Array, " R N"], +) -> tuple[ + Float[Array, " P+R P+R"], + Shaped[Array, " P+R *d_shape"], + Float[Array, " N"], + Float[Array, " N"], +]: + """Build the RBF interpolation system of equations. + + Parameters + ---------- + y : (P, N) float ndarray + Data point coordinates. + d : (P, S) float ndarray + Data values at `y`. + smoothing : (P,) float ndarray + Smoothing parameter for each data point. + kernel_index : int + Index of the RBF kernel in _KERNEL_FUNCTIONS. + epsilon : float + Shape parameter. + powers : (R, N) int ndarray + The exponents for each monomial in the polynomial. + + Returns + ------- + lhs : (P + R, P + R) float ndarray + Left-hand side of the system. + rhs : (P + R, S) float ndarray + Right-hand side of the system. + shift : (N,) float ndarray + Domain shift used to create the polynomial matrix. + scale : (N,) float ndarray + Domain scaling used to create the polynomial matrix. + """ + P, N = y.shape + R = powers.shape[0] + + # Shift and scale the polynomial domain to be between -1 and 1 (match SciPy) + mins = jnp.min(y, axis=0) + maxs = jnp.max(y, axis=0) + shift = (maxs + mins) / 2 + scale = (maxs - mins) / 2 + # The scale may be zero if there is a single point or all the points have + # the same value for some dimension. Avoid division by zero by replacing + # zeros with ones. + scale = jnp.where(scale == 0.0, 1.0, scale) + + # Apply epsilon scaling to coordinates (match SciPy order) + yeps = y * epsilon + yhat = (y - shift) / scale + + # Build the RBF matrix - use epsilon-scaled coordinates directly + r = jnp.sqrt(jnp.sum((yeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) + K = _rbf_kernel_direct(r, kernel_index) + + # Add smoothing to diagonal + K = K + jnp.diag(smoothing) + + # Build the polynomial matrix using transformed coordinates + if R > 0: + poly_matrix = jnp.prod(yhat[:, None, :] ** powers[None, :, :], axis=2) + lhs = jnp.block([[K, poly_matrix], [poly_matrix.T, jnp.zeros((R, R))]]) + rhs = jnp.block([[d], [jnp.zeros((R, d.shape[1]))]]) + else: + lhs = K + rhs = d + + return lhs, rhs, shift, scale + + +def _build_evaluation_coefficients( + x: Float[Array, " Q N"], + y: Float[Array, " P N"], + kernel_index: int, + epsilon: float, + powers: Int[Array, " R N"], + shift: Float[Array, " N"], + scale: Float[Array, " N"], +) -> Float[Array, " Q P+R"]: + """Build the coefficients for evaluating the RBF interpolant. + + Parameters + ---------- + x : (Q, N) float ndarray + Evaluation point coordinates. + y : (P, N) float ndarray + Data point coordinates. + kernel_index : int + Index of the RBF kernel in _KERNEL_FUNCTIONS. + epsilon : float + Shape parameter. + powers : (R, N) int ndarray + The exponents for each monomial in the polynomial. + shift : (N,) float ndarray + Domain shift used to create the polynomial matrix. + scale : (N,) float ndarray + Domain scaling used to create the polynomial matrix. + + Returns + ------- + (Q, P + R) float ndarray + Coefficients for evaluating the RBF interpolant. + """ + Q, N = x.shape + P, _ = y.shape + R = powers.shape[0] + + # Apply epsilon scaling to coordinates (match SciPy order) + yeps = y * epsilon + xeps = x * epsilon + xhat = (x - shift) / scale + + # Build the RBF matrix using epsilon-scaled coordinates + r = jnp.sqrt(jnp.sum((xeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) + K = _rbf_kernel_direct(r, kernel_index) + + # Build the polynomial matrix using transformed coordinates + if R > 0: + poly_matrix = jnp.prod(xhat[:, None, :] ** powers[None, :, :], axis=2) + return jnp.block([K, poly_matrix]) + else: + return K + + +# Define individual kernel functions for JAX compatibility +def _linear_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Linear RBF kernel: -r.""" + return -r + + +def _thin_plate_spline_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Thin plate spline RBF kernel: r^2 * log(r).""" + return jnp.where(r > 0, r**2 * jnp.log(r), 0.0) + + +def _cubic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Cubic RBF kernel: r^3.""" + return r**3 + + +def _quintic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Quintic RBF kernel: -r^5.""" + return -(r**5) + + +def _multiquadric_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Multiquadric RBF kernel: -sqrt(1 + r^2).""" + return -jnp.sqrt(1 + r**2) + + +def _inverse_multiquadric_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Inverse multiquadric RBF kernel: 1/sqrt(1 + r^2).""" + return 1 / jnp.sqrt(1 + r**2) + + +def _inverse_quadratic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Inverse quadratic RBF kernel: 1/(1 + r^2).""" + return 1 / (1 + r**2) + + +def _gaussian_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: + """Gaussian RBF kernel: exp(-r^2).""" + return jnp.exp(-(r**2)) + + +# Kernel function list for jax.lax.switch (order must match _KERNEL_NAMES) +_KERNEL_FUNCTIONS = [ + _cubic_kernel, + _gaussian_kernel, + _inverse_multiquadric_kernel, + _inverse_quadratic_kernel, + _linear_kernel, + _multiquadric_kernel, + _quintic_kernel, + _thin_plate_spline_kernel, +] + +# Ordered list of kernel names (sorted alphabetically for consistency) +_KERNEL_NAMES = [ + "cubic", + "gaussian", + "inverse_multiquadric", + "inverse_quadratic", + "linear", + "multiquadric", + "quintic", + "thin_plate_spline", +] + +# Create mapping from name to index +_KERNEL_NAME_TO_INDEX = {name: i for i, name in enumerate(_KERNEL_NAMES)} + + +def _get_kernel_index(kernel: str) -> int: + """Get the index for a kernel name.""" + if kernel not in _AVAILABLE: + raise ValueError(f"`kernel` must be one of {_AVAILABLE}.") + return _KERNEL_NAME_TO_INDEX[kernel] + + +class RBFInterpolator(eqx.Module): + """Radial basis function (RBF) interpolation in N dimensions. + + Parameters + ---------- + y : (npoints, ndims) array_like + 2-D array of data point coordinates. + d : (npoints, ...) array_like + N-D array of data values at `y`. The length of `d` along the first + axis must be equal to the length of `y`. Unlike some interpolators, the + interpolation axis cannot be changed. + neighbors : int, optional + If specified, the value of the interpolant at each evaluation point + will be computed using only this many nearest data points. All the data + points are used by default. + smoothing : float or (npoints, ) array_like, optional + Smoothing parameter. The interpolant perfectly fits the data when this + is set to 0. For large values, the interpolant approaches a least + squares fit of a polynomial with the specified degree. Default is 0. + kernel : str, optional + Type of RBF. This should be one of + + - 'linear' : ``-r`` + - 'thin_plate_spline' : ``r**2 * log(r)`` + - 'cubic' : ``r**3`` + - 'quintic' : ``-r**5`` + - 'multiquadric' : ``-sqrt(1 + r^2)`` + - 'inverse_multiquadric' : ``1/sqrt(1 + r^2)`` + - 'inverse_quadratic' : ``1/(1 + r^2)`` + - 'gaussian' : ``exp(-r^2)`` + + Default is 'thin_plate_spline'. + epsilon : float, optional + Shape parameter that scales the input to the RBF. If `kernel` is + 'linear', 'thin_plate_spline', 'cubic', or 'quintic', this defaults to + 1 and can be ignored because it has the same effect as scaling the + smoothing parameter. Otherwise, this must be specified. + degree : int, optional + Degree of the added polynomial. For some RBFs the interpolant may not + be well-posed if the polynomial degree is too small. Those RBFs and + their corresponding minimum degrees are + + - 'multiquadric' : 0 + - 'linear' : 0 + - 'thin_plate_spline' : 1 + - 'cubic' : 1 + - 'quintic' : 2 + + The default value is the minimum degree for `kernel` or 0 if there is + no minimum degree. Set this to -1 for no added polynomial. + """ + + y: Float[Array, " P N"] + d: Shaped[Array, " P *d_shape"] + d_shape: tuple + d_dtype: type = eqx.field(static=True) + neighbors: Optional[int] + smoothing: Float[Array, " P"] + kernel: str = eqx.field(static=True) + kernel_index: int = eqx.field(static=True) + epsilon: float + powers: Int[Array, " R N"] + _shift: Optional[Float[Array, " N"]] + _scale: Optional[Float[Array, " N"]] + _coeffs: Optional[Shaped[Array, " P+R *d_shape"]] + _tree: Optional[Array] + + def __init__( + self, + y: Float[Array, " P N"], + d: Shaped[Array, " P *d_shape"], + neighbors: Optional[int] = None, + smoothing: Union[float, Float[Array, " P"]] = 0.0, + kernel: str = "thin_plate_spline", + epsilon: Optional[float] = None, + degree: Optional[int] = None, + ): + y = asarray_inexact(y) + if y.ndim != 2: + raise ValueError("`y` must be a 2-dimensional array.") + + ny, ndim = y.shape + + d_dtype = complex if jnp.iscomplexobj(d) else float + d = asarray_inexact(d) + if d.dtype != d_dtype: + d = d.astype(d_dtype) + if d.shape[0] != ny: + raise ValueError(f"Expected the first axis of `d` to have length {ny}.") + + d_shape = d.shape[1:] + d = d.reshape((ny, -1)) + # If `d` is complex, convert it to a float array with twice as many + # columns. Otherwise, the LHS matrix would need to be converted to + # complex and take up 2x more memory than necessary. + d = d.view(float) + + if jnp.isscalar(smoothing): + smoothing = jnp.full(ny, smoothing, dtype=float) + else: + smoothing = asarray_inexact(smoothing) + if smoothing.shape != (ny,): + raise ValueError( + f"Expected `smoothing` to be a scalar or have shape ({ny},)." + ) + + kernel = kernel.lower() + if kernel not in _AVAILABLE: + raise ValueError(f"`kernel` must be one of {_AVAILABLE}.") + + if epsilon is None: + if kernel in _SCALE_INVARIANT: + epsilon = 1.0 + else: + raise ValueError( + "`epsilon` must be specified if `kernel` is not one of " + f"{_SCALE_INVARIANT}." + ) + else: + epsilon = float(epsilon) + + min_degree = _NAME_TO_MIN_DEGREE.get(kernel, -1) + if degree is None: + degree = max(min_degree, 0) + else: + degree = int(degree) + if degree < -1: + raise ValueError("`degree` must be at least -1.") + elif -1 < degree < min_degree: + # Use JAX debug print instead of Python warnings for JAX compatibility + warning_msg = ( + f"WARNING: `degree` should not be below {min_degree} except -1 " + f"when `kernel` is '{kernel}'. " + f"The interpolant may not be uniquely " + f"solvable, and the smoothing parameter may have an " + f"unintuitive effect." + ) + jax.debug.print("RBF Interpolator Warning: {}", warning_msg) + + if neighbors is None: + nobs = ny + else: + # Make sure the number of nearest neighbors used for interpolation + # does not exceed the number of observations. + neighbors = int(min(neighbors, ny)) + nobs = neighbors + + powers = _monomial_powers(ndim, degree) + # The polynomial matrix must have full column rank in order for the + # interpolant to be well-posed, which is not possible if there are + # fewer observations than monomials. + if powers.shape[0] > nobs: + raise ValueError( + f"At least {powers.shape[0]} data points are required when " + f"`degree` is {degree} and the number of dimensions is {ndim}." + ) + + # Get kernel index for JAX-compatible dispatch + kernel_index = _get_kernel_index(kernel) + + if neighbors is None: + lhs, rhs, shift, scale = _build_system( + y, d, smoothing, kernel_index, epsilon, powers + ) + coeffs = solve(lhs, rhs) + self._shift = shift + self._scale = scale + self._coeffs = coeffs + self._tree = None + else: + self._shift = None + self._scale = None + self._coeffs = None + # Build the tree for nearest neighbor queries + self._tree = jk.build_tree(y) + + self.y = y + self.d = d + self.d_shape = d_shape + self.d_dtype = d_dtype + self.neighbors = neighbors + self.smoothing = smoothing + self.kernel = kernel + self.kernel_index = kernel_index + self.epsilon = epsilon + self.powers = powers + + def _chunk_evaluator( + self, + x: Float[Array, " Q N"], + y: Float[Array, " P N"], + shift: Float[Array, " N"], + scale: Float[Array, " N"], + coeffs: Shaped[Array, " P+R *d_shape"], + memory_budget: int = 1000000, + ) -> Shaped[Array, " Q *d_shape"]: + """Evaluate the interpolation while controlling memory consumption. + + Parameters + ---------- + x : (Q, N) float ndarray + Array of points on which to evaluate + y : (P, N) float ndarray + Array of points on which we know function values + shift : (N,) float ndarray + Domain shift used to create the polynomial matrix. + scale : (N,) float ndarray + Domain scaling used to create the polynomial matrix. + coeffs : (P + R, S) float ndarray + Coefficients in front of basis functions + memory_budget : int + Total amount of memory (in units of sizeof(float)) we wish + to devote for storing the array of coefficients for + interpolated points. If we need more memory than that, we + chunk the input. + + Returns + ------- + (Q, S) float ndarray + Interpolated array + """ + nx, ndim = x.shape + if self.neighbors is None: + nnei = len(y) + else: + nnei = self.neighbors + + # in each chunk we consume the same space we already occupy + chunksize = memory_budget // (self.powers.shape[0] + nnei) + 1 + + if chunksize <= nx: + out = jnp.empty((nx, self.d.shape[1]), dtype=float) + for i in range(0, nx, chunksize): + vec = _build_evaluation_coefficients( + x[i : i + chunksize, :], + y, + self.kernel_index, + self.epsilon, + self.powers, + shift, + scale, + ) + out = out.at[i : i + chunksize, :].set(jnp.dot(vec, coeffs)) + else: + vec = _build_evaluation_coefficients( + x, y, self.kernel_index, self.epsilon, self.powers, shift, scale + ) + out = jnp.dot(vec, coeffs) + + return out + + def __call__(self, x: Float[Array, " Q N"]) -> Shaped[Array, " Q *d_shape"]: + """Evaluate the interpolant at `x`. + + Parameters + ---------- + x : (Q, N) array_like + Evaluation point coordinates. + + Returns + ------- + (Q, ...) ndarray + Values of the interpolant at `x`. + """ + x = asarray_inexact(x) + if x.ndim != 2: + raise ValueError("`x` must be a 2-dimensional array.") + + nx, ndim = x.shape + if ndim != self.y.shape[1]: + raise ValueError( + f"Expected the second axis of `x` to have length {self.y.shape[1]}." + ) + + # Our memory budget for storing RBF coefficients is + # based on how many floats in memory we already occupy + # If this number is below 1e6 we just use 1e6 + # This memory budget is used to decide how we chunk + # the inputs + memory_budget = max(x.size + self.y.size + self.d.size, 1000000) + + if self.neighbors is None: + out = self._chunk_evaluator( + x, + self.y, + self._shift, + self._scale, + self._coeffs, + memory_budget=memory_budget, + ) + else: + # Get the indices of the k nearest observation points to each + # evaluation point. + neighbors, _ = jk.query_neighbors(self._tree, x, k=self.neighbors) + if self.neighbors == 1: + # jaxkd may squeeze the output when k=1, ensure it's 2D + neighbors = jnp.atleast_2d(neighbors).T + + out = jnp.empty((nx, self.d.shape[1]), dtype=float) + + # Process each evaluation point individually + # This is simpler but less optimized than the scipy version + def process_single_point(xi, neighbors_i): + # Extract the neighborhood data + ynbr = self.y[neighbors_i] + dnbr = self.d[neighbors_i] + snbr = self.smoothing[neighbors_i] + + # Build and solve the local system + lhs, rhs, shift, scale = _build_system( + ynbr, dnbr, snbr, self.kernel_index, self.epsilon, self.powers + ) + coeffs = solve(lhs, rhs) + + # Evaluate at the single query point (no chunking needed) + xnbr = xi[None, :] # Add batch dimension + vec = _build_evaluation_coefficients( + xnbr, + ynbr, + self.kernel_index, + self.epsilon, + self.powers, + shift, + scale, + ) + result = jnp.dot(vec, coeffs) + + return result[0] # Extract the single result + + # Process points in chunks to stay within memory budget + chunk_size = max(1, int(memory_budget / (self.neighbors * self.d.shape[1]))) + num_chunks = (nx + chunk_size - 1) // chunk_size # Ceiling division + + # Process each chunk using vmap + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, nx) + x_chunk = x[start_idx:end_idx] + neighbors_chunk = neighbors[start_idx:end_idx] + + # Use vmap to process points in this chunk in parallel + out = out.at[start_idx:end_idx].set( + jax.vmap(process_single_point)(x_chunk, neighbors_chunk) + ) + + out = out.view(self.d_dtype) + out = out.reshape((nx,) + self.d_shape) + return out diff --git a/requirements.txt b/requirements.txt index 600f7f2..f80100b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ jax >= 0.4.30, < 0.10 jaxtyping >= 0.2.24, < 0.4.0 lineax >= 0.0.5, <= 0.1.0 numpy >= 1.20.0, < 2.4 +jaxkd >= 0.1.1 diff --git a/tests/test_scipy.py b/tests/test_scipy.py index a54280f..062ce20 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -58,6 +58,7 @@ CubicSpline, PchipInterpolator, PPoly, + RBFInterpolator, ) jax_config.update("jax_enable_x64", True) @@ -912,3 +913,283 @@ def test_CubicHermiteSpline_error_handling(): dydx_with_nan = [1, 0, np.nan] assert_raises(ValueError, CubicHermiteSpline, x, y, dydx_with_nan) + + +class TestRBFInterpolator: + """Test RBF interpolation for SciPy API compatibility.""" + + def _make_test_data_1d(self): + """Create 1D test data.""" + x = np.linspace(0, 1, 10) + y = np.sin(2 * np.pi * x) + return x[:, None], y + + def _make_test_data_2d(self): + """Create 2D test data.""" + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + X, Y = np.meshgrid(x, y) + Z = np.sin(2 * np.pi * X) * np.cos(2 * np.pi * Y) + points = np.column_stack((X.ravel(), Y.ravel())) + values = Z.ravel() + return points, values + + def test_eval_1d(self): + """Test basic evaluation in 1D against SciPy.""" + x, y = self._make_test_data_1d() + + # Test with thin plate spline (doesn't require epsilon) + rbf_jax = RBFInterpolator(x, y, kernel="thin_plate_spline") + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel="thin_plate_spline") + + # Test points + x_test = np.linspace(0, 1, 20)[:, None] + + # Evaluate + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + # Compare + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_eval_2d(self): + """Test basic evaluation in 2D against SciPy.""" + points, values = self._make_test_data_2d() + + # Test with cubic kernel + rbf_jax = RBFInterpolator(points, values, kernel="cubic") + rbf_scipy = scipy.interpolate.RBFInterpolator(points, values, kernel="cubic") + + # Test points + x_test = np.linspace(0, 1, 10) + y_test = np.linspace(0, 1, 10) + X_test, Y_test = np.meshgrid(x_test, y_test) + points_test = np.column_stack((X_test.ravel(), Y_test.ravel())) + + # Evaluate + values_jax = rbf_jax(points_test) + values_scipy = rbf_scipy(points_test) + + # Compare + assert_allclose(values_jax, values_scipy, rtol=1e-10, atol=1e-10) + + def test_vector_valued(self): + """Test vector-valued functions.""" + x, y = self._make_test_data_1d() + # Create vector-valued function + y_vector = np.column_stack((y, 2.0 * y)) + + rbf_jax = RBFInterpolator(x, y_vector, kernel="thin_plate_spline") + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, y_vector, kernel="thin_plate_spline" + ) + + x_test = np.linspace(0, 1, 15)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_multidimensional_values(self): + """Test 3D array of values.""" + x, y_ = self._make_test_data_1d() + # Create 3D array of values + y = np.empty((10, 2, 2)) + y[:, 0, 0] = y_ + y[:, 1, 0] = 2.0 * y_ + y[:, 0, 1] = 3.0 * y_ + y[:, 1, 1] = 4.0 * y_ + + rbf_jax = RBFInterpolator(x, y, kernel="thin_plate_spline") + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel="thin_plate_spline") + + x_test = np.linspace(0, 1, 8)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_complex_values(self): + """Test interpolation with complex values.""" + x, _ = self._make_test_data_1d() + y = np.exp(2j * np.pi * x.ravel()) + + rbf_jax = RBFInterpolator(x, y, kernel="thin_plate_spline") + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel="thin_plate_spline") + + x_test = np.linspace(0, 1, 20)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_smoothing(self): + """Test smoothing parameter.""" + x, y = self._make_test_data_1d() + smoothing = 0.1 + + rbf_jax = RBFInterpolator(x, y, kernel="thin_plate_spline", smoothing=smoothing) + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, y, kernel="thin_plate_spline", smoothing=smoothing + ) + + x_test = np.linspace(0, 1, 20)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_polynomial_degree(self): + """Test different polynomial degrees.""" + x, y = self._make_test_data_1d() + + for degree in [-1, 0, 1, 2]: + # Both implementations should issue the same warnings for invalid degrees + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + rbf_jax = RBFInterpolator( + x, y, kernel="thin_plate_spline", degree=degree + ) + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, y, kernel="thin_plate_spline", degree=degree + ) + + x_test = np.linspace(0, 1, 20)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_kernels_basic(self): + """Test basic kernels that don't require epsilon.""" + x, y = self._make_test_data_1d() + kernels = ["linear", "thin_plate_spline", "cubic", "quintic"] + + for kernel in kernels: + rbf_jax = RBFInterpolator(x, y, kernel=kernel) + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel=kernel) + + x_test = np.linspace(0, 1, 15)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_epsilon_kernels(self): + """Test kernels that require epsilon parameter.""" + x, y = self._make_test_data_1d() + kernels = [ + "multiquadric", + "inverse_multiquadric", + "inverse_quadratic", + "gaussian", + ] + epsilon = 1.0 + + for kernel in kernels: + rbf_jax = RBFInterpolator(x, y, kernel=kernel, epsilon=epsilon) + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, y, kernel=kernel, epsilon=epsilon + ) + + x_test = np.linspace(0, 1, 15)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-8, atol=1e-8) + + def test_dtypes(self): + """Test different data types.""" + # Integer coordinates and values + x = np.array([[0], [1], [2], [3]], dtype=int) + y = np.array([1, 4, 2, 5], dtype=int) + + rbf_jax = RBFInterpolator(x, y, kernel="linear") + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, + y, + kernel="linear", + ) + + x_test = np.array([[0.5], [1.5], [2.5]], dtype=float) + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_incorrect_inputs(self): + """Test error handling for incorrect inputs.""" + x, y = self._make_test_data_1d() + + # Test invalid kernel + with assert_raises(ValueError): + RBFInterpolator(x, y, kernel="invalid_kernel") + + # Test invalid degree + with assert_raises(ValueError): + RBFInterpolator(x, y, degree=-2) + + # Test epsilon required for certain kernels + with assert_raises(ValueError): + RBFInterpolator(x, y, kernel="multiquadric") # No epsilon provided + + # Test invalid smoothing shape + with assert_raises(ValueError): + RBFInterpolator(x, y, smoothing=np.ones(len(x) + 1)) + + def test_single_point(self): + """Test with single data point (degenerate case).""" + x = np.array([[0.5]]) + y = np.array([2.0]) + + rbf_jax = RBFInterpolator(x, y, kernel="linear") + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel="linear") + + x_test = np.array([[0.5], [0.3], [0.7]]) + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_two_points(self): + """Test with minimal two-point setup.""" + x = np.array([[0], [1]]) + y = np.array([0, 2]) + + rbf_jax = RBFInterpolator(x, y, kernel="linear") + rbf_scipy = scipy.interpolate.RBFInterpolator(x, y, kernel="linear") + + x_test = np.linspace(0, 1, 11)[:, None] + + y_jax = rbf_jax(x_test) + y_scipy = rbf_scipy(x_test) + + assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + + def test_neighbors_basic(self): + """Test neighbors functionality against SciPy.""" + points, values = self._make_test_data_2d() + + rbf_jax = RBFInterpolator( + points, values, kernel="thin_plate_spline", neighbors=10 + ) + rbf_scipy = scipy.interpolate.RBFInterpolator( + points, values, kernel="thin_plate_spline", neighbors=10 + ) + + points_test = np.array([[0.25, 0.75], [0.6, 0.4]]) + + values_jax = rbf_jax(points_test) + values_scipy = rbf_scipy(points_test) + + # Should be very close to SciPy + assert_allclose(values_jax, values_scipy, rtol=1e-8, atol=1e-8) From 1601382253f13834c0c13bbc5d1ab9e940091845 Mon Sep 17 00:00:00 2001 From: Mathias Truel <28542694+mtruel@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:33:27 +0100 Subject: [PATCH 2/5] Refine RBF kernel handling Switch to squared-distance kernels, improve dtype safety, and simplify evaluation chunking. --- interpax/_rbf.py | 286 +++++++++++++++++++++++++---------------------- 1 file changed, 153 insertions(+), 133 deletions(-) diff --git a/interpax/_rbf.py b/interpax/_rbf.py index 5b41f87..2eae960 100644 --- a/interpax/_rbf.py +++ b/interpax/_rbf.py @@ -1,16 +1,18 @@ """Module for RBF interpolation using JAX. Based on scipy implementation.""" +import math +import warnings from itertools import combinations_with_replacement -from typing import Optional, Union +from typing import Any, Callable, Optional, Union, cast import equinox as eqx import jax -import jax.debug import jax.lax import jax.numpy as jnp import jaxkd as jk from jax.scipy.linalg import solve from jaxtyping import Array, Float, Int, Shaped +from typing_extensions import Literal from .utils import asarray_inexact @@ -58,9 +60,8 @@ def _monomial_powers(ndim: int, degree: int) -> Int[Array, " nmonos ndim"]: Array where each row contains the powers for each variable in a monomial. """ - nmonos = jnp.prod(jnp.arange(degree + 1, degree + ndim + 1)) // jnp.prod( - jnp.arange(1, ndim + 1) - ) + with jax.ensure_compile_time_eval(): + nmonos = math.comb(degree + ndim, ndim) out = jnp.zeros((nmonos, ndim), dtype=jnp.int32) count = 0 for deg in range(degree + 1): @@ -74,54 +75,32 @@ def _monomial_powers(ndim: int, degree: int) -> Int[Array, " nmonos ndim"]: return out -def _rbf_kernel( - r: Float[Array, "..."], kernel_index: int, epsilon: float -) -> Float[Array, "..."]: - """Evaluate the RBF kernel function. - - Parameters - ---------- - r : ndarray - Distance between points - kernel_index : int - Index of the RBF kernel in _KERNEL_FUNCTIONS - epsilon : float - Shape parameter - - Returns - ------- - ndarray - Value of the RBF kernel - """ - r = epsilon * r - return _rbf_kernel_direct(r, kernel_index) - - def _rbf_kernel_direct( - r: Float[Array, "..."], kernel_index: int + r2: Float[Array, "..."], + kernel_func: Callable[[Float[Array, "..."]], Float[Array, "..."]], ) -> Float[Array, "..."]: - """Evaluate the RBF kernel function with pre-scaled distances. + """Evaluate the RBF kernel function with pre-scaled squared distances. Parameters ---------- - r : ndarray - Distance between points (already scaled by epsilon) - kernel_index : int - Index of the RBF kernel in _KERNEL_FUNCTIONS + r2 : ndarray + Squared distance between points (already scaled by epsilon) + kernel_func : callable + RBF kernel function for squared distances Returns ------- ndarray Value of the RBF kernel """ - return jax.lax.switch(kernel_index, _KERNEL_FUNCTIONS, r) + return kernel_func(r2) def _build_system( y: Float[Array, " P N"], d: Shaped[Array, " P *d_shape"], smoothing: Float[Array, " P"], - kernel_index: int, + kernel_func: Callable[[Float[Array, "..."]], Float[Array, "..."]], epsilon: float, powers: Int[Array, " R N"], ) -> tuple[ @@ -140,8 +119,8 @@ def _build_system( Data values at `y`. smoothing : (P,) float ndarray Smoothing parameter for each data point. - kernel_index : int - Index of the RBF kernel in _KERNEL_FUNCTIONS. + kernel_func : callable + RBF kernel function for squared distances. epsilon : float Shape parameter. powers : (R, N) int ndarray @@ -176,8 +155,8 @@ def _build_system( yhat = (y - shift) / scale # Build the RBF matrix - use epsilon-scaled coordinates directly - r = jnp.sqrt(jnp.sum((yeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) - K = _rbf_kernel_direct(r, kernel_index) + r2 = jnp.sum((yeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2) + K = _rbf_kernel_direct(r2, kernel_func) # Add smoothing to diagonal K = K + jnp.diag(smoothing) @@ -185,8 +164,10 @@ def _build_system( # Build the polynomial matrix using transformed coordinates if R > 0: poly_matrix = jnp.prod(yhat[:, None, :] ** powers[None, :, :], axis=2) - lhs = jnp.block([[K, poly_matrix], [poly_matrix.T, jnp.zeros((R, R))]]) - rhs = jnp.block([[d], [jnp.zeros((R, d.shape[1]))]]) + lhs = jnp.block( + [[K, poly_matrix], [poly_matrix.T, jnp.zeros((R, R), dtype=K.dtype)]] + ) + rhs = jnp.block([[d], [jnp.zeros((R, d.shape[1]), dtype=d.dtype)]]) else: lhs = K rhs = d @@ -197,7 +178,7 @@ def _build_system( def _build_evaluation_coefficients( x: Float[Array, " Q N"], y: Float[Array, " P N"], - kernel_index: int, + kernel_func: Callable[[Float[Array, "..."]], Float[Array, "..."]], epsilon: float, powers: Int[Array, " R N"], shift: Float[Array, " N"], @@ -211,8 +192,8 @@ def _build_evaluation_coefficients( Evaluation point coordinates. y : (P, N) float ndarray Data point coordinates. - kernel_index : int - Index of the RBF kernel in _KERNEL_FUNCTIONS. + kernel_func : callable + RBF kernel function for squared distances. epsilon : float Shape parameter. powers : (R, N) int ndarray @@ -237,8 +218,8 @@ def _build_evaluation_coefficients( xhat = (x - shift) / scale # Build the RBF matrix using epsilon-scaled coordinates - r = jnp.sqrt(jnp.sum((xeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) - K = _rbf_kernel_direct(r, kernel_index) + r2 = jnp.sum((xeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2) + K = _rbf_kernel_direct(r2, kernel_func) # Build the polynomial matrix using transformed coordinates if R > 0: @@ -249,79 +230,65 @@ def _build_evaluation_coefficients( # Define individual kernel functions for JAX compatibility -def _linear_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _linear_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Linear RBF kernel: -r.""" - return -r + return jnp.where(r2 > 0, -jnp.sqrt(r2), 0.0) -def _thin_plate_spline_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _thin_plate_spline_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Thin plate spline RBF kernel: r^2 * log(r).""" - return jnp.where(r > 0, r**2 * jnp.log(r), 0.0) + safe_r2 = jnp.where(r2 > 0, r2, jnp.ones_like(r2)) + return jnp.where(r2 > 0, r2 * 0.5 * jnp.log(safe_r2), 0.0) -def _cubic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _cubic_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Cubic RBF kernel: r^3.""" - return r**3 + return jnp.where(r2 > 0, r2 * jnp.sqrt(r2), 0.0) -def _quintic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _quintic_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Quintic RBF kernel: -r^5.""" - return -(r**5) + return jnp.where(r2 > 0, -(r2 * r2 * jnp.sqrt(r2)), 0.0) -def _multiquadric_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _multiquadric_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Multiquadric RBF kernel: -sqrt(1 + r^2).""" - return -jnp.sqrt(1 + r**2) + return -jnp.sqrt(1 + r2) -def _inverse_multiquadric_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _inverse_multiquadric_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Inverse multiquadric RBF kernel: 1/sqrt(1 + r^2).""" - return 1 / jnp.sqrt(1 + r**2) + return 1 / jnp.sqrt(1 + r2) -def _inverse_quadratic_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _inverse_quadratic_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Inverse quadratic RBF kernel: 1/(1 + r^2).""" - return 1 / (1 + r**2) + return 1 / (1 + r2) -def _gaussian_kernel(r: Float[Array, "..."]) -> Float[Array, "..."]: +def _gaussian_kernel(r2: Float[Array, "..."]) -> Float[Array, "..."]: """Gaussian RBF kernel: exp(-r^2).""" - return jnp.exp(-(r**2)) - - -# Kernel function list for jax.lax.switch (order must match _KERNEL_NAMES) -_KERNEL_FUNCTIONS = [ - _cubic_kernel, - _gaussian_kernel, - _inverse_multiquadric_kernel, - _inverse_quadratic_kernel, - _linear_kernel, - _multiquadric_kernel, - _quintic_kernel, - _thin_plate_spline_kernel, -] - -# Ordered list of kernel names (sorted alphabetically for consistency) -_KERNEL_NAMES = [ - "cubic", - "gaussian", - "inverse_multiquadric", - "inverse_quadratic", - "linear", - "multiquadric", - "quintic", - "thin_plate_spline", -] - -# Create mapping from name to index -_KERNEL_NAME_TO_INDEX = {name: i for i, name in enumerate(_KERNEL_NAMES)} + return jnp.exp(-r2) + + +# Kernel function list +_KERNEL_FUNCTIONS = { + "cubic": _cubic_kernel, + "gaussian": _gaussian_kernel, + "inverse_multiquadric": _inverse_multiquadric_kernel, + "inverse_quadratic": _inverse_quadratic_kernel, + "linear": _linear_kernel, + "multiquadric": _multiquadric_kernel, + "quintic": _quintic_kernel, + "thin_plate_spline": _thin_plate_spline_kernel, +} -def _get_kernel_index(kernel: str) -> int: - """Get the index for a kernel name.""" +def _get_kernel(kernel: str) -> Callable[[Float[Array, "..."]], Float[Array, "..."]]: + """Get the kernel function for a kernel name.""" if kernel not in _AVAILABLE: raise ValueError(f"`kernel` must be one of {_AVAILABLE}.") - return _KERNEL_NAME_TO_INDEX[kernel] + return _KERNEL_FUNCTIONS[kernel] class RBFInterpolator(eqx.Module): @@ -376,20 +343,33 @@ class RBFInterpolator(eqx.Module): no minimum degree. Set this to -1 for no added polynomial. """ + __hash__ = object.__hash__ + y: Float[Array, " P N"] d: Shaped[Array, " P *d_shape"] d_shape: tuple - d_dtype: type = eqx.field(static=True) + d_dtype: jnp.dtype = eqx.field(static=True) neighbors: Optional[int] smoothing: Float[Array, " P"] - kernel: str = eqx.field(static=True) - kernel_index: int = eqx.field(static=True) + kernel: Literal[ + "cubic", + "gaussian", + "inverse_multiquadric", + "inverse_quadratic", + "linear", + "multiquadric", + "quintic", + "thin_plate_spline", + ] = eqx.field(static=True) + kernel_func: Callable[[Float[Array, "..."]], Float[Array, "..."]] = eqx.field( + static=True + ) epsilon: float powers: Int[Array, " R N"] _shift: Optional[Float[Array, " N"]] _scale: Optional[Float[Array, " N"]] _coeffs: Optional[Shaped[Array, " P+R *d_shape"]] - _tree: Optional[Array] + _tree: Optional[Any] def __init__( self, @@ -407,10 +387,8 @@ def __init__( ny, ndim = y.shape - d_dtype = complex if jnp.iscomplexobj(d) else float d = asarray_inexact(d) - if d.dtype != d_dtype: - d = d.astype(d_dtype) + d_dtype = d.dtype if d.shape[0] != ny: raise ValueError(f"Expected the first axis of `d` to have length {ny}.") @@ -419,12 +397,14 @@ def __init__( # If `d` is complex, convert it to a float array with twice as many # columns. Otherwise, the LHS matrix would need to be converted to # complex and take up 2x more memory than necessary. - d = d.view(float) + if jnp.iscomplexobj(d): + float_dtype = jnp.float32 if d_dtype == jnp.complex64 else jnp.float64 + d = d.view(float_dtype) if jnp.isscalar(smoothing): - smoothing = jnp.full(ny, smoothing, dtype=float) + smoothing = jnp.full(ny, smoothing, dtype=y.dtype) else: - smoothing = asarray_inexact(smoothing) + smoothing = asarray_inexact(smoothing).astype(y.dtype) if smoothing.shape != (ny,): raise ValueError( f"Expected `smoothing` to be a scalar or have shape ({ny},)." @@ -433,6 +413,19 @@ def __init__( kernel = kernel.lower() if kernel not in _AVAILABLE: raise ValueError(f"`kernel` must be one of {_AVAILABLE}.") + kernel = cast( + Literal[ + "cubic", + "gaussian", + "inverse_multiquadric", + "inverse_quadratic", + "linear", + "multiquadric", + "quintic", + "thin_plate_spline", + ], + kernel, + ) if epsilon is None: if kernel in _SCALE_INVARIANT: @@ -443,7 +436,7 @@ def __init__( f"{_SCALE_INVARIANT}." ) else: - epsilon = float(epsilon) + epsilon = float(asarray_inexact(epsilon)) min_degree = _NAME_TO_MIN_DEGREE.get(kernel, -1) if degree is None: @@ -453,7 +446,7 @@ def __init__( if degree < -1: raise ValueError("`degree` must be at least -1.") elif -1 < degree < min_degree: - # Use JAX debug print instead of Python warnings for JAX compatibility + # Use standard warnings since inputs are static warning_msg = ( f"WARNING: `degree` should not be below {min_degree} except -1 " f"when `kernel` is '{kernel}'. " @@ -461,7 +454,7 @@ def __init__( f"solvable, and the smoothing parameter may have an " f"unintuitive effect." ) - jax.debug.print("RBF Interpolator Warning: {}", warning_msg) + warnings.warn(warning_msg, UserWarning) if neighbors is None: nobs = ny @@ -481,12 +474,11 @@ def __init__( f"`degree` is {degree} and the number of dimensions is {ndim}." ) - # Get kernel index for JAX-compatible dispatch - kernel_index = _get_kernel_index(kernel) + kernel_func = _get_kernel(kernel) if neighbors is None: lhs, rhs, shift, scale = _build_system( - y, d, smoothing, kernel_index, epsilon, powers + y, d, smoothing, kernel_func, epsilon, powers ) coeffs = solve(lhs, rhs) self._shift = shift @@ -507,7 +499,7 @@ def __init__( self.neighbors = neighbors self.smoothing = smoothing self.kernel = kernel - self.kernel_index = kernel_index + self.kernel_func = kernel_func self.epsilon = epsilon self.powers = powers @@ -555,21 +547,32 @@ def _chunk_evaluator( chunksize = memory_budget // (self.powers.shape[0] + nnei) + 1 if chunksize <= nx: - out = jnp.empty((nx, self.d.shape[1]), dtype=float) - for i in range(0, nx, chunksize): + pad = (-nx) % chunksize + if pad: + x_pad = jnp.pad(x, ((0, pad), (0, 0)), mode="edge") + else: + x_pad = x + + num_chunks = x_pad.shape[0] // chunksize + x_chunks = x_pad.reshape((num_chunks, chunksize, ndim)) + + def process_chunk(x_chunk): vec = _build_evaluation_coefficients( - x[i : i + chunksize, :], + x_chunk, y, - self.kernel_index, + self.kernel_func, self.epsilon, self.powers, shift, scale, ) - out = out.at[i : i + chunksize, :].set(jnp.dot(vec, coeffs)) + return jnp.dot(vec, coeffs) + + out_chunks = jax.lax.map(process_chunk, x_chunks) + out = out_chunks.reshape((-1, self.d.shape[1]))[:nx] else: vec = _build_evaluation_coefficients( - x, y, self.kernel_index, self.epsilon, self.powers, shift, scale + x, y, self.kernel_func, self.epsilon, self.powers, shift, scale ) out = jnp.dot(vec, coeffs) @@ -606,6 +609,8 @@ def __call__(self, x: Float[Array, " Q N"]) -> Shaped[Array, " Q *d_shape"]: memory_budget = max(x.size + self.y.size + self.d.size, 1000000) if self.neighbors is None: + if self._shift is None or self._scale is None or self._coeffs is None: + raise ValueError("RBFInterpolator coefficients are not initialized.") out = self._chunk_evaluator( x, self.y, @@ -615,6 +620,8 @@ def __call__(self, x: Float[Array, " Q N"]) -> Shaped[Array, " Q *d_shape"]: memory_budget=memory_budget, ) else: + if self._tree is None: + raise ValueError("RBFInterpolator neighbor tree is not initialized.") # Get the indices of the k nearest observation points to each # evaluation point. neighbors, _ = jk.query_neighbors(self._tree, x, k=self.neighbors) @@ -622,7 +629,7 @@ def __call__(self, x: Float[Array, " Q N"]) -> Shaped[Array, " Q *d_shape"]: # jaxkd may squeeze the output when k=1, ensure it's 2D neighbors = jnp.atleast_2d(neighbors).T - out = jnp.empty((nx, self.d.shape[1]), dtype=float) + out = jnp.empty((nx, self.d.shape[1]), dtype=self.d.dtype) # Process each evaluation point individually # This is simpler but less optimized than the scipy version @@ -634,7 +641,7 @@ def process_single_point(xi, neighbors_i): # Build and solve the local system lhs, rhs, shift, scale = _build_system( - ynbr, dnbr, snbr, self.kernel_index, self.epsilon, self.powers + ynbr, dnbr, snbr, self.kernel_func, self.epsilon, self.powers ) coeffs = solve(lhs, rhs) @@ -643,7 +650,7 @@ def process_single_point(xi, neighbors_i): vec = _build_evaluation_coefficients( xnbr, ynbr, - self.kernel_index, + self.kernel_func, self.epsilon, self.powers, shift, @@ -655,20 +662,33 @@ def process_single_point(xi, neighbors_i): # Process points in chunks to stay within memory budget chunk_size = max(1, int(memory_budget / (self.neighbors * self.d.shape[1]))) - num_chunks = (nx + chunk_size - 1) // chunk_size # Ceiling division - - # Process each chunk using vmap - for i in range(num_chunks): - start_idx = i * chunk_size - end_idx = min(start_idx + chunk_size, nx) - x_chunk = x[start_idx:end_idx] - neighbors_chunk = neighbors[start_idx:end_idx] - - # Use vmap to process points in this chunk in parallel - out = out.at[start_idx:end_idx].set( - jax.vmap(process_single_point)(x_chunk, neighbors_chunk) + if chunk_size < nx: + pad = (-nx) % chunk_size + if pad: + x_pad = jnp.pad(x, ((0, pad), (0, 0)), mode="edge") + neighbors_pad = jnp.pad(neighbors, ((0, pad), (0, 0)), mode="edge") + else: + x_pad = x + neighbors_pad = neighbors + + num_chunks = x_pad.shape[0] // chunk_size + x_chunks = x_pad.reshape((num_chunks, chunk_size, ndim)) + neighbors_chunks = neighbors_pad.reshape( + (num_chunks, chunk_size, self.neighbors) ) - out = out.view(self.d_dtype) + def process_chunk(chunk_data): + x_chunk, neighbors_chunk = chunk_data + return jax.vmap(process_single_point)(x_chunk, neighbors_chunk) + + out_chunks = jax.lax.map(process_chunk, (x_chunks, neighbors_chunks)) + out = out_chunks.reshape((-1, self.d.shape[1]))[:nx] + else: + out = jax.vmap(process_single_point)(x, neighbors) + + if jnp.issubdtype(self.d_dtype, jnp.complexfloating): + out = out.astype(self.d.dtype).view(self.d_dtype) + else: + out = out.astype(self.d_dtype) out = out.reshape((nx,) + self.d_shape) return out From e66af165c5b729c8ae36df6367a934a77fc71cab Mon Sep 17 00:00:00 2001 From: Mathias Truel <28542694+mtruel@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:33:36 +0100 Subject: [PATCH 3/5] Expand RBFInterpolator tests Assert warning behavior, dtype preservation, gradients, and kernel tolerances. --- tests/test_scipy.py | 71 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 062ce20..c7ef607 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -784,7 +784,10 @@ def check_correctness(S, bc_start="not-a-knot", bc_end="not-a-knot", tol=1e-14): else: order, value = bc_start assert_allclose( - S(x[0], order), value, rtol=tol, atol=tol # pyright: ignore + S(x[0], order), + value, + rtol=tol, + atol=tol, # pyright: ignore ) if bc_end == "not-a-knot": @@ -800,7 +803,10 @@ def check_correctness(S, bc_start="not-a-knot", bc_end="not-a-knot", tol=1e-14): else: order, value = bc_end assert_allclose( - S(x[-1], order), value, rtol=tol, atol=tol # pyright: ignore + S(x[-1], order), + value, + rtol=tol, + atol=tol, # pyright: ignore ) def check_all_bc(self, x, y, axis): @@ -1048,9 +1054,17 @@ def test_polynomial_degree(self): x, y = self._make_test_data_1d() for degree in [-1, 0, 1, 2]: - # Both implementations should issue the same warnings for invalid degrees - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) + warns = degree < 1 and degree != -1 + if warns: + with pytest.warns(UserWarning): + rbf_jax = RBFInterpolator( + x, y, kernel="thin_plate_spline", degree=degree + ) + with pytest.warns(UserWarning): + rbf_scipy = scipy.interpolate.RBFInterpolator( + x, y, kernel="thin_plate_spline", degree=degree + ) + else: rbf_jax = RBFInterpolator( x, y, kernel="thin_plate_spline", degree=degree ) @@ -1103,7 +1117,10 @@ def test_epsilon_kernels(self): y_jax = rbf_jax(x_test) y_scipy = rbf_scipy(x_test) - assert_allclose(y_jax, y_scipy, rtol=1e-8, atol=1e-8) + if kernel == "gaussian": + assert_allclose(y_jax, y_scipy, rtol=5e-8, atol=5e-8) + else: + assert_allclose(y_jax, y_scipy, rtol=1e-8, atol=1e-8) def test_dtypes(self): """Test different data types.""" @@ -1125,6 +1142,48 @@ def test_dtypes(self): assert_allclose(y_jax, y_scipy, rtol=1e-10, atol=1e-10) + x32 = np.array([[0.0], [1.0], [2.0]], dtype=np.float32) + y32 = np.array([0.0, 1.0, -0.5], dtype=np.float32) + rbf_jax = RBFInterpolator(x32, y32, kernel="linear") + y32_out = rbf_jax(np.array([[0.5]], dtype=np.float32)) + assert_equal(y32_out.dtype, np.dtype(np.float32)) + + y64 = np.array([0.0, 1.0, -0.5], dtype=np.float64) + rbf_jax = RBFInterpolator(x32, y64, kernel="linear") + y64_out = rbf_jax(np.array([[0.5]], dtype=np.float32)) + assert_equal(y64_out.dtype, np.dtype(np.float64)) + + y_complex = np.array([0.0 + 0.0j, 1.0 + 2.0j, -0.5 + 0.25j], dtype=np.complex64) + rbf_jax = RBFInterpolator(x32, y_complex, kernel="linear") + y_complex_out = rbf_jax(np.array([[0.5]], dtype=np.float32)) + assert_equal(y_complex_out.dtype, np.dtype(np.complex64)) + + def test_gradients(self): + """Test that gradients are finite and reasonable.""" + x = jnp.array([[0.0], [1.0], [2.0]], dtype=jnp.float32) + y = jnp.array([0.0, 1.0, -0.5], dtype=jnp.float32) + rbf = RBFInterpolator(x, y, kernel="thin_plate_spline") + + def f(p): + return rbf(p[None, :])[0] + + def g(p): + return rbf(p[None, :]) + + p0 = jnp.array([0.0], dtype=jnp.float32) + grad0 = jax.grad(f)(p0) + assert_(jnp.all(jnp.isfinite(grad0))) + jac_fwd = jax.jacfwd(g)(p0) + jac_rev = jax.jacrev(g)(p0) + assert_(jnp.all(jnp.isfinite(jac_fwd))) + assert_(jnp.all(jnp.isfinite(jac_rev))) + + p1 = jnp.array([0.3], dtype=jnp.float32) + grad1 = jax.grad(f)(p1) + eps = 1e-3 + fd = (f(p1 + eps) - f(p1 - eps)) / (2 * eps) + assert_allclose(np.asarray(grad1), np.asarray(fd), rtol=1e-2, atol=1e-2) + def test_incorrect_inputs(self): """Test error handling for incorrect inputs.""" x, y = self._make_test_data_1d() From f527b391391c90789c4203eab66cb484e9c24abe Mon Sep 17 00:00:00 2001 From: Mathias Truel <28542694+mtruel@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:33:41 +0100 Subject: [PATCH 4/5] Update RBF dependency pins Bump jax/lineax bounds and require jaxkd >= 0.1.2. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f80100b..3d17a71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ jax >= 0.4.30, < 0.10 jaxtyping >= 0.2.24, < 0.4.0 lineax >= 0.0.5, <= 0.1.0 numpy >= 1.20.0, < 2.4 -jaxkd >= 0.1.1 +jaxkd >= 0.1.2, < 0.2 From d47a23e0fa98131149dcac28eda61eccea26be7b Mon Sep 17 00:00:00 2001 From: Mathias Truel <28542694+mtruel@users.noreply.github.com> Date: Fri, 6 Feb 2026 15:49:09 +0100 Subject: [PATCH 5/5] Add RBFInterpolator to documentation --- README.rst | 2 +- docs/api.rst | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index bc8f4d8..433e569 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ interpax is a library for interpolation and function approximation using JAX. Includes methods for nearest neighbor, linear, and several cubic interpolation schemes in 1d, 2d, and 3d, as well as Fourier interpolation for periodic functions in -1d and 2d. +1d and 2d. Also includes a scipy like RBFInterpolator. Coming soon: - Spline interpolation for rectilinear grids in N-dimensions diff --git a/docs/api.rst b/docs/api.rst index 3e09dd6..389e763 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -31,6 +31,7 @@ except where noted in the documentation. interpax.CubicSpline interpax.PchipInterpolator interpax.PPoly + interpax.RBFInterpolator Functional interface for 1D, 2D, 3D interpolation