diff --git a/lineax/_operator.py b/lineax/_operator.py index f52953d..69be63d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1131,11 +1131,15 @@ def mv(self, vector): return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): - return jnp.matmul( - self.operator1.as_matrix(), - self.operator2.as_matrix(), - precision=lax.Precision.HIGHEST, # pyright: ignore - ) + if isinstance(self.operator1, IdentityLinearOperator): + return self.operator2.as_matrix() + if isinstance(self.operator2, IdentityLinearOperator): + return self.operator1.as_matrix() + _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, self.operator1.in_structure()) + def mv_flat(v): + out = self.operator1.mv(unravel(v)) + return jfu.ravel_pytree(out)[0] + return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix()) def transpose(self): return self.operator2.transpose() @ self.operator1.transpose() @@ -1909,7 +1913,13 @@ def _(operator): @materialise.register(ComposedLinearOperator) def _(operator): - return materialise(operator.operator1) @ materialise(operator.operator2) + op1 = materialise(operator.operator1) + op2 = materialise(operator.operator2) + if isinstance(op1, IdentityLinearOperator): + return op2 + if isinstance(op2, IdentityLinearOperator): + return op1 + return op1 @ op2 @diagonal.register(ComposedLinearOperator)