Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,6 @@ def tridiagonal(

@tridiagonal.register(MatrixLinearOperator)
@tridiagonal.register(PyTreeLinearOperator)
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
def _(operator):
matrix = operator.as_matrix()
assert matrix.ndim == 2
Expand All @@ -1433,6 +1431,63 @@ def _(operator):
return diagonal, lower_diagonal, upper_diagonal


@tridiagonal.register(JacobianLinearOperator)
def _(operator):
with jax.ensure_compile_time_eval():
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)

coloring = jnp.arange(flat.size) % 3

basis = jnp.zeros((3, flat.size), dtype=flat.dtype)
for i in range(3):
basis = basis.at[i, i::3].set(1.0)

if operator.jac == "fwd" or operator.jac is None:
compressed_jac = jax.vmap(lambda x: operator.mv(unravel(x)))(basis)
compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac)
lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))]
upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))]
elif operator.jac == "bwd":
fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args))
_, vjp_fun = jax.vjp(fn, operator.x)
compressed_jac = jax.vmap(lambda x: vjp_fun(unravel(x)))(basis)
compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac)
upper_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))]
lower_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))]
else:
raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.")

diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))]

return diag, lower_diag, upper_diag


@tridiagonal.register(FunctionLinearOperator)
def _(operator):
with jax.ensure_compile_time_eval():
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)

coloring = jnp.arange(flat.size) % 3

basis = jnp.zeros((3, flat.size), dtype=flat.dtype)
for i in range(3):
basis = basis.at[i, i::3].set(1.0)

compressed_jac = jax.vmap(lambda x: operator.fn(unravel(x)))(basis)

compressed_jac_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(compressed_jac)

diag = compressed_jac_flat[(coloring, jnp.arange(flat.size))]
lower_diag = compressed_jac_flat[(coloring[:-1], jnp.arange(1, flat.size))]
upper_diag = compressed_jac_flat[(coloring[1:], jnp.arange(flat.size - 1))]

return diag, lower_diag, upper_diag


@tridiagonal.register(DiagonalLinearOperator)
def _(operator):
diag = diagonal(operator)
Expand Down
30 changes: 29 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,41 @@ def make_jac_operator(getkey, matrix, tags):
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags)


@_operators_append
def make_jacfwd_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd")


@_operators_append
def make_jacrev_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd")


@_operators_append
def make_trivial_diagonal_operator(getkey, matrix, tags):
assert tags == lx.diagonal_tag
Expand Down
5 changes: 5 additions & 0 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey):
tags = ()
in_size = 5
out_size = 3
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
operator = make_operator(getkey, matrix, tags)
v1, v2 = (
jr.normal(getkey(), (in_size,), dtype=dtype),
Expand Down
24 changes: 24 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype):
else:
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
tags = ()
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
matrix1 = make_operator(getkey, matrix, tags)
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
scalar = jr.normal(getkey(), (), dtype=dtype)
Expand Down Expand Up @@ -267,6 +272,25 @@ def test_is_negative_semidefinite(dtype, getkey):
)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_tridiagonal(dtype, getkey):
matrix = jr.normal(getkey(), (5, 5), dtype=dtype)
matrix_diag = jnp.diag(matrix)
matrix_lower_diag = jnp.diag(matrix, k=-1)
matrix_upper_diag = jnp.diag(matrix, k=1)
tridiag_matrix = (
jnp.diag(matrix_diag)
+ jnp.diag(matrix_lower_diag, k=-1)
+ jnp.diag(matrix_upper_diag, k=1)
)
operators = _setup(getkey, tridiag_matrix)
for operator in operators:
diag, lower_diag, upper_diag = lx.tridiagonal(operator)
assert jnp.allclose(diag, matrix_diag)
assert jnp.allclose(lower_diag, matrix_lower_diag)
assert jnp.allclose(upper_diag, matrix_upper_diag)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_tridiagonal(dtype, getkey):
diag1 = jr.normal(getkey(), (5,), dtype=dtype)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_well_posed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .helpers import (
construct_matrix,
make_jacrev_operator,
ops,
params,
solvers,
Expand All @@ -31,6 +32,10 @@
@pytest.mark.parametrize("ops", ops)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
Expand Down
Loading