Skip to content

"cubic2" method causes cache miss under jit #115

@maxecharles

Description

@maxecharles

Hi there,

I've noticed an issue with the "cubic2" method of interpolation causing a cache miss under jit, and thus a silent recompile. I originally noticed it in interp2d but in creating a MWE of this issue I found it to happen for interp1d as well.

I noticed this by moving my code (which includes a few interp2d(... method="cubic2") calls deep down) from a 2080Ti GPU to an A100 GPU and seeing barely any speedup at all. This baffled me but in hindsight makes sense if it's essentially compiling on the CPU for every function call. Indeed just changing the interpolation method to "cubic" caused the cache_miss to not trigger and the code to speed up by a factor of ~30.

I've confirmed this is what happening by code profiling/tracing with Perfetto, where you can see a cache miss is being triggered in the trace. I'm not sure what would be causing this, and haven't had an opportunity to actually look into the interpax interpolation code.

This is similar to this jax issue, however when I did testing with the flags jax_log_compiles=True and jax_explain_cache_misses=True, nothing showed up (a truly pathological silent error) and I'm still not entirely sure why.

Here is a minimal reproducible example.

import jax
import jax.numpy as jnp
import interpax
from interpax import interp1d
from time import time

print("JAX version:", jax.__version__)
print("Interpax version:", interpax.__version__)

# creating some data for testing
xp = jnp.linspace(0, 2 * jnp.pi, 100)
xq = jnp.linspace(0, 2 * jnp.pi, 10000)
f = lambda x: jnp.sin(x)
fp = f(xp)
> JAX version: 0.6.0
> Interpax version: 0.3.10

First looking at the "cubic" method, which behaves fine.

# Testing "cubic"
@jax.jit
def cubic_func():
    fq = interp1d(xq, xp, fp, method="cubic")
    return fq


start = time()
cubic_func()  # compiling
print("Compilation time:", time() - start)

# tracing and timing the "cubic" function under jit
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    for i in range(5):
        start = time()
        cubic_func().block_until_ready()  # running under jit
        print(f"Iteration {i + 1} time:", time() - start)
> Compilation time: 0.2931036949157715
> 2025-08-12 15:31:20.276223: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Iteration 1 time: 0.00045752525329589844
> Iteration 2 time: 0.00023293495178222656
> Iteration 3 time: 0.00021767616271972656
> Iteration 4 time: 0.0002753734588623047
> Iteration 5 time: 0.00024366378784179688
> 2025-08-12 15:31:20.286202: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
127.0.0.1 - - [12/Aug/2025 15:31:31] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [12/Aug/2025 15:31:31] "OPTIONS /status HTTP/1.1" 501 -
127.0.0.1 - - [12/Aug/2025 15:31:31] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [12/Aug/2025 15:31:31] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
127.0.0.1 - - [12/Aug/2025 15:31:31] code 404, message File not found
127.0.0.1 - - [12/Aug/2025 15:31:31] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [12/Aug/2025 15:31:31] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Image

Trace file you can open with Perfetto: cubic-trace.perfetto_trace.gz

Next similarly looking at the "cubic2" method, which causes a cache miss under jit, and thus a silent recompile.

# Testing "cubic2"
@jax.jit
def cubic2_func():
    fq = interp1d(xq, xp, fp, method="cubic2")
    return fq


# compiling "cubic2" function
start = time()
cubic2_func()  # compiling
print("Compilation time:", time() - start)

# tracing and timing the "cubic2" function under jit
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    for i in range(5):
        start = time()
        cubic2_func().block_until_ready()  # running under jit
        print(f"Iteration {i + 1} time:", time() - start)
> Compilation time: 4.072752475738525
> 2025-08-12 15:31:35.507632: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Iteration 1 time: 0.04135751724243164
> Iteration 2 time: 0.0422360897064209
> Iteration 3 time: 0.0421442985534668
> Iteration 4 time: 0.041191816329956055
> Iteration 5 time: 0.04430389404296875
> 2025-08-12 15:31:35.725075: E external/xla/xla/python/profiler/internal/python_hooks.cc:412] Can't import tensorflow.python.profiler.trace
> Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "OPTIONS /status HTTP/1.1" 501 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 501, message Unsupported method ('OPTIONS')
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] code 404, message File not found
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "POST /status HTTP/1.1" 404 -
> 127.0.0.1 - - [12/Aug/2025 15:32:01] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -
Image

Trace file you can open with Perfetto: cubic2-trace.perfetto_trace.gz

Thanks,
Max

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions