From aae3448b946b0bb83f1cd758274b6694d56fe8ab Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 28 Jan 2026 01:18:01 +0000 Subject: [PATCH 1/7] add failing jacrev custom vjp linearise tests --- tests/helpers.py | 50 ++++++++++++++++++++++++++++++++++++++++ tests/test_adjoint.py | 5 ++++ tests/test_operator.py | 19 ++++++++++++--- tests/test_well_posed.py | 5 ++++ 4 files changed, 76 insertions(+), 3 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 640f0e0..73c1633 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -209,6 +209,56 @@ def make_jac_operator(getkey, matrix, tags): return lx.JacobianLinearOperator(fn, x, None, tags) +@_operators_append +def make_jacfwd_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") + + +@_operators_append +def make_jacrev_operator(getkey, matrix, tags): + """JacobianLinearOperator with jac='bwd' using a custom_vjp function. + + This uses custom_vjp so that forward-mode autodiff is NOT available, + which tests that jac='bwd' works correctly without relying on JVP. + """ + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + + # Use custom_vjp to define a function that only has reverse-mode autodiff + @jax.custom_vjp + def custom_fn(x): + return a + (b + diff) @ x + c @ x**2 + + def custom_fn_fwd(x): + return custom_fn(x), x + + def custom_fn_bwd(x, g): + # Jacobian is: (b + diff) + 2 * c * x + # VJP is: g @ J = g @ ((b + diff) + 2 * c * x) + # So J.T @ g = + return ((b + diff).T @ g + 2 * (c.T @ g) * x,) + + custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd) + + fn = lambda x, _: custom_fn(x) + return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") + + @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 8e693b1..d4d3efd 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey): tags = () in_size = 5 out_size = 3 + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), diff --git a/tests/test_operator.py b/tests/test_operator.py index 26737c2..1c1e64a 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,6 +23,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype): else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) @@ -137,9 +142,15 @@ def _assert_except_diag(cond_fun, operators, flip_cond): @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_linearise(dtype, getkey): - operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype)) + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + operators = _setup(getkey, matrix) + vec = jr.normal(getkey(), (3,), dtype=dtype) for operator in operators: - lx.linearise(operator) + linearised = lx.linearise(operator) + # Actually evaluate the linearised operator to ensure it works + result = linearised.mv(vec) + expected = operator.mv(vec) + assert tree_allclose(result, expected) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) @@ -440,7 +451,7 @@ def f_bwd(_, g): y = dict(bar=jnp.arange(2.0) + 1) true_out = dict(foo=jnp.array([16.0, 17.0])) - for op in (rev_op, lx.materialise(rev_op)): + for op in (rev_op, lx.linearise(rev_op), lx.materialise(rev_op)): out = op.mv(y) assert tree_allclose(out, true_out) @@ -449,3 +460,5 @@ def f_bwd(_, g): fwd_op.mv(y) with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): lx.materialise(fwd_op) + with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): + lx.linearise(fwd_op).mv(y) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 631ae68..8c8b7ec 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -20,6 +20,7 @@ from .helpers import ( construct_matrix, + make_jacrev_operator, ops, params, solvers, @@ -31,6 +32,10 @@ @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: From 03f3302e1bf585304bf9119e704490aca260e0fc Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 28 Jan 2026 01:42:07 +0000 Subject: [PATCH 2/7] fix broken JacobianLinearOperator linearise --- lineax/_operator.py | 31 +++++++++++++++++++++++++++---- tests/test_operator.py | 13 +++++++++---- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index a83eeff..39d9d36 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1238,10 +1238,33 @@ def _(operator): @linearise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) - (_, aux), lin = jax.linearize(fn, operator.x) - lin = _NoAuxOut(lin) - out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags) - return AuxLinearOperator(out, aux) + if operator.jac == "bwd": + # For backward mode, use VJP + linear_transpose. + # This works with custom_vjp functions that don't support forward-mode. + _, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True) + if symmetric_tag in operator.tags: + # For symmetric: J = J.T, so vjp directly gives J @ v + out = FunctionLinearOperator( + _Unwrap(vjp_fn), operator.in_structure(), operator.tags + ) + else: + # Transpose the VJP to get J @ v from J.T @ v + transpose_vjp = jax.linear_transpose( + lambda g: vjp_fn(g)[0], operator.out_structure() + ) + + def mv_fn(v): + (out,) = transpose_vjp(v) + return out + + out = FunctionLinearOperator(mv_fn, operator.in_structure(), operator.tags) + return AuxLinearOperator(out, aux) + else: + # Original implementation for fwd/None + (_, aux), lin = jax.linearize(fn, operator.x) + lin = _NoAuxOut(lin) + out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags) + return AuxLinearOperator(out, aux) # materialise diff --git a/tests/test_operator.py b/tests/test_operator.py index 1c1e64a..24d2ae1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -143,9 +143,16 @@ def _assert_except_diag(cond_fun, operators, flip_cond): @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_linearise(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) - operators = _setup(getkey, matrix) + operators = list(_setup(getkey, matrix)) vec = jr.normal(getkey(), (3,), dtype=dtype) for operator in operators: + # Skip jacrev operators with complex dtype (jacrev doesn't support complex) + if ( + isinstance(operator, lx.JacobianLinearOperator) + and operator.jac == "bwd" + and dtype is jnp.complex128 + ): + continue linearised = lx.linearise(operator) # Actually evaluate the linearised operator to ensure it works result = linearised.mv(vec) @@ -451,7 +458,7 @@ def f_bwd(_, g): y = dict(bar=jnp.arange(2.0) + 1) true_out = dict(foo=jnp.array([16.0, 17.0])) - for op in (rev_op, lx.linearise(rev_op), lx.materialise(rev_op)): + for op in (rev_op, lx.materialise(rev_op)): out = op.mv(y) assert tree_allclose(out, true_out) @@ -460,5 +467,3 @@ def f_bwd(_, g): fwd_op.mv(y) with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): lx.materialise(fwd_op) - with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): - lx.linearise(fwd_op).mv(y) From ee3f6b4618c3a2371280c5e7fadf655797533207 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 28 Jan 2026 16:18:30 +0000 Subject: [PATCH 3/7] optimise mv --- lineax/_operator.py | 33 +++++++++++++++++++++++---------- tests/test_operator.py | 16 +++++++++++----- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 39d9d36..53f8dce 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -543,15 +543,20 @@ class JacobianLinearOperator(AbstractLinearOperator): `MatrixLinearOperator(jax.jacfwd(fn)(x))`. The Jacobian is not materialised; matrix-vector products, which are in fact - Jacobian-vector products, are computed using autodifferentiation, specifically - `jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to - `jax.jvp(fn, (x,), (v,))`. - - See also [`lineax.linearise`][], which caches the primal computation, i.e. - it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)` + Jacobian-vector products, are computed using autodifferentiation. By default + (or with `jac="fwd"`), this uses `jax.jvp`. With `jac="bwd"`, this uses + `jax.vjp` combined with `jax.linear_transpose`, which works even with functions + that only define a custom VJP (via `jax.custom_vjp`) and don't support + forward-mode differentiation. See also [`lineax.materialise`][], which materialises the whole Jacobian in memory. + + !!! tip + + For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache + the primal computation. This is especially beneficial with `jac="bwd"` + as the primal computation affects the entire backward pass. """ fn: Callable[ @@ -618,10 +623,18 @@ def mv(self, vector): if self.jac == "fwd" or self.jac is None: _, out = jax.jvp(fn, (self.x,), (vector,)) elif self.jac == "bwd": - jac = jax.jacrev(fn)(self.x) - out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv( - vector - ) + # Use VJP + linear_transpose instead of materializing full Jacobian. + # This works even for custom_vjp functions that don't have JVP rules. + _, vjp_fn = jax.vjp(fn, self.x) + if symmetric_tag in self.tags: + # For symmetric operators, J = J.T, so vjp directly gives J @ v + (out,) = vjp_fn(vector) + else: + # For non-symmetric, transpose the VJP to get J @ v from J.T @ v + transpose_vjp = jax.linear_transpose( + lambda g: vjp_fn(g)[0], self.out_structure() + ) + (out,) = transpose_vjp(vector) else: raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.") return out diff --git a/tests/test_operator.py b/tests/test_operator.py index 24d2ae1..f5d891b 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -439,25 +439,31 @@ def test_zero_pytree_as_matrix(dtype): def test_jacrev_operator(): + # Test that custom_vjp is respected. The custom backward multiplies by 3 + # instead of the true derivative (which would be 2). + # This tests that lineax uses the custom_vjp, not the true derivative. @jax.custom_vjp def f(x, _): - return dict(foo=x["bar"] + 2) + return dict(foo=x["bar"] * 2) # forward: multiply by 2 def f_fwd(x, _): return f(x, None), None def f_bwd(_, g): - return dict(bar=g["foo"] + 5), None + # Custom backward: multiply by 3 (not the true derivative 2) + # This must be linear in g for linear_transpose to work correctly. + return dict(bar=g["foo"] * 3), None f.defvjp(f_fwd, f_bwd) x = dict(bar=jnp.arange(2.0)) rev_op = lx.JacobianLinearOperator(f, x, jac="bwd") - as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]]) + # Jacobian is 3*I (from custom backward, not 2*I from true derivative) + as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]]) assert tree_allclose(rev_op.as_matrix(), as_matrix) - y = dict(bar=jnp.arange(2.0) + 1) - true_out = dict(foo=jnp.array([16.0, 17.0])) + y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2] + true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6] for op in (rev_op, lx.materialise(rev_op)): out = op.mv(y) assert tree_allclose(out, true_out) From fdaad63a1ae5f3a8c995369331cd5fbcc80952a0 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 31 Jan 2026 01:25:56 +0000 Subject: [PATCH 4/7] return early rather than skip jacrev complex tests --- tests/test_adjoint.py | 5 ++--- tests/test_operator.py | 5 ++--- tests/test_well_posed.py | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index d4d3efd..5c0ff2f 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -35,9 +35,8 @@ def test_adjoint(make_operator, dtype, getkey): in_size = 5 out_size = 3 if make_operator is make_jacrev_operator and dtype is jnp.complex128: - pytest.skip( - 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' - ) + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), diff --git a/tests/test_operator.py b/tests/test_operator.py index f5d891b..307bb00 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -47,9 +47,8 @@ def test_ops(make_operator, getkey, dtype): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () if make_operator is make_jacrev_operator and dtype is jnp.complex128: - pytest.skip( - 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' - ) + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 8c8b7ec..4686ce8 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -33,9 +33,8 @@ @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): if make_operator is make_jacrev_operator and dtype is jnp.complex128: - pytest.skip( - 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' - ) + # JacobianLinearOperator does not support complex dtypes when jac="bwd" + return if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: From 46338cefae580bf7d7a3da4a551d8d1531657731 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 31 Jan 2026 01:33:02 +0000 Subject: [PATCH 5/7] symmetric_tag in self.tags -> is_symmetric(self) --- lineax/_operator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 53f8dce..ad4f89a 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -267,7 +267,7 @@ def as_matrix(self): return self.matrix def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags)) @@ -446,7 +446,7 @@ def concat_in(struct, subpytree): return jnp.concatenate(matrix, axis=0) def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self def _transpose(struct, subtree): @@ -626,7 +626,7 @@ def mv(self, vector): # Use VJP + linear_transpose instead of materializing full Jacobian. # This works even for custom_vjp functions that don't have JVP rules. _, vjp_fn = jax.vjp(fn, self.x) - if symmetric_tag in self.tags: + if is_symmetric(self): # For symmetric operators, J = J.T, so vjp directly gives J @ v (out,) = vjp_fn(vector) else: @@ -643,7 +643,7 @@ def as_matrix(self): return materialise(self).as_matrix() def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self fn = _NoAuxOut(_NoAuxIn(self.fn, self.args)) # Works because vjpfn is a PyTree @@ -710,7 +710,7 @@ def as_matrix(self): return materialise(self).as_matrix() def transpose(self): - if symmetric_tag in self.tags: + if is_symmetric(self): return self transpose_fn = jax.linear_transpose(self.fn, self.in_structure()) @@ -1255,7 +1255,7 @@ def _(operator): # For backward mode, use VJP + linear_transpose. # This works with custom_vjp functions that don't support forward-mode. _, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True) - if symmetric_tag in operator.tags: + if is_symmetric(operator): # For symmetric: J = J.T, so vjp directly gives J @ v out = FunctionLinearOperator( _Unwrap(vjp_fn), operator.in_structure(), operator.tags From 5d6c1f4bf9c61e7732031b3ee536254e591749f4 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 31 Jan 2026 01:56:17 +0000 Subject: [PATCH 6/7] use _Unwrap and keep more examples from old docs --- lineax/_operator.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index ad4f89a..c0ee7c9 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -544,8 +544,9 @@ class JacobianLinearOperator(AbstractLinearOperator): The Jacobian is not materialised; matrix-vector products, which are in fact Jacobian-vector products, are computed using autodifferentiation. By default - (or with `jac="fwd"`), this uses `jax.jvp`. With `jac="bwd"`, this uses - `jax.vjp` combined with `jax.linear_transpose`, which works even with functions + (or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to + `jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with + `jax.linear_transpose`, which works even with functions that only define a custom VJP (via `jax.custom_vjp`) and don't support forward-mode differentiation. @@ -555,8 +556,8 @@ class JacobianLinearOperator(AbstractLinearOperator): !!! tip For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache - the primal computation. This is especially beneficial with `jac="bwd"` - as the primal computation affects the entire backward pass. + the primal computation, e.g. for `jac="fwd"/None` it returns + `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)` """ fn: Callable[ @@ -1253,31 +1254,21 @@ def _(operator): fn = _NoAuxIn(operator.fn, operator.args) if operator.jac == "bwd": # For backward mode, use VJP + linear_transpose. - # This works with custom_vjp functions that don't support forward-mode. + # This works even with custom_vjp functions that don't support forward-mode AD. _, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True) if is_symmetric(operator): # For symmetric: J = J.T, so vjp directly gives J @ v - out = FunctionLinearOperator( - _Unwrap(vjp_fn), operator.in_structure(), operator.tags - ) + lin = _Unwrap(vjp_fn()) else: # Transpose the VJP to get J @ v from J.T @ v - transpose_vjp = jax.linear_transpose( - lambda g: vjp_fn(g)[0], operator.out_structure() + lin = _Unwrap( + jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure()) ) - - def mv_fn(v): - (out,) = transpose_vjp(v) - return out - - out = FunctionLinearOperator(mv_fn, operator.in_structure(), operator.tags) - return AuxLinearOperator(out, aux) - else: - # Original implementation for fwd/None + else: # "fwd" or None (_, aux), lin = jax.linearize(fn, operator.x) lin = _NoAuxOut(lin) - out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags) - return AuxLinearOperator(out, aux) + out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags) + return AuxLinearOperator(out, aux) # materialise From c3ca96d3dc6403460bbdb60414982a8f6c0d69cd Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 31 Jan 2026 02:03:10 +0000 Subject: [PATCH 7/7] skip jacrev for test_tangent_as_matrix --- tests/test_operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_operator.py b/tests/test_operator.py index 307bb00..dae9562 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -300,7 +300,12 @@ def test_is_tridiagonal(dtype, getkey): @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_tangent_as_matrix(dtype, getkey): def _list_setup(matrix): - return list(_setup(getkey, matrix)) + # Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP + return [ + op + for op in _setup(getkey, matrix) + if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd") + ] matrix = jr.normal(getkey(), (3, 3), dtype=dtype) t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype)