diff --git a/lineax/_operator.py b/lineax/_operator.py index a83eeff..c0ee7c9 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -267,7 +267,7 @@ def as_matrix(self): return self.matrix def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags)) @@ -446,7 +446,7 @@ def concat_in(struct, subpytree): return jnp.concatenate(matrix, axis=0) def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self def _transpose(struct, subtree): @@ -543,15 +543,21 @@ class JacobianLinearOperator(AbstractLinearOperator): `MatrixLinearOperator(jax.jacfwd(fn)(x))`. The Jacobian is not materialised; matrix-vector products, which are in fact - Jacobian-vector products, are computed using autodifferentiation, specifically - `jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to - `jax.jvp(fn, (x,), (v,))`. - - See also [`lineax.linearise`][], which caches the primal computation, i.e. - it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)` + Jacobian-vector products, are computed using autodifferentiation. By default + (or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to + `jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with + `jax.linear_transpose`, which works even with functions + that only define a custom VJP (via `jax.custom_vjp`) and don't support + forward-mode differentiation. See also [`lineax.materialise`][], which materialises the whole Jacobian in memory. + + !!! tip + + For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache + the primal computation, e.g. for `jac="fwd"/None` it returns + `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)` """ fn: Callable[ @@ -618,10 +624,18 @@ def mv(self, vector): if self.jac == "fwd" or self.jac is None: _, out = jax.jvp(fn, (self.x,), (vector,)) elif self.jac == "bwd": - jac = jax.jacrev(fn)(self.x) - out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv( - vector - ) + # Use VJP + linear_transpose instead of materializing full Jacobian. + # This works even for custom_vjp functions that don't have JVP rules. + _, vjp_fn = jax.vjp(fn, self.x) + if is_symmetric(self): + # For symmetric operators, J = J.T, so vjp directly gives J @ v + (out,) = vjp_fn(vector) + else: + # For non-symmetric, transpose the VJP to get J @ v from J.T @ v + transpose_vjp = jax.linear_transpose( + lambda g: vjp_fn(g)[0], self.out_structure() + ) + (out,) = transpose_vjp(vector) else: raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.") return out @@ -630,7 +644,7 @@ def as_matrix(self): return materialise(self).as_matrix() def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self fn = _NoAuxOut(_NoAuxIn(self.fn, self.args)) # Works because vjpfn is a PyTree @@ -697,7 +711,7 @@ def as_matrix(self): return materialise(self).as_matrix() def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self transpose_fn = jax.linear_transpose(self.fn, self.in_structure()) @@ -1238,8 +1252,21 @@ def _(operator): @linearise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) - (_, aux), lin = jax.linearize(fn, operator.x) - lin = _NoAuxOut(lin) + if operator.jac == "bwd": + # For backward mode, use VJP + linear_transpose. + # This works even with custom_vjp functions that don't support forward-mode AD. + _, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True) + if is_symmetric(operator): + # For symmetric: J = J.T, so vjp directly gives J @ v + lin = _Unwrap(vjp_fn()) + else: + # Transpose the VJP to get J @ v from J.T @ v + lin = _Unwrap( + jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure()) + ) + else: # "fwd" or None + (_, aux), lin = jax.linearize(fn, operator.x) + lin = _NoAuxOut(lin) out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags) return AuxLinearOperator(out, aux) diff --git a/tests/helpers.py b/tests/helpers.py index 640f0e0..73c1633 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -209,6 +209,56 @@ def make_jac_operator(getkey, matrix, tags): return lx.JacobianLinearOperator(fn, x, None, tags) +@_operators_append +def make_jacfwd_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") + + +@_operators_append +def make_jacrev_operator(getkey, matrix, tags): + """JacobianLinearOperator with jac='bwd' using a custom_vjp function. + + This uses custom_vjp so that forward-mode autodiff is NOT available, + which tests that jac='bwd' works correctly without relying on JVP. + """ + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + + # Use custom_vjp to define a function that only has reverse-mode autodiff + @jax.custom_vjp + def custom_fn(x): + return a + (b + diff) @ x + c @ x**2 + + def custom_fn_fwd(x): + return custom_fn(x), x + + def custom_fn_bwd(x, g): + # Jacobian is: (b + diff) + 2 * c * x + # VJP is: g @ J = g @ ((b + diff) + 2 * c * x) + # So J.T @ g = + return ((b + diff).T @ g + 2 * (c.T @ g) * x,) + + custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd) + + fn = lambda x, _: custom_fn(x) + return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") + + @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 8e693b1..5c0ff2f 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -33,6 +34,9 @@ def test_adjoint(make_operator, dtype, getkey): tags = () in_size = 5 out_size = 3 + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), diff --git a/tests/test_operator.py b/tests/test_operator.py index 26737c2..dae9562 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,6 +23,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -45,6 +46,9 @@ def test_ops(make_operator, getkey, dtype): else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) @@ -137,9 +141,22 @@ def _assert_except_diag(cond_fun, operators, flip_cond): @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_linearise(dtype, getkey): - operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype)) + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + operators = list(_setup(getkey, matrix)) + vec = jr.normal(getkey(), (3,), dtype=dtype) for operator in operators: - lx.linearise(operator) + # Skip jacrev operators with complex dtype (jacrev doesn't support complex) + if ( + isinstance(operator, lx.JacobianLinearOperator) + and operator.jac == "bwd" + and dtype is jnp.complex128 + ): + continue + linearised = lx.linearise(operator) + # Actually evaluate the linearised operator to ensure it works + result = linearised.mv(vec) + expected = operator.mv(vec) + assert tree_allclose(result, expected) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) @@ -283,7 +300,12 @@ def test_is_tridiagonal(dtype, getkey): @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_tangent_as_matrix(dtype, getkey): def _list_setup(matrix): - return list(_setup(getkey, matrix)) + # Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP + return [ + op + for op in _setup(getkey, matrix) + if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd") + ] matrix = jr.normal(getkey(), (3, 3), dtype=dtype) t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype) @@ -421,25 +443,31 @@ def test_zero_pytree_as_matrix(dtype): def test_jacrev_operator(): + # Test that custom_vjp is respected. The custom backward multiplies by 3 + # instead of the true derivative (which would be 2). + # This tests that lineax uses the custom_vjp, not the true derivative. @jax.custom_vjp def f(x, _): - return dict(foo=x["bar"] + 2) + return dict(foo=x["bar"] * 2) # forward: multiply by 2 def f_fwd(x, _): return f(x, None), None def f_bwd(_, g): - return dict(bar=g["foo"] + 5), None + # Custom backward: multiply by 3 (not the true derivative 2) + # This must be linear in g for linear_transpose to work correctly. + return dict(bar=g["foo"] * 3), None f.defvjp(f_fwd, f_bwd) x = dict(bar=jnp.arange(2.0)) rev_op = lx.JacobianLinearOperator(f, x, jac="bwd") - as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]]) + # Jacobian is 3*I (from custom backward, not 2*I from true derivative) + as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]]) assert tree_allclose(rev_op.as_matrix(), as_matrix) - y = dict(bar=jnp.arange(2.0) + 1) - true_out = dict(foo=jnp.array([16.0, 17.0])) + y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2] + true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6] for op in (rev_op, lx.materialise(rev_op)): out = op.mv(y) assert tree_allclose(out, true_out) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 631ae68..4686ce8 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -20,6 +20,7 @@ from .helpers import ( construct_matrix, + make_jacrev_operator, ops, params, solvers, @@ -31,6 +32,9 @@ @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: