diff --git a/lineax/_operator.py b/lineax/_operator.py index b430a2f..fc05c23 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1148,12 +1148,20 @@ def mv(self, vector): return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): - return jnp.matmul( - self.operator1.as_matrix(), - self.operator2.as_matrix(), - precision=lax.Precision.HIGHEST, # pyright: ignore + if isinstance(self.operator1, IdentityLinearOperator): + return self.operator2.as_matrix() + if isinstance(self.operator2, IdentityLinearOperator): + return self.operator1.as_matrix() + _, unravel = eqx.filter_eval_shape( + jfu.ravel_pytree, self.operator1.in_structure() ) + def mv_flat(v): + out = self.operator1.mv(unravel(v)) + return jfu.ravel_pytree(out)[0] + + return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix()) + def transpose(self): return self.operator2.transpose() @ self.operator1.transpose() @@ -1939,7 +1947,13 @@ def _(operator): @materialise.register(ComposedLinearOperator) def _(operator): - return materialise(operator.operator1) @ materialise(operator.operator2) + op1 = materialise(operator.operator1) + op2 = materialise(operator.operator2) + if isinstance(op1, IdentityLinearOperator): + return op2 + if isinstance(op2, IdentityLinearOperator): + return op1 + return op1 @ op2 @diagonal.register(ComposedLinearOperator)