From 4a8bdb7bd786361de36116ff29e2f315f425bece Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 15 Dec 2025 23:32:09 -0800 Subject: [PATCH 1/9] Add polyroot --- interpax/_spline.py | 200 +++++++++++++++++++++++++++++++++++++- interpax/utils.py | 19 ++++ tests/test_interpolate.py | 61 +++++++++++- 3 files changed, 278 insertions(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 54b9c38..12a8362 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1,6 +1,7 @@ """Functions for interpolating splines that are JAX differentiable.""" from collections import OrderedDict +from functools import partial from typing import Any, Union import equinox as eqx @@ -12,7 +13,7 @@ from ._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC from ._fd_derivs import approx_df -from .utils import asarray_inexact, errorif, isbool, wrap_jit +from .utils import asarray_inexact, errorif, isbool, safediv, wrap_jit CUBIC_METHODS = ( "cubic", @@ -1209,3 +1210,200 @@ def noclip(fq, *_): ) return fq + + +def _subtract_last(c, k): + """Subtract ``k`` from last index of last axis of ``c``. + + Semantically same as ``return c.at[...,-1].subtract(k)``, + but allows dimension to increase. + """ + c_1 = c[..., -1] - k + return jnp.concatenate( + [ + jnp.broadcast_to(c[..., :-1], (*c_1.shape, c.shape[-1] - 1)), + c_1[..., jnp.newaxis], + ], + axis=-1, + ) + + +def _filter_distinct(r, sentinel, eps): + """Set all but one of matching adjacent elements in ``r`` to ``sentinel``.""" + # eps needs to be low enough that close distinct roots do not get removed. + # Otherwise, algorithms relying on continuity will fail. + mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps) + return jnp.where(mask, sentinel, r) + + +_roots_companion = jnp.vectorize( + partial(jnp.roots, strip_zeros=False), signature="(m)->(n)" +) + + +def _polyroot_vec( + c, + k=0.0, + a_min=None, + a_max=None, + sort=False, + sentinel=jnp.nan, + eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), + distinct=False, +): + """Roots of polynomial with given coefficients. + + Parameters + ---------- + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + k : jnp.ndarray + Shape (..., *c.shape[:-1]). + Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``. + a_min : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + a_max : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + sort : bool + Whether to sort the roots. + sentinel : float + Value with which to pad array in place of filtered elements. + Anything less than ``a_min`` or greater than ``a_max`` plus some floating point + error buffer will work just like nan while avoiding ``nan`` gradient. + eps : float + Absolute tolerance with which to consider value as zero. + distinct : bool + Whether to only return the distinct roots. If true, when the multiplicity is + greater than one, the repeated roots are set to ``sentinel``. + + Returns + ------- + r : jnp.ndarray + Shape (..., *c.shape[:-1], c.shape[-1] - 1). + The roots of the polynomial, iterated over the last axis. + + """ + get_only_real_roots = not (a_min is None and a_max is None) + num_coef = c.shape[-1] + func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic} + + if ( + num_coef in func + and get_only_real_roots + and not (jnp.iscomplexobj(c) or jnp.iscomplexobj(k)) + ): + # Compute from analytic formula to avoid the issue of complex roots with small + # imaginary parts and to avoid nan in gradient. Also consumes less memory. + r = func[num_coef]( + *jnp.moveaxis(c[..., :-1], -1, 0), c[..., -1] - k, sentinel, eps, distinct + ) + # We already filtered distinct roots for quadratics. + distinct = distinct and num_coef > 3 + else: + r = _roots_companion(_subtract_last(c, k)) + + if get_only_real_roots: + a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis] + a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis] + r = jnp.where( + (jnp.abs(r.imag) <= eps) & (a_min <= r.real) & (r.real <= a_max), + r.real, + sentinel, + ) + + if sort or distinct: + r = jnp.sort(r, axis=-1) + r = _filter_distinct(r, sentinel, eps) if distinct else r + assert r.shape[-1] == num_coef - 1 + return r + + +def _root_cubic(a, b, c, d, sentinel, eps, distinct): + """Return real cubic root assuming real coefficients.""" + # numerical.recipes/book.html, page 228 + + def irreducible(Q, R, b, mask): + # Three irrational real roots. + theta = R / jnp.sqrt(jnp.where(mask, Q**3, 1.0)) + theta = jnp.arccos(jnp.where(jnp.abs(theta) < 1.0, theta, 0.0)) + return jnp.moveaxis( + -2 + * jnp.sqrt(Q) + * jnp.stack( + [ + jnp.cos(theta / 3), + jnp.cos((theta + 2 * jnp.pi) / 3), + jnp.cos((theta - 2 * jnp.pi) / 3), + ] + ) + - b / 3, + source=0, + destination=-1, + ) + + def reducible(Q, R, b): + # One real and two complex roots. + A = -jnp.sign(R) * jnp.cbrt(jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3))) + B = safediv(Q, A) + r1 = (A + B) - b / 3 + return _concat_sentinel(r1[..., jnp.newaxis], sentinel, num=2) + + def root(b, c, d): + b = safediv(b, a) + c = safediv(c, a) + d = safediv(d, a) + Q = (b**2 - 3 * c) / 9 + R = (2 * b**3 - 9 * b * c) / 54 + 0.5 * d + mask = R**2 < Q**3 + return jnp.where( + mask[..., jnp.newaxis], + irreducible(jnp.abs(Q), R, b, mask), + reducible(Q, R, b), + ) + + return jnp.where( + # Tests catch failure here if eps < 1e-12 for 64 bit precision. + jnp.expand_dims(jnp.abs(a) <= eps, axis=-1), + _concat_sentinel( + _root_quadratic(b, c, d, sentinel, eps, distinct), + sentinel, + ), + root(b, c, d), + ) + + +def _root_quadratic(a, b, c, sentinel, eps, distinct): + """Return real quadratic root assuming real coefficients.""" + # numerical.recipes/book.html, page 227 + + discriminant = b**2 - 4 * a * c + q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant))) + r1 = jnp.where( + discriminant < 0, + sentinel, + safediv(q, a, _root_linear(b, c, sentinel, eps)), + ) + r2 = jnp.where( + # more robust to remove repeated roots with discriminant + (discriminant < 0) | (distinct & (discriminant <= eps)), + sentinel, + safediv(c, q, sentinel), + ) + return jnp.stack([r1, r2], axis=-1) + + +def _root_linear(a, b, sentinel, eps, distinct=False): + """Return real linear root assuming real coefficients.""" + return safediv(-b, a, jnp.where(jnp.abs(b) <= eps, 0, sentinel)) + + +def _concat_sentinel(r, sentinel, num=1): + """Concatenate ``sentinel`` ``num`` times to ``r`` on last axis.""" + sent = jnp.broadcast_to(sentinel, (*r.shape[:-1], num)) + return jnp.append(r, sent, axis=-1) diff --git a/interpax/utils.py b/interpax/utils.py index 1d6c091..28b32f5 100644 --- a/interpax/utils.py +++ b/interpax/utils.py @@ -61,3 +61,22 @@ def wrapper(fun): return foo return wrapper + + +def safediv(a, b, fill=0.0, threshold=0.0): + """Divide a/b with guards for division by zero. + + Parameters + ---------- + a, b : ndarray + Numerator and denominator. + fill : float, ndarray, optional + Value to return where b is zero. + threshold : float >= 0 + How small is b allowed to be. + + """ + mask = jnp.abs(b) <= threshold + num = jnp.where(mask, fill, a) + den = jnp.where(mask, 1.0, b) + return num / den diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index b378de2..b54e836 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -16,6 +16,7 @@ interp2d, interp3d, ) +from interpax._spline import _polyroot_vec jax_config.update("jax_enable_x64", True) @@ -412,7 +413,7 @@ def test_fft_interp2d(dtype): sx=0.2, sy=0.3, dx=np.diff(x[spx][1])[0], - dy=np.diff(y[spy][1])[0] + dy=np.diff(y[spy][1])[0], ).squeeze(), ) for epx in ["o", "e"]: # eval parity x @@ -576,3 +577,61 @@ def test_extrap_float(): np.testing.assert_allclose(interpol(4.5, 5.3), 1.0) np.testing.assert_allclose(interpol(-4.5, 5.3), 0.0) np.testing.assert_allclose(interpol(4.5, -5.3), 0.0) + + +def test_polyroot_vec(eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12)): + """Test vectorized computation of cubic polynomial exact roots.""" + c = np.arange(-24, 24).reshape(4, 6, -1).transpose(-1, 1, 0) + # Ensure broadcasting won't hide error in implementation. + assert np.unique(c.shape).size == c.ndim + + k = np.broadcast_to(np.arange(c.shape[-2]), c.shape[:-1]) + # Now increase dimension so that shapes still broadcast, but stuff like + # c[...,-1]-=k is not allowed because it grows the dimension of c. + # This is needed functionality in polyroot_vec that requires an awkward + # loop to obtain if using jnp.vectorize. + k = np.stack([k, k * 2 + 1]) + r = _polyroot_vec(c, k, sort=True, eps=eps) + + for i in range(k.shape[0]): + d = c.copy() + d[..., -1] -= k[i] + # np.roots cannot be vectorized because it strips leading zeros and + # output shape is therefore dynamic. + for idx in np.ndindex(d.shape[:-1]): + np.testing.assert_allclose( + r[(i, *idx)], + np.sort(np.roots(d[idx])), + err_msg=f"Eigenvalue branch of polyroot_vec failed at {i, *idx}.", + ) + + # Now test analytic formula branch, Ensure it filters distinct roots, + # and ensure zero coefficients don't bust computation due to singularities + # in analytic formulae which are not present in iterative eigenvalue scheme. + c = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [1, -1, -8, 12], + [1, -6, 11, -6], + [0, -6, 11, -2], + ] + ) + r = _polyroot_vec(c, sort=True, distinct=True, eps=eps) + for j in range(c.shape[0]): + root = r[j][~np.isnan(r[j])] + unique_root = np.unique(np.roots(c[j])) + assert root.size == unique_root.size + np.testing.assert_allclose( + root, + unique_root, + err_msg=f"Analytic branch of polyroot_vec failed at {j}.", + ) + c = np.array([0, 1, -1, -8, 12]) + r = _polyroot_vec(c, sort=True, distinct=True, eps=eps) + r = r[~np.isnan(r)] + unique_r = np.unique(np.roots(c)) + assert r.size == unique_r.size + np.testing.assert_allclose(r, unique_r) From b6739656b395f4b6574b26e089504a151cf7afe0 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 15 Dec 2025 23:41:13 -0800 Subject: [PATCH 2/9] inline thing --- interpax/_spline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 12a8362..919a286 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1405,5 +1405,4 @@ def _root_linear(a, b, sentinel, eps, distinct=False): def _concat_sentinel(r, sentinel, num=1): """Concatenate ``sentinel`` ``num`` times to ``r`` on last axis.""" - sent = jnp.broadcast_to(sentinel, (*r.shape[:-1], num)) - return jnp.append(r, sent, axis=-1) + return jnp.append(r, jnp.broadcast_to(sentinel, (*r.shape[:-1], num)), axis=-1) From b59371dbf93bd98765160dcdf54302c957b1c454 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Thu, 18 Dec 2025 03:37:46 -0800 Subject: [PATCH 3/9] . --- interpax/_spline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 919a286..9c98448 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1300,9 +1300,8 @@ def _polyroot_vec( ): # Compute from analytic formula to avoid the issue of complex roots with small # imaginary parts and to avoid nan in gradient. Also consumes less memory. - r = func[num_coef]( - *jnp.moveaxis(c[..., :-1], -1, 0), c[..., -1] - k, sentinel, eps, distinct - ) + c = jnp.moveaxis(c, -1, 0) + r = func[num_coef](*c[:-1], c[-1] - k, sentinel, eps, distinct) # We already filtered distinct roots for quadratics. distinct = distinct and num_coef > 3 else: From 8d7e180bfff894fcf427c9f5ff1d7d56c5bd71b5 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Fri, 19 Dec 2025 17:59:12 -0800 Subject: [PATCH 4/9] . --- interpax/_spline.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 9c98448..4e87bb2 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1241,14 +1241,14 @@ def _filter_distinct(r, sentinel, eps): ) -def _polyroot_vec( +def polyroot_vec( c, k=0.0, a_min=None, a_max=None, sort=False, sentinel=jnp.nan, - eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), + eps=_eps, distinct=False, ): """Roots of polynomial with given coefficients. @@ -1302,6 +1302,10 @@ def _polyroot_vec( # imaginary parts and to avoid nan in gradient. Also consumes less memory. c = jnp.moveaxis(c, -1, 0) r = func[num_coef](*c[:-1], c[-1] - k, sentinel, eps, distinct) + if num_coef == 2: + r = r[jnp.newaxis] + r = jnp.moveaxis(r, 0, -1) + # We already filtered distinct roots for quadratics. distinct = distinct and num_coef > 3 else: @@ -1318,7 +1322,8 @@ def _polyroot_vec( if sort or distinct: r = jnp.sort(r, axis=-1) - r = _filter_distinct(r, sentinel, eps) if distinct else r + if distinct: + r = _filter_distinct(r, sentinel, eps) assert r.shape[-1] == num_coef - 1 return r @@ -1331,7 +1336,7 @@ def irreducible(Q, R, b, mask): # Three irrational real roots. theta = R / jnp.sqrt(jnp.where(mask, Q**3, 1.0)) theta = jnp.arccos(jnp.where(jnp.abs(theta) < 1.0, theta, 0.0)) - return jnp.moveaxis( + return ( -2 * jnp.sqrt(Q) * jnp.stack( @@ -1341,9 +1346,7 @@ def irreducible(Q, R, b, mask): jnp.cos((theta - 2 * jnp.pi) / 3), ] ) - - b / 3, - source=0, - destination=-1, + - b / 3 ) def reducible(Q, R, b): @@ -1351,7 +1354,7 @@ def reducible(Q, R, b): A = -jnp.sign(R) * jnp.cbrt(jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3))) B = safediv(Q, A) r1 = (A + B) - b / 3 - return _concat_sentinel(r1[..., jnp.newaxis], sentinel, num=2) + return _concat_sentinel(r1[jnp.newaxis], sentinel, num=2) def root(b, c, d): b = safediv(b, a) @@ -1361,14 +1364,14 @@ def root(b, c, d): R = (2 * b**3 - 9 * b * c) / 54 + 0.5 * d mask = R**2 < Q**3 return jnp.where( - mask[..., jnp.newaxis], + mask, irreducible(jnp.abs(Q), R, b, mask), reducible(Q, R, b), ) return jnp.where( # Tests catch failure here if eps < 1e-12 for 64 bit precision. - jnp.expand_dims(jnp.abs(a) <= eps, axis=-1), + jnp.abs(a) <= eps, _concat_sentinel( _root_quadratic(b, c, d, sentinel, eps, distinct), sentinel, @@ -1394,7 +1397,7 @@ def _root_quadratic(a, b, c, sentinel, eps, distinct): sentinel, safediv(c, q, sentinel), ) - return jnp.stack([r1, r2], axis=-1) + return jnp.stack([r1, r2]) def _root_linear(a, b, sentinel, eps, distinct=False): @@ -1404,4 +1407,4 @@ def _root_linear(a, b, sentinel, eps, distinct=False): def _concat_sentinel(r, sentinel, num=1): """Concatenate ``sentinel`` ``num`` times to ``r`` on last axis.""" - return jnp.append(r, jnp.broadcast_to(sentinel, (*r.shape[:-1], num)), axis=-1) + return jnp.concatenate((r, jnp.broadcast_to(sentinel, (num, *r.shape[1:])))) From 49663287bd3d2a66986cadebad2366806d26f75f Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Fri, 19 Dec 2025 18:00:24 -0800 Subject: [PATCH 5/9] edge case for linear stuff see previous commit --- interpax/_spline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 4e87bb2..3e963ff 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1241,14 +1241,14 @@ def _filter_distinct(r, sentinel, eps): ) -def polyroot_vec( +def _polyroot_vec( c, k=0.0, a_min=None, a_max=None, sort=False, sentinel=jnp.nan, - eps=_eps, + eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), distinct=False, ): """Roots of polynomial with given coefficients. From 9bcf12a559c2fc1d049d2311bb96c87c02fbc372 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Sat, 20 Dec 2025 11:42:40 -0800 Subject: [PATCH 6/9] edge case in polyroot --- interpax/_spline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/interpax/_spline.py b/interpax/_spline.py index 3e963ff..f139d12 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1291,6 +1291,7 @@ def _polyroot_vec( """ get_only_real_roots = not (a_min is None and a_max is None) num_coef = c.shape[-1] + distinct = distinct and num_coef > 2 func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic} if ( From 4391062bf1ca2aca143d604061935735a0083836 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Sat, 20 Dec 2025 11:55:10 -0800 Subject: [PATCH 7/9] fix docstring --- interpax/_spline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index f139d12..63b8517 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1407,5 +1407,5 @@ def _root_linear(a, b, sentinel, eps, distinct=False): def _concat_sentinel(r, sentinel, num=1): - """Concatenate ``sentinel`` ``num`` times to ``r`` on last axis.""" + """Concatenate ``sentinel`` ``num`` times to ``r`` on first axis.""" return jnp.concatenate((r, jnp.broadcast_to(sentinel, (num, *r.shape[1:])))) From c7c50d360c10b39d005dcd8f8e0541b4b3d1ea90 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Fri, 2 Jan 2026 21:47:32 -0800 Subject: [PATCH 8/9] reduce flops --- interpax/_spline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 63b8517..34e3fee 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1360,9 +1360,8 @@ def reducible(Q, R, b): def root(b, c, d): b = safediv(b, a) c = safediv(c, a) - d = safediv(d, a) Q = (b**2 - 3 * c) / 9 - R = (2 * b**3 - 9 * b * c) / 54 + 0.5 * d + R = (2 * b**3 - 9 * b * c) / 54 + safediv(0.5 * d, a) mask = R**2 < Q**3 return jnp.where( mask, From ddbf1eecb4353127e291bb62f9f79916ee58602b Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Sat, 10 Jan 2026 15:18:57 -0800 Subject: [PATCH 9/9] . --- interpax/_spline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 34e3fee..e5820e0 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1361,7 +1361,7 @@ def root(b, c, d): b = safediv(b, a) c = safediv(c, a) Q = (b**2 - 3 * c) / 9 - R = (2 * b**3 - 9 * b * c) / 54 + safediv(0.5 * d, a) + R = (2 * b**3 - 9 * b * c) / 54 + safediv(d, 2 * a) mask = R**2 < Q**3 return jnp.where( mask, @@ -1370,7 +1370,7 @@ def root(b, c, d): ) return jnp.where( - # Tests catch failure here if eps < 1e-12 for 64 bit precision. + # Tests catch failure here if eps < 1e-12 for double precision. jnp.abs(a) <= eps, _concat_sentinel( _root_quadratic(b, c, d, sentinel, eps, distinct),