Skip to content

Support "diagonal" primitives with no/slow JVP batch rule#164

Open
jpbrodrick89 wants to merge 40 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/diagfromjac
Open

Support "diagonal" primitives with no/slow JVP batch rule#164
jpbrodrick89 wants to merge 40 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/diagfromjac

Conversation

@jpbrodrick89
Copy link
Contributor

@jpbrodrick89 jpbrodrick89 commented Jun 17, 2025

BREAKING CHANGE: Diagonals are no longer "extracted" from operators but rely on the promise of the diagonal tag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.

Preface

At present, both JacobianLinearOperator and FunctionLinearOperator require full materialisation even if provided with a diagonal tag. This seems self-evidently expensive (in practice it certainly can be but more often is not, see below) and requires the underlying function (which could potentially be a custom primitive) to have a batching rule. As it is currently the case that tags are considered to be a "promise" and are unchecked with no guarantee of behaviour, there are some shortcuts we can take.

Changes made

The proposal here is to use the observation that the diagonal of a matrix can be obtained by pre/post-multiplying it by a unit vector and thereby re-write the single-dispatch diagonal method for JacobianLinearOperator and FunctionLinearOperator so that the as_matrix() method is not required. For JacobianLinearOperator either jax.jvp or jax.vjp will be called depending on the jac keyword (forward-mode should always be more efficient but meeting the user's expectation will avoid issues if forward-mode is not supported such as when using a custom_vjp is used). For FunctionLinearOperator, we can just use self.mv.

However, if the matrix is not actually diagonal this identity will not hold and results may be unexpected due to contributions from off-diagonals.

I considered using operator.transpose().mv instead of writing out vjp but if the matrix is tagged as symmetric then this would end up calling jacrev instead of vjp.

Why is this helpful?

When using lineax directly one can of course just define a DiagonalOperator instead of a more general JacobianLinearOperator, but this is not always possible. For example, when using optimistix, the operator is instantiated within the optimisation routine and the only way to inform the optimiser about the underlying structure of the matrix is through tags. Therefore, if the function being optimised is a primitive (e.g. an FFI) with a JVP rule that does not support batching a user is stuck. If a slow batching rule, such as vmap_method="sequential", is used the current approach is also painfully slow for large matrix sizes.

Performance impact

I had initially hoped this to have a minor positive impact on performance across the board, but as ever I have massively underestimated the power of XLA. In practice, whether this PR seems to improve performance (e.g. for a linear_solve or an optimistix.root_find) of a pure jax function appears to fluctuate with array size. By playing around with different XLA_FLAGS and other environment variables, my best guess is that this is mostly due to threading; a vmap applied to a jnp.eye is threaded much more aggressively meaning that the apparent time complexity appears to be of lower order than a more direct approach. However, when I tried to eliminate threading this PR still seems to have an 8–10% negative impact on performance for array sizes > 100 on an optimistix.root_find.

Pure `jax` comparison: using `jvp` when attempting to enfore single-threadedness is about 14µs faster.
@jax.jit
def from_eye(x):
    eye = jnp.eye(len(x), dtype=x.dtype)
    jac = jax.vmap(lambda t: jnp.cos(x)*t, out_axes=-1)(eye)
    return jnp.diag(jac)

@jax.jit
def direct(x):
    return jnp.cos(x)

It seems self-evident that the second function should be more efficient, however with the new thunk runtime on my Mac from_eye runs faster than direct (referred to as wrapped in the diagram below, you can ignore unwrapped and vmap as similar performance) for array sizes > ~1.5E4:
image

Disabling the thunk runtime (with XLA_FLAGS=--xla_cpu_use_thunk_runtime=false which is reported to run faster in some circumstances) decreases the gap between the two by slightly slowing down the eye implementation and accelerating the direct approach:
image

Going further and following all suggestions in github.com/jax-ml/jax/discussions/22739 to limit to one thread/core and we can see the direct approach is now consistently about 14µs faster:
image

`linear_solve` significantly faster (often >2x) for array sizes <2E4 using thunk runtime, but runs about 6–10% slower for large array sizes when disabling and attempting to enforce a single thread
Code tested
solver = lineax.Diagonal(well_posed=True)

def double(x, args):
    return (2.0 * jnp.ones_like(x)) * x

@jax.jit
def jac_op_to_sol(rhs):
    op = lineax.JacobianLinearOperator(double, rhs, tags=frozenset({lineax.diagonal_tag}))
    return lineax.linear_solve(op, rhs, solver, throw=False)

@jax.jit
def func_op_to_sol(rhs):
    op = lineax.FunctionLinearOperator(lambda x: double(x, None), jax.eval_shape(lambda: rhs), tags=frozenset({lineax.diagonal_tag}))
    return lineax.linear_solve(op, rhs, solver, throw=False)

Using standard thunk runtime and EQX_ON_ERROR=nan we see significant speedup for array sizes < 1E4
image

Enforcing single-thread the performance between the old and the new approaches is very similar but tracks at about 6–10% slower for larger array sizes.
image

(Note that DiagonalOperator is actually slower somehow.)

Similar behaviour is observed with `optimistix.root_find` (but with more modest gains, and some hits for larger array sizes)

I compared performance for a multi-root find of the sin function (with EQX_ON_ERROR=nan):

x = jnp.repeat(jnp.linspace(-10., 10., num=20), n // 20)
newton = optx.Newton(rtol=1e-8, atol=1e-8, linear_solver=lx.Diagonal())
optx.root_find(lambda x, args: jnp.sin(x), newton, x, max_steps=8, tags=frozenset({lx.diagonal_tag}), throw=False).value

Default settings (jax 0.6.1, main vs this branch of lineax) with and without standard thunk` runtime:
image

In both runtimes this PR improves/maintains performance by up to a factor of 2 for arrays of size up to 1E4 at which point it becomes slightly slow than the current version (by ~8%).

However, limiting to one thread as best as I can most of the noise is eliminated and the two have very similar performance time (the change tracking about 6% slower) except for an array size of 20 where the proposed change is faster:
image

Much more substantial performance improvement (8x or higher) is observed for primitives that only support `sequential` batching rules

This is a very contrived example, but based on very real use cases we have over at tesseract-core and tesseract-jax. I have defined a new primitive version of sin with a jvp rule that batches sequentially and is therefore slow and doesn't benefit from compilation/threading in the same way:

Code for primitives
import numpy as onp

import jax
import jax.numpy as jnp
from jax.interpreters import ad, batching, mlir
from jax.core import ShapedArray
from jax.extend import core

from jax._src.lib.mlir.dialects import hlo
from jax._src.ffi import ffi_batching_rule

from functools import partial

jax.config.update("jax_enable_x64", True)

sin_p = core.Primitive("mysin")
cos_mult_p = core.Primitive("cos_mult")

def sin_prim(x):
    return sin_p.bind(x)

def cos_mult_prim(*args):
    p = args[0]
    return cos_mult_p.bind(*args,
                           vmap_method="sequential",
                           result_avals=(ShapedArray(p.shape, p.dtype), ))

def cos_mult_impl(*args, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return onp.cos(args[0]) * args[1]

def cos_mult_abstract_eval(*args, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return ShapedArray(args[0].shape, args[0].dtype)

def sin_jvp(p, t):
    return sin_prim(p[0]), cos_mult_prim(p[0], t[0])

def sin_lowering(ctx, x):
    return [ hlo.SineOp(x).result ]

def cos_mult_lowering(ctx, p, t, *, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return [ hlo.MulOp(hlo.CosineOp(p), t).result ]

def cos_mult_batch(args, axes, vmap_method="sequential", result_avals=None):
    as_tup, out_dims =  ffi_batching_rule(cos_mult_p, args, axes, 
                            vmap_method="sequential",
                            result_avals=(ShapedArray(args[0].shape, args[0].dtype),))
    return jnp.stack(list(as_tup)), out_dims[0]

sin_p.def_impl(onp.sin)
cos_mult_p.def_impl(onp.cos)
sin_p.def_abstract_eval(lambda x: ShapedArray(x.shape, x.dtype))
cos_mult_p.def_abstract_eval(cos_mult_abstract_eval)
mlir.register_lowering(sin_p, sin_lowering)
mlir.register_lowering(cos_mult_p, cos_mult_lowering)
ad.primitive_jvps[sin_p] = sin_jvp
# If the `cos_mult_p` batching rule isn't present this won't run on the current version
# but will with this PR
batching.primitive_batchers[cos_mult_p] = cos_mult_batch

I then ran the same tests as before but with sin_p instead of jnp.sin and we can see the time complexity of the current version is almost quadratic for array sizes greater than 100 (as one would naively expect for a dense jacobian) meaning that speedups range from a factor of 2 (array size of 20) to a factor of 8 (array size of 5000) and higher:
image

Running benchmarks/solver_speed.py shows a negligible improvement in the single Diagonal solve but a 50% faster batch solve, this could of course be down to noise as the solve is only timed once. (This uses lx.Diagonal so not relevant and probably just a fluke.)

Testing done

  • CI passes after modifying test_diagonal such that operators are actually initialised with diagonal matrices
  • Can find root (using both Newton and Bisection) of scalar function with no batching rule and take gradients through the root solve (not possible previously) this tests both JacobianLinearOperator and FunctionLinearOperator in action
  • Can obtain diagonal from JacobianLinearOperator with jac="bwd"

Happy to perform any further requested testing you see fit/necessary. I appreciate I haven't managed to test reverse-mode especially extensively.

Next steps

In a future PR, I would like to do something similar for other structures (e.g. tridiagonal) this should address the large O(n) discrepancy observed in #149 (but not the O(0) discrepancy). I believe this will be a much more consistent and meaningful gain than observed here. This PR here should likely be a lot easier to grok and reason about the concept and discuss framework/design choices (although maybe not the performance impact :) ) before building out further.

@jpbrodrick89 jpbrodrick89 changed the title Support primitives with no/slow JVP batch rule Support "diagonal" primitives with no/slow JVP batch rule Jun 17, 2025
@jpbrodrick89
Copy link
Contributor Author

I think I prefer this implementation (latest commit) with unravel and ones rather than map and ones_like, it seems more consistent with what is done elsewhere in the codebase (and in my tridiagonal PR) and probably more efficient for complex PyTrees. Performance is essentially identical for my test with a single 1D array.

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jun 24, 2025

I finally realised why jacobian didn't run consistently slower. This is because the computation was dominated by "transcendental" evaluations rather the MAC's/pure FLOPs. The cost of a transcendental evaluation depends on the evaluation and for sine and cosine evaluating it at the zero's of the identity matrix is very, very cheap meaning that actually computing them did not provide great cost and the more aggressive threading of jacobian often outweighed this. If I instead use a transcendental equation that is ever so slightly less trivial to evaluate at 0—jnp.exp( - (x - 1.0)**2 / 2.0)—I get much more consistent and convincing results (typically 1.5–2.5x faster for n>1E3 with default jax settings):

image

image

Would it be helpful to re-write the PR message with updated benchmarks and prose in light of this?

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jun 24, 2025

That reasoning was completely wrong again, the derivative is not evaluated at the tangent vectors just multiplied by them. I think the reduced FLOPs due to matrix multiplication is only more noticeable on the example above because jnp.exp is a bit cheaper than jnp.sin. In general, jax.jacobian is very very efficient for unary functions.

However, it is easy to fool jit by breaking up the function as it then needs to keep track of which function is applied to which index. This is actually not a very far fetched scenario if one is solving two independent ODE's for example.

def myfunc(x):
    halfway = len(x) // 2
    return jnp.concatenate([jnp.sin(x[:halfway]), jnp.cos(x[halfway:])])

image

Here we seen an O(n) speed exceeding 1E4 for array sizes of 4E4 which is huge.

In general I see no significant adverse impact of this and some very pronounced positive impacts in realistic use cases.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I really like seeing a colouring approach like this!

I will note however that lx.diagonal is documented as Extracts the diagonal from a linear operator, and returns a vector, which is meant to include extracting the diagonal from nondiagonal operators. I think the implementations you have here should check with is_diagonal to determine whether to dispatch to the new or the old implementation.

@@ -1363,10 +1363,40 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:

@diagonal.register(MatrixLinearOperator)
@diagonal.register(PyTreeLinearOperator)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note, I think we could make the PyTreeLinearOperator case a little more efficient by calling jnp.diag on each 'diagonal leaf', and concatenate those together. (You don't have to do that here, just spotting it.)

elif operator.jac == "bwd":
fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args))
_, vjp_fun = jax.vjp(fn, operator.x)
diag_as_pytree = vjp_fun(unravel(basis))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fail for operators with different input and output structures. (They might still be mathematically square, just split up into a pytree in different ways / with different dtypes.) This needs to be a basis formed from operator.out_structure().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a test for mismatched input and output structure and looks like these are not allowed to be diagonal in the current version of lineax. Shall we leave relaxing that assumption and fixing this for future work? This implementation should work fine as long as the assumption holds.

FAILED tests/test_operator.py::test_is_symmetric[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_symmetric[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
FAILED tests/test_operator.py::test_is_diagonal[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_diagonal[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, these are PyTreeLinearOperators but I need to address for JacobianLinearOperators

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually got exactly the same error with JacobianLinearOperator, if this is not expected I can push the test I added to test/helpers to investigate further

@_operators_append
def make_nontrivial_jac_operator(getkey, matrix, tags):
    # makes a Jacobian linear operator from matrix with
    # input structure {"array", (in_size -1,), "scalar": ()}
    # output structure ((), (out_size - 1,))
    out_size, in_size = matrix.shape
    x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
    a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
    b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
    c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
    fn_tmp = lambda x, _: a + b @ x + c @ x**2
    jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
    diff = matrix - jac

    def fn(x, args):
        x_flat = jnp.concatenate([x["array"], x["scalar"][jnp.newaxis]])
        y_flat = a + (b + diff) @ x_flat + c @ x_flat**2
        y = [y_flat[0], y_flat[1:]]
        return y

    return lx.JacobianLinearOperator(fn, {"array": x[:-1], "scalar": x[-1]}, None, tags)

@jpbrodrick89
Copy link
Contributor Author

Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think diagonal or tridagonal is ever used by operators that don't have the corresponding tag, so we have an option to just change the doc string instead of you like? While I could certainly imagine situations where the current documented usage could be convenient when testing out stability of various differencing schemes (e.g. extracting a tridiagonal to treat implicitly and treating the rest explicitly) but in some cases there should be a more efficient way to express the operator in two part.

@patrick-kidger
Copy link
Owner

Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think diagonal or tridagonal is ever used by operators that don't have the corresponding tag, so we have an option to just change the doc string instead of you like?

However they're actually a standalone public API themselves, which could be used independent of the solvers. :)

@jpbrodrick89
Copy link
Contributor Author

Sorry for abandoning this for so long as the day job took over. I returned as it was found to provide orders of magnitude impact on a problem I was working (root finding over multiple interpolations) even for small array sizes (200–2000). I have addressed the main point of retaining extraction of the diagonal when the diagonal tag is missing. However, your other two comments seem at odds with each other: either we ensure input/output structures match (which is actually the case, see above) enabling us to extract diagonal leafs for PyTreeLinearOperator and leaving the JacobianLinearOperator vjp implementation as it is, or we relax this assumption and are no longer able to extract diagonal leafs as there is not guaranteed to be such an object.

@patrick-kidger
Copy link
Owner

Sorry for the long delay getting back to you, some personal life things took over for a while.

So, now to actually answer your question: good point.

I imagine we could probably do the 'diagonal leaf' approach when the structures match, and go for the more expensive approach when they don't?

adconner and others added 5 commits December 5, 2025 01:10
…full_rank (patrick-kidger#158)

The two functions allow_dependent_{rows,columns} together did the job of
answering if the solver accepts full rank matrices for the purposes of
the jvp. Allowing them to be implemented separately created some issues:
1) Invalid states were representable. Eg. What does it mean that
   dependent columns are allowed for square matrices if dependent rows
   are not? What does it mean that dependent rows are not allowed for
   matrices with more rows than columns?
2) As the functions accept operator as input, a custom solver could in
   principle decide its answer based on operator's dynamic value rather
   than only jax compilation static information regarding it, as in all
   the lineax defined solvers. This would prevent jax compilation and
   jit.

Both issues are addressed by asking the solver to report only if it
assumes the input is numerically full rank. If this assumption is
exactly violated, its behavior is allowed to be undefined, and is
allowed to error, produce NaN values, and produce invalid values.
@jpbrodrick89
Copy link
Contributor Author

No worries, hope you're managing alright. Sorry to miss you at DiffSys, but looking forward to catching up with Johanna!

Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated?

Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done?

Johanna Haffner and others added 8 commits December 13, 2025 12:16
…verflow

Decrease default value to prevent overflow in 32-bit.
There seem to be some spurious downstream failures in Diffrax with JAX 0.8.2 otherwise. Probably JAX has started promoting these to tracers on some unusual codepath.
@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 22, 2025

Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated?

Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done?

Yup, I think that'd be reasonable! Let me know when this PR is ready and we'll get this merged :)

jpbrodrick89 and others added 8 commits January 30, 2026 00:33
* Add sparse materialisation helper and efficient diagonal paths

This PR introduces _try_sparse_materialise helper and optimizes diagonal
operator handling throughout lineax.

Key changes:
- Add _try_sparse_materialise() that converts diagonal-tagged operators to
  DiagonalLinearOperator, preserving pytree structure via unravel
- Add efficient diagonal() for JLO/FLO using single JVP/VJP with ones basis
- Add efficient diagonal() for Composed: diag(A @ B) = diag(A) * diag(B)
- Simplify mv() for MLO, PTLO, Add, Composed to use _try_sparse_materialise
- Apply early sparse materialisation in materialise() registrations

Aux handling:
- Fix bug: linearise/materialise now preserve aux on AuxLinearOperator
- Preserve aux from first operator in Composed (output comes from op1)
- Inner aux in Add children silently stripped (unclear semantics - may
  warrant guards in future)

---------

Co-authored-by: jpbrodrick89 <jpbrodrick89@users.noreply.github.com>
@jpbrodrick89
Copy link
Contributor Author

Note this is largely ready to go and includes #195 (but not #196), happy to get those in first and merge this in after for a cleaner diff.

Main change I've made is the try_sparse_materialise helper which converts an operator to DiagonalLinearOperator if it's diagonal, which can be used in mv or materialise and should make the code more maintainable and extensible (I shouldn't need to touch materialise in the tridiagonal PR except for JacobianLinearOperator. This is a bit of a breaking change though as materialise does not just return self for as many operators but instead materialises as the "most efficient" representation of the operator. Happy to hear any thoughts. :-)

Also happy to be targeting a dev branch if you prefer as these are not just "fixes", one just doesn't exist right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants