diff --git a/lineax/_operator.py b/lineax/_operator.py index 06782a1..5a05606 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1422,8 +1422,6 @@ def tridiagonal( @tridiagonal.register(MatrixLinearOperator) @tridiagonal.register(PyTreeLinearOperator) -@tridiagonal.register(JacobianLinearOperator) -@tridiagonal.register(FunctionLinearOperator) def _(operator): matrix = operator.as_matrix() assert matrix.ndim == 2 @@ -1433,6 +1431,63 @@ def _(operator): return diagonal, lower_diagonal, upper_diagonal +@tridiagonal.register(JacobianLinearOperator) +def _(operator): + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + + coloring = jnp.arange(flat.size) % 3 + + basis = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + basis = basis.at[i, i::3].set(1.0) + + if operator.jac == "fwd" or operator.jac is None: + compressed_jac = jax.vmap(lambda x: operator.mv(unravel(x)))(basis) + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) + lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] + elif operator.jac == "bwd": + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) + _, vjp_fun = jax.vjp(fn, operator.x) + compressed_jac = jax.vmap(lambda x: vjp_fun(unravel(x)))(basis) + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) + upper_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + lower_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] + else: + raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") + + diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))] + + return diag, lower_diag, upper_diag + + +@tridiagonal.register(FunctionLinearOperator) +def _(operator): + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + + coloring = jnp.arange(flat.size) % 3 + + basis = jnp.zeros((3, flat.size), dtype=flat.dtype) + for i in range(3): + basis = basis.at[i, i::3].set(1.0) + + compressed_jac = jax.vmap(lambda x: operator.fn(unravel(x)))(basis) + + compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac) + + diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))] + lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))] + upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))] + + return diag, lower_diag, upper_diag + + @tridiagonal.register(DiagonalLinearOperator) def _(operator): diag = diagonal(operator) diff --git a/tests/helpers.py b/tests/helpers.py index c902a66..d8914c8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -199,13 +199,41 @@ def make_jac_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2 + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 return lx.JacobianLinearOperator(fn, x, None, tags) +@_operators_append +def make_jacfwd_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") + + +@_operators_append +def make_jacrev_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") + + @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 75fbbc1..5827efc 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey): tags = () in_size = 5 out_size = 3 + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), diff --git a/tests/test_operator.py b/tests/test_operator.py index 26737c2..843aaca 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,6 +23,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype): else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) @@ -267,6 +272,25 @@ def test_is_negative_semidefinite(dtype, getkey): ) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_tridiagonal(dtype, getkey): + matrix = jr.normal(getkey(), (5, 5), dtype=dtype) + matrix_diag = jnp.diag(matrix) + matrix_lower_diag = jnp.diag(matrix, k=-1) + matrix_upper_diag = jnp.diag(matrix, k=1) + tridiag_matrix = ( + jnp.diag(matrix_diag) + + jnp.diag(matrix_lower_diag, k=-1) + + jnp.diag(matrix_upper_diag, k=1) + ) + operators = _setup(getkey, tridiag_matrix) + for operator in operators: + diag, lower_diag, upper_diag = lx.tridiagonal(operator) + assert jnp.allclose(diag, matrix_diag) + assert jnp.allclose(lower_diag, matrix_lower_diag) + assert jnp.allclose(upper_diag, matrix_upper_diag) + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_tridiagonal(dtype, getkey): diag1 = jr.normal(getkey(), (5,), dtype=dtype) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index e377ef4..c3c0358 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -20,6 +20,7 @@ from .helpers import ( construct_matrix, + make_jacrev_operator, ops, params, solvers, @@ -31,6 +32,10 @@ @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: