Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

- Adds ``"even_spacing"`` option to ``interpXd`` and ``InterpolatorXD``. If the user
sets this to True, the neighboring indices will be computed without a binary search.


v0.3.0
------
- Adds a number of classes that replicate most of the functionality of the
corresponding classes from scipy.interpolate :
- ``scipy.interpolate.PPoly`` -> ``interpax.PPoly``
Expand Down
97 changes: 72 additions & 25 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class Interpolator1D(AbstractInterpolator):
period : float > 0, None
periodicity of the function. If given, function is assumed to be periodic
on the interval [0,period]. None denotes no periodicity
even_spacing : bool
if True, user ensures that x is evenly spaced with constant
dx (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

"""

Expand All @@ -88,6 +92,7 @@ class Interpolator1D(AbstractInterpolator):
extrap: Union[bool, float, tuple]
period: Union[None, float]
axis: int
even_spacing: bool = eqx.field(static=True)

def __init__(
self,
Expand All @@ -96,6 +101,7 @@ def __init__(
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float] = None,
even_spacing: bool = False,
**kwargs,
) -> None:
x, f = map(asarray_inexact, (x, f))
Expand All @@ -115,6 +121,7 @@ def __init__(
self.method = method
self.extrap = extrap
self.period = period # pyright: ignore
self.even_spacing = even_spacing

if fx is None:
fx = approx_df(x, f, method, axis, **kwargs)
Expand Down Expand Up @@ -146,6 +153,7 @@ def __call__(
dx,
self.extrap,
self.period,
self.even_spacing,
**self.derivs,
)

Expand Down Expand Up @@ -186,6 +194,10 @@ class Interpolator2D(AbstractInterpolator):
periodicity of the function in x, y directions. None denotes no periodicity,
otherwise function is assumed to be periodic on the interval [0,period]. Use a
single value for the same in both directions.
even_spacing : bool
if True, user ensures that each array x, y is evenly spaced with constant
dx, dy (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

"""

Expand All @@ -197,6 +209,7 @@ class Interpolator2D(AbstractInterpolator):
extrap: Union[bool, float, tuple]
period: Union[None, float, tuple]
axis: int
even_spacing: bool = eqx.field(static=True)

def __init__(
self,
Expand All @@ -206,6 +219,7 @@ def __init__(
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
even_spacing: bool = False,
**kwargs,
):
x, y, f = map(asarray_inexact, (x, y, f))
Expand Down Expand Up @@ -233,6 +247,7 @@ def __init__(
self.method = method
self.extrap = extrap
self.period = period
self.even_spacing = even_spacing

if fx is None:
fx = approx_df(x, f, method, 0, **kwargs)
Expand Down Expand Up @@ -274,6 +289,7 @@ def __call__(
(dx, dy),
self.extrap,
self.period,
self.even_spacing,
**self.derivs,
)

Expand Down Expand Up @@ -316,6 +332,10 @@ class Interpolator3D(AbstractInterpolator):
periodicity of the function in x, y, z directions. None denotes no periodicity,
otherwise function is assumed to be periodic on the interval [0,period]. Use a
single value for the same in both directions.
even_spacing : bool
if True, user ensures that each array x, y, z is evenly spaced with constant
dx, dy and dz (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

"""

Expand All @@ -328,6 +348,7 @@ class Interpolator3D(AbstractInterpolator):
extrap: Union[bool, float, tuple]
period: Union[None, float, tuple]
axis: int
even_spacing: bool = eqx.field(static=True)

def __init__(
self,
Expand All @@ -338,6 +359,7 @@ def __init__(
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
even_spacing: bool = False,
**kwargs,
):
x, y, z, f = map(asarray_inexact, (x, y, z, f))
Expand Down Expand Up @@ -376,6 +398,7 @@ def __init__(
self.method = method
self.extrap = extrap
self.period = period
self.even_spacing = even_spacing

if fx is None:
fx = approx_df(x, f, method, 0, **kwargs)
Expand Down Expand Up @@ -437,11 +460,12 @@ def __call__(
(dx, dy, dz),
self.extrap,
self.period,
self.even_spacing,
**self.derivs,
)


@wrap_jit(static_argnames=["method"])
@wrap_jit(static_argnames=["method", "even_spacing"])
def interp1d(
xq: Real[ArrayLike, " Nq"],
x: Real[ArrayLike, " Nx"],
Expand All @@ -450,6 +474,7 @@ def interp1d(
derivative: int = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float] = None,
even_spacing: bool = False,
**kwargs,
) -> Inexact[Array, "Nq ..."]:
"""Interpolate a 1d function.
Expand Down Expand Up @@ -488,6 +513,10 @@ def interp1d(
period : float > 0, None
periodicity of the function. If given, function is assumed to be periodic
on the interval [0,period]. None denotes no periodicity
even_spacing : bool
if True, user ensures that x is evenly spaced with constant
dx (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

Returns
-------
Expand Down Expand Up @@ -523,10 +552,16 @@ def interp1d(
xq, x, f, fx = _make_periodic(xq, x, period, axis, f, fx)
lowx = highx = True

# find the index
if not even_spacing:
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
else:
dx = x[1] - x[0]
i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if this still works for periodic inputs


if method == "nearest":

def derivative0_nearest():
i = jnp.argmin(jnp.abs(xq[:, np.newaxis] - x[np.newaxis]), axis=1)
return f[i]

def derivative1_nearest():
Expand All @@ -537,7 +572,6 @@ def derivative1_nearest():
elif method == "linear":

def derivative0_linear():
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis)
dx = x[i] - x[i - 1]
dxi = jnp.where(dx == 0, 0, 1 / dx)
Expand All @@ -550,7 +584,6 @@ def derivative0_linear():
return fq

def derivative1_linear():
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis)
dx = x[i] - x[i - 1]
dxi = jnp.where(dx == 0, 0, 1 / dx)
Expand All @@ -565,7 +598,6 @@ def derivative2_linear():

else:
assert method in CUBIC_METHODS
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
if fx is None:
fx = approx_df(x, f, method, axis, **kwargs)
assert fx.shape == f.shape
Expand All @@ -589,7 +621,7 @@ def derivative2_linear():
return fq.reshape(outshape)


@wrap_jit(static_argnames=["method"])
@wrap_jit(static_argnames=["method", "even_spacing"])
def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
xq: Real[ArrayLike, " Nq"],
yq: Real[ArrayLike, " Nq"],
Expand All @@ -600,6 +632,7 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
derivative: Union[int, tuple] = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
even_spacing: bool = False,
**kwargs,
) -> Inexact[Array, "Nq ..."]:
"""Interpolate a 2d function.
Expand Down Expand Up @@ -644,6 +677,10 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
periodicity of the function in x, y directions. None denotes no periodicity,
otherwise function is assumed to be periodic on the interval [0,period]. Use a
single value for the same in both directions.
even_spacing : bool
if True, user ensures that each array x, y is evenly spaced with constant
dx, dy (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

Returns
-------
Expand Down Expand Up @@ -692,14 +729,22 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
yq, y, f, fx, fy, fxy = _make_periodic(yq, y, periody, 1, f, fx, fy, fxy)
lowy = highy = True

# find the indices
if not even_spacing:
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
else:
dx = x[1] - x[0]
dy = y[1] - y[0]
i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1)
j = jnp.clip(jnp.floor((yq - y[0]) / dy).astype(int) + 1, 1, len(y) - 1)

if method == "nearest":

def derivative0():
# because of the regular spaced grid we know that the nearest point
# will be one of the 4 neighbors on the grid, so we first find those
# and then take the nearest one among them.
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
neighbors_x = jnp.array(
[[x[i], x[i - 1], x[i], x[i - 1]], [y[j], y[j], y[j - 1], y[j - 1]]]
)
Expand All @@ -719,9 +764,6 @@ def derivative1():
)

elif method == "linear":
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)

f00 = f[i - 1, j - 1]
f01 = f[i - 1, j]
f10 = f[i, j - 1]
Expand Down Expand Up @@ -757,9 +799,6 @@ def derivative1():
fxy = approx_df(y, fx, method, 1, **kwargs)
assert fx.shape == fy.shape == fxy.shape == f.shape

i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)

dx = x[i] - x[i - 1]
deltax = xq - x[i - 1]
dxi = jnp.where(dx == 0, 0, 1 / dx)
Expand Down Expand Up @@ -798,7 +837,7 @@ def derivative1():
return fq.reshape(outshape)


@wrap_jit(static_argnames=["method"])
@wrap_jit(static_argnames=["method", "even_spacing"])
def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
xq: Real[ArrayLike, " Nq"],
yq: Real[ArrayLike, " Nq"],
Expand All @@ -811,6 +850,7 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
derivative: Union[int, tuple] = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
even_spacing: bool = False,
**kwargs,
) -> Inexact[Array, "Nq ..."]:
"""Interpolate a 3d function.
Expand Down Expand Up @@ -859,6 +899,10 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
periodicity of the function in x, y, z directions. None denotes no periodicity,
otherwise function is assumed to be periodic on the interval [0,period]. Use a
single value for the same in all directions.
even_spacing : bool
if True, user ensures that each array x, y, z is evenly spaced with constant
dx, dy and dz (there won't be internal checks). This helps finding the
neighboring points for the query positions faster. Defalts to False.

Returns
-------
Expand Down Expand Up @@ -926,15 +970,25 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
)
lowz = highz = True

# find the indices
if not even_spacing:
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1)
else:
dx = x[1] - x[0]
dy = y[1] - y[0]
dz = z[1] - z[0]
i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1)
j = jnp.clip(jnp.floor((yq - y[0]) / dy).astype(int) + 1, 1, len(y) - 1)
k = jnp.clip(jnp.floor((zq - z[0]) / dz).astype(int) + 1, 1, len(z) - 1)

if method == "nearest":

def derivative0():
# because of the regular spaced grid we know that the nearest point
# will be one of the 8 neighbors on the grid, so we first find those
# and then take the nearest one among them.
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1)
neighbors_x = jnp.array(
[
[x[i], x[i - 1], x[i], x[i - 1], x[i], x[i - 1], x[i], x[i - 1]],
Expand Down Expand Up @@ -969,10 +1023,6 @@ def derivative1():
)

elif method == "linear":
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1)

f000 = f[i - 1, j - 1, k - 1]
f001 = f[i - 1, j - 1, k]
f010 = f[i - 1, j, k - 1]
Expand Down Expand Up @@ -1037,9 +1087,6 @@ def derivative1():
== fxyz.shape
== f.shape
)
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1)
k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1)

dx = x[i] - x[i - 1]
deltax = xq - x[i - 1]
Expand Down
Loading