-
Notifications
You must be signed in to change notification settings - Fork 30
Description
Hi there,
I've noticed an issue with the "cubic2" method of interpolation causing a cache miss under jit, and thus a silent recompile. I originally noticed it in interp2d but in creating a MWE of this issue I found it to happen for interp1d as well.
I noticed this by moving my code (which includes a few interp2d(... method="cubic2") calls deep down) from a 2080Ti GPU to an A100 GPU and seeing barely any speedup at all. This baffled me but in hindsight makes sense if it's essentially compiling on the CPU for every function call. Indeed just changing the interpolation method to "cubic" caused the cache_miss to not trigger and the code to speed up by a factor of ~30.
I've confirmed this is what happening by code profiling/tracing with Perfetto, where you can see a cache miss is being triggered in the trace. I'm not sure what would be causing this, and haven't had an opportunity to actually look into the interpax interpolation code.
This is similar to this jax issue, however when I did testing with the flags jax_log_compiles=True and jax_explain_cache_misses=True, nothing showed up (a truly pathological silent error) and I'm still not entirely sure why.
Here is a minimal reproducible example.
import jax
import jax.numpy as jnp
import interpax
from interpax import interp1d
from time import time
print("JAX version:", jax.__version__)
print("Interpax version:", interpax.__version__)
# creating some data for testing
xp = jnp.linspace(0, 2 * jnp.pi, 100)
xq = jnp.linspace(0, 2 * jnp.pi, 10000)
f = lambda x: jnp.sin(x)
fp = f(xp)> JAX version: 0.6.0
> Interpax version: 0.3.10
First looking at the "cubic" method, which behaves fine.
# Testing "cubic"
@jax.jit
def cubic_func():
fq = interp1d(xq, xp, fp, method="cubic")
return fq
start = time()
cubic_func() # compiling
print("Compilation time:", time() - start)
# tracing and timing the "cubic" function under jit
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
for i in range(5):
start = time()
cubic_func().block_until_ready() # running under jit
print(f"Iteration {i + 1} time:", time() - start)> Compilation time: 0.2931036949157715
> 2025-08-12 15:31:20.276223: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Iteration 1 time: 0.00045752525329589844
> Iteration 2 time: 0.00023293495178222656
> Iteration 3 time: 0.00021767616271972656
> Iteration 4 time: 0.0002753734588623047
> Iteration 5 time: 0.00024366378784179688
> 2025-08-12 15:31:20.286202: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
127.0.0.1 - - [12/Aug/2025 15:31:31] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [12/Aug/2025 15:31:31] "OPTIONS /status HTTP/1.1" 501 -
127.0.0.1 - - [12/Aug/2025 15:31:31] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [12/Aug/2025 15:31:31] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
127.0.0.1 - - [12/Aug/2025 15:31:31] code 404, message File not found
127.0.0.1 - - [12/Aug/2025 15:31:31] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [12/Aug/2025 15:31:31] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Trace file you can open with Perfetto: cubic-trace.perfetto_trace.gz
Next similarly looking at the "cubic2" method, which causes a cache miss under jit, and thus a silent recompile.
# Testing "cubic2"
@jax.jit
def cubic2_func():
fq = interp1d(xq, xp, fp, method="cubic2")
return fq
# compiling "cubic2" function
start = time()
cubic2_func() # compiling
print("Compilation time:", time() - start)
# tracing and timing the "cubic2" function under jit
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
for i in range(5):
start = time()
cubic2_func().block_until_ready() # running under jit
print(f"Iteration {i + 1} time:", time() - start)> Compilation time: 4.072752475738525
> 2025-08-12 15:31:35.507632: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Iteration 1 time: 0.04135751724243164
> Iteration 2 time: 0.0422360897064209
> Iteration 3 time: 0.0421442985534668
> Iteration 4 time: 0.041191816329956055
> Iteration 5 time: 0.04430389404296875
> 2025-08-12 15:31:35.725075: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "OPTIONS /status HTTP/1.1" 501 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 404, message File not found
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "POST /status HTTP/1.1" 404 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Trace file you can open with Perfetto: cubic2-trace.perfetto_trace.gz
Thanks,
Max