Skip to content

Remove asarray workaround?#193

Closed
bagibence wants to merge 6 commits intopatrick-kidger:devfrom
bagibence:remove_asarray
Closed

Remove asarray workaround?#193
bagibence wants to merge 6 commits intopatrick-kidger:devfrom
bagibence:remove_asarray

Conversation

@bagibence
Copy link
Contributor

I just noticed that the JAX issue optimistix._misc.asarray was working around seems to be fixed now.
jax-ml/jax#18020 is still open, but it might not be a problem because jax>=0.7.2 "represents constants in its internal jaxpr representation as a TypedNdArray".

I added a non-array leaf to the smoke_aux used in the tests and the example from @patrick-kidger's comment. They pass with jax>=0.7.2, but not with earlier versions.

This example from another comment doesn't pass, but not sure how essential that is for Optimistix:

f, consts = jax.closure_convert(jax.numpy.asarray, 1.0)
out = f(1.0, *consts)
assert type(out) is not float

JAX commits that might be relevant:
jax-ml/jax@b43662f
jax-ml/jax@00538d0
jax-ml/jax@cb4fcd8

I don't have a good understanding of JAX internals, so please feel free to ignore this if it's not helpful.

@bagibence
Copy link
Contributor Author

The CI fails because jax>=0.7.0 requires Python 3.11, so dropping support for 3.10 is another change here.

@johannahaffner
Copy link
Collaborator

johannahaffner commented Dec 10, 2025

Hi Bence!

Thank you for looking into this :) On your specific points:

  • Dropping compatibility with earlier versions of JAX and correspondingly earlier versions of Equinox is not a price I'm willing to pay to drop a small utility function. We definitely have users whose code still uses older versions, and there are cases where older versions are more performant for certain specialised tasks.
  • Dropping support for Python 3.10 is reasonable though - I think it makes sense to match what JAX is doing in that regard, plus we'd adhere to the "last three versions of Python" rule. Our code would continue to run on older versions to the extent it does today, assuming a compatible JAX version. We'd just no longer make the promise that we'll put in the maintenance work to make sure that it continues to do so. Could you edit this PR to do only this?

@bagibence
Copy link
Contributor Author

Of course, but I'll rather open a tiny PR for dropping 3.10. Do you want one for updating smoke_aux or just drop that?

Ignoring the need to support older JAX versions, would this be a valid solution? I wanted to replicate the setup done in minimise without private functions from Optimistix or copying the workaround. Do you see anything that might go wrong here?

@johannahaffner
Copy link
Collaborator

Ignoring the need to support older JAX versions, would this be a valid solution? I wanted to replicate the setup done in minimise without private functions from Optimistix or copying the workaround. Do you see anything that might go wrong here?

No, this should work. On older versions the JVP of jnp.asarray was the problem, not the function itself. If you're no longer running into the same bug then this is probably a good way to go :)

@johannahaffner
Copy link
Collaborator

I'll close this PR :)

johannahaffner pushed a commit to johannahaffner/sif2jax that referenced this pull request Jan 8, 2026
johannahaffner added a commit to johannahaffner/sif2jax that referenced this pull request Jan 9, 2026
… recent JAX (#47)

* assume that _asarray custom JVP workaround can be retired

* updating required JAX version to drop custom JVP on _asarray workaround for JAX issue #15676 which is no longer required as of 0.7.2 (patrick-kidger/optimistix#193 (comment))

* retire private _asarray

* update local path

* Fix dimensions for CVXQP1, CVXBQP1, and DIXMAANA1

- CVXQP1: Fix dimension from 100 to 10000 to match SIF parameter
- CVXBQP1: Fix dimension from 10000 to 100000 to match SIF parameter
- DIXMAANA1: Fix dimension from 3000 to 3 (M=1, so N=3*M=3) and correct expected objective value from 0.0 to 1.0 (constant term)

All three problems now pass their full test suites.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Vectorize LISWET base class and add tuple-based caching

- Replace Python for-loops with vectorized JAX operations in coefficient computation
- Add __init__ method to pre-compute perturbation, t_values, and constraint coefficients
- Store pre-computed values as tuples (hashable) instead of JAX arrays (unhashable)
- Convert tuples to arrays with appropriate dtype inside objective/constraint functions
- Fix dtype promotion issues by using explicit float literals in arange calls

Performance improvements:
- LISWET9: Runtime tests now pass (previously 6.92x slower, now within 5x threshold)
- LISWET10: Runtime tests now pass (previously 7.40x slower, now within 5x threshold)
- LISWET11: Runtime tests now pass (previously 7.49x slower, now within 5x threshold)

All correctness tests pass for LISWET9, LISWET10, LISWET11.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* Fix field ordering in LISWET base class

Move pre-computed tuple fields after fields with defaults to avoid dataclass ordering issues.

* Vectorize HADAMARD and HADAMALS implementations

- Replaced nested for-loops with fully vectorized JAX operations
- Used searchsorted and index arithmetic for column-major upper triangle extraction
- Ensured consistent dtypes by matching searchsorted output dtype
- All tests pass including runtime tests (within 5x threshold)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* simplifications for HS problems

* Vectorize EXPLIN2 exponential sum for 25x gradient speedup

Replace Python for-loop with vectorized JAX operations in the exponential
term computation. This dramatically improves gradient performance:

- Gradient: 22.06x slower → 0.88x (now faster than Fortran!)
- Objective: 2.26x → 2.88x (still well within 5x target)

The vectorization eliminates loop overhead and allows JAX to generate
optimized computation graphs for efficient derivative computation.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* version bump

---------

Co-authored-by: Assistant <assistant@example.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants