From 9ec429dbb721c5a5fda6594e5fae6edcd5ac37aa Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 15 Aug 2025 23:12:39 -0500 Subject: [PATCH 1/9] FFT interpolation now uses real FFTs since the Fourier transform is Hermitian --- interpax/_fourier.py | 45 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/interpax/_fourier.py b/interpax/_fourier.py index a7b4158..9c6819c 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -14,7 +14,7 @@ def fft_interp1d( sx: Optional[Num[ArrayLike, " s"]] = None, dx: float = 1.0, ) -> Inexact[Array, "n ... s"]: - """Interpolation of a 1d periodic function via FFT. + """Interpolation of a real-valued 1D periodic function via FFT. Parameters ---------- @@ -33,18 +33,15 @@ def fft_interp1d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - c = jnp.fft.ifft(f, axis=0) - nx = c.shape[0] + c = jnp.fft.rfft(f, axis=0, norm="forward") + nx = f.shape[0] if sx is not None: + tau = 2 * jnp.pi sx = asarray_inexact(sx) - sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) + sx = jnp.exp(1j * jnp.fft.rfftfreq(nx, dx / tau)[:, None] * sx) c = (c[None].T * sx).T c = jnp.moveaxis(c, 0, -1) - pad = ((n - nx) // 2, n - nx - (n - nx) // 2) - if nx % 2 != 0: - pad = pad[::-1] - c = jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0)) - return jnp.fft.fft(c, axis=0).real + return jnp.fft.irfft(c, n, axis=0, norm="forward") @wrap_jit(static_argnames=["n1", "n2"]) @@ -57,7 +54,7 @@ def fft_interp2d( dx: float = 1.0, dy: float = 1.0, ) -> Inexact[Array, "n1 n2 ... s"]: - """Interpolation of a 2d periodic function via FFT. + """Interpolation of a real-valued 2D periodic function via FFT. Parameters ---------- @@ -77,30 +74,26 @@ def fft_interp2d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - c = jnp.fft.ifft2(f, axes=(0, 1)) - nx, ny = c.shape[:2] + c = jnp.fft.rfft2(f, axes=(0, 1), norm="forward") + nx, ny = f.shape[:2] if (sx is not None) and (sy is not None): + tau = 2 * jnp.pi sx = asarray_inexact(sx) sy = asarray_inexact(sy) - sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) - sy = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(ny)[:, None] * sy / dy) - c = (c[None].T * sx[None, :, :] * sy[:, None, :]).T + sx = jnp.exp(1j * jnp.fft.fftfreq(nx, dx / tau)[:, None] * sx) + sy = jnp.exp(1j * jnp.fft.rfftfreq(ny, dy / tau)[:, None] * sy) + c = (c[None].T * (sx[None] * sy[:, None])).T c = jnp.moveaxis(c, 0, -1) - padx = ((n1 - nx) // 2, n1 - nx - (n1 - nx) // 2) - pady = ((n2 - ny) // 2, n2 - ny - (n2 - ny) // 2) - if nx % 2 != 0: - padx = padx[::-1] - if ny % 2 != 0: - pady = pady[::-1] + pad = n1 - nx + pad = (pad // 2, pad - pad // 2) + if nx % 2 != 0: + pad = pad[::-1] c = jnp.fft.ifftshift( - _pad_along_axis(jnp.fft.fftshift(c, axes=0), padx, axis=0), axes=0 - ) - c = jnp.fft.ifftshift( - _pad_along_axis(jnp.fft.fftshift(c, axes=1), pady, axis=1), axes=1 + _pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0), axes=0 ) - return jnp.fft.fft2(c, axes=(0, 1)).real + return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward") def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): From 706d48fd0841b2fba6e723812bcd90fa22590b3c Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 15 Aug 2025 23:44:27 -0500 Subject: [PATCH 2/9] Cosmetic change to make pad_along_axis simpler --- interpax/_fourier.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/interpax/_fourier.py b/interpax/_fourier.py index 9c6819c..abe574c 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -98,17 +98,18 @@ def fft_interp2d( def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): """Pad with zeros or truncate a given dimension.""" - array = jnp.moveaxis(array, axis, 0) + index = [slice(None)] * array.ndim + start = stop = None if pad[0] < 0: - array = array[abs(pad[0]) :] + start = abs(pad[0]) pad = (0, pad[1]) if pad[1] < 0: - array = array[: -abs(pad[1])] + stop = -abs(pad[1]) pad = (pad[0], 0) - npad = [(0, 0)] * array.ndim - npad[0] = pad + index[axis] = slice(start, stop) + pad_width = [(0, 0)] * array.ndim + pad_width[axis] = pad - array = jnp.pad(array, pad_width=npad, mode="constant", constant_values=0) - return jnp.moveaxis(array, 0, axis) + return jnp.pad(array[tuple(index)], pad_width) From 1f22d7f912a8dab19756647b1e767e5461e0f880 Mon Sep 17 00:00:00 2001 From: unalmis Date: Sat, 16 Aug 2025 17:31:20 -0500 Subject: [PATCH 3/9] Clean up --- interpax/_fourier.py | 32 ++++++++++++++++++-------------- tests/test_interpolate.py | 8 ++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/interpax/_fourier.py b/interpax/_fourier.py index abe574c..24eb31f 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -33,14 +33,16 @@ def fft_interp1d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - c = jnp.fft.rfft(f, axis=0, norm="forward") nx = f.shape[0] + c = jnp.fft.rfft(f, axis=0, norm="forward") + if sx is not None: tau = 2 * jnp.pi sx = asarray_inexact(sx) sx = jnp.exp(1j * jnp.fft.rfftfreq(nx, dx / tau)[:, None] * sx) c = (c[None].T * sx).T c = jnp.moveaxis(c, 0, -1) + return jnp.fft.irfft(c, n, axis=0, norm="forward") @@ -75,7 +77,10 @@ def fft_interp2d( """ f = asarray_inexact(f) c = jnp.fft.rfft2(f, axes=(0, 1), norm="forward") - nx, ny = f.shape[:2] + return _fft_interp2d(c, *f.shape[:2], n1, n2, sx, sy, dx, dy) + + +def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): if (sx is not None) and (sy is not None): tau = 2 * jnp.pi sx = asarray_inexact(sx) @@ -85,31 +90,30 @@ def fft_interp2d( c = (c[None].T * (sx[None] * sy[:, None])).T c = jnp.moveaxis(c, 0, -1) - pad = n1 - nx - pad = (pad // 2, pad - pad // 2) - if nx % 2 != 0: - pad = pad[::-1] - c = jnp.fft.ifftshift( - _pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0), axes=0 - ) + return jnp.fft.irfft2(_fft_pad(c, n1, nx), (n1, n2), axes=(0, 1), norm="forward") + - return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward") +def _fft_pad(c, n_out, n_in, axis=0): + p = n_out - n_in + p = (p // 2, p - p // 2) + if n_in % 2 != 0: + p = p[::-1] + return jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axis), p, axis), axis) def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): """Pad with zeros or truncate a given dimension.""" index = [slice(None)] * array.ndim + pad_width = [(0, 0)] * array.ndim start = stop = None if pad[0] < 0: - start = abs(pad[0]) + start = -pad[0] pad = (0, pad[1]) if pad[1] < 0: - stop = -abs(pad[1]) + stop = pad[1] pad = (pad[0], 0) index[axis] = slice(start, stop) - pad_width = [(0, 0)] * array.ndim pad_width[axis] = pad - return jnp.pad(array[tuple(index)], pad_width) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index f2760a6..a3106a6 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -321,7 +321,7 @@ def fun(x): fi = f1[sp][1] fs = fun(x[sp][1] + 0.2) np.testing.assert_allclose( - fs, fft_interp1d(fi, *fi.shape, sx=0.2, dx=np.diff(x[sp][1])[0]).squeeze() + fft_interp1d(fi, *fi.shape, sx=0.2, dx=np.diff(x[sp][1])[0]).squeeze(), fs ) for ep in ["o", "e"]: # eval parity for s in ["up", "down"]: # up or downsample @@ -333,7 +333,7 @@ def fun(x): xe = 1 true = fun(x[ep][xe]) interp = fft_interp1d(f1[sp][xs], x[ep][xe].size) - np.testing.assert_allclose(true, interp, atol=1e-12, rtol=1e-12) + np.testing.assert_allclose(interp, true, atol=1e-12, rtol=1e-12) @pytest.mark.unit @@ -375,7 +375,6 @@ def fun2(x, y): fi = f2[spx][spy][1][1] fs = fun2(x[spx][1] + 0.2, y[spy][1] + 0.3) np.testing.assert_allclose( - fs, fft_interp2d( fi, *fi.shape, @@ -384,6 +383,7 @@ def fun2(x, y): dx=np.diff(x[spx][1])[0], dy=np.diff(y[spy][1])[0] ).squeeze(), + fs, ) for epx in ["o", "e"]: # eval parity x for epy in ["o", "e"]: # eval parity y @@ -406,7 +406,7 @@ def fun2(x, y): f2[spx][spy][xs][ys], x[epx][xe].size, y[epy][ye].size ) np.testing.assert_allclose( - true, interp, atol=1e-12, rtol=1e-12 + interp, true, atol=1e-12, rtol=1e-12 ) From 3cda4d7bc7f4eea987266d48bea818032dc152a4 Mon Sep 17 00:00:00 2001 From: unalmis Date: Sat, 16 Aug 2025 19:23:19 -0500 Subject: [PATCH 4/9] fixing tests --- tests/test_interpolate.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index a3106a6..c9184fa 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -317,11 +317,12 @@ def fun(x): for i in [1, 2]: f1[p][i] = fun(x[p][i]) + sx = 0.2 for sp in ["o", "e"]: # source parity fi = f1[sp][1] - fs = fun(x[sp][1] + 0.2) + fs = fun(x[sp][1] + sx) np.testing.assert_allclose( - fft_interp1d(fi, *fi.shape, sx=0.2, dx=np.diff(x[sp][1])[0]).squeeze(), fs + fft_interp1d(fi, *fi.shape, sx, dx=x[sp][1][1] - x[sp][1][0]).squeeze(), fs ) for ep in ["o", "e"]: # eval parity for s in ["up", "down"]: # up or downsample @@ -331,8 +332,10 @@ def fun(x): else: xs = 2 xe = 1 - true = fun(x[ep][xe]) - interp = fft_interp1d(f1[sp][xs], x[ep][xe].size) + true = fun(x[ep][xe] + sx) + interp = fft_interp1d( + f1[sp][xs], x[ep][xe].size, sx, dx=x[sp][xs][1] - x[sp][xs][0] + ).squeeze() np.testing.assert_allclose(interp, true, atol=1e-12, rtol=1e-12) @@ -370,16 +373,18 @@ def fun2(x, y): for j in [1, 2]: f2[xp][yp][i][j] = fun2(x[xp][i], y[yp][j]) + shiftx = 0.2 + shifty = 0.3 for spx in ["o", "e"]: # source parity x for spy in ["o", "e"]: # source parity y fi = f2[spx][spy][1][1] - fs = fun2(x[spx][1] + 0.2, y[spy][1] + 0.3) + fs = fun2(x[spx][1] + shiftx, y[spy][1] + shifty) np.testing.assert_allclose( fft_interp2d( fi, *fi.shape, - sx=0.2, - sy=0.3, + shiftx, + shifty, dx=np.diff(x[spx][1])[0], dy=np.diff(y[spy][1])[0] ).squeeze(), @@ -401,10 +406,17 @@ def fun2(x, y): else: ys = 2 ye = 1 - true = fun2(x[epx][xe], y[epy][ye]) + + true = fun2(x[epx][xe] + shiftx, y[epy][ye] + shifty) interp = fft_interp2d( - f2[spx][spy][xs][ys], x[epx][xe].size, y[epy][ye].size - ) + f2[spx][spy][xs][ys], + x[epx][xe].size, + y[epy][ye].size, + shiftx, + shifty, + dx=x[spx][xs][1] - x[spx][xs][0], + dy=y[spy][ys][1] - y[spy][ys][0], + ).squeeze() np.testing.assert_allclose( interp, true, atol=1e-12, rtol=1e-12 ) From f9fa8c6bbafa7b6e0dd3df4724ccee668ff7564e Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 17 Aug 2025 01:52:30 -0500 Subject: [PATCH 5/9] Upgrades FFT interpolation to preserve double the width of the frequency spectrum with negligible additional computation. --- CHANGELOG.md | 3 +++ interpax/_fourier.py | 38 +++++++++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f88bd1..bf5a763 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ corresponding classes from scipy.interpolate : functions. - Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity with respect to each coordinate individually. +- Upgrades FFT interpolation to use the real fast Fourier transform. +- Upgrades FFT interpolation to preserve double the width of the frequency spectrum +with negligible additional computation. v0.2.4 diff --git a/interpax/_fourier.py b/interpax/_fourier.py index 24eb31f..a38ad4d 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -33,9 +33,11 @@ def fft_interp1d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - nx = f.shape[0] c = jnp.fft.rfft(f, axis=0, norm="forward") + return _fft_interp1d(c, f.shape[0], n, sx, dx) + +def _fft_interp1d(c, nx, n, sx, dx): if sx is not None: tau = 2 * jnp.pi sx = asarray_inexact(sx) @@ -43,7 +45,19 @@ def fft_interp1d( c = (c[None].T * sx).T c = jnp.moveaxis(c, 0, -1) - return jnp.fft.irfft(c, n, axis=0, norm="forward") + if n >= nx: + return jnp.fft.irfft(c, n, axis=0, norm="forward") + + if n < c.shape[0]: + c = c[:n] + elif nx % 2 == 0: + c = c.at[-1].divide(2) + c = c.at[0].divide(2) * 2 + + x = jnp.linspace(0, dx * nx, n, endpoint=False) + x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1))) + c = _fft_pad(c, n, 0) + return (jnp.fft.ifft(c, axis=0, norm="forward") * x).real @wrap_jit(static_argnames=["n1", "n2"]) @@ -90,15 +104,29 @@ def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): c = (c[None].T * (sx[None] * sy[:, None])).T c = jnp.moveaxis(c, 0, -1) - return jnp.fft.irfft2(_fft_pad(c, n1, nx), (n1, n2), axes=(0, 1), norm="forward") + c = _fft_pad(jnp.fft.fftshift(c, 0), n1, 0) + if n2 >= ny: + return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward") + + if n2 < c.shape[1]: + c = c[:, :n2] + elif ny % 2 == 0: + c = c.at[:, -1].divide(2) + c = c.at[:, 0].divide(2) * 2 + + y = jnp.linspace(0, dy * ny, n2, endpoint=False) + y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2))) + c = _fft_pad(c, n2, 1) + return (jnp.fft.ifft2(c, axes=(0, 1), norm="forward") * y).real -def _fft_pad(c, n_out, n_in, axis=0): +def _fft_pad(c_shift, n_out, axis): + n_in = c_shift.shape[axis] p = n_out - n_in p = (p // 2, p - p // 2) if n_in % 2 != 0: p = p[::-1] - return jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axis), p, axis), axis) + return jnp.fft.ifftshift(_pad_along_axis(c_shift, p, axis), axis) def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): From a6c2b647cbd674845e962d629412c7150b2e545f Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 17 Aug 2025 23:46:36 -0500 Subject: [PATCH 6/9] In the 2D upsampling (downsampling) case, the padding for the second transform is now applied after (before) the first transform is completed. This reduces the size of the problem, so the computation is less expensive. --- CHANGELOG.md | 2 +- interpax/_fourier.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf5a763..6cbb177 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ functions. with respect to each coordinate individually. - Upgrades FFT interpolation to use the real fast Fourier transform. - Upgrades FFT interpolation to preserve double the width of the frequency spectrum -with negligible additional computation. +with no additional computation. v0.2.4 diff --git a/interpax/_fourier.py b/interpax/_fourier.py index a38ad4d..d6902f7 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -56,6 +56,7 @@ def _fft_interp1d(c, nx, n, sx, dx): x = jnp.linspace(0, dx * nx, n, endpoint=False) x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1))) + c = _fft_pad(c, n, 0) return (jnp.fft.ifft(c, axis=0, norm="forward") * x).real @@ -116,8 +117,10 @@ def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): y = jnp.linspace(0, dy * ny, n2, endpoint=False) y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2))) + + c = jnp.fft.ifft(c, axis=0, norm="forward") c = _fft_pad(c, n2, 1) - return (jnp.fft.ifft2(c, axes=(0, 1), norm="forward") * y).real + return (jnp.fft.ifft(c, axis=1, norm="forward") * y).real def _fft_pad(c_shift, n_out, axis): From bc9bf3a44fde207ba875942f9bd56d5815e48d45 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 18 Aug 2025 20:07:53 -0500 Subject: [PATCH 7/9] update changelog --- CHANGELOG.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cbb177..e67ba2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,10 @@ corresponding classes from scipy.interpolate : functions. - Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity with respect to each coordinate individually. -- Upgrades FFT interpolation to use the real fast Fourier transform. -- Upgrades FFT interpolation to preserve double the width of the frequency spectrum -with no additional computation. +- [Improves FFT interpolation](https://github.com/f0uriest/interpax/pull/116) + - The real FFT is now used where possible. + - Double the width of the Fourier spectrum is now preserved when interpolating to a less dense grid, at no additional cost. + - In the 2D upsampling case, the second transform is now padded only after computing the first transform. In the 2D downsampling case, the second transform is now truncated prior to computing the first transform. This reduces the size of the problem, so the computation is less expensive. v0.2.4 From 079283186176e2c65a84e12bfb30eec8c9dcae0e Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Sun, 14 Sep 2025 08:30:50 -0700 Subject: [PATCH 8/9] Fix bug --- interpax/_fourier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/interpax/_fourier.py b/interpax/_fourier.py index d6902f7..4ead2d4 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -54,7 +54,7 @@ def _fft_interp1d(c, nx, n, sx, dx): c = c.at[-1].divide(2) c = c.at[0].divide(2) * 2 - x = jnp.linspace(0, dx * nx, n, endpoint=False) + x = jnp.linspace(0, 2 * jnp.pi, n, endpoint=False) x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1))) c = _fft_pad(c, n, 0) @@ -115,7 +115,7 @@ def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): c = c.at[:, -1].divide(2) c = c.at[:, 0].divide(2) * 2 - y = jnp.linspace(0, dy * ny, n2, endpoint=False) + y = jnp.linspace(0, 2 * jnp.pi, n2, endpoint=False) y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2))) c = jnp.fft.ifft(c, axis=0, norm="forward") From 052e56461763baf5f9afae64e53e174d3bfb2bbd Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Mon, 15 Sep 2025 09:17:35 -0700 Subject: [PATCH 9/9] Update CHANGELOG.md --- CHANGELOG.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e67ba2b..30254a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ Changelog ========= +- [Improves FFT interpolation](https://github.com/f0uriest/interpax/pull/116) + - The real FFT is now used where possible. + - Double the width of the Fourier spectrum is now preserved when interpolating to a less dense grid, at no additional cost. + - In the 2D upsampling case, the second transform is now padded only after computing the first transform. In the 2D downsampling case, the second transform is now truncated prior to computing the first transform. This reduces the size of the problem, so the computation is less expensive. - Adds a number of classes that replicate most of the functionality of the corresponding classes from scipy.interpolate : - ``scipy.interpolate.PPoly`` -> ``interpax.PPoly`` @@ -12,10 +16,6 @@ corresponding classes from scipy.interpolate : functions. - Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity with respect to each coordinate individually. -- [Improves FFT interpolation](https://github.com/f0uriest/interpax/pull/116) - - The real FFT is now used where possible. - - Double the width of the Fourier spectrum is now preserved when interpolating to a less dense grid, at no additional cost. - - In the 2D upsampling case, the second transform is now padded only after computing the first transform. In the 2D downsampling case, the second transform is now truncated prior to computing the first transform. This reduces the size of the problem, so the computation is less expensive. v0.2.4