Conversation
JAX issue #156765 seems to be fixed and using jnp.asarray passes Optimistix's tests.
|
The CI fails because |
|
Hi Bence! Thank you for looking into this :) On your specific points:
|
|
Of course, but I'll rather open a tiny PR for dropping 3.10. Do you want one for updating Ignoring the need to support older JAX versions, would this be a valid solution? I wanted to replicate the setup done in |
No, this should work. On older versions the JVP of |
|
I'll close this PR :) |
…nd for JAX issue #15676 which is no longer required as of 0.7.2 (patrick-kidger/optimistix#193 (comment))
… recent JAX (#47) * assume that _asarray custom JVP workaround can be retired * updating required JAX version to drop custom JVP on _asarray workaround for JAX issue #15676 which is no longer required as of 0.7.2 (patrick-kidger/optimistix#193 (comment)) * retire private _asarray * update local path * Fix dimensions for CVXQP1, CVXBQP1, and DIXMAANA1 - CVXQP1: Fix dimension from 100 to 10000 to match SIF parameter - CVXBQP1: Fix dimension from 10000 to 100000 to match SIF parameter - DIXMAANA1: Fix dimension from 3000 to 3 (M=1, so N=3*M=3) and correct expected objective value from 0.0 to 1.0 (constant term) All three problems now pass their full test suites. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Vectorize LISWET base class and add tuple-based caching - Replace Python for-loops with vectorized JAX operations in coefficient computation - Add __init__ method to pre-compute perturbation, t_values, and constraint coefficients - Store pre-computed values as tuples (hashable) instead of JAX arrays (unhashable) - Convert tuples to arrays with appropriate dtype inside objective/constraint functions - Fix dtype promotion issues by using explicit float literals in arange calls Performance improvements: - LISWET9: Runtime tests now pass (previously 6.92x slower, now within 5x threshold) - LISWET10: Runtime tests now pass (previously 7.40x slower, now within 5x threshold) - LISWET11: Runtime tests now pass (previously 7.49x slower, now within 5x threshold) All correctness tests pass for LISWET9, LISWET10, LISWET11. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Fix field ordering in LISWET base class Move pre-computed tuple fields after fields with defaults to avoid dataclass ordering issues. * Vectorize HADAMARD and HADAMALS implementations - Replaced nested for-loops with fully vectorized JAX operations - Used searchsorted and index arithmetic for column-major upper triangle extraction - Ensured consistent dtypes by matching searchsorted output dtype - All tests pass including runtime tests (within 5x threshold) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * simplifications for HS problems * Vectorize EXPLIN2 exponential sum for 25x gradient speedup Replace Python for-loop with vectorized JAX operations in the exponential term computation. This dramatically improves gradient performance: - Gradient: 22.06x slower → 0.88x (now faster than Fortran!) - Objective: 2.26x → 2.88x (still well within 5x target) The vectorization eliminates loop overhead and allows JAX to generate optimized computation graphs for efficient derivative computation. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * version bump --------- Co-authored-by: Assistant <assistant@example.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
I just noticed that the JAX issue
optimistix._misc.asarraywas working around seems to be fixed now.jax-ml/jax#18020 is still open, but it might not be a problem because
jax>=0.7.2"represents constants in its internal jaxpr representation as aTypedNdArray".I added a non-array leaf to the
smoke_auxused in the tests and the example from @patrick-kidger's comment. They pass withjax>=0.7.2, but not with earlier versions.This example from another comment doesn't pass, but not sure how essential that is for Optimistix:
JAX commits that might be relevant:
jax-ml/jax@b43662f
jax-ml/jax@00538d0
jax-ml/jax@cb4fcd8
I don't have a good understanding of JAX internals, so please feel free to ignore this if it's not helpful.