Skip to content

Add RBFInterpolator with neighbour supports (#4) #111

Open
mtruel wants to merge 5 commits intof0uriest:mainfrom
mtruel:rbf
Open

Add RBFInterpolator with neighbour supports (#4) #111
mtruel wants to merge 5 commits intof0uriest:mainfrom
mtruel:rbf

Conversation

@mtruel
Copy link

@mtruel mtruel commented Jul 8, 2025

Adapt scipy.interpolate.RBFInterpolator to be jax compatible using jaxkd for kd-trees. Tests are done against Scipy implementation. Resolves #4

@mtruel mtruel marked this pull request as ready for review July 8, 2025 17:06
Copy link
Owner

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, it looks really good!
Detailed comments below, the main things are some potential dtype issues and possible problems with reverse mode AD due to sqrts.

It also looks like the current code isn't jit compiled anywhere. For most other things in interpax I've added either jax.jit or eqx.filter_jit to the main entry points to speed things up. Would be good to do the same, but may need some care to deal with static arguments.

interpax/_rbf.py Outdated
Array where each row contains the powers for each variable in a
monomial.
"""
nmonos = jnp.prod(jnp.arange(degree + 1, degree + ndim + 1)) // jnp.prod(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be np rather than jnp? I thought if this is traced/jitted then nmonos will be traced so can't be used as an array size for out

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interpax/_rbf.py Outdated
ndarray
Value of the RBF kernel
"""
return jax.lax.switch(kernel_index, _KERNEL_FUNCTIONS, r)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want a switch here, instead just do regular dictionary lookup. The kernel is specified by a string, which is always static, and switch will compile all branches even when they're not needed

interpax/_rbf.py Outdated
yhat = (y - shift) / scale

# Build the RBF matrix - use epsilon-scaled coordinates directly
r = jnp.sqrt(jnp.sum((yeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there could be gradient issues here at r=0 due to the sqrt. It looks like for most kernels this can be avoided by using r**2 directly and never taking the sqrt.

interpax/_rbf.py Outdated
xhat = (x - shift) / scale

# Build the RBF matrix using epsilon-scaled coordinates
r = jnp.sqrt(jnp.sum((xeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see previous comment about sqrt

interpax/_rbf.py Outdated
Comment on lines 293 to 324
_KERNEL_FUNCTIONS = [
_cubic_kernel,
_gaussian_kernel,
_inverse_multiquadric_kernel,
_inverse_quadratic_kernel,
_linear_kernel,
_multiquadric_kernel,
_quintic_kernel,
_thin_plate_spline_kernel,
]

# Ordered list of kernel names (sorted alphabetically for consistency)
_KERNEL_NAMES = [
"cubic",
"gaussian",
"inverse_multiquadric",
"inverse_quadratic",
"linear",
"multiquadric",
"quintic",
"thin_plate_spline",
]

# Create mapping from name to index
_KERNEL_NAME_TO_INDEX = {name: i for i, name in enumerate(_KERNEL_NAMES)}


def _get_kernel_index(kernel: str) -> int:
"""Get the index for a kernel name."""
if kernel not in _AVAILABLE:
raise ValueError(f"`kernel` must be one of {_AVAILABLE}.")
return _KERNEL_NAME_TO_INDEX[kernel]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this probably isn't needed, we can just use a regular dictionary with string keys like scipy.

interpax/_rbf.py Outdated
# jaxkd may squeeze the output when k=1, ensure it's 2D
neighbors = jnp.atleast_2d(neighbors).T

out = jnp.empty((nx, self.d.shape[1]), dtype=float)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check dtype

interpax/_rbf.py Outdated
x_chunk = x[start_idx:end_idx]
neighbors_chunk = neighbors[start_idx:end_idx]

# Use vmap to process points in this chunk in parallel
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see previous comment about possibly using lax.map instead of vmap + for loop

requirements.txt Outdated
jaxtyping >= 0.2.24, < 0.4.0
lineax >= 0.0.5, <= 0.0.8
numpy >= 1.20.0, < 2.4
jaxkd >= 0.1.1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need 0.1.2, as 0.1.1 tries to install jax[cuda] even on non-GPU systems.

I also like to pin a max version to guard against future breaking changes, so that pip doesn't pull a version that we haven't tested against. so something like
jaxkd >= 0.1.2, < 0.2

assert_raises(ValueError, CubicHermiteSpline, x, y, dydx_with_nan)


class TestRBFInterpolator:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few other tests we should add:

  • test against accidental upcasting. If all inputs are float32, the output should be as well. Same with complex64 etc.
  • Sanity checks for derivatives. IE, make sure jax.grad/jax.jacfwd/jax.jacrev work without error and don't return NaN (due to sqrt at r=0). Ideally would also include a simple check against finite differences to make sure there aren't any numerical issues.


for degree in [-1, 0, 1, 2]:
# Both implementations should issue the same warnings for invalid degrees
with warnings.catch_warnings():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you want something like with pytest.warns(...) to ensure that it throws the warning?

Switch to squared-distance kernels, improve dtype safety, and simplify evaluation chunking.
Assert warning behavior, dtype preservation, gradients, and kernel tolerances.
Bump jax/lineax bounds and require jaxkd >= 0.1.2.
@mtruel
Copy link
Author

mtruel commented Feb 6, 2026

Hi, Thanks for the review !

I took some time make changes from your feedback. I implemented most of them. I also updated the docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add RBF interpolation for unstructured data

2 participants