Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 200 additions & 1 deletion interpax/_spline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for interpolating splines that are JAX differentiable."""

from collections import OrderedDict
from functools import partial
from typing import Any, Union

import equinox as eqx
Expand All @@ -12,7 +13,7 @@

from ._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC
from ._fd_derivs import approx_df
from .utils import asarray_inexact, errorif, isbool, wrap_jit
from .utils import asarray_inexact, errorif, isbool, safediv, wrap_jit

CUBIC_METHODS = (
"cubic",
Expand Down Expand Up @@ -1209,3 +1210,201 @@ def noclip(fq, *_):
)

return fq


def _subtract_last(c, k):
"""Subtract ``k`` from last index of last axis of ``c``.

Semantically same as ``return c.at[...,-1].subtract(k)``,
but allows dimension to increase.
"""
c_1 = c[..., -1] - k
return jnp.concatenate(
[
jnp.broadcast_to(c[..., :-1], (*c_1.shape, c.shape[-1] - 1)),
c_1[..., jnp.newaxis],
],
axis=-1,
)


def _filter_distinct(r, sentinel, eps):
"""Set all but one of matching adjacent elements in ``r`` to ``sentinel``."""
# eps needs to be low enough that close distinct roots do not get removed.
# Otherwise, algorithms relying on continuity will fail.
mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps)
return jnp.where(mask, sentinel, r)


_roots_companion = jnp.vectorize(
partial(jnp.roots, strip_zeros=False), signature="(m)->(n)"
)


def _polyroot_vec(
c,
k=0.0,
a_min=None,
a_max=None,
sort=False,
sentinel=jnp.nan,
eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12),
distinct=False,
):
"""Roots of polynomial with given coefficients.

Parameters
----------
c : jnp.ndarray
Last axis should store coefficients of a polynomial. For a polynomial given by
∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at
``c[...,n-i]``.
k : jnp.ndarray
Shape (..., *c.shape[:-1]).
Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``.
a_min : jnp.ndarray
Shape (..., *c.shape[:-1]).
Minimum ``a_min`` and maximum ``a_max`` value to return roots between.
If specified only real roots are returned, otherwise returns all complex roots.
a_max : jnp.ndarray
Shape (..., *c.shape[:-1]).
Minimum ``a_min`` and maximum ``a_max`` value to return roots between.
If specified only real roots are returned, otherwise returns all complex roots.
sort : bool
Whether to sort the roots.
sentinel : float
Value with which to pad array in place of filtered elements.
Anything less than ``a_min`` or greater than ``a_max`` plus some floating point
error buffer will work just like nan while avoiding ``nan`` gradient.
eps : float
Absolute tolerance with which to consider value as zero.
distinct : bool
Whether to only return the distinct roots. If true, when the multiplicity is
greater than one, the repeated roots are set to ``sentinel``.

Returns
-------
r : jnp.ndarray
Shape (..., *c.shape[:-1], c.shape[-1] - 1).
The roots of the polynomial, iterated over the last axis.

"""
get_only_real_roots = not (a_min is None and a_max is None)
num_coef = c.shape[-1]
distinct = distinct and num_coef > 2
func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic}

if (
num_coef in func
and get_only_real_roots
and not (jnp.iscomplexobj(c) or jnp.iscomplexobj(k))
):
# Compute from analytic formula to avoid the issue of complex roots with small
# imaginary parts and to avoid nan in gradient. Also consumes less memory.
c = jnp.moveaxis(c, -1, 0)
r = func[num_coef](*c[:-1], c[-1] - k, sentinel, eps, distinct)
if num_coef == 2:
r = r[jnp.newaxis]
r = jnp.moveaxis(r, 0, -1)

# We already filtered distinct roots for quadratics.
distinct = distinct and num_coef > 3
else:
r = _roots_companion(_subtract_last(c, k))

if get_only_real_roots:
a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis]
a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis]
r = jnp.where(
(jnp.abs(r.imag) <= eps) & (a_min <= r.real) & (r.real <= a_max),
r.real,
sentinel,
)

if sort or distinct:
r = jnp.sort(r, axis=-1)
if distinct:
r = _filter_distinct(r, sentinel, eps)
assert r.shape[-1] == num_coef - 1
return r


def _root_cubic(a, b, c, d, sentinel, eps, distinct):
"""Return real cubic root assuming real coefficients."""
# numerical.recipes/book.html, page 228

def irreducible(Q, R, b, mask):
# Three irrational real roots.
theta = R / jnp.sqrt(jnp.where(mask, Q**3, 1.0))
theta = jnp.arccos(jnp.where(jnp.abs(theta) < 1.0, theta, 0.0))
return (
-2
* jnp.sqrt(Q)
* jnp.stack(
[
jnp.cos(theta / 3),
jnp.cos((theta + 2 * jnp.pi) / 3),
jnp.cos((theta - 2 * jnp.pi) / 3),
]
)
- b / 3
)

def reducible(Q, R, b):
# One real and two complex roots.
A = -jnp.sign(R) * jnp.cbrt(jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3)))
B = safediv(Q, A)
r1 = (A + B) - b / 3
return _concat_sentinel(r1[jnp.newaxis], sentinel, num=2)

def root(b, c, d):
b = safediv(b, a)
c = safediv(c, a)
Q = (b**2 - 3 * c) / 9
R = (2 * b**3 - 9 * b * c) / 54 + safediv(d, 2 * a)
mask = R**2 < Q**3
return jnp.where(
mask,
irreducible(jnp.abs(Q), R, b, mask),
reducible(Q, R, b),
)

return jnp.where(
# Tests catch failure here if eps < 1e-12 for double precision.
jnp.abs(a) <= eps,
_concat_sentinel(
_root_quadratic(b, c, d, sentinel, eps, distinct),
sentinel,
),
root(b, c, d),
)


def _root_quadratic(a, b, c, sentinel, eps, distinct):
"""Return real quadratic root assuming real coefficients."""
# numerical.recipes/book.html, page 227

discriminant = b**2 - 4 * a * c
q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant)))
r1 = jnp.where(
discriminant < 0,
sentinel,
safediv(q, a, _root_linear(b, c, sentinel, eps)),
)
r2 = jnp.where(
# more robust to remove repeated roots with discriminant
(discriminant < 0) | (distinct & (discriminant <= eps)),
sentinel,
safediv(c, q, sentinel),
)
return jnp.stack([r1, r2])


def _root_linear(a, b, sentinel, eps, distinct=False):
"""Return real linear root assuming real coefficients."""
return safediv(-b, a, jnp.where(jnp.abs(b) <= eps, 0, sentinel))


def _concat_sentinel(r, sentinel, num=1):
"""Concatenate ``sentinel`` ``num`` times to ``r`` on first axis."""
return jnp.concatenate((r, jnp.broadcast_to(sentinel, (num, *r.shape[1:]))))
19 changes: 19 additions & 0 deletions interpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,22 @@ def wrapper(fun):
return foo

return wrapper


def safediv(a, b, fill=0.0, threshold=0.0):
"""Divide a/b with guards for division by zero.

Parameters
----------
a, b : ndarray
Numerator and denominator.
fill : float, ndarray, optional
Value to return where b is zero.
threshold : float >= 0
How small is b allowed to be.

"""
mask = jnp.abs(b) <= threshold
num = jnp.where(mask, fill, a)
den = jnp.where(mask, 1.0, b)
return num / den
61 changes: 60 additions & 1 deletion tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
interp2d,
interp3d,
)
from interpax._spline import _polyroot_vec

jax_config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -412,7 +413,7 @@ def test_fft_interp2d(dtype):
sx=0.2,
sy=0.3,
dx=np.diff(x[spx][1])[0],
dy=np.diff(y[spy][1])[0]
dy=np.diff(y[spy][1])[0],
).squeeze(),
)
for epx in ["o", "e"]: # eval parity x
Expand Down Expand Up @@ -576,3 +577,61 @@ def test_extrap_float():
np.testing.assert_allclose(interpol(4.5, 5.3), 1.0)
np.testing.assert_allclose(interpol(-4.5, 5.3), 0.0)
np.testing.assert_allclose(interpol(4.5, -5.3), 0.0)


def test_polyroot_vec(eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12)):
"""Test vectorized computation of cubic polynomial exact roots."""
c = np.arange(-24, 24).reshape(4, 6, -1).transpose(-1, 1, 0)
# Ensure broadcasting won't hide error in implementation.
assert np.unique(c.shape).size == c.ndim

k = np.broadcast_to(np.arange(c.shape[-2]), c.shape[:-1])
# Now increase dimension so that shapes still broadcast, but stuff like
# c[...,-1]-=k is not allowed because it grows the dimension of c.
# This is needed functionality in polyroot_vec that requires an awkward
# loop to obtain if using jnp.vectorize.
k = np.stack([k, k * 2 + 1])
r = _polyroot_vec(c, k, sort=True, eps=eps)

for i in range(k.shape[0]):
d = c.copy()
d[..., -1] -= k[i]
# np.roots cannot be vectorized because it strips leading zeros and
# output shape is therefore dynamic.
for idx in np.ndindex(d.shape[:-1]):
np.testing.assert_allclose(
r[(i, *idx)],
np.sort(np.roots(d[idx])),
err_msg=f"Eigenvalue branch of polyroot_vec failed at {i, *idx}.",
)

# Now test analytic formula branch, Ensure it filters distinct roots,
# and ensure zero coefficients don't bust computation due to singularities
# in analytic formulae which are not present in iterative eigenvalue scheme.
c = np.array(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, -1, -8, 12],
[1, -6, 11, -6],
[0, -6, 11, -2],
]
)
r = _polyroot_vec(c, sort=True, distinct=True, eps=eps)
for j in range(c.shape[0]):
root = r[j][~np.isnan(r[j])]
unique_root = np.unique(np.roots(c[j]))
assert root.size == unique_root.size
np.testing.assert_allclose(
root,
unique_root,
err_msg=f"Analytic branch of polyroot_vec failed at {j}.",
)
c = np.array([0, 1, -1, -8, 12])
r = _polyroot_vec(c, sort=True, distinct=True, eps=eps)
r = r[~np.isnan(r)]
unique_r = np.unique(np.roots(c))
assert r.size == unique_r.size
np.testing.assert_allclose(r, unique_r)