diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f88bd1..fe53c92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,12 @@ Changelog ========= +- Adds ``"even_spacing"`` option to ``interpXd`` and ``InterpolatorXD``. If the user +sets this to True, the neighboring indices will be computed without a binary search. + + +v0.3.0 +------ - Adds a number of classes that replicate most of the functionality of the corresponding classes from scipy.interpolate : - ``scipy.interpolate.PPoly`` -> ``interpax.PPoly`` diff --git a/interpax/_spline.py b/interpax/_spline.py index 54b9c38..8bfa46c 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -78,6 +78,10 @@ class Interpolator1D(AbstractInterpolator): period : float > 0, None periodicity of the function. If given, function is assumed to be periodic on the interval [0,period]. None denotes no periodicity + even_spacing : bool + if True, user ensures that x is evenly spaced with constant + dx (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. """ @@ -88,6 +92,7 @@ class Interpolator1D(AbstractInterpolator): extrap: Union[bool, float, tuple] period: Union[None, float] axis: int + even_spacing: bool = eqx.field(static=True) def __init__( self, @@ -96,6 +101,7 @@ def __init__( method: str = "cubic", extrap: Union[bool, float, tuple] = False, period: Union[None, float] = None, + even_spacing: bool = False, **kwargs, ) -> None: x, f = map(asarray_inexact, (x, f)) @@ -115,6 +121,7 @@ def __init__( self.method = method self.extrap = extrap self.period = period # pyright: ignore + self.even_spacing = even_spacing if fx is None: fx = approx_df(x, f, method, axis, **kwargs) @@ -146,6 +153,7 @@ def __call__( dx, self.extrap, self.period, + self.even_spacing, **self.derivs, ) @@ -186,6 +194,10 @@ class Interpolator2D(AbstractInterpolator): periodicity of the function in x, y directions. None denotes no periodicity, otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. + even_spacing : bool + if True, user ensures that each array x, y is evenly spaced with constant + dx, dy (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. """ @@ -197,6 +209,7 @@ class Interpolator2D(AbstractInterpolator): extrap: Union[bool, float, tuple] period: Union[None, float, tuple] axis: int + even_spacing: bool = eqx.field(static=True) def __init__( self, @@ -206,6 +219,7 @@ def __init__( method: str = "cubic", extrap: Union[bool, float, tuple] = False, period: Union[None, float, tuple] = None, + even_spacing: bool = False, **kwargs, ): x, y, f = map(asarray_inexact, (x, y, f)) @@ -233,6 +247,7 @@ def __init__( self.method = method self.extrap = extrap self.period = period + self.even_spacing = even_spacing if fx is None: fx = approx_df(x, f, method, 0, **kwargs) @@ -274,6 +289,7 @@ def __call__( (dx, dy), self.extrap, self.period, + self.even_spacing, **self.derivs, ) @@ -316,6 +332,10 @@ class Interpolator3D(AbstractInterpolator): periodicity of the function in x, y, z directions. None denotes no periodicity, otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. + even_spacing : bool + if True, user ensures that each array x, y, z is evenly spaced with constant + dx, dy and dz (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. """ @@ -328,6 +348,7 @@ class Interpolator3D(AbstractInterpolator): extrap: Union[bool, float, tuple] period: Union[None, float, tuple] axis: int + even_spacing: bool = eqx.field(static=True) def __init__( self, @@ -338,6 +359,7 @@ def __init__( method: str = "cubic", extrap: Union[bool, float, tuple] = False, period: Union[None, float, tuple] = None, + even_spacing: bool = False, **kwargs, ): x, y, z, f = map(asarray_inexact, (x, y, z, f)) @@ -376,6 +398,7 @@ def __init__( self.method = method self.extrap = extrap self.period = period + self.even_spacing = even_spacing if fx is None: fx = approx_df(x, f, method, 0, **kwargs) @@ -437,11 +460,12 @@ def __call__( (dx, dy, dz), self.extrap, self.period, + self.even_spacing, **self.derivs, ) -@wrap_jit(static_argnames=["method"]) +@wrap_jit(static_argnames=["method", "even_spacing"]) def interp1d( xq: Real[ArrayLike, " Nq"], x: Real[ArrayLike, " Nx"], @@ -450,6 +474,7 @@ def interp1d( derivative: int = 0, extrap: Union[bool, float, tuple] = False, period: Union[None, float] = None, + even_spacing: bool = False, **kwargs, ) -> Inexact[Array, "Nq ..."]: """Interpolate a 1d function. @@ -488,6 +513,10 @@ def interp1d( period : float > 0, None periodicity of the function. If given, function is assumed to be periodic on the interval [0,period]. None denotes no periodicity + even_spacing : bool + if True, user ensures that x is evenly spaced with constant + dx (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. Returns ------- @@ -523,10 +552,16 @@ def interp1d( xq, x, f, fx = _make_periodic(xq, x, period, axis, f, fx) lowx = highx = True + # find the index + if not even_spacing: + i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) + else: + dx = x[1] - x[0] + i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1) + if method == "nearest": def derivative0_nearest(): - i = jnp.argmin(jnp.abs(xq[:, np.newaxis] - x[np.newaxis]), axis=1) return f[i] def derivative1_nearest(): @@ -537,7 +572,6 @@ def derivative1_nearest(): elif method == "linear": def derivative0_linear(): - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis) dx = x[i] - x[i - 1] dxi = jnp.where(dx == 0, 0, 1 / dx) @@ -550,7 +584,6 @@ def derivative0_linear(): return fq def derivative1_linear(): - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis) dx = x[i] - x[i - 1] dxi = jnp.where(dx == 0, 0, 1 / dx) @@ -565,7 +598,6 @@ def derivative2_linear(): else: assert method in CUBIC_METHODS - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) if fx is None: fx = approx_df(x, f, method, axis, **kwargs) assert fx.shape == f.shape @@ -589,7 +621,7 @@ def derivative2_linear(): return fq.reshape(outshape) -@wrap_jit(static_argnames=["method"]) +@wrap_jit(static_argnames=["method", "even_spacing"]) def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces xq: Real[ArrayLike, " Nq"], yq: Real[ArrayLike, " Nq"], @@ -600,6 +632,7 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces derivative: Union[int, tuple] = 0, extrap: Union[bool, float, tuple] = False, period: Union[None, float, tuple] = None, + even_spacing: bool = False, **kwargs, ) -> Inexact[Array, "Nq ..."]: """Interpolate a 2d function. @@ -644,6 +677,10 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces periodicity of the function in x, y directions. None denotes no periodicity, otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. + even_spacing : bool + if True, user ensures that each array x, y is evenly spaced with constant + dx, dy (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. Returns ------- @@ -692,14 +729,22 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces yq, y, f, fx, fy, fxy = _make_periodic(yq, y, periody, 1, f, fx, fy, fxy) lowy = highy = True + # find the indices + if not even_spacing: + i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) + j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) + else: + dx = x[1] - x[0] + dy = y[1] - y[0] + i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1) + j = jnp.clip(jnp.floor((yq - y[0]) / dy).astype(int) + 1, 1, len(y) - 1) + if method == "nearest": def derivative0(): # because of the regular spaced grid we know that the nearest point # will be one of the 4 neighbors on the grid, so we first find those # and then take the nearest one among them. - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) neighbors_x = jnp.array( [[x[i], x[i - 1], x[i], x[i - 1]], [y[j], y[j], y[j - 1], y[j - 1]]] ) @@ -719,9 +764,6 @@ def derivative1(): ) elif method == "linear": - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) - f00 = f[i - 1, j - 1] f01 = f[i - 1, j] f10 = f[i, j - 1] @@ -757,9 +799,6 @@ def derivative1(): fxy = approx_df(y, fx, method, 1, **kwargs) assert fx.shape == fy.shape == fxy.shape == f.shape - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) - dx = x[i] - x[i - 1] deltax = xq - x[i - 1] dxi = jnp.where(dx == 0, 0, 1 / dx) @@ -798,7 +837,7 @@ def derivative1(): return fq.reshape(outshape) -@wrap_jit(static_argnames=["method"]) +@wrap_jit(static_argnames=["method", "even_spacing"]) def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces xq: Real[ArrayLike, " Nq"], yq: Real[ArrayLike, " Nq"], @@ -811,6 +850,7 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces derivative: Union[int, tuple] = 0, extrap: Union[bool, float, tuple] = False, period: Union[None, float, tuple] = None, + even_spacing: bool = False, **kwargs, ) -> Inexact[Array, "Nq ..."]: """Interpolate a 3d function. @@ -859,6 +899,10 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces periodicity of the function in x, y, z directions. None denotes no periodicity, otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in all directions. + even_spacing : bool + if True, user ensures that each array x, y, z is evenly spaced with constant + dx, dy and dz (there won't be internal checks). This helps finding the + neighboring points for the query positions faster. Defalts to False. Returns ------- @@ -926,15 +970,25 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces ) lowz = highz = True + # find the indices + if not even_spacing: + i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) + j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) + k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1) + else: + dx = x[1] - x[0] + dy = y[1] - y[0] + dz = z[1] - z[0] + i = jnp.clip(jnp.floor((xq - x[0]) / dx).astype(int) + 1, 1, len(x) - 1) + j = jnp.clip(jnp.floor((yq - y[0]) / dy).astype(int) + 1, 1, len(y) - 1) + k = jnp.clip(jnp.floor((zq - z[0]) / dz).astype(int) + 1, 1, len(z) - 1) + if method == "nearest": def derivative0(): # because of the regular spaced grid we know that the nearest point # will be one of the 8 neighbors on the grid, so we first find those # and then take the nearest one among them. - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) - k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1) neighbors_x = jnp.array( [ [x[i], x[i - 1], x[i], x[i - 1], x[i], x[i - 1], x[i], x[i - 1]], @@ -969,10 +1023,6 @@ def derivative1(): ) elif method == "linear": - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) - k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1) - f000 = f[i - 1, j - 1, k - 1] f001 = f[i - 1, j - 1, k] f010 = f[i - 1, j, k - 1] @@ -1037,9 +1087,6 @@ def derivative1(): == fxyz.shape == f.shape ) - i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) - j = jnp.clip(jnp.searchsorted(y, yq, side="right"), 1, len(y) - 1) - k = jnp.clip(jnp.searchsorted(z, zq, side="right"), 1, len(z) - 1) dx = x[i] - x[i - 1] deltax = xq - x[i - 1] diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index b378de2..56d6179 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -100,6 +100,38 @@ def test_interp1d_vector_valued(self): fq = interp1d(x, xp, fp, method="monotonic-0") np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-2) + @pytest.mark.unit + def test_interp1d_even_spaced(self): + """Test for interpolating vwith evenly spaced grid.""" + xp = np.linspace(0, 2 * np.pi, 100) + x = np.linspace(0, 2 * np.pi, 300)[10:-10] + f = lambda x: np.array([np.sin(x), np.cos(x)]) + fp = f(xp).T + + fq = interp1d(x, xp, fp, method="nearest", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-2, atol=1e-1) + + fq = interp1d(x, xp, fp, method="linear", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3) + + fq = interp1d(x, xp, fp, method="cubic", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="cubic2", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="cardinal", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="catmull-rom", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="monotonic", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3) + + fq = interp1d(x, xp, fp, method="monotonic-0", even_spacing=True) + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-2) + @pytest.mark.unit def test_interp1d_extrap_periodic(self): """Test extrapolation and periodic BC of 1d interpolation.""" @@ -225,6 +257,27 @@ def test_interp2d_vector_valued(self): fq = interp2d(x, y, xp, yp, fp, method="cubic") np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-5, atol=2e-3) + @pytest.mark.unit + def test_interp2d_even_spaced(self): + """Test for interpolating with evenly spaced grid.""" + xp = np.linspace(0, 3 * np.pi, 99) + yp = np.linspace(0, 2 * np.pi, 40) + x = np.linspace(0, 3 * np.pi, 200) + y = np.linspace(0, 2 * np.pi, 200) + xxp, yyp = np.meshgrid(xp, yp, indexing="ij") + + f = lambda x, y: np.array([np.sin(x) * np.cos(y), np.sin(x) + np.cos(y)]) + fp = f(xxp.T, yyp.T).T + + fq = interp2d(x, y, xp, yp, fp, method="nearest", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-2, atol=1.2e-1) + + fq = interp2d(x, y, xp, yp, fp, method="linear", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-3, atol=1e-2) + + fq = interp2d(x, y, xp, yp, fp, method="cubic", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-5, atol=2e-3) + class TestInterp3D: """Tests for interp3d function.""" @@ -314,6 +367,29 @@ def test_interp3d_vector_valued(self): fq = interp3d(x, y, z, xp, yp, zp, fp, method="cubic") np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-5, atol=5e-3) + @pytest.mark.unit + def test_interp3d_even_spaced(self): + """Test for interpolating with evenly spaced grid.""" + x = np.linspace(0, np.pi, 1000) + y = np.linspace(0, 2 * np.pi, 1000) + z = np.linspace(0, 3, 1000) + xp = np.linspace(0, np.pi, 20) + yp = np.linspace(0, 2 * np.pi, 30) + zp = np.linspace(0, 3, 25) + xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") + + f = lambda x, y, z: np.array([np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)]) + fp = f(xxp.T, yyp.T, zzp.T).T + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-2, atol=1) + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="linear", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-3, atol=1e-1) + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="cubic", even_spacing=True) + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-5, atol=5e-3) + @pytest.mark.unit @pytest.mark.parametrize(