Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

- [Improves FFT interpolation](https://github.com/f0uriest/interpax/pull/116)
- The real FFT is now used where possible.
- Double the width of the Fourier spectrum is now preserved when interpolating to a less dense grid, at no additional cost.
- In the 2D upsampling case, the second transform is now padded only after computing the first transform. In the 2D downsampling case, the second transform is now truncated prior to computing the first transform. This reduces the size of the problem, so the computation is less expensive.
- Adds a number of classes that replicate most of the functionality of the
corresponding classes from scipy.interpolate :
- ``scipy.interpolate.PPoly`` -> ``interpax.PPoly``
Expand Down
101 changes: 65 additions & 36 deletions interpax/_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def fft_interp1d(
sx: Optional[Num[ArrayLike, " s"]] = None,
dx: float = 1.0,
) -> Inexact[Array, "n ... s"]:
"""Interpolation of a 1d periodic function via FFT.
"""Interpolation of a real-valued 1D periodic function via FFT.

Parameters
----------
Expand All @@ -33,18 +33,32 @@ def fft_interp1d(
Interpolated (and possibly shifted) data points
"""
f = asarray_inexact(f)
c = jnp.fft.ifft(f, axis=0)
nx = c.shape[0]
c = jnp.fft.rfft(f, axis=0, norm="forward")
return _fft_interp1d(c, f.shape[0], n, sx, dx)


def _fft_interp1d(c, nx, n, sx, dx):
if sx is not None:
tau = 2 * jnp.pi
sx = asarray_inexact(sx)
sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx)
sx = jnp.exp(1j * jnp.fft.rfftfreq(nx, dx / tau)[:, None] * sx)
c = (c[None].T * sx).T
c = jnp.moveaxis(c, 0, -1)
pad = ((n - nx) // 2, n - nx - (n - nx) // 2)
if nx % 2 != 0:
pad = pad[::-1]
c = jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0))
return jnp.fft.fft(c, axis=0).real

if n >= nx:
return jnp.fft.irfft(c, n, axis=0, norm="forward")

if n < c.shape[0]:
c = c[:n]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this truncation mean that we're still interpolating the band limited signal, not the full one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.shape[0] has size f.shape[0]//2 + 1. so our band limited signal can preserve double the width of the fourier spectrum than before

elif nx % 2 == 0:
c = c.at[-1].divide(2)
c = c.at[0].divide(2) * 2

x = jnp.linspace(0, 2 * jnp.pi, n, endpoint=False)
x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1)))
Comment on lines +57 to +58
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this part? Isn't this just multiplying the output by exp(i*n*x)?

Copy link
Author

@unalmis unalmis Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we transform sum_n^N c_n e^(i n x) for n>=0 to e^iNx//2 sum_k c_n e^(ikx) where k now includes the relevant negative values (that result from pulling the highest frequency outside the sum)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


c = _fft_pad(c, n, 0)
return (jnp.fft.ifft(c, axis=0, norm="forward") * x).real


@wrap_jit(static_argnames=["n1", "n2"])
Expand All @@ -57,7 +71,7 @@ def fft_interp2d(
dx: float = 1.0,
dy: float = 1.0,
) -> Inexact[Array, "n1 n2 ... s"]:
"""Interpolation of a 2d periodic function via FFT.
"""Interpolation of a real-valued 2D periodic function via FFT.

Parameters
----------
Expand All @@ -77,45 +91,60 @@ def fft_interp2d(
Interpolated (and possibly shifted) data points
"""
f = asarray_inexact(f)
c = jnp.fft.ifft2(f, axes=(0, 1))
nx, ny = c.shape[:2]
c = jnp.fft.rfft2(f, axes=(0, 1), norm="forward")
return _fft_interp2d(c, *f.shape[:2], n1, n2, sx, sy, dx, dy)


def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy):
if (sx is not None) and (sy is not None):
tau = 2 * jnp.pi
sx = asarray_inexact(sx)
sy = asarray_inexact(sy)
sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx)
sy = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(ny)[:, None] * sy / dy)
c = (c[None].T * sx[None, :, :] * sy[:, None, :]).T
sx = jnp.exp(1j * jnp.fft.fftfreq(nx, dx / tau)[:, None] * sx)
sy = jnp.exp(1j * jnp.fft.rfftfreq(ny, dy / tau)[:, None] * sy)
c = (c[None].T * (sx[None] * sy[:, None])).T
c = jnp.moveaxis(c, 0, -1)
padx = ((n1 - nx) // 2, n1 - nx - (n1 - nx) // 2)
pady = ((n2 - ny) // 2, n2 - ny - (n2 - ny) // 2)
if nx % 2 != 0:
padx = padx[::-1]
if ny % 2 != 0:
pady = pady[::-1]

c = jnp.fft.ifftshift(
_pad_along_axis(jnp.fft.fftshift(c, axes=0), padx, axis=0), axes=0
)
c = jnp.fft.ifftshift(
_pad_along_axis(jnp.fft.fftshift(c, axes=1), pady, axis=1), axes=1
)
c = _fft_pad(jnp.fft.fftshift(c, 0), n1, 0)
if n2 >= ny:
return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward")

if n2 < c.shape[1]:
c = c[:, :n2]
elif ny % 2 == 0:
c = c.at[:, -1].divide(2)
c = c.at[:, 0].divide(2) * 2

y = jnp.linspace(0, 2 * jnp.pi, n2, endpoint=False)
y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2)))

return jnp.fft.fft2(c, axes=(0, 1)).real
c = jnp.fft.ifft(c, axis=0, norm="forward")
c = _fft_pad(c, n2, 1)
return (jnp.fft.ifft(c, axis=1, norm="forward") * y).real


def _fft_pad(c_shift, n_out, axis):
n_in = c_shift.shape[axis]
p = n_out - n_in
p = (p // 2, p - p // 2)
if n_in % 2 != 0:
p = p[::-1]
return jnp.fft.ifftshift(_pad_along_axis(c_shift, p, axis), axis)


def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0):
"""Pad with zeros or truncate a given dimension."""
array = jnp.moveaxis(array, axis, 0)
index = [slice(None)] * array.ndim
pad_width = [(0, 0)] * array.ndim
start = stop = None

if pad[0] < 0:
array = array[abs(pad[0]) :]
start = -pad[0]
pad = (0, pad[1])
if pad[1] < 0:
array = array[: -abs(pad[1])]
stop = pad[1]
pad = (pad[0], 0)

npad = [(0, 0)] * array.ndim
npad[0] = pad

array = jnp.pad(array, pad_width=npad, mode="constant", constant_values=0)
return jnp.moveaxis(array, 0, axis)
index[axis] = slice(start, stop)
pad_width[axis] = pad
return jnp.pad(array[tuple(index)], pad_width)
38 changes: 25 additions & 13 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,12 @@ def fun(x):
for i in [1, 2]:
f1[p][i] = fun(x[p][i])

sx = 0.2
for sp in ["o", "e"]: # source parity
fi = f1[sp][1]
fs = fun(x[sp][1] + 0.2)
fs = fun(x[sp][1] + sx)
np.testing.assert_allclose(
fs, fft_interp1d(fi, *fi.shape, sx=0.2, dx=np.diff(x[sp][1])[0]).squeeze()
fft_interp1d(fi, *fi.shape, sx, dx=x[sp][1][1] - x[sp][1][0]).squeeze(), fs
)
for ep in ["o", "e"]: # eval parity
for s in ["up", "down"]: # up or downsample
Expand All @@ -331,9 +332,11 @@ def fun(x):
else:
xs = 2
xe = 1
true = fun(x[ep][xe])
interp = fft_interp1d(f1[sp][xs], x[ep][xe].size)
np.testing.assert_allclose(true, interp, atol=1e-12, rtol=1e-12)
true = fun(x[ep][xe] + sx)
interp = fft_interp1d(
f1[sp][xs], x[ep][xe].size, sx, dx=x[sp][xs][1] - x[sp][xs][0]
).squeeze()
np.testing.assert_allclose(interp, true, atol=1e-12, rtol=1e-12)


@pytest.mark.unit
Expand Down Expand Up @@ -370,20 +373,22 @@ def fun2(x, y):
for j in [1, 2]:
f2[xp][yp][i][j] = fun2(x[xp][i], y[yp][j])

shiftx = 0.2
shifty = 0.3
for spx in ["o", "e"]: # source parity x
for spy in ["o", "e"]: # source parity y
fi = f2[spx][spy][1][1]
fs = fun2(x[spx][1] + 0.2, y[spy][1] + 0.3)
fs = fun2(x[spx][1] + shiftx, y[spy][1] + shifty)
np.testing.assert_allclose(
fs,
fft_interp2d(
fi,
*fi.shape,
sx=0.2,
sy=0.3,
shiftx,
shifty,
dx=np.diff(x[spx][1])[0],
dy=np.diff(y[spy][1])[0]
).squeeze(),
fs,
)
for epx in ["o", "e"]: # eval parity x
for epy in ["o", "e"]: # eval parity y
Expand All @@ -401,12 +406,19 @@ def fun2(x, y):
else:
ys = 2
ye = 1
true = fun2(x[epx][xe], y[epy][ye])

true = fun2(x[epx][xe] + shiftx, y[epy][ye] + shifty)
interp = fft_interp2d(
f2[spx][spy][xs][ys], x[epx][xe].size, y[epy][ye].size
)
f2[spx][spy][xs][ys],
x[epx][xe].size,
y[epy][ye].size,
shiftx,
shifty,
dx=x[spx][xs][1] - x[spx][xs][0],
dy=y[spy][ys][1] - y[spy][ys][0],
).squeeze()
np.testing.assert_allclose(
true, interp, atol=1e-12, rtol=1e-12
interp, true, atol=1e-12, rtol=1e-12
)


Expand Down