Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def as_matrix(self):
return self.matrix

def transpose(self):
if symmetric_tag in self.tags:
if is_symmetric(self):
return self
return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags))

Expand Down Expand Up @@ -446,7 +446,7 @@ def concat_in(struct, subpytree):
return jnp.concatenate(matrix, axis=0)

def transpose(self):
if symmetric_tag in self.tags:
if is_symmetric(self):
return self

def _transpose(struct, subtree):
Expand Down Expand Up @@ -543,15 +543,21 @@ class JacobianLinearOperator(AbstractLinearOperator):
`MatrixLinearOperator(jax.jacfwd(fn)(x))`.

The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation, specifically
`jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`.

See also [`lineax.linearise`][], which caches the primal computation, i.e.
it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
Jacobian-vector products, are computed using autodifferentiation. By default
(or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with
`jax.linear_transpose`, which works even with functions
that only define a custom VJP (via `jax.custom_vjp`) and don't support
forward-mode differentiation.

See also [`lineax.materialise`][], which materialises the whole Jacobian in
memory.

!!! tip

For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache
the primal computation, e.g. for `jac="fwd"/None` it returns
`_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
"""

fn: Callable[
Expand Down Expand Up @@ -618,10 +624,18 @@ def mv(self, vector):
if self.jac == "fwd" or self.jac is None:
_, out = jax.jvp(fn, (self.x,), (vector,))
elif self.jac == "bwd":
jac = jax.jacrev(fn)(self.x)
out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv(
vector
)
# Use VJP + linear_transpose instead of materializing full Jacobian.
# This works even for custom_vjp functions that don't have JVP rules.
_, vjp_fn = jax.vjp(fn, self.x)
if is_symmetric(self):
# For symmetric operators, J = J.T, so vjp directly gives J @ v
(out,) = vjp_fn(vector)
else:
# For non-symmetric, transpose the VJP to get J @ v from J.T @ v
transpose_vjp = jax.linear_transpose(
lambda g: vjp_fn(g)[0], self.out_structure()
)
(out,) = transpose_vjp(vector)
else:
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
return out
Expand All @@ -630,7 +644,7 @@ def as_matrix(self):
return materialise(self).as_matrix()

def transpose(self):
if symmetric_tag in self.tags:
if is_symmetric(self):
return self
fn = _NoAuxOut(_NoAuxIn(self.fn, self.args))
# Works because vjpfn is a PyTree
Expand Down Expand Up @@ -697,7 +711,7 @@ def as_matrix(self):
return materialise(self).as_matrix()

def transpose(self):
if symmetric_tag in self.tags:
if is_symmetric(self):
return self
transpose_fn = jax.linear_transpose(self.fn, self.in_structure())

Expand Down Expand Up @@ -1238,8 +1252,21 @@ def _(operator):
@linearise.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxIn(operator.fn, operator.args)
(_, aux), lin = jax.linearize(fn, operator.x)
lin = _NoAuxOut(lin)
if operator.jac == "bwd":
# For backward mode, use VJP + linear_transpose.
# This works even with custom_vjp functions that don't support forward-mode AD.
_, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True)
if is_symmetric(operator):
# For symmetric: J = J.T, so vjp directly gives J @ v
lin = _Unwrap(vjp_fn())
else:
# Transpose the VJP to get J @ v from J.T @ v
lin = _Unwrap(
jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure())
)
else: # "fwd" or None
(_, aux), lin = jax.linearize(fn, operator.x)
lin = _NoAuxOut(lin)
out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
return AuxLinearOperator(out, aux)

Expand Down
50 changes: 50 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,56 @@ def make_jac_operator(getkey, matrix, tags):
return lx.JacobianLinearOperator(fn, x, None, tags)


@_operators_append
def make_jacfwd_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd")


@_operators_append
def make_jacrev_operator(getkey, matrix, tags):
"""JacobianLinearOperator with jac='bwd' using a custom_vjp function.

This uses custom_vjp so that forward-mode autodiff is NOT available,
which tests that jac='bwd' works correctly without relying on JVP.
"""
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac

# Use custom_vjp to define a function that only has reverse-mode autodiff
@jax.custom_vjp
def custom_fn(x):
return a + (b + diff) @ x + c @ x**2

def custom_fn_fwd(x):
return custom_fn(x), x

def custom_fn_bwd(x, g):
# Jacobian is: (b + diff) + 2 * c * x
# VJP is: g @ J = g @ ((b + diff) + 2 * c * x)
# So J.T @ g =
return ((b + diff).T @ g + 2 * (c.T @ g) * x,)

custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd)

fn = lambda x, _: custom_fn(x)
return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd")


@_operators_append
def make_trivial_diagonal_operator(getkey, matrix, tags):
assert tags == lx.diagonal_tag
Expand Down
4 changes: 4 additions & 0 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -33,6 +34,9 @@ def test_adjoint(make_operator, dtype, getkey):
tags = ()
in_size = 5
out_size = 3
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
operator = make_operator(getkey, matrix, tags)
v1, v2 = (
jr.normal(getkey(), (in_size,), dtype=dtype),
Expand Down
44 changes: 36 additions & 8 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -45,6 +46,9 @@ def test_ops(make_operator, getkey, dtype):
else:
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
tags = ()
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
matrix1 = make_operator(getkey, matrix, tags)
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
scalar = jr.normal(getkey(), (), dtype=dtype)
Expand Down Expand Up @@ -137,9 +141,22 @@ def _assert_except_diag(cond_fun, operators, flip_cond):

@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_linearise(dtype, getkey):
operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype))
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
operators = list(_setup(getkey, matrix))
vec = jr.normal(getkey(), (3,), dtype=dtype)
for operator in operators:
lx.linearise(operator)
# Skip jacrev operators with complex dtype (jacrev doesn't support complex)
if (
isinstance(operator, lx.JacobianLinearOperator)
and operator.jac == "bwd"
and dtype is jnp.complex128
):
continue
linearised = lx.linearise(operator)
# Actually evaluate the linearised operator to ensure it works
result = linearised.mv(vec)
expected = operator.mv(vec)
assert tree_allclose(result, expected)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
Expand Down Expand Up @@ -283,7 +300,12 @@ def test_is_tridiagonal(dtype, getkey):
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_tangent_as_matrix(dtype, getkey):
def _list_setup(matrix):
return list(_setup(getkey, matrix))
# Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP
return [
op
for op in _setup(getkey, matrix)
if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd")
]

matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
Expand Down Expand Up @@ -421,25 +443,31 @@ def test_zero_pytree_as_matrix(dtype):


def test_jacrev_operator():
# Test that custom_vjp is respected. The custom backward multiplies by 3
# instead of the true derivative (which would be 2).
# This tests that lineax uses the custom_vjp, not the true derivative.
@jax.custom_vjp
def f(x, _):
return dict(foo=x["bar"] + 2)
return dict(foo=x["bar"] * 2) # forward: multiply by 2

def f_fwd(x, _):
return f(x, None), None

def f_bwd(_, g):
return dict(bar=g["foo"] + 5), None
# Custom backward: multiply by 3 (not the true derivative 2)
# This must be linear in g for linear_transpose to work correctly.
return dict(bar=g["foo"] * 3), None

f.defvjp(f_fwd, f_bwd)

x = dict(bar=jnp.arange(2.0))
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]])
# Jacobian is 3*I (from custom backward, not 2*I from true derivative)
as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]])
assert tree_allclose(rev_op.as_matrix(), as_matrix)

y = dict(bar=jnp.arange(2.0) + 1)
true_out = dict(foo=jnp.array([16.0, 17.0]))
y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2]
true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6]
for op in (rev_op, lx.materialise(rev_op)):
out = op.mv(y)
assert tree_allclose(out, true_out)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_well_posed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .helpers import (
construct_matrix,
make_jacrev_operator,
ops,
params,
solvers,
Expand All @@ -31,6 +32,9 @@
@pytest.mark.parametrize("ops", ops)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
Expand Down
Loading