Add RBFInterpolator with neighbour supports (#4) #111
Add RBFInterpolator with neighbour supports (#4) #111mtruel wants to merge 5 commits intof0uriest:mainfrom
Conversation
There was a problem hiding this comment.
Thanks for this, it looks really good!
Detailed comments below, the main things are some potential dtype issues and possible problems with reverse mode AD due to sqrts.
It also looks like the current code isn't jit compiled anywhere. For most other things in interpax I've added either jax.jit or eqx.filter_jit to the main entry points to speed things up. Would be good to do the same, but may need some care to deal with static arguments.
interpax/_rbf.py
Outdated
| Array where each row contains the powers for each variable in a | ||
| monomial. | ||
| """ | ||
| nmonos = jnp.prod(jnp.arange(degree + 1, degree + ndim + 1)) // jnp.prod( |
There was a problem hiding this comment.
does this need to be np rather than jnp? I thought if this is traced/jitted then nmonos will be traced so can't be used as an array size for out
There was a problem hiding this comment.
we may also want to wrap this whole function in https://docs.jax.dev/en/latest/_autosummary/jax.ensure_compile_time_eval.html
interpax/_rbf.py
Outdated
| ndarray | ||
| Value of the RBF kernel | ||
| """ | ||
| return jax.lax.switch(kernel_index, _KERNEL_FUNCTIONS, r) |
There was a problem hiding this comment.
I think we don't want a switch here, instead just do regular dictionary lookup. The kernel is specified by a string, which is always static, and switch will compile all branches even when they're not needed
interpax/_rbf.py
Outdated
| yhat = (y - shift) / scale | ||
|
|
||
| # Build the RBF matrix - use epsilon-scaled coordinates directly | ||
| r = jnp.sqrt(jnp.sum((yeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) |
There was a problem hiding this comment.
there could be gradient issues here at r=0 due to the sqrt. It looks like for most kernels this can be avoided by using r**2 directly and never taking the sqrt.
interpax/_rbf.py
Outdated
| xhat = (x - shift) / scale | ||
|
|
||
| # Build the RBF matrix using epsilon-scaled coordinates | ||
| r = jnp.sqrt(jnp.sum((xeps[:, None, :] - yeps[None, :, :]) ** 2, axis=2)) |
There was a problem hiding this comment.
see previous comment about sqrt
interpax/_rbf.py
Outdated
| _KERNEL_FUNCTIONS = [ | ||
| _cubic_kernel, | ||
| _gaussian_kernel, | ||
| _inverse_multiquadric_kernel, | ||
| _inverse_quadratic_kernel, | ||
| _linear_kernel, | ||
| _multiquadric_kernel, | ||
| _quintic_kernel, | ||
| _thin_plate_spline_kernel, | ||
| ] | ||
|
|
||
| # Ordered list of kernel names (sorted alphabetically for consistency) | ||
| _KERNEL_NAMES = [ | ||
| "cubic", | ||
| "gaussian", | ||
| "inverse_multiquadric", | ||
| "inverse_quadratic", | ||
| "linear", | ||
| "multiquadric", | ||
| "quintic", | ||
| "thin_plate_spline", | ||
| ] | ||
|
|
||
| # Create mapping from name to index | ||
| _KERNEL_NAME_TO_INDEX = {name: i for i, name in enumerate(_KERNEL_NAMES)} | ||
|
|
||
|
|
||
| def _get_kernel_index(kernel: str) -> int: | ||
| """Get the index for a kernel name.""" | ||
| if kernel not in _AVAILABLE: | ||
| raise ValueError(f"`kernel` must be one of {_AVAILABLE}.") | ||
| return _KERNEL_NAME_TO_INDEX[kernel] |
There was a problem hiding this comment.
I think this probably isn't needed, we can just use a regular dictionary with string keys like scipy.
interpax/_rbf.py
Outdated
| # jaxkd may squeeze the output when k=1, ensure it's 2D | ||
| neighbors = jnp.atleast_2d(neighbors).T | ||
|
|
||
| out = jnp.empty((nx, self.d.shape[1]), dtype=float) |
interpax/_rbf.py
Outdated
| x_chunk = x[start_idx:end_idx] | ||
| neighbors_chunk = neighbors[start_idx:end_idx] | ||
|
|
||
| # Use vmap to process points in this chunk in parallel |
There was a problem hiding this comment.
see previous comment about possibly using lax.map instead of vmap + for loop
requirements.txt
Outdated
| jaxtyping >= 0.2.24, < 0.4.0 | ||
| lineax >= 0.0.5, <= 0.0.8 | ||
| numpy >= 1.20.0, < 2.4 | ||
| jaxkd >= 0.1.1 |
There was a problem hiding this comment.
I think we need 0.1.2, as 0.1.1 tries to install jax[cuda] even on non-GPU systems.
I also like to pin a max version to guard against future breaking changes, so that pip doesn't pull a version that we haven't tested against. so something like
jaxkd >= 0.1.2, < 0.2
| assert_raises(ValueError, CubicHermiteSpline, x, y, dydx_with_nan) | ||
|
|
||
|
|
||
| class TestRBFInterpolator: |
There was a problem hiding this comment.
A few other tests we should add:
- test against accidental upcasting. If all inputs are float32, the output should be as well. Same with complex64 etc.
- Sanity checks for derivatives. IE, make sure
jax.grad/jax.jacfwd/jax.jacrevwork without error and don't return NaN (due to sqrt at r=0). Ideally would also include a simple check against finite differences to make sure there aren't any numerical issues.
tests/test_scipy.py
Outdated
|
|
||
| for degree in [-1, 0, 1, 2]: | ||
| # Both implementations should issue the same warnings for invalid degrees | ||
| with warnings.catch_warnings(): |
There was a problem hiding this comment.
i think you want something like with pytest.warns(...) to ensure that it throws the warning?
…0uriest#4) Using scipy API and using jaxkd for kd-trees.
Switch to squared-distance kernels, improve dtype safety, and simplify evaluation chunking.
Assert warning behavior, dtype preservation, gradients, and kernel tolerances.
Bump jax/lineax bounds and require jaxkd >= 0.1.2.
|
Hi, Thanks for the review ! I took some time make changes from your feedback. I implemented most of them. I also updated the docs. |
Adapt
scipy.interpolate.RBFInterpolatorto be jax compatible usingjaxkdfor kd-trees. Tests are done against Scipy implementation. Resolves #4