Skip to content

Conversation

@danieldjohnson
Copy link
Collaborator

Only slicing with integers or None are supported. Previously, other values such as scalar JAX arrays were passed through, but this can cause confusing behavior and expose bugs because JAX does not generally allow putting JAX arrays in static PyTree metadata. Casting them immediately forces the values to be concrete at trace time and prevents confusing bugs later.

Fixes #117

Only slicing with integers or None are supported. Previously,
other values such as scalar JAX arrays were passed through,
but this can cause confusing behavior and expose bugs because
JAX does not generally allow putting JAX arrays in static PyTree
metadata. Casting them immediately forces the values to be concrete
at trace time and prevents confusing bugs later.

Fixes #117
@danieldjohnson danieldjohnson merged commit d44ea6c into main Jun 22, 2025
2 checks passed
@danieldjohnson danieldjohnson deleted the push-vlwyvqmnqmuz branch June 22, 2025 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: potential tracer leak in _jitted_nmapped_getitem?

1 participant