-
Notifications
You must be signed in to change notification settings - Fork 30
Improve FFT interpolation (1/2) #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9ec429d
706d48f
1f22d7f
3cda4d7
f9fa8c6
a6c2b64
bc9bf3a
720b606
0792831
052e564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ def fft_interp1d( | |
| sx: Optional[Num[ArrayLike, " s"]] = None, | ||
| dx: float = 1.0, | ||
| ) -> Inexact[Array, "n ... s"]: | ||
| """Interpolation of a 1d periodic function via FFT. | ||
| """Interpolation of a real-valued 1D periodic function via FFT. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -33,18 +33,32 @@ def fft_interp1d( | |
| Interpolated (and possibly shifted) data points | ||
| """ | ||
| f = asarray_inexact(f) | ||
| c = jnp.fft.ifft(f, axis=0) | ||
| nx = c.shape[0] | ||
| c = jnp.fft.rfft(f, axis=0, norm="forward") | ||
| return _fft_interp1d(c, f.shape[0], n, sx, dx) | ||
|
|
||
|
|
||
| def _fft_interp1d(c, nx, n, sx, dx): | ||
| if sx is not None: | ||
| tau = 2 * jnp.pi | ||
| sx = asarray_inexact(sx) | ||
| sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) | ||
| sx = jnp.exp(1j * jnp.fft.rfftfreq(nx, dx / tau)[:, None] * sx) | ||
| c = (c[None].T * sx).T | ||
| c = jnp.moveaxis(c, 0, -1) | ||
| pad = ((n - nx) // 2, n - nx - (n - nx) // 2) | ||
| if nx % 2 != 0: | ||
| pad = pad[::-1] | ||
| c = jnp.fft.ifftshift(_pad_along_axis(jnp.fft.fftshift(c, axes=0), pad, axis=0)) | ||
| return jnp.fft.fft(c, axis=0).real | ||
|
|
||
| if n >= nx: | ||
| return jnp.fft.irfft(c, n, axis=0, norm="forward") | ||
|
|
||
| if n < c.shape[0]: | ||
| c = c[:n] | ||
| elif nx % 2 == 0: | ||
| c = c.at[-1].divide(2) | ||
| c = c.at[0].divide(2) * 2 | ||
|
|
||
| x = jnp.linspace(0, 2 * jnp.pi, n, endpoint=False) | ||
| x = jnp.exp(1j * (c.shape[0] // 2) * x).reshape(n, *((1,) * (c.ndim - 1))) | ||
|
Comment on lines
+57
to
+58
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain this part? Isn't this just multiplying the output by
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we transform sum_n^N c_n e^(i n x) for n>=0 to e^iNx//2 sum_k c_n e^(ikx) where k now includes the relevant negative values (that result from pulling the highest frequency outside the sum)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| c = _fft_pad(c, n, 0) | ||
| return (jnp.fft.ifft(c, axis=0, norm="forward") * x).real | ||
|
|
||
|
|
||
| @wrap_jit(static_argnames=["n1", "n2"]) | ||
|
|
@@ -57,7 +71,7 @@ def fft_interp2d( | |
| dx: float = 1.0, | ||
| dy: float = 1.0, | ||
| ) -> Inexact[Array, "n1 n2 ... s"]: | ||
| """Interpolation of a 2d periodic function via FFT. | ||
| """Interpolation of a real-valued 2D periodic function via FFT. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -77,45 +91,60 @@ def fft_interp2d( | |
| Interpolated (and possibly shifted) data points | ||
| """ | ||
| f = asarray_inexact(f) | ||
| c = jnp.fft.ifft2(f, axes=(0, 1)) | ||
| nx, ny = c.shape[:2] | ||
| c = jnp.fft.rfft2(f, axes=(0, 1), norm="forward") | ||
| return _fft_interp2d(c, *f.shape[:2], n1, n2, sx, sy, dx, dy) | ||
|
|
||
|
|
||
| def _fft_interp2d(c, nx, ny, n1, n2, sx, sy, dx, dy): | ||
| if (sx is not None) and (sy is not None): | ||
| tau = 2 * jnp.pi | ||
| sx = asarray_inexact(sx) | ||
| sy = asarray_inexact(sy) | ||
| sx = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(nx)[:, None] * sx / dx) | ||
| sy = jnp.exp(-1j * 2 * jnp.pi * jnp.fft.fftfreq(ny)[:, None] * sy / dy) | ||
| c = (c[None].T * sx[None, :, :] * sy[:, None, :]).T | ||
| sx = jnp.exp(1j * jnp.fft.fftfreq(nx, dx / tau)[:, None] * sx) | ||
| sy = jnp.exp(1j * jnp.fft.rfftfreq(ny, dy / tau)[:, None] * sy) | ||
| c = (c[None].T * (sx[None] * sy[:, None])).T | ||
| c = jnp.moveaxis(c, 0, -1) | ||
| padx = ((n1 - nx) // 2, n1 - nx - (n1 - nx) // 2) | ||
| pady = ((n2 - ny) // 2, n2 - ny - (n2 - ny) // 2) | ||
| if nx % 2 != 0: | ||
| padx = padx[::-1] | ||
| if ny % 2 != 0: | ||
| pady = pady[::-1] | ||
|
|
||
| c = jnp.fft.ifftshift( | ||
| _pad_along_axis(jnp.fft.fftshift(c, axes=0), padx, axis=0), axes=0 | ||
| ) | ||
| c = jnp.fft.ifftshift( | ||
| _pad_along_axis(jnp.fft.fftshift(c, axes=1), pady, axis=1), axes=1 | ||
| ) | ||
| c = _fft_pad(jnp.fft.fftshift(c, 0), n1, 0) | ||
| if n2 >= ny: | ||
| return jnp.fft.irfft2(c, (n1, n2), axes=(0, 1), norm="forward") | ||
|
|
||
| if n2 < c.shape[1]: | ||
| c = c[:, :n2] | ||
| elif ny % 2 == 0: | ||
| c = c.at[:, -1].divide(2) | ||
| c = c.at[:, 0].divide(2) * 2 | ||
|
|
||
| y = jnp.linspace(0, 2 * jnp.pi, n2, endpoint=False) | ||
| y = jnp.exp(1j * (c.shape[1] // 2) * y).reshape(1, n2, *((1,) * (c.ndim - 2))) | ||
|
|
||
| return jnp.fft.fft2(c, axes=(0, 1)).real | ||
| c = jnp.fft.ifft(c, axis=0, norm="forward") | ||
| c = _fft_pad(c, n2, 1) | ||
| return (jnp.fft.ifft(c, axis=1, norm="forward") * y).real | ||
|
|
||
|
|
||
| def _fft_pad(c_shift, n_out, axis): | ||
| n_in = c_shift.shape[axis] | ||
| p = n_out - n_in | ||
| p = (p // 2, p - p // 2) | ||
| if n_in % 2 != 0: | ||
| p = p[::-1] | ||
| return jnp.fft.ifftshift(_pad_along_axis(c_shift, p, axis), axis) | ||
|
|
||
|
|
||
| def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): | ||
| """Pad with zeros or truncate a given dimension.""" | ||
| array = jnp.moveaxis(array, axis, 0) | ||
| index = [slice(None)] * array.ndim | ||
| pad_width = [(0, 0)] * array.ndim | ||
| start = stop = None | ||
|
|
||
| if pad[0] < 0: | ||
| array = array[abs(pad[0]) :] | ||
| start = -pad[0] | ||
| pad = (0, pad[1]) | ||
| if pad[1] < 0: | ||
| array = array[: -abs(pad[1])] | ||
| stop = pad[1] | ||
| pad = (pad[0], 0) | ||
|
|
||
| npad = [(0, 0)] * array.ndim | ||
| npad[0] = pad | ||
|
|
||
| array = jnp.pad(array, pad_width=npad, mode="constant", constant_values=0) | ||
| return jnp.moveaxis(array, 0, axis) | ||
| index[axis] = slice(start, stop) | ||
| pad_width[axis] = pad | ||
| return jnp.pad(array[tuple(index)], pad_width) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't this truncation mean that we're still interpolating the band limited signal, not the full one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c.shape[0] has size f.shape[0]//2 + 1. so our band limited signal can preserve double the width of the fourier spectrum than before