diff --git a/lineax/_operator.py b/lineax/_operator.py index b430a2f..05dfba0 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -262,6 +262,9 @@ def __init__( self.tags = _frozenset(tags) def mv(self, vector): + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST) def as_matrix(self): @@ -326,6 +329,14 @@ def __init__(self, value): self.value = value +def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array: + """Extract the leaf from a pytree at the given keypath.""" + for path, leaf in jtu.tree_leaves_with_path(pytree): + if path == keypath: + return leaf + raise ValueError(f"Leaf not found at keypath {keypath}") + + # The `{input,output}_structure`s have to be static because otherwise abstract # evaluation rules will promote them to ShapedArrays. class PyTreeLinearOperator(AbstractLinearOperator): @@ -421,7 +432,11 @@ def mv(self, vector): # vector has structure [tree(in), leaf(in)] # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] - # return has struture [tree(out), leaf(out)] + # return has structure [tree(out), leaf(out)] + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) + def matmul(_, matrix): return _tree_matmul(matrix, vector) @@ -1011,6 +1026,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) mv1 = self.operator1.mv(vector) mv2 = self.operator2.mv(vector) return (mv1**ω + mv2**ω).ω @@ -1145,6 +1163,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): @@ -1330,8 +1351,35 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: _default_not_implemented("materialise", operator) +def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: + """Try to materialise to a sparse operator. + + Returns a DiagonalLinearOperator if the operator is tagged as diagonal, + otherwise returns the original operator unchanged. The resulting operator + preserves the input/output structure of the original operator. + + Note: This function silently strips aux and as such should not be called + on AuxLinearOperator or JacobianLinearOperatory directly. + """ + if is_diagonal(operator): + diag_flat = diagonal(operator) + _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + diag_pytree = unravel(diag_flat) + return DiagonalLinearOperator(diag_pytree) + return operator + + +def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree[Array]: + """Construct a PyTree of ones matching the given structure.""" + return jtu.tree_map(lambda s: jnp.ones(s.shape, s.dtype), structure) + + @materialise.register(MatrixLinearOperator) @materialise.register(PyTreeLinearOperator) +def _(operator): + return _try_sparse_materialise(operator) + + @materialise.register(IdentityLinearOperator) @materialise.register(DiagonalLinearOperator) @materialise.register(TridiagonalLinearOperator) @@ -1342,6 +1390,21 @@ def _(operator): @materialise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) + # can't use try_sparse_materialise as strips aux + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + if operator.jac == "bwd": + ((_, aux), vjp_fn) = jax.vjp(fn, operator.x) + if aux is None: + aux_ct = None + else: + aux_ct = jtu.tree_map(jnp.zeros_like, aux) + (diag_as_pytree,) = vjp_fn((basis, aux_ct)) + else: # "fwd" or None + (_, aux), (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,)) + out = DiagonalLinearOperator(diag_as_pytree) + return AuxLinearOperator(out, aux) jac, aux = jacobian( fn, operator.in_size(), @@ -1356,6 +1419,9 @@ def _(operator): @materialise.register(FunctionLinearOperator) def _(operator): + out = _try_sparse_materialise(operator) + if out is not operator: + return out flat, unravel = strip_weak_dtype( eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) ) @@ -1397,11 +1463,36 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) +def _(operator): + return jnp.diag(operator.as_matrix()) + + @diagonal.register(PyTreeLinearOperator) +def _(operator): + if is_diagonal(operator): + + def extract_diag(keypath, struct, subpytree): + block = _leaf_from_keypath(subpytree, keypath) + return jnp.diag(block.reshape(struct.size, struct.size)) + + diags = jtu.tree_map_with_path( + extract_diag, operator.out_structure(), operator.pytree + ) + return jnp.concatenate(jtu.tree_leaves(diags)) + else: + return jnp.diag(operator.as_matrix()) + + @diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) def _(operator): - return jnp.diag(operator.as_matrix()) + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + diag_as_pytree = operator.mv(basis) + diag, _ = jfu.ravel_pytree(diag_as_pytree) + return diag + return diagonal(materialise(operator)) @diagonal.register(DiagonalLinearOperator) @@ -1863,9 +1954,27 @@ def _(operator, transform=transform): def _(operator, transform=transform): return transform(operator.operator) / operator.scalar - @transform.register(AuxLinearOperator) # pyright: ignore + +# diagonal strips aux (returns array, not operator) +@diagonal.register(AuxLinearOperator) +def _(operator): + return diagonal(operator.operator) + + +# linearise and materialise preserve aux +for transform in (linearise, materialise): + + @transform.register(AuxLinearOperator) def _(operator, transform=transform): - return transform(operator.operator) + return AuxLinearOperator(transform(operator.operator), operator.aux) + + +@materialise.register(AddLinearOperator) +def _(operator): + out = _try_sparse_materialise(operator) + if out is not operator: + return out + return materialise(operator.operator1) + materialise(operator.operator2) @linearise.register(TangentLinearOperator) @@ -1934,16 +2043,31 @@ def _(operator): @linearise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(linearise(inner_composed), aux) return linearise(operator.operator1) @ linearise(operator.operator2) @materialise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(materialise(inner_composed), aux) + out = _try_sparse_materialise(operator) + if out is not operator: + return out return materialise(operator.operator1) @ materialise(operator.operator2) @diagonal.register(ComposedLinearOperator) def _(operator): + if is_diagonal(operator): + return diagonal(operator.operator1) * diagonal(operator.operator2) return jnp.diag(operator.as_matrix()) diff --git a/tests/helpers.py b/tests/helpers.py index 73c1633..4259b85 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -216,7 +216,7 @@ def make_jacfwd_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 @@ -235,7 +235,7 @@ def make_jacrev_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac diff --git a/tests/test_operator.py b/tests/test_operator.py index dae9562..eb08162 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -177,9 +177,17 @@ def test_materialise_large(dtype, getkey): def test_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) matrix_diag = jnp.diag(matrix) + # test we properly extract diagonal from a dense matrix when not tagged operators = _setup(getkey, matrix) for operator in operators: assert jnp.allclose(lx.diagonal(operator), matrix_diag) + # test we properly extract diagonal from diagonal matrix when tagged + operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag) + for operator in operators: + if isinstance(operator, lx.IdentityLinearOperator): + assert jnp.allclose(lx.diagonal(operator), jnp.ones(3)) + else: + assert jnp.allclose(lx.diagonal(operator), matrix_diag) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))