diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f88bd1..30254a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ Changelog ========= +- [Improves FFT interpolation](https://github.com/f0uriest/interpax/pull/116) + - The real FFT is now used where possible. + - Double the width of the Fourier spectrum is now preserved when interpolating to a less dense grid, at no additional cost. + - In the 2D upsampling case, the second transform is now padded only after computing the first transform. In the 2D downsampling case, the second transform is now truncated prior to computing the first transform. This reduces the size of the problem, so the computation is less expensive. - Adds a number of classes that replicate most of the functionality of the corresponding classes from scipy.interpolate : - ``scipy.interpolate.PPoly`` -> ``interpax.PPoly`` diff --git a/interpax/_fourier.py b/interpax/_fourier.py index a7b4158..4ead2d4 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -14,7 +14,7 @@ def fft_interp1d( sx: Optional[Num[ArrayLike, " s"]] = None, dx: float = 1.0, ) -> Inexact[Array, "n ... s"]: - """Interpolation of a 1d periodic function via FFT. + """Interpolation of a real-valued 1D periodic function via FFT. Parameters ---------- @@ -33,18 +33,32 @@ def fft_interp1d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - c = jnp.fft.ifft(f, axis=0) - nx = c.shape[0] + c = jnp.fft.rfft(f, axis=0, norm="forward") + return _fft_interp1d(c, f.shape[0], n, sx, dx) + + +def _fft_interp1d(c, nx, n, sx, dx): if sx is not None: + tau = 2 * jnp.pi sx = asarray_inexact(sx) - sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) + sx = jnp.exp(1j * jnp.fft.rfftfreq(nx, dx / tau)[:, None] * sx) c = (c[None].T * sx).T c = jnp.moveaxis(c, 0, -1) - pad = ((n - nx) // 2, n - nx - (n - nx) // 2) - if nx % 2 != 0: - pad = pad[::-1] - c = jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0)) - return jnp.fft.fft(c, axis=0).real + + if n >= nx: + return jnp.fft.irfft(c, n, axis=0, norm="forward") + + if n < c.shape[0]: + c = c[:n] + elif nx % 2 == 0: + c = c.at[-1].divide(2) + c = c.at[0].divide(2) * 2 + + x = jnp.linspace(0, 2 * jnp.pi, n, endpoint=False) + x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1))) + + c = _fft_pad(c, n, 0) + return (jnp.fft.ifft(c, axis=0, norm="forward") * x).real @wrap_jit(static_argnames=["n1", "n2"]) @@ -57,7 +71,7 @@ def fft_interp2d( dx: float = 1.0, dy: float = 1.0, ) -> Inexact[Array, "n1 n2 ... s"]: - """Interpolation of a 2d periodic function via FFT. + """Interpolation of a real-valued 2D periodic function via FFT. Parameters ---------- @@ -77,45 +91,60 @@ def fft_interp2d( Interpolated (and possibly shifted) data points """ f = asarray_inexact(f) - c = jnp.fft.ifft2(f, axes=(0, 1)) - nx, ny = c.shape[:2] + c = jnp.fft.rfft2(f, axes=(0, 1), norm="forward") + return _fft_interp2d(c, *f.shape[:2], n1, n2, sx, sy, dx, dy) + + +def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): if (sx is not None) and (sy is not None): + tau = 2 * jnp.pi sx = asarray_inexact(sx) sy = asarray_inexact(sy) - sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) - sy = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(ny)[:, None] * sy / dy) - c = (c[None].T * sx[None, :, :] * sy[:, None, :]).T + sx = jnp.exp(1j * jnp.fft.fftfreq(nx, dx / tau)[:, None] * sx) + sy = jnp.exp(1j * jnp.fft.rfftfreq(ny, dy / tau)[:, None] * sy) + c = (c[None].T * (sx[None] * sy[:, None])).T c = jnp.moveaxis(c, 0, -1) - padx = ((n1 - nx) // 2, n1 - nx - (n1 - nx) // 2) - pady = ((n2 - ny) // 2, n2 - ny - (n2 - ny) // 2) - if nx % 2 != 0: - padx = padx[::-1] - if ny % 2 != 0: - pady = pady[::-1] - c = jnp.fft.ifftshift( - _pad_along_axis(jnp.fft.fftshift(c, axes=0), padx, axis=0), axes=0 - ) - c = jnp.fft.ifftshift( - _pad_along_axis(jnp.fft.fftshift(c, axes=1), pady, axis=1), axes=1 - ) + c = _fft_pad(jnp.fft.fftshift(c, 0), n1, 0) + if n2 >= ny: + return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward") + + if n2 < c.shape[1]: + c = c[:, :n2] + elif ny % 2 == 0: + c = c.at[:, -1].divide(2) + c = c.at[:, 0].divide(2) * 2 + + y = jnp.linspace(0, 2 * jnp.pi, n2, endpoint=False) + y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2))) - return jnp.fft.fft2(c, axes=(0, 1)).real + c = jnp.fft.ifft(c, axis=0, norm="forward") + c = _fft_pad(c, n2, 1) + return (jnp.fft.ifft(c, axis=1, norm="forward") * y).real + + +def _fft_pad(c_shift, n_out, axis): + n_in = c_shift.shape[axis] + p = n_out - n_in + p = (p // 2, p - p // 2) + if n_in % 2 != 0: + p = p[::-1] + return jnp.fft.ifftshift(_pad_along_axis(c_shift, p, axis), axis) def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): """Pad with zeros or truncate a given dimension.""" - array = jnp.moveaxis(array, axis, 0) + index = [slice(None)] * array.ndim + pad_width = [(0, 0)] * array.ndim + start = stop = None if pad[0] < 0: - array = array[abs(pad[0]) :] + start = -pad[0] pad = (0, pad[1]) if pad[1] < 0: - array = array[: -abs(pad[1])] + stop = pad[1] pad = (pad[0], 0) - npad = [(0, 0)] * array.ndim - npad[0] = pad - - array = jnp.pad(array, pad_width=npad, mode="constant", constant_values=0) - return jnp.moveaxis(array, 0, axis) + index[axis] = slice(start, stop) + pad_width[axis] = pad + return jnp.pad(array[tuple(index)], pad_width) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index f2760a6..c9184fa 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -317,11 +317,12 @@ def fun(x): for i in [1, 2]: f1[p][i] = fun(x[p][i]) + sx = 0.2 for sp in ["o", "e"]: # source parity fi = f1[sp][1] - fs = fun(x[sp][1] + 0.2) + fs = fun(x[sp][1] + sx) np.testing.assert_allclose( - fs, fft_interp1d(fi, *fi.shape, sx=0.2, dx=np.diff(x[sp][1])[0]).squeeze() + fft_interp1d(fi, *fi.shape, sx, dx=x[sp][1][1] - x[sp][1][0]).squeeze(), fs ) for ep in ["o", "e"]: # eval parity for s in ["up", "down"]: # up or downsample @@ -331,9 +332,11 @@ def fun(x): else: xs = 2 xe = 1 - true = fun(x[ep][xe]) - interp = fft_interp1d(f1[sp][xs], x[ep][xe].size) - np.testing.assert_allclose(true, interp, atol=1e-12, rtol=1e-12) + true = fun(x[ep][xe] + sx) + interp = fft_interp1d( + f1[sp][xs], x[ep][xe].size, sx, dx=x[sp][xs][1] - x[sp][xs][0] + ).squeeze() + np.testing.assert_allclose(interp, true, atol=1e-12, rtol=1e-12) @pytest.mark.unit @@ -370,20 +373,22 @@ def fun2(x, y): for j in [1, 2]: f2[xp][yp][i][j] = fun2(x[xp][i], y[yp][j]) + shiftx = 0.2 + shifty = 0.3 for spx in ["o", "e"]: # source parity x for spy in ["o", "e"]: # source parity y fi = f2[spx][spy][1][1] - fs = fun2(x[spx][1] + 0.2, y[spy][1] + 0.3) + fs = fun2(x[spx][1] + shiftx, y[spy][1] + shifty) np.testing.assert_allclose( - fs, fft_interp2d( fi, *fi.shape, - sx=0.2, - sy=0.3, + shiftx, + shifty, dx=np.diff(x[spx][1])[0], dy=np.diff(y[spy][1])[0] ).squeeze(), + fs, ) for epx in ["o", "e"]: # eval parity x for epy in ["o", "e"]: # eval parity y @@ -401,12 +406,19 @@ def fun2(x, y): else: ys = 2 ye = 1 - true = fun2(x[epx][xe], y[epy][ye]) + + true = fun2(x[epx][xe] + shiftx, y[epy][ye] + shifty) interp = fft_interp2d( - f2[spx][spy][xs][ys], x[epx][xe].size, y[epy][ye].size - ) + f2[spx][spy][xs][ys], + x[epx][xe].size, + y[epy][ye].size, + shiftx, + shifty, + dx=x[spx][xs][1] - x[spx][xs][0], + dy=y[spy][ys][1] - y[spy][ys][0], + ).squeeze() np.testing.assert_allclose( - true, interp, atol=1e-12, rtol=1e-12 + interp, true, atol=1e-12, rtol=1e-12 )