From d7dc4be5975e6f458b4eb8bb417d9aa1aa8bafdd Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 18 Jun 2025 00:18:09 +0100 Subject: [PATCH 1/5] Streamline tridiagonalisation of JacobianLinearOperator/FunctionLinearOperator using coloring methods --- lineax/_operator.py | 61 ++++++++++++++++++++++++++++++++++++++++-- tests/test_operator.py | 20 ++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 06782a1..d0e2f31 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1422,8 +1422,6 @@ def tridiagonal( @tridiagonal.register(MatrixLinearOperator) @tridiagonal.register(PyTreeLinearOperator) -@tridiagonal.register(JacobianLinearOperator) -@tridiagonal.register(FunctionLinearOperator) def _(operator): matrix = operator.as_matrix() assert matrix.ndim == 2 @@ -1433,6 +1431,65 @@ def _(operator): return diagonal, lower_diagonal, upper_diagonal +@tridiagonal.register(JacobianLinearOperator) +def _(operator): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + + eye_squashed = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + eye_squashed = eye_squashed.at[i, i::3].set(1.0) + + if operator.jac == "fwd" or operator.jac is None: + colors_as_pytrees = jax.vmap(lambda x: operator.mv(unravel(x)))(eye_squashed) + elif operator.jac == "bwd": + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) + _, vjp_fun = jax.vjp(fn, operator.x) + colors_as_pytrees = jax.vmap(lambda x: vjp_fun(unravel(x)))(eye_squashed) + else: + raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") + + colors_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(colors_as_pytrees) + + diag = jnp.zeros(flat.size, dtype=flat.dtype) + lower_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) + upper_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) + + for i in range(3): + diag = diag.at[i::3].set(colors_flat[i, i::3]) + lower_diag = lower_diag.at[i::3].set(colors_flat[i, i + 1 :: 3]) + upper_diag = upper_diag.at[i::3].set(colors_flat[(i + 1) % 3, i:-1:3]) + + return diag, lower_diag, upper_diag + + +@tridiagonal.register(FunctionLinearOperator) +def _(operator): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + + eye_squashed = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + eye_squashed = eye_squashed.at[i, i::3].set(1.0) + + colors_as_pytrees = jax.vmap(lambda x: operator.fn(unravel(x)))(eye_squashed) + + colors_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(colors_as_pytrees) + + diag = jnp.zeros(flat.size, dtype=flat.dtype) + lower_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) + upper_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) + + for i in range(3): + diag = diag.at[i::3].set(colors_flat[i, i::3]) + lower_diag = lower_diag.at[i::3].set(colors_flat[i, i + 1 :: 3]) + upper_diag = upper_diag.at[i::3].set(colors_flat[(i + 1) % 3, i:-1:3]) + + return diag, lower_diag, upper_diag + + @tridiagonal.register(DiagonalLinearOperator) def _(operator): diag = diagonal(operator) diff --git a/tests/test_operator.py b/tests/test_operator.py index 26737c2..2ab1744 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -267,6 +267,26 @@ def test_is_negative_semidefinite(dtype, getkey): ) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_tridiagonal(dtype, getkey): + matrix = jr.normal(getkey(), (5, 5), dtype=dtype) + matrix_diag = jnp.diag(matrix) + matrix_lower_diag = jnp.diag(matrix, k=-1) + matrix_upper_diag = jnp.diag(matrix, k=1) + tridiag_matrix = ( + jnp.diag(matrix_diag) + + jnp.diag(matrix_lower_diag, k=-1) + + jnp.diag(matrix_upper_diag, k=1) + ) + print(tridiag_matrix) + operators = _setup(getkey, tridiag_matrix) + for operator in operators: + diag, lower_diag, upper_diag = lx.tridiagonal(operator) + assert jnp.allclose(diag, matrix_diag) + assert jnp.allclose(lower_diag, matrix_lower_diag) + assert jnp.allclose(upper_diag, matrix_upper_diag) + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_tridiagonal(dtype, getkey): diag1 = jr.normal(getkey(), (5,), dtype=dtype) From 1cee5fa8630848af3cf8ff20a28c73665b3f7c6e Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 18 Jun 2025 11:24:07 +0100 Subject: [PATCH 2/5] add tests for JaobianLinearOperator with jac set to fwd/bwd --- tests/helpers.py | 30 +++++++++++++++++++++++++++++- tests/test_operator.py | 6 +++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index c902a66..d8914c8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -199,13 +199,41 @@ def make_jac_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2 + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 return lx.JacobianLinearOperator(fn, x, None, tags) +@_operators_append +def make_jacfwd_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") + + +@_operators_append +def make_jacrev_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") + + @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diff --git a/tests/test_operator.py b/tests/test_operator.py index 2ab1744..843aaca 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,6 +23,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype): else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) @@ -278,7 +283,6 @@ def test_tridiagonal(dtype, getkey): + jnp.diag(matrix_lower_diag, k=-1) + jnp.diag(matrix_upper_diag, k=1) ) - print(tridiag_matrix) operators = _setup(getkey, tridiag_matrix) for operator in operators: diag, lower_diag, upper_diag = lx.tridiagonal(operator) From b2532a3a7c17dc47dcf7d5eda426787567de303f Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 18 Jun 2025 11:25:03 +0100 Subject: [PATCH 3/5] fix jac=bwd + optimization --- lineax/_operator.py | 64 ++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index d0e2f31..5a05606 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1433,59 +1433,57 @@ def _(operator): @tridiagonal.register(JacobianLinearOperator) def _(operator): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + + coloring = jnp.arange(flat.size) % 3 - eye_squashed = jnp.zeros((3, flat.size), dtype=flat.dtype) - for i in range(3): - eye_squashed = eye_squashed.at[i, i::3].set(1.0) + basis = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + basis = basis.at[i, i::3].set(1.0) if operator.jac == "fwd" or operator.jac is None: - colors_as_pytrees = jax.vmap(lambda x: operator.mv(unravel(x)))(eye_squashed) + compressed_jac = jax.vmap(lambda x: operator.mv(unravel(x)))(basis) + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) + lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] elif operator.jac == "bwd": fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) _, vjp_fun = jax.vjp(fn, operator.x) - colors_as_pytrees = jax.vmap(lambda x: vjp_fun(unravel(x)))(eye_squashed) + compressed_jac = jax.vmap(lambda x: vjp_fun(unravel(x)))(basis) + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) + upper_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + lower_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] else: raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") - colors_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(colors_as_pytrees) - - diag = jnp.zeros(flat.size, dtype=flat.dtype) - lower_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) - upper_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) - - for i in range(3): - diag = diag.at[i::3].set(colors_flat[i, i::3]) - lower_diag = lower_diag.at[i::3].set(colors_flat[i, i + 1 :: 3]) - upper_diag = upper_diag.at[i::3].set(colors_flat[(i + 1) % 3, i:-1:3]) + diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))] return diag, lower_diag, upper_diag @tridiagonal.register(FunctionLinearOperator) def _(operator): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) - eye_squashed = jnp.zeros((3, flat.size), dtype=flat.dtype) - for i in range(3): - eye_squashed = eye_squashed.at[i, i::3].set(1.0) + coloring = jnp.arange(flat.size) % 3 - colors_as_pytrees = jax.vmap(lambda x: operator.fn(unravel(x)))(eye_squashed) + basis = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + basis = basis.at[i, i::3].set(1.0) - colors_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(colors_as_pytrees) + compressed_jac = jax.vmap(lambda x: operator.fn(unravel(x)))(basis) - diag = jnp.zeros(flat.size, dtype=flat.dtype) - lower_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) - upper_diag = jnp.zeros(flat.size - 1, dtype=flat.dtype) + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) - for i in range(3): - diag = diag.at[i::3].set(colors_flat[i, i::3]) - lower_diag = lower_diag.at[i::3].set(colors_flat[i, i + 1 :: 3]) - upper_diag = upper_diag.at[i::3].set(colors_flat[(i + 1) % 3, i:-1:3]) + diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))] + lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] return diag, lower_diag, upper_diag From d48d6a30298c6b3d9f0cf077127d49c0ce659c43 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 18 Jun 2025 12:03:45 +0100 Subject: [PATCH 4/5] skipped jacrev adjoint test for complex128 --- tests/test_adjoint.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 75fbbc1..5827efc 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey): tags = () in_size = 5 out_size = 3 + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), From b2bca6f1ed0fdd9b236ba5e1e02aaeaa706c4b53 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 18 Jun 2025 14:44:48 +0100 Subject: [PATCH 5/5] skipped jacrev small_well_posed test for complex128 (possibly overkill: skipped 80 tests but only 4 failed) --- tests/test_well_posed.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index e377ef4..c3c0358 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -20,6 +20,7 @@ from .helpers import ( construct_matrix, + make_jacrev_operator, ops, params, solvers, @@ -31,6 +32,10 @@ @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: