From da6af45a6ebac1c34c35983fb459c5ba5beb140d Mon Sep 17 00:00:00 2001 From: jpbrodrick89 Date: Sat, 31 Jan 2026 11:57:26 +0000 Subject: [PATCH 1/2] Optimize ComposedLinearOperator materialisation and as_matrix - Add identity shortcuts: Identity @ X returns X, X @ Identity returns X - Use vmap over operator1.mv instead of matmul for as_matrix, enabling efficient composition when operator1 has O(N) mv (e.g., Diagonal) https://claude.ai/code/session_0143xm3Fot5bh7Zy3GfkP6bD --- lineax/_operator.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index f52953d..69be63d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1131,11 +1131,15 @@ def mv(self, vector): return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): - return jnp.matmul( - self.operator1.as_matrix(), - self.operator2.as_matrix(), - precision=lax.Precision.HIGHEST, # pyright: ignore - ) + if isinstance(self.operator1, IdentityLinearOperator): + return self.operator2.as_matrix() + if isinstance(self.operator2, IdentityLinearOperator): + return self.operator1.as_matrix() + _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, self.operator1.in_structure()) + def mv_flat(v): + out = self.operator1.mv(unravel(v)) + return jfu.ravel_pytree(out)[0] + return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix()) def transpose(self): return self.operator2.transpose() @ self.operator1.transpose() @@ -1909,7 +1913,13 @@ def _(operator): @materialise.register(ComposedLinearOperator) def _(operator): - return materialise(operator.operator1) @ materialise(operator.operator2) + op1 = materialise(operator.operator1) + op2 = materialise(operator.operator2) + if isinstance(op1, IdentityLinearOperator): + return op2 + if isinstance(op2, IdentityLinearOperator): + return op1 + return op1 @ op2 @diagonal.register(ComposedLinearOperator) From b13d829bdaa461242278ee8914e5e652c5d3d33d Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 12:14:46 +0000 Subject: [PATCH 2/2] run precommit --- lineax/_operator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index d01f1f1..fc05c23 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1152,10 +1152,14 @@ def as_matrix(self): return self.operator2.as_matrix() if isinstance(self.operator2, IdentityLinearOperator): return self.operator1.as_matrix() - _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, self.operator1.in_structure()) + _, unravel = eqx.filter_eval_shape( + jfu.ravel_pytree, self.operator1.in_structure() + ) + def mv_flat(v): out = self.operator1.mv(unravel(v)) return jfu.ravel_pytree(out)[0] + return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix()) def transpose(self):