Support "diagonal" primitives with no/slow JVP batch rule#164
Support "diagonal" primitives with no/slow JVP batch rule#164jpbrodrick89 wants to merge 40 commits intopatrick-kidger:mainfrom
Conversation
merge main into fork
…ntire Jacobian matrix
|
I think I prefer this implementation (latest commit) with |
|
I finally realised why Would it be helpful to re-write the PR message with updated benchmarks and prose in light of this? |
|
That reasoning was completely wrong again, the derivative is not evaluated at the tangent vectors just multiplied by them. I think the reduced FLOPs due to matrix multiplication is only more noticeable on the example above because However, it is easy to fool def myfunc(x):
halfway = len(x) // 2
return jnp.concatenate([jnp.sin(x[:halfway]), jnp.cos(x[halfway:])])Here we seen an O(n) speed exceeding 1E4 for array sizes of 4E4 which is huge. In general I see no significant adverse impact of this and some very pronounced positive impacts in realistic use cases. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Nice, I really like seeing a colouring approach like this!
I will note however that lx.diagonal is documented as Extracts the diagonal from a linear operator, and returns a vector, which is meant to include extracting the diagonal from nondiagonal operators. I think the implementations you have here should check with is_diagonal to determine whether to dispatch to the new or the old implementation.
lineax/_operator.py
Outdated
| @@ -1363,10 +1363,40 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: | |||
|
|
|||
| @diagonal.register(MatrixLinearOperator) | |||
| @diagonal.register(PyTreeLinearOperator) | |||
There was a problem hiding this comment.
Side note, I think we could make the PyTreeLinearOperator case a little more efficient by calling jnp.diag on each 'diagonal leaf', and concatenate those together. (You don't have to do that here, just spotting it.)
lineax/_operator.py
Outdated
| elif operator.jac == "bwd": | ||
| fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) | ||
| _, vjp_fun = jax.vjp(fn, operator.x) | ||
| diag_as_pytree = vjp_fun(unravel(basis)) |
There was a problem hiding this comment.
I think this will fail for operators with different input and output structures. (They might still be mathematically square, just split up into a pytree in different ways / with different dtypes.) This needs to be a basis formed from operator.out_structure().
There was a problem hiding this comment.
I just added a test for mismatched input and output structure and looks like these are not allowed to be diagonal in the current version of lineax. Shall we leave relaxing that assumption and fixing this for future work? This implementation should work fine as long as the assumption holds.
FAILED tests/test_operator.py::test_is_symmetric[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_symmetric[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
FAILED tests/test_operator.py::test_is_diagonal[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_diagonal[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
There was a problem hiding this comment.
Never mind, these are PyTreeLinearOperators but I need to address for JacobianLinearOperators
There was a problem hiding this comment.
I actually got exactly the same error with JacobianLinearOperator, if this is not expected I can push the test I added to test/helpers to investigate further
@_operators_append
def make_nontrivial_jac_operator(getkey, matrix, tags):
# makes a Jacobian linear operator from matrix with
# input structure {"array", (in_size -1,), "scalar": ()}
# output structure ((), (out_size - 1,))
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
def fn(x, args):
x_flat = jnp.concatenate([x["array"], x["scalar"][jnp.newaxis]])
y_flat = a + (b + diff) @ x_flat + c @ x_flat**2
y = [y_flat[0], y_flat[1:]]
return y
return lx.JacobianLinearOperator(fn, {"array": x[:-1], "scalar": x[-1]}, None, tags)|
Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think |
However they're actually a standalone public API themselves, which could be used independent of the solvers. :) |
|
Sorry for abandoning this for so long as the day job took over. I returned as it was found to provide orders of magnitude impact on a problem I was working (root finding over multiple interpolations) even for small array sizes (200–2000). I have addressed the main point of retaining extraction of the diagonal when the diagonal tag is missing. However, your other two comments seem at odds with each other: either we ensure input/output structures match (which is actually the case, see above) enabling us to extract diagonal leafs for |
|
Sorry for the long delay getting back to you, some personal life things took over for a while. So, now to actually answer your question: good point. I imagine we could probably do the 'diagonal leaf' approach when the structures match, and go for the more expensive approach when they don't? |
…full_rank (patrick-kidger#158) The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues: 1) Invalid states were representable. Eg. What does it mean that dependent columns are allowed for square matrices if dependent rows are not? What does it mean that dependent rows are not allowed for matrices with more rows than columns? 2) As the functions accept operator as input, a custom solver could in principle decide its answer based on operator's dynamic value rather than only jax compilation static information regarding it, as in all the lineax defined solvers. This would prevent jax compilation and jit. Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values.
|
No worries, hope you're managing alright. Sorry to miss you at DiffSys, but looking forward to catching up with Johanna! Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated? Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done? |
…verflow Decrease default value to prevent overflow in 32-bit.
There seem to be some spurious downstream failures in Diffrax with JAX 0.8.2 otherwise. Probably JAX has started promoting these to tracers on some unusual codepath.
Yup, I think that'd be reasonable! Let me know when this PR is ready and we'll get this merged :) |
* Add sparse materialisation helper and efficient diagonal paths This PR introduces _try_sparse_materialise helper and optimizes diagonal operator handling throughout lineax. Key changes: - Add _try_sparse_materialise() that converts diagonal-tagged operators to DiagonalLinearOperator, preserving pytree structure via unravel - Add efficient diagonal() for JLO/FLO using single JVP/VJP with ones basis - Add efficient diagonal() for Composed: diag(A @ B) = diag(A) * diag(B) - Simplify mv() for MLO, PTLO, Add, Composed to use _try_sparse_materialise - Apply early sparse materialisation in materialise() registrations Aux handling: - Fix bug: linearise/materialise now preserve aux on AuxLinearOperator - Preserve aux from first operator in Composed (output comes from op1) - Inner aux in Add children silently stripped (unclear semantics - may warrant guards in future) --------- Co-authored-by: jpbrodrick89 <jpbrodrick89@users.noreply.github.com>
|
Note this is largely ready to go and includes #195 (but not #196), happy to get those in first and merge this in after for a cleaner diff. Main change I've made is the Also happy to be targeting a dev branch if you prefer as these are not just "fixes", one just doesn't exist right now. |



BREAKING CHANGE: Diagonals are no longer "extracted" from operators but rely on the promise of the
diagonaltag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.Preface
At present, both
JacobianLinearOperatorandFunctionLinearOperatorrequire full materialisation even if provided with adiagonaltag. This seems self-evidently expensive (in practice it certainly can be but more often is not, see below) and requires the underlying function (which could potentially be a custom primitive) to have a batching rule. As it is currently the case that tags are considered to be a "promise" and are unchecked with no guarantee of behaviour, there are some shortcuts we can take.Changes made
The proposal here is to use the observation that the diagonal of a matrix can be obtained by pre/post-multiplying it by a unit vector and thereby re-write the single-dispatch
diagonalmethod forJacobianLinearOperatorandFunctionLinearOperatorso that theas_matrix()method is not required. ForJacobianLinearOperatoreitherjax.jvporjax.vjpwill be called depending on thejackeyword (forward-mode should always be more efficient but meeting the user's expectation will avoid issues if forward-mode is not supported such as when using acustom_vjpis used). ForFunctionLinearOperator, we can just useself.mv.However, if the matrix is not actually diagonal this identity will not hold and results may be unexpected due to contributions from off-diagonals.
I considered using
operator.transpose().mvinstead of writing outvjpbut if the matrix is tagged assymmetricthen this would end up callingjacrevinstead ofvjp.Why is this helpful?
When using
lineaxdirectly one can of course just define aDiagonalOperatorinstead of a more generalJacobianLinearOperator, but this is not always possible. For example, when usingoptimistix, the operator is instantiated within the optimisation routine and the only way to inform the optimiser about the underlying structure of the matrix is throughtags. Therefore, if the function being optimised is a primitive (e.g. an FFI) with a JVP rule that does not support batching a user is stuck. If a slow batching rule, such asvmap_method="sequential", is used the current approach is also painfully slow for large matrix sizes.Performance impact
I had initially hoped this to have a minor positive impact on performance across the board, but as ever I have massively underestimated the power of XLA. In practice, whether this PR seems to improve performance (e.g. for a
linear_solveor anoptimistix.root_find) of a pure jax function appears to fluctuate with array size. By playing around with differentXLA_FLAGSand other environment variables, my best guess is that this is mostly due to threading; avmapapplied to ajnp.eyeis threaded much more aggressively meaning that the apparent time complexity appears to be of lower order than a more direct approach. However, when I tried to eliminate threading this PR still seems to have an 8–10% negative impact on performance for array sizes > 100 on anoptimistix.root_find.Pure `jax` comparison: using `jvp` when attempting to enfore single-threadedness is about 14µs faster.
It seems self-evident that the second function should be more efficient, however with the new

thunkruntime on my Macfrom_eyeruns faster thandirect(referred to aswrappedin the diagram below, you can ignoreunwrappedandvmapas similar performance) for array sizes > ~1.5E4:Disabling the

thunkruntime (withXLA_FLAGS=--xla_cpu_use_thunk_runtime=falsewhich is reported to run faster in some circumstances) decreases the gap between the two by slightly slowing down theeyeimplementation and accelerating thedirectapproach:Going further and following all suggestions in github.com/jax-ml/jax/discussions/22739 to limit to one thread/core and we can see the

directapproach is now consistently about 14µs faster:`linear_solve` significantly faster (often >2x) for array sizes <2E4 using thunk runtime, but runs about 6–10% slower for large array sizes when disabling and attempting to enforce a single thread
Code tested
Using standard thunk runtime and

EQX_ON_ERROR=nanwe see significant speedup for array sizes < 1E4Enforcing single-thread the performance between the old and the new approaches is very similar but tracks at about 6–10% slower for larger array sizes.

(Note that
DiagonalOperatoris actually slower somehow.)Similar behaviour is observed with `optimistix.root_find` (but with more modest gains, and some hits for larger array sizes)
I compared performance for a multi-root find of the
sinfunction (withEQX_ON_ERROR=nan):Default settings (

jax0.6.1,mainvs this branch oflineax) with and without standardthunk` runtime:In both runtimes this PR improves/maintains performance by up to a factor of 2 for arrays of size up to 1E4 at which point it becomes slightly slow than the current version (by ~8%).
However, limiting to one thread as best as I can most of the noise is eliminated and the two have very similar performance time (the change tracking about 6% slower) except for an array size of 20 where the proposed change is faster:

Much more substantial performance improvement (8x or higher) is observed for primitives that only support `sequential` batching rules
This is a very contrived example, but based on very real use cases we have over at tesseract-core and tesseract-jax. I have defined a new primitive version of
sinwith ajvprule that batches sequentially and is therefore slow and doesn't benefit from compilation/threading in the same way:Code for primitives
I then ran the same tests as before but with

sin_pinstead ofjnp.sinand we can see the time complexity of the current version is almost quadratic for array sizes greater than 100 (as one would naively expect for a dense jacobian) meaning that speedups range from a factor of 2 (array size of 20) to a factor of 8 (array size of 5000) and higher:Running(This usesbenchmarks/solver_speed.pyshows a negligible improvement in the singleDiagonalsolve but a 50% faster batch solve, this could of course be down to noise as the solve is only timed once.lx.Diagonalso not relevant and probably just a fluke.)Testing done
test_diagonalsuch that operators are actually initialised with diagonal matricesNewtonandBisection) of scalar function with no batching rule and take gradients through the root solve (not possible previously) this tests bothJacobianLinearOperatorandFunctionLinearOperatorin actiondiagonalfromJacobianLinearOperatorwithjac="bwd"Happy to perform any further requested testing you see fit/necessary. I appreciate I haven't managed to test reverse-mode especially extensively.
Next steps
In a future PR, I would like to do something similar for other structures (e.g. tridiagonal) this should address the large O(n) discrepancy observed in #149 (but not the O(0) discrepancy). I believe this will be a much more consistent and meaningful gain than observed here. This PR here should likely be a lot easier to grok and reason about the concept and discuss framework/design choices (although maybe not the performance impact :) ) before building out further.