From 30e773668236d5dad204804b995f0647b2624793 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 16 Jun 2025 10:33:22 +0100 Subject: [PATCH 01/33] Streamline diagonalisation by avoiding computation/instantiation of entire Jacobian matrix --- lineax/_operator.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 82dd9fc..ce80365 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1363,12 +1363,44 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) @diagonal.register(PyTreeLinearOperator) -@diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) def _(operator): return jnp.diag(operator.as_matrix()) +@diagonal.register(JacobianLinearOperator) +def _(operator): + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) + + if operator.jac is None: + # Diagonal matrices are square therefore fwd should be more effiicient + jac_fwd = True + elif operator.jac == "fwd": + jac_fwd = True + elif operator.jac == "bwd": + jac_fwd = False + else: + raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") + + if jac_fwd: + _, diag_as_pytree = jax.jvp( + fn, (operator.x,), (jax.tree.map(lambda x: jnp.ones_like(x), operator.x),) + ) + else: + _, vjp_fun = jax.vjp(fn, operator.x) + diag_as_pytree = vjp_fun(jax.tree.map(lambda x: jnp.ones_like(x), operator.x)) + + return jfu.ravel_pytree(diag_as_pytree)[0] + + +@diagonal.register(FunctionLinearOperator) +def _(operator): + diag_as_pytree = operator.fn( + jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), operator.in_structure()) + ) + return jfu.ravel_pytree(diag_as_pytree)[0] + + @diagonal.register(DiagonalLinearOperator) def _(operator): diagonal, _ = jfu.ravel_pytree(operator.diagonal) From e8f474a4041264573907389eb2c7e257eb7944a2 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 16 Jun 2025 11:15:55 +0100 Subject: [PATCH 02/33] changed to using mv instead of jvp --- lineax/_operator.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index ce80365..02e879d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1370,25 +1370,16 @@ def _(operator): @diagonal.register(JacobianLinearOperator) def _(operator): - fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) - - if operator.jac is None: - # Diagonal matrices are square therefore fwd should be more effiicient - jac_fwd = True - elif operator.jac == "fwd": - jac_fwd = True - elif operator.jac == "bwd": - jac_fwd = False - else: - raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") - - if jac_fwd: - _, diag_as_pytree = jax.jvp( - fn, (operator.x,), (jax.tree.map(lambda x: jnp.ones_like(x), operator.x),) + if operator.jac == "fwd" or operator.jac is None: + diag_as_pytree = operator.mv( + jax.tree.map(lambda x: jnp.ones_like(x), operator.x) ) - else: + elif operator.jac == "bwd": + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) _, vjp_fun = jax.vjp(fn, operator.x) diag_as_pytree = vjp_fun(jax.tree.map(lambda x: jnp.ones_like(x), operator.x)) + else: + raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") return jfu.ravel_pytree(diag_as_pytree)[0] From 0e7703c08c4c0a47413fdc0d3fee4c0c03d3e536 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 16 Jun 2025 13:49:00 +0100 Subject: [PATCH 03/33] initialise test_diagonal operators with diagonal matrix --- tests/test_operator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_operator.py b/tests/test_operator.py index 347068f..f77b2c1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -160,8 +160,10 @@ def test_materialise_large(dtype, getkey): def test_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) matrix_diag = jnp.diag(matrix) - operators = _setup(getkey, matrix) + operators = _setup(getkey, jnp.diag(matrix_diag)) for operator in operators: + print(operator) + print(lx.diagonal(operator)) assert jnp.allclose(lx.diagonal(operator), matrix_diag) From 3b7d8057e4c4aa9304c85374ab7217970a27f806 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 16 Jun 2025 15:22:40 +0100 Subject: [PATCH 04/33] remove redundant original diagonal overload for FunctionLinearOperator --- lineax/_operator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 02e879d..28f9c1c 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1363,7 +1363,6 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) @diagonal.register(PyTreeLinearOperator) -@diagonal.register(FunctionLinearOperator) def _(operator): return jnp.diag(operator.as_matrix()) From 97f9e7fe9ee400eecf3fc609671f1874ae8c446a Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 19 Jun 2025 09:42:37 +0100 Subject: [PATCH 05/33] unravel instead of map --- lineax/_operator.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 53790b9..c2a05e9 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1369,14 +1369,18 @@ def _(operator): @diagonal.register(JacobianLinearOperator) def _(operator): - if operator.jac == "fwd" or operator.jac is None: - diag_as_pytree = operator.mv( - jax.tree.map(lambda x: jnp.ones_like(x), operator.x) + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) ) + basis = jnp.ones(flat.size, dtype=flat.dtype) + + if operator.jac == "fwd" or operator.jac is None: + diag_as_pytree = operator.mv(unravel(basis)) elif operator.jac == "bwd": fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) _, vjp_fun = jax.vjp(fn, operator.x) - diag_as_pytree = vjp_fun(jax.tree.map(lambda x: jnp.ones_like(x), operator.x)) + diag_as_pytree = vjp_fun(unravel(basis)) else: raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") @@ -1385,9 +1389,13 @@ def _(operator): @diagonal.register(FunctionLinearOperator) def _(operator): - diag_as_pytree = operator.fn( - jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), operator.in_structure()) - ) + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + basis = jnp.ones(flat.size, dtype=flat.dtype) + + diag_as_pytree = operator.fn(unravel(basis)) return jfu.ravel_pytree(diag_as_pytree)[0] From d48ba150df083ba1231e9efcb6ef5e3bc85143c4 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 24 Oct 2025 10:19:58 +0100 Subject: [PATCH 06/33] continue to extract diagonal if no diagonal tag --- lineax/_operator.py | 50 +++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index c2a05e9..274b25f 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1369,34 +1369,40 @@ def _(operator): @diagonal.register(JacobianLinearOperator) def _(operator): - with jax.ensure_compile_time_eval(): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - basis = jnp.ones(flat.size, dtype=flat.dtype) - - if operator.jac == "fwd" or operator.jac is None: - diag_as_pytree = operator.mv(unravel(basis)) - elif operator.jac == "bwd": - fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) - _, vjp_fun = jax.vjp(fn, operator.x) - diag_as_pytree = vjp_fun(unravel(basis)) - else: - raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + basis = jnp.ones(flat.size, dtype=flat.dtype) + + if operator.jac == "fwd" or operator.jac is None: + diag_as_pytree = operator.mv(unravel(basis)) + elif operator.jac == "bwd": + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) + _, vjp_fun = jax.vjp(fn, operator.x) + diag_as_pytree = vjp_fun(unravel(basis)) + else: + raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") - return jfu.ravel_pytree(diag_as_pytree)[0] + return jfu.ravel_pytree(diag_as_pytree)[0] + else: + return jnp.diag(operator.as_matrix()) @diagonal.register(FunctionLinearOperator) def _(operator): - with jax.ensure_compile_time_eval(): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - basis = jnp.ones(flat.size, dtype=flat.dtype) + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + basis = jnp.ones(flat.size, dtype=flat.dtype) - diag_as_pytree = operator.fn(unravel(basis)) - return jfu.ravel_pytree(diag_as_pytree)[0] + diag_as_pytree = operator.fn(unravel(basis)) + return jfu.ravel_pytree(diag_as_pytree)[0] + else: + return jnp.diag(operator.as_matrix()) @diagonal.register(DiagonalLinearOperator) From 40e2ad566d85cfc633d86e2890455fcdf1fd3ea1 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 24 Oct 2025 10:39:25 +0100 Subject: [PATCH 07/33] rmeove print statement in test_operator --- tests/test_operator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_operator.py b/tests/test_operator.py index 8eb4528..cec1f36 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -162,8 +162,6 @@ def test_diagonal(dtype, getkey): matrix_diag = jnp.diag(matrix) operators = _setup(getkey, jnp.diag(matrix_diag)) for operator in operators: - print(operator) - print(lx.diagonal(operator)) assert jnp.allclose(lx.diagonal(operator), matrix_diag) From 8e91b446aaa1a2a154a922b8644e3e4aa1f624e1 Mon Sep 17 00:00:00 2001 From: Austin Conner Date: Thu, 4 Dec 2025 19:10:55 -0500 Subject: [PATCH 08/33] Replace allow_dependent_rows and allow_dependent_columns with assume_full_rank (#158) The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues: 1) Invalid states were representable. Eg. What does it mean that dependent columns are allowed for square matrices if dependent rows are not? What does it mean that dependent rows are not allowed for matrices with more rows than columns? 2) As the functions accept operator as input, a custom solver could in principle decide its answer based on operator's dynamic value rather than only jax compilation static information regarding it, as in all the lineax defined solvers. This would prevent jax compilation and jit. Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values. --- docs/api/solvers.md | 4 +- lineax/_solve.py | 87 ++++++++++------------------------- lineax/_solver/bicgstab.py | 7 +-- lineax/_solver/cg.py | 18 ++------ lineax/_solver/cholesky.py | 7 +-- lineax/_solver/diagonal.py | 7 +-- lineax/_solver/gmres.py | 7 +-- lineax/_solver/lsmr.py | 7 +-- lineax/_solver/lu.py | 7 +-- lineax/_solver/qr.py | 18 +------- lineax/_solver/svd.py | 7 +-- lineax/_solver/triangular.py | 7 +-- lineax/_solver/tridiagonal.py | 7 +-- 13 files changed, 51 insertions(+), 139 deletions(-) diff --git a/docs/api/solvers.md b/docs/api/solvers.md index e2ef2ea..65cdaeb 100644 --- a/docs/api/solvers.md +++ b/docs/api/solvers.md @@ -9,9 +9,9 @@ If you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it w members: - init - compute - - allow_dependent_columns - - allow_dependent_rows - transpose + - conj + - assume_full_rank ::: lineax.AutoLinearSolver options: diff --git a/lineax/_solve.py b/lineax/_solve.py index bbe7fdf..7f2b4a8 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -198,16 +198,17 @@ def _linear_solve_jvp(primals, tangents): # -A'x term vec = (-(t_operator.mv(solution) ** ω)).ω vecs.append(vec) - allow_dependent_rows = solver.allow_dependent_rows(operator) - allow_dependent_columns = solver.allow_dependent_columns(operator) - if allow_dependent_rows or allow_dependent_columns: + rows, columns = operator.out_size(), operator.in_size() + assume_independent_rows = solver.assume_full_rank() and rows <= columns + assume_independent_columns = solver.assume_full_rank() and columns <= rows + if not assume_independent_rows or not assume_independent_columns: operator_conj_transpose = conj(operator).transpose() t_operator_conj_transpose = conj(t_operator).transpose() state_conj, options_conj = solver.conj(state, options) state_conj_transpose, options_conj_transpose = solver.transpose( state_conj, options_conj ) - if allow_dependent_rows: + if not assume_independent_rows: lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore tmp, _, _ = eqxi.filter_primitive_bind( @@ -221,7 +222,7 @@ def _linear_solve_jvp(primals, tangents): ) vecs.append(tmp) - if allow_dependent_columns: + if not assume_independent_columns: tmp1, _, _ = eqxi.filter_primitive_bind( linear_solve_p, operator_conj_transpose, # pyright: ignore @@ -392,56 +393,6 @@ def compute( taken. """ - @abc.abstractmethod - def allow_dependent_columns(self, operator: AbstractLinearOperator) -> bool: - """Does this method ever produce non-NaN outputs for operators with linearly - dependent columns? (Even if only sometimes.) - - If `True` then a more expensive backward pass is needed, to account for the - extra generality. - - If you do not need to autodifferentiate through a custom linear solver then you - simply define this method as - ```python - class MyLinearSolver(AbstractLinearsolver): - def allow_dependent_columns(self, operator): - raise NotImplementedError - ``` - - **Arguments:** - - - `operator`: a linear operator. - - **Returns:** - - Either `True` or `False`. - """ - - @abc.abstractmethod - def allow_dependent_rows(self, operator: AbstractLinearOperator) -> bool: - """Does this method ever produce non-NaN outputs for operators with - linearly dependent rows? (Even if only sometimes) - - If `True` then a more expensive backward pass is needed, to account for the - extra generality. - - If you do not need to autodifferentiate through a custom linear solver then you - simply define this method as - ```python - class MyLinearSolver(AbstractLinearsolver): - def allow_dependent_rows(self, operator): - raise NotImplementedError - ``` - - **Arguments:** - - - `operator`: a linear operator. - - **Returns:** - - Either `True` or `False`. - """ - @abc.abstractmethod def transpose( self, state: _SolverState, options: dict[str, Any] @@ -499,6 +450,23 @@ def conj( - The options for the conjugated operator. """ + @abc.abstractmethod + def assume_full_rank(self) -> bool: + """Does this solver assume that all operators are full rank? + + When `False`, a more expensive backward pass is needed to account for + the extra generality. In a custom linear solver, it is always safe to + return False. + + **Arguments:** + + Nothing. + + **Returns:** + + Either `True` or `False`. + """ + _qr_token = eqxi.str2jax("qr_token") _diagonal_token = eqxi.str2jax("diagonal_token") @@ -661,13 +629,8 @@ def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): conj_state = (token, conj_state) return conj_state, conj_options - def allow_dependent_columns(self, operator: AbstractLinearOperator) -> bool: - token = self._select_solver(operator) - return _lookup(token).allow_dependent_columns(operator) - - def allow_dependent_rows(self, operator: AbstractLinearOperator) -> bool: - token = self._select_solver(operator) - return _lookup(token).allow_dependent_rows(operator) + def assume_full_rank(self): + return self.well_posed is not False AutoLinearSolver.__init__.__doc__ = """**Arguments:** diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 6e0f248..e701aeb 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -218,11 +218,8 @@ def conj(self, state: _BiCGStabState, options: dict[str, Any]): operator = state return conj(operator), conj_options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True BiCGStab.__init__.__doc__ = r"""**Arguments:** diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 3ba020d..44b4a4f 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -284,11 +284,8 @@ class CG(_AbstractCG): _normal: ClassVar[bool] = False - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True class NormalCG(_AbstractCG): @@ -321,15 +318,8 @@ class NormalCG(_AbstractCG): _normal: ClassVar[bool] = True - def allow_dependent_columns(self, operator): - rows = operator.out_size() - columns = operator.in_size() - return columns > rows - - def allow_dependent_rows(self, operator): - rows = operator.out_size() - columns = operator.in_size() - return rows > columns + def assume_full_rank(self): + return True CG.__init__.__doc__ = r"""**Arguments:** diff --git a/lineax/_solver/cholesky.py b/lineax/_solver/cholesky.py index 3daba0c..a785780 100644 --- a/lineax/_solver/cholesky.py +++ b/lineax/_solver/cholesky.py @@ -85,11 +85,8 @@ def conj(self, state: _CholeskyState, options: dict[str, Any]): factor, is_nsd = state return (factor.conj(), is_nsd), options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True Cholesky.__init__.__doc__ = """**Arguments:** diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 9eb5f17..6e92295 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -101,11 +101,8 @@ def conj(self, state: _DiagonalState, options: dict[str, Any]): conj_state = conj_diag, packed_structures return conj_state, conj_options - def allow_dependent_columns(self, operator): - return not self.well_posed - - def allow_dependent_rows(self, operator): - return not self.well_posed + def assume_full_rank(self): + return self.well_posed Diagonal.__init__.__doc__ = """**Arguments**: diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index a964f8c..edf0f83 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -429,11 +429,8 @@ def conj(self, state: _GMRESState, options: dict[str, Any]): operator = state return conj(operator), conj_options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True GMRES.__init__.__doc__ = r"""**Arguments:** diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 90b6d57..a6ef8a9 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -407,11 +407,8 @@ def conj(self, state: _LSMRState, options: dict[str, Any]): conj_options = {} return conj(operator), conj_options - def allow_dependent_rows(self, operator): - return True - - def allow_dependent_columns(self, operator): - return True + def assume_full_rank(self): + return False LSMR.__init__.__doc__ = r"""**Arguments:** diff --git a/lineax/_solver/lu.py b/lineax/_solver/lu.py index b48015c..d118991 100644 --- a/lineax/_solver/lu.py +++ b/lineax/_solver/lu.py @@ -84,11 +84,8 @@ def conj( conj_options = {} return conj_state, conj_options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True LU.__init__.__doc__ = """**Arguments:** diff --git a/lineax/_solver/qr.py b/lineax/_solver/qr.py index 0d48ed0..2551e95 100644 --- a/lineax/_solver/qr.py +++ b/lineax/_solver/qr.py @@ -100,22 +100,8 @@ def conj(self, state: _QRState, options: dict[str, Any]): conj_options = {} return conj_state, conj_options - def allow_dependent_columns(self, operator): - rows = operator.out_size() - columns = operator.in_size() - # We're able to pull an efficiency trick here. - # - # As we don't use a rank-revealing implementation, then we always require that - # the operator have full rank. - # - # So if we have columns <= rows, then we know that all our columns are linearly - # independent. We can return `False` and get a computationally cheaper jvp rule. - return columns > rows - - def allow_dependent_rows(self, operator): - rows = operator.out_size() - columns = operator.in_size() - return rows > columns + def assume_full_rank(self): + return True QR.__init__.__doc__ = """**Arguments:** diff --git a/lineax/_solver/svd.py b/lineax/_solver/svd.py index a361116..a64b720 100644 --- a/lineax/_solver/svd.py +++ b/lineax/_solver/svd.py @@ -92,11 +92,8 @@ def conj(self, state: _SVDState, options: dict[str, Any]): conj_options = {} return conj_state, conj_options - def allow_dependent_columns(self, operator): - return True - - def allow_dependent_rows(self, operator): - return True + def assume_full_rank(self): + return False SVD.__init__.__doc__ = """**Arguments**: diff --git a/lineax/_solver/triangular.py b/lineax/_solver/triangular.py index 58cc349..f926295 100644 --- a/lineax/_solver/triangular.py +++ b/lineax/_solver/triangular.py @@ -103,11 +103,8 @@ def conj(self, state: _TriangularState, options: dict[str, Any]): conj_options = {} return conj_state, conj_options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True Triangular.__init__.__doc__ = """**Arguments:** diff --git a/lineax/_solver/tridiagonal.py b/lineax/_solver/tridiagonal.py index 2fa4304..7d83eb7 100644 --- a/lineax/_solver/tridiagonal.py +++ b/lineax/_solver/tridiagonal.py @@ -84,11 +84,8 @@ def conj(self, state: _TridiagonalState, options: dict[str, Any]): conj_state = (conj_diagonals, packed_structures) return conj_state, options - def allow_dependent_columns(self, operator): - return False - - def allow_dependent_rows(self, operator): - return False + def assume_full_rank(self): + return True Tridiagonal.__init__.__doc__ = """**Arguments:** From a6bacf7b44e6f1b28725e23fdc9da78b2e2028a7 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 5 Dec 2025 01:17:56 +0100 Subject: [PATCH 09/33] pyright fixes --- lineax/_operator.py | 2 +- lineax/_solver/cg.py | 4 +++- lineax/_solver/diagonal.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index f52953d..7d83693 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1819,7 +1819,7 @@ def _(operator): @transform.register(AddLinearOperator) # pyright: ignore def _(operator, transform=transform): - return transform(operator.operator1) + transform(operator.operator2) + return transform(operator.operator1) + transform(operator.operator2) # pyright: ignore @transform.register(MulLinearOperator) def _(operator, transform=transform): diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 44b4a4f..16d6ac5 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -188,7 +188,9 @@ def body_fun(value): inner_prod = tree_dot(mat_p, p) alpha = gamma / inner_prod alpha = tree_where( - jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), alpha, jnp.nan + jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), # pyright: ignore + alpha, + jnp.nan, # pyright: ignore ) diff = (alpha * p**ω).ω y = (y**ω + diff**ω).ω diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 6e92295..334e96e 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -77,7 +77,7 @@ def compute( (size,) = diag.shape rcond = resolve_rcond(self.rcond, size, size, diag.dtype) abs_diag = jnp.abs(diag) - diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) + diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) # pyright: ignore solution = vector / diag solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} From 8d96e7c1cbc04488fbfc578f23ca404e34f8ccf6 Mon Sep 17 00:00:00 2001 From: Austin Conner Date: Thu, 4 Dec 2025 21:05:17 -0500 Subject: [PATCH 10/33] Implement Normal, a solver applying an inner solver to the normal equations (#159) --- docs/api/solvers.md | 16 +-- docs/examples/no_materialisation.ipynb | 2 +- lineax/__init__.py | 1 + lineax/_solver/__init__.py | 1 + lineax/_solver/cg.py | 181 ++++++------------------- lineax/_solver/cholesky.py | 2 +- lineax/_solver/normal.py | 176 ++++++++++++++++++++++++ tests/helpers.py | 8 +- tests/test_adjoint.py | 2 +- tests/test_singular.py | 82 ++++++----- tests/test_solve.py | 2 +- tests/test_well_posed.py | 2 +- 12 files changed, 291 insertions(+), 184 deletions(-) create mode 100644 lineax/_solver/normal.py diff --git a/docs/api/solvers.md b/docs/api/solvers.md index 65cdaeb..623ca16 100644 --- a/docs/api/solvers.md +++ b/docs/api/solvers.md @@ -42,9 +42,16 @@ These are capable of solving ill-posed linear problems. members: - __init__ +--- + +::: lineax.Normal + options: + members: + - __init__ + !!! info - In addition to these, `lineax.Diagonal(well_posed=False)` and [`lineax.NormalCG`][] (below) also support ill-posed problems. + In addition to these, `lineax.Diagonal(well_posed=False)` (below) also supports ill-posed problems. ## Structure-exploiting solvers @@ -95,13 +102,6 @@ These solvers use only matrix-vector products, and do not require instantiating --- -::: lineax.NormalCG - options: - members: - - __init__ - ---- - ::: lineax.BiCGStab options: members: diff --git a/docs/examples/no_materialisation.ipynb b/docs/examples/no_materialisation.ipynb index ec49136..2a35464 100644 --- a/docs/examples/no_materialisation.ipynb +++ b/docs/examples/no_materialisation.ipynb @@ -54,7 +54,7 @@ "y = jnp.array([1.0, 2.0, 3.0])\n", "operator = lx.JacobianLinearOperator(f, y, args=None)\n", "vector = f(y, args=None)\n", - "solver = lx.NormalCG(rtol=1e-6, atol=1e-6)\n", + "solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n", "solution = lx.linear_solve(operator, vector, solver)" ] }, diff --git a/lineax/__init__.py b/lineax/__init__.py index bb98cf8..0367041 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -60,6 +60,7 @@ GMRES as GMRES, LSMR as LSMR, LU as LU, + Normal as Normal, NormalCG as NormalCG, QR as QR, SVD as SVD, diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index a294565..2cee02c 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -19,6 +19,7 @@ from .gmres import GMRES as GMRES from .lsmr import LSMR as LSMR from .lu import LU as LU +from .normal import Normal as Normal from .qr import QR as QR from .svd import SVD as SVD from .triangular import Triangular as Triangular diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 16d6ac5..ddc8c13 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, ClassVar, TYPE_CHECKING, TypeAlias +from typing import Any, TypeAlias import equinox.internal as eqxi import jax @@ -23,12 +23,6 @@ from equinox.internal import ω from jaxtyping import Array, PyTree, Scalar - -if TYPE_CHECKING: - from typing import ClassVar as AbstractClassVar -else: - from equinox.internal import AbstractClassVar - from .._misc import resolve_rcond, structure_equal, tree_where from .._norm import max_norm, tree_dot from .._operator import ( @@ -41,6 +35,7 @@ from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import preconditioner_and_y0 +from .normal import Normal _CGState: TypeAlias = tuple[AbstractLinearOperator, bool] @@ -48,17 +43,34 @@ # TODO(kidger): this is pretty slow to compile. # - CG evaluates `operator.mv` three times. -# - Normal CG evaluates `operator.mv` seven (!) times. # Possibly this can be cheapened a bit somehow? -class _AbstractCG(AbstractLinearSolver[_CGState]): +class CG(AbstractLinearSolver[_CGState]): + """Conjugate gradient solver for linear systems. + + The operator should be positive or negative definite. + + Equivalent to `scipy.sparse.linalg.cg`. + + This supports the following `options` (as passed to + `lx.linear_solve(..., options=...)`). + + - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][] + to be used as preconditioner. Defaults to + [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning, + so it is the preconditioned residual that is minimized, though the actual + termination criteria uses the un-preconditioned residual. + + - `y0`: The initial estimate of the solution to the linear system. Defaults to all + zeros. + + """ + rtol: float atol: float norm: Callable[[PyTree], Scalar] = max_norm stabilise_every: int | None = 10 max_steps: int | None = None - _normal: AbstractClassVar[bool] - def __check_init__(self): if isinstance(self.rtol, (int, float)) and self.rtol < 0: raise ValueError("Tolerances must be non-negative.") @@ -75,18 +87,18 @@ def __check_init__(self): def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options is_nsd = is_negative_semidefinite(operator) - if not self._normal: - if not structure_equal(operator.in_structure(), operator.out_structure()): - raise ValueError( - "`CG()` may only be used for linear solves with square matrices." - ) - if not (is_positive_semidefinite(operator) | is_nsd): - raise ValueError( - "`CG()` may only be used for positive " - "or negative definite linear operators" - ) - if is_nsd: - operator = -operator + if not structure_equal(operator.in_structure(), operator.out_structure()): + raise ValueError( + "`CG()` may only be used for linear solves with square matrices." + ) + if not (is_positive_semidefinite(operator) | is_nsd): + raise ValueError( + "`CG()` may only be used for positive " + "or negative definite linear operators" + ) + if is_nsd: + operator = -operator + operator = linearise(operator) return operator, is_nsd # This differs from jax.scipy.sparse.linalg.cg in: @@ -103,46 +115,16 @@ def compute( ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: operator, is_nsd = state preconditioner, y0 = preconditioner_and_y0(operator, vector, options) - if self._normal: - # Linearise if JacobianLinearOperator, to avoid computing the forward - # pass separately for mv and transpose_mv. - # This choice is "fast by default", even at the expense of memory. - # If a downstream user wants to avoid this then they can call - # ``` - # linear_solve( - # conj(operator.T) @ operator, operator.mv(b), solver=CG() - # ) - # ``` - # directly. - operator = linearise(operator) - preconditioner = linearise(preconditioner) - - _mv = operator.mv - _transpose_mv = conj(operator.transpose()).mv - _pmv = preconditioner.mv - _transpose_pmv = conj(preconditioner.transpose()).mv - - def mv(vector: PyTree) -> PyTree: - return _transpose_mv(_mv(vector)) - - def psolve(vector: PyTree) -> PyTree: - return _pmv(_transpose_pmv(vector)) - - vector = _transpose_mv(vector) - else: - if not is_positive_semidefinite(preconditioner): - raise ValueError("The preconditioner must be positive definite.") - mv = operator.mv - psolve = preconditioner.mv - + if not is_positive_semidefinite(preconditioner): + raise ValueError("The preconditioner must be positive definite.") leaves, _ = jtu.tree_flatten(vector) size = sum(leaf.size for leaf in leaves) if self.max_steps is None: max_steps = 10 * size # Copied from SciPy! else: max_steps = self.max_steps - r0 = (vector**ω - mv(y0) ** ω).ω - p0 = psolve(r0) + r0 = (vector**ω - operator.mv(y0) ** ω).ω + p0 = preconditioner.mv(r0) gamma0 = tree_dot(p0, r0) rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves)) initial_value = ( @@ -184,7 +166,7 @@ def cond_fun(value): def body_fun(value): _, y, r, p, gamma, step = value - mat_p = mv(p) + mat_p = operator.mv(p) inner_prod = tree_dot(mat_p, p) alpha = gamma / inner_prod alpha = tree_where( @@ -201,7 +183,7 @@ def body_fun(value): # We compute the residual the "expensive" way every now and again, so as to # correct numerical rounding errors. def stable_r(): - return (vector**ω - mv(y) ** ω).ω + return (vector**ω - operator.mv(y) ** ω).ω def cheap_r(): return (r**ω - alpha * mat_p**ω).ω @@ -215,7 +197,7 @@ def cheap_r(): stable_step = eqxi.nonbatchable(stable_step) r = lax.cond(stable_step, stable_r, cheap_r) - z = psolve(r) + z = preconditioner.mv(r) gamma_prev = gamma gamma = tree_dot(z, r) beta = gamma / gamma_prev @@ -239,7 +221,7 @@ def cheap_r(): RESULTS.successful, ) - if is_nsd and not self._normal: + if is_nsd: solution = -(solution**ω).ω stats = {"num_steps": num_steps, "max_steps": self.max_steps} return solution, result, stats @@ -260,66 +242,6 @@ def conj(self, state: _CGState, options: dict[str, Any]): conj_state = conj(psd_op), is_nsd return conj_state, conj_options - -class CG(_AbstractCG): - """Conjugate gradient solver for linear systems. - - The operator should be positive or negative definite. - - Equivalent to `scipy.sparse.linalg.cg`. - - This supports the following `options` (as passed to - `lx.linear_solve(..., options=...)`). - - - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][] - to be used as preconditioner. Defaults to - [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning, - so it is the preconditioned residual that is minimized, though the actual - termination criteria uses the un-preconditioned residual. - - `y0`: The initial estimate of the solution to the linear system. Defaults to all - zeros. - - !!! info - - - """ - - _normal: ClassVar[bool] = False - - def assume_full_rank(self): - return True - - -class NormalCG(_AbstractCG): - """Conjugate gradient applied to the normal equations: - - `A^T A = A^T b` - - of a system of linear equations. Note that this squares the condition - number, so it is not recommended. This is a fast but potentially inaccurate - method, especially in 32 bit floating point precision. - - This can handle nonsquare operators provided they are full-rank. - - This supports the following `options` (as passed to - `lx.linear_solve(..., options=...)`). - - - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as - preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. Note that - the preconditioner should approximate the inverse of `A`, not the inverse of - `A^T A`. This method uses left preconditioning, so it is the preconditioned - residual that is minimized, though the actual termination criteria uses - the un-preconditioned residual. - - `y0`: The initial estimate of the solution to the linear system. Defaults to all - zeros. - - !!! info - - - """ - - _normal: ClassVar[bool] = True - def assume_full_rank(self): return True @@ -341,19 +263,6 @@ def assume_full_rank(self): than this are required, then the solve is halted with a failure. """ -NormalCG.__init__.__doc__ = r"""**Arguments:** -- `rtol`: Relative tolerance for terminating solve. -- `atol`: Absolute tolerance for terminating solve. -- `norm`: The norm to use when computing whether the error falls within the tolerance. - Defaults to the max norm. -- `stabilise_every`: The conjugate gradient is an iterative method that produces - candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$ - is small enough. For computational efficiency, the values $r_i$ are computed using - other internal quantities, and not by directly evaluating the formula above. - However, this computation of $r_i$ is susceptible to drift due to limited - floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed - directly using the formula above, in order to stabilise the computation. -- `max_steps`: The maximum number of iterations to run the solver for. If more steps - than this are required, then the solve is halted with a failure. -""" +def NormalCG(*args): + return Normal(CG(*args)) diff --git a/lineax/_solver/cholesky.py b/lineax/_solver/cholesky.py index a785780..58d3d29 100644 --- a/lineax/_solver/cholesky.py +++ b/lineax/_solver/cholesky.py @@ -57,7 +57,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): if is_nsd: matrix = -matrix factor, lower = jsp.linalg.cho_factor(matrix) - # Fix lower triangular for simplicity. + # Fix upper triangular for simplicity. assert lower is False return factor, is_nsd diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py new file mode 100644 index 0000000..cacd837 --- /dev/null +++ b/lineax/_solver/normal.py @@ -0,0 +1,176 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import copy +from typing import Any, TypeVar + +from jaxtyping import Array, PyTree + +from .._operator import ( + conj, + TaggedLinearOperator, +) +from .._solution import RESULTS +from .._solve import AbstractLinearOperator, AbstractLinearSolver +from .._tags import positive_semidefinite_tag + + +_InnerSolverState = TypeVar("_InnerSolverState") + + +def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool): + preconditioner = options.get("preconditioner") + y0 = options.get("y0") + inner_options = copy(options) + del options + if preconditioner is not None: + if tall: + inner_options["preconditioner"] = TaggedLinearOperator( + preconditioner @ conj(preconditioner.transpose()), + positive_semidefinite_tag, + ) + else: + inner_options["preconditioner"] = TaggedLinearOperator( + conj(preconditioner.transpose()) @ preconditioner, + positive_semidefinite_tag, + ) + if preconditioner is not None and y0 is not None and not tall: + inner_options["y0"] = conj(preconditioner.transpose()).mv(y0) + return inner_options + + +class Normal( + AbstractLinearSolver[ + tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]] + ] +): + """Wrapper for an inner solver of positive (semi)definite systems. The + wrapped solver handles possibly nonsquare systems $Ax = b$ by applying the + inner solver to the normal equations + + $A^* A x = A^* b$ + + if $m \\ge n$, otherwise + + $A A^* y = b$, + + where $x = A^* y$. + + If the inner solver solves systems with positive definite $A$, the wrapped + solver solves systems with full rank $A$. + + If the inner solver solves systems with positive semidefinite $A$, the + wrapped solver solves systems with arbitrary, possibly rank deficient, $A$. + + Note that this squares the condition number, so applying this method to an + iterative inner solver may result in slow convergence and high sensitivity + to roundoff error. In this case it may be advantageous to choose an + appropriate preconditioner or initial solution guess for the problem. + + This wrapper adjusts the following `options` before passing to the inner + operator (as passed to `lx.linear_solve(..., options=...)`). + + - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as + preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This + should be an approximation of the (pseudo)inverse of $A$. When passed + to the inner solver, the preconditioner $M$ is replaced by $M M^*$ and + $M^* M$ in the first and second versions of the normal equations, + respectively. + + - `y0`: An initial estimate of the solution of the linear system $Ax = b$. + Defaults to all zeros. In the second version of the normal equations, + $y_0$ is replaced with $M^* y_0$, where $M$ is the given outer + preconditioner. + + !!! Info + + Good choices of inner solvers are the direct [`lineax.Cholesky`][] and + the iterative [`lineax.CG`][]. + + """ + + inner_solver: AbstractLinearSolver[_InnerSolverState] + + def init(self, operator, options): + tall = operator.out_size() >= operator.in_size() + if tall: + inner_operator = conj(operator.transpose()) @ operator + else: + inner_operator = operator @ conj(operator.transpose()) + inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag) + inner_options = normal_preconditioner_and_y0(options, tall) + inner_state = self.inner_solver.init(inner_operator, inner_options) + operator_conj_transpose = conj(operator.transpose()) + return inner_state, tall, operator_conj_transpose, inner_options + + def compute( + self, + state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + vector: PyTree[Array], + options: dict[str, Any], + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + inner_state, tall, operator_conj_transpose, inner_options = state + del state, options + if tall: + vector = operator_conj_transpose.mv(vector) + solution, result, extra_stats = self.inner_solver.compute( + inner_state, vector, inner_options + ) + if not tall: + solution = operator_conj_transpose.mv(solution) + return solution, result, extra_stats + + def transpose( + self, + state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + options: dict[str, Any], + ): + inner_state, tall, operator_conj_transpose, inner_options = state + inner_state_conj, inner_options = self.inner_solver.conj( + inner_state, inner_options + ) + state_transpose = ( + inner_state_conj, + not tall, + operator_conj_transpose.transpose(), + inner_options, + ) + return state_transpose, options + + def conj( + self, + state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + options: dict[str, Any], + ): + inner_state, tall, operator_conj_transpose, inner_options = state + inner_state_conj, inner_options = self.inner_solver.conj( + inner_state, inner_options + ) + state_conj = ( + inner_state_conj, + tall, + conj(operator_conj_transpose), + inner_options, + ) + return state_conj, options + + def assume_full_rank(self): + return self.inner_solver.assume_full_rank() + + +Normal.__init__.__doc__ = """**Arguments:** + +- `inner_solver`: The solver to wrap. It should support solving positive + definite systems or positive semidefinite systems +""" diff --git a/tests/helpers.py b/tests/helpers.py index c427bc0..4462300 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,7 +57,7 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): - if isinstance(solver, lx.NormalCG): + if isinstance(solver, lx.Normal): cond_cutoff = math.sqrt(1000) else: cond_cutoff = 1000 @@ -109,13 +109,13 @@ def construct_poisson_matrix(size, dtype=jnp.float64): (lx.SVD(), (), True), (lx.BiCGStab(rtol=tol, atol=tol), (), False), (lx.GMRES(rtol=tol, atol=tol), (), False), - (lx.NormalCG(rtol=tol, atol=tol), (), False), (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag, False), (lx.CG(rtol=tol, atol=tol), lx.negative_semidefinite_tag, False), - (lx.NormalCG(rtol=tol, atol=tol), lx.negative_semidefinite_tag, False), + (lx.Normal(lx.CG(rtol=tol, atol=tol)), (), False), + (lx.LSMR(atol=tol, rtol=tol), (), True), (lx.Cholesky(), lx.positive_semidefinite_tag, False), (lx.Cholesky(), lx.negative_semidefinite_tag, False), - (lx.LSMR(atol=tol, rtol=tol), (), True), + (lx.Normal(lx.Cholesky()), (), False), ] solvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse] solvers = [a for a, _, _ in solvers_tags_pseudoinverse] diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 984b84d..8e693b1 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -84,7 +84,7 @@ def fn(y): # complicated, see gh #160 lx.GMRES(tol, tol, max_steps=4, restart=1), lx.BiCGStab(tol, tol, max_steps=3), - lx.NormalCG(tol, tol, max_steps=4), + lx.Normal(lx.CG(tol, tol, max_steps=4)), lx.CG(tol, tol, max_steps=3), ], ) diff --git a/tests/test_singular.py b/tests/test_singular.py index 26bd209..04c18ec 100644 --- a/tests/test_singular.py +++ b/tests/test_singular.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import functools as ft import equinox as eqx @@ -107,6 +106,8 @@ def test_gmres_stagnation_or_breakdown(getkey, dtype): lx.QR(), lx.SVD(), lx.LSMR(atol=tol, rtol=tol), + lx.Normal(lx.Cholesky()), + lx.Normal(lx.SVD()), ), ) def test_nonsquare_pytree_operator1(solver): @@ -128,6 +129,8 @@ def test_nonsquare_pytree_operator1(solver): lx.QR(), lx.SVD(), lx.LSMR(atol=tol, rtol=tol), + lx.Normal(lx.Cholesky()), + lx.Normal(lx.SVD()), ), ) def test_nonsquare_pytree_operator2(solver): @@ -142,11 +145,21 @@ def test_nonsquare_pytree_operator2(solver): assert tree_allclose(out, true_out) +@pytest.mark.parametrize( + "solver", + ( + lx.AutoLinearSolver(well_posed=None), + lx.QR(), + lx.SVD(), + lx.Normal(lx.Cholesky()), + lx.Normal(lx.SVD()), + ), +) @pytest.mark.parametrize("full_rank", (True, False)) @pytest.mark.parametrize("jvp", (False, True)) @pytest.mark.parametrize("wide", (False, True)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) -def test_qr_nonsquare_mat_vec(full_rank, jvp, wide, dtype, getkey): +def test_nonsquare_mat_vec(solver, full_rank, jvp, wide, dtype, getkey): if wide: out_size = 3 in_size = 6 @@ -154,39 +167,50 @@ def test_qr_nonsquare_mat_vec(full_rank, jvp, wide, dtype, getkey): out_size = 6 in_size = 3 matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) - if full_rank: - context = contextlib.nullcontext() - else: - context = pytest.raises(Exception) - if wide: - matrix = matrix.at[:, 2:].set(0) - else: - matrix = matrix.at[2:, :].set(0) + if not full_rank: + if solver.assume_full_rank(): + # There is nothing to test. + return + # nontrivial rank 2 sparsity pattern + matrix = matrix.at[1:, 1:].set(0) vector = jr.normal(getkey(), (out_size,), dtype=dtype) lx_solve = lambda mat, vec: lx.linear_solve( - lx.MatrixLinearOperator(mat), vec, lx.QR() + lx.MatrixLinearOperator(mat), vec, solver ).value jnp_solve = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore if jvp: lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve)) jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve)) t_matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) + if not full_rank: + # t_matrix must be chosen tangent to the manifold of rank 2 + # matrices at matrix. A simple way to achieve this is to make the + # same restriction as we did to matrix + t_matrix = t_matrix.at[1:, 1:].set(0) t_vector = jr.normal(getkey(), (out_size,), dtype=dtype) args = ((matrix, vector), (t_matrix, t_vector)) else: args = (matrix, vector) - with context: - x = lx_solve(*args) # pyright: ignore - if full_rank: - true_x = jnp_solve(*args) - assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) + x = lx_solve(*args) # pyright: ignore + true_x = jnp_solve(*args) + assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) +@pytest.mark.parametrize( + "solver", + ( + lx.AutoLinearSolver(well_posed=None), + lx.QR(), + lx.SVD(), + lx.Normal(lx.Cholesky()), + lx.Normal(lx.SVD()), + ), +) @pytest.mark.parametrize("full_rank", (True, False)) @pytest.mark.parametrize("jvp", (False, True)) @pytest.mark.parametrize("wide", (False, True)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) -def test_qr_nonsquare_vec(full_rank, jvp, wide, dtype, getkey): +def test_nonsquare_vec(solver, full_rank, jvp, wide, dtype, getkey): if wide: out_size = 3 in_size = 6 @@ -194,17 +218,15 @@ def test_qr_nonsquare_vec(full_rank, jvp, wide, dtype, getkey): out_size = 6 in_size = 3 matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) - if full_rank: - context = contextlib.nullcontext() - else: - context = pytest.raises(Exception) - if wide: - matrix = matrix.at[:, 2:].set(0) - else: - matrix = matrix.at[2:, :].set(0) + if not full_rank: + if solver.assume_full_rank(): + # There is nothing to test. + return + # nontrivial rank 2 sparsity pattern + matrix = matrix.at[1:, 1:].set(0) vector = jr.normal(getkey(), (out_size,), dtype=dtype) lx_solve = lambda vec: lx.linear_solve( - lx.MatrixLinearOperator(matrix), vec, lx.QR() + lx.MatrixLinearOperator(matrix), vec, solver ).value jnp_solve = lambda vec: jnp.linalg.lstsq(matrix, vec)[0] # pyright: ignore if jvp: @@ -214,11 +236,9 @@ def test_qr_nonsquare_vec(full_rank, jvp, wide, dtype, getkey): args = ((vector,), (t_vector,)) else: args = (vector,) - with context: - x = lx_solve(*args) # pyright: ignore - if full_rank: - true_x = jnp_solve(*args) - assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) + x = lx_solve(*args) # pyright: ignore + true_x = jnp_solve(*args) + assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) _iterative_solvers = ( diff --git a/tests/test_solve.py b/tests/test_solve.py index 0b71429..0cb9f2f 100644 --- a/tests/test_solve.py +++ b/tests/test_solve.py @@ -175,7 +175,7 @@ def to_grad(x): "solver", ( lx.CG(0.0, 0.0, max_steps=2), - lx.NormalCG(0.0, 0.0, max_steps=2), + lx.Normal(lx.CG(0.0, 0.0, max_steps=2)), lx.BiCGStab(0.0, 0.0, max_steps=2), lx.GMRES(0.0, 0.0, max_steps=2), lx.LSMR(0.0, 0.0, max_steps=2), diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index e377ef4..631ae68 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -53,7 +53,7 @@ def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): def test_pytree_wellposed(solver, getkey, dtype): if not isinstance( solver, - (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG, lx.NormalCG), + (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG), ): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 From 71db7e0cea4b77cf8d97f50b4ceb96dac23d1d6f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:43:38 +0000 Subject: [PATCH 11/33] Tidy up results. (#170) --- lineax/_solver/bicgstab.py | 10 +++++----- lineax/_solver/cg.py | 14 ++++++-------- lineax/_solver/gmres.py | 11 ++++++----- lineax/_solver/lsmr.py | 14 +++++++------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index e701aeb..7208745 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -187,14 +187,14 @@ def body_fun(carry): if self.max_steps is None: result = RESULTS.where( - (num_steps == max_steps), RESULTS.singular, RESULTS.successful + num_steps == max_steps, RESULTS.singular, RESULTS.successful ) - else: + elif has_scale: result = RESULTS.where( - (num_steps == self.max_steps), - RESULTS.max_steps_reached if has_scale else RESULTS.successful, - RESULTS.successful, + num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) + else: + result = RESULTS.successful # breakdown is only an issue if we did not converge breakdown = breakdown_occurred(omega, alpha, rho) & not_converged( residual, diff, solution diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index ddc8c13..2541586 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -208,18 +208,16 @@ def cheap_r(): cond_fun, body_fun, initial_value ) - if (self.max_steps is None) or (max_steps < self.max_steps): + if self.max_steps is None: result = RESULTS.where( - num_steps == max_steps, - RESULTS.singular, - RESULTS.successful, + num_steps == max_steps, RESULTS.singular, RESULTS.successful ) - else: + elif has_scale: result = RESULTS.where( - num_steps == max_steps, - RESULTS.max_steps_reached if has_scale else RESULTS.successful, - RESULTS.successful, + num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) + else: + result = RESULTS.successful if is_nsd: solution = -(solution**ω).ω diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index edf0f83..b18a80a 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -220,14 +220,15 @@ def body_fun(carry): if self.max_steps is None: result = RESULTS.where( - (num_steps == max_steps), RESULTS.singular, RESULTS.successful + num_steps == max_steps, RESULTS.singular, RESULTS.successful ) - else: + elif has_scale: result = RESULTS.where( - (num_steps == self.max_steps), - RESULTS.max_steps_reached if has_scale else RESULTS.successful, - RESULTS.successful, + num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) + else: + result = RESULTS.successful + result = RESULTS.where( stagnation_counter >= self.stagnation_iters, RESULTS.stagnation, result ) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index a6ef8a9..7496e77 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -327,19 +327,19 @@ def beta_zero(alpha, beta, u, v): "cond_A": loop_state["condA"], "norm_x": self.norm(loop_state["x"]), } - if (self.max_steps is None) or (max_steps < self.max_steps): + + if self.max_steps is None: result = RESULTS.where( - loop_state["itn"] == max_steps, - RESULTS.singular, - RESULTS.successful, + loop_state["itn"] == max_steps, RESULTS.singular, RESULTS.successful ) - else: + elif has_scale: result = RESULTS.where( loop_state["itn"] == max_steps, - RESULTS.max_steps_reached if has_scale else RESULTS.successful, + RESULTS.max_steps_reached, RESULTS.successful, ) - + else: + result = RESULTS.successful result = RESULTS.where(loop_state["istop"] < 3, RESULTS.successful, result) result = RESULTS.where(loop_state["istop"] == 3, RESULTS.conlim, result) From 9d83182954b198e7306b8689e0ea2eeff2db3c54 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:43:52 +0000 Subject: [PATCH 12/33] Update infra (#182) --- .github/workflows/release.yml | 6 +++--- .github/workflows/run_tests.yml | 29 +++++++++---------------- .gitignore | 1 + .pre-commit-config.yaml | 4 ++-- CONTRIBUTING.md | 38 ++++++--------------------------- lineax/_solver/gmres.py | 2 +- pyproject.toml | 32 +++++++++++++++++++++------ tests/requirements.txt | 5 ----- 8 files changed, 48 insertions(+), 69 deletions(-) delete mode 100644 tests/requirements.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 21142eb..70c85a0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,14 +10,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Release - uses: patrick-kidger/action_update_python_project@v6 + uses: patrick-kidger/action_update_python_project@v8 with: python-version: "3.11" test-script: | cp -r ${{ github.workspace }}/tests ./tests cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml - python -m pip install -r ./tests/requirements.txt - python -m tests + uv sync --extra tests --no-install-project --inexact + uv run --no-sync pytest pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger github-token: ${{ github.token }} diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 9d1d270..40a956d 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -1,17 +1,3 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - name: Run tests on: @@ -21,7 +7,8 @@ jobs: run-test: strategy: matrix: - python-version: [ "3.10", "3.12" ] + # must match the `language_version` in `.pre-commit-config.yaml` + python-version: [ 3.11 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} @@ -37,12 +24,16 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r ./tests/requirements.txt + python -m pip install '.[dev,docs,tests]' - name: Checks with pre-commit - uses: pre-commit/action@v3.0.1 + run: | + pre-commit run --all-files - name: Test with pytest run: | - python -m pip install . - python -m tests + pytest + + - name: Check that documentation can be built. + run: | + mkdocs build diff --git a/.gitignore b/.gitignore index e42ffd9..0c5b718 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ examples/data .pymon .idea .venv +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f50acda..0888df4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: files: ^pyproject\.toml$ additional_dependencies: ["toml-sort==0.23.1"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.13.0 hooks: - id: ruff-format types_or: [ python, pyi, jupyter ] @@ -30,7 +30,7 @@ repos: types_or: [ python, pyi, jupyter ] args: [ --fix ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.330 + rev: v1.1.406 hooks: - id: pyright additional_dependencies: ["jax", "equinox", "pytest"] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3c0dbe8..fed9667 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started. First fork the library on GitHub. -Then clone and install the library in development mode: +Then clone and install the library: ```bash git clone https://github.com/your-username-here/lineax.git cd lineax -pip install -e . +pip install -e '.[dev]' +pre-commit install # `pre-commit` is installed by `pip` on the previous line ``` -Then install the pre-commit hook: - -```bash -pip install pre-commit -pre-commit install -``` - -These hooks use Black to format the code, and ruff to lint it. - --- **If you're making changes to the code:** @@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass: ```bash -pip install -r tests/requirements.txt -python -m tests +pip install -e '.[tests]' +pytest # `pytest` is installed by `pip` on the previous line. ``` Then push your changes back to your fork of the repository: @@ -53,26 +45,8 @@ Finally, open a pull request on GitHub! Make your changes. You can then build the documentation by doing ```bash -pip install -r docs/requirements.txt +pip install -e '.[docs]' mkdocs serve ``` -Then doing `Control-C`, and running: -``` -mkdocs serve -``` -(So you run `mkdocs serve` twice.) You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement (CLA). You (or your employer) retain the copyright to your -contribution; this simply gives us permission to use and redistribute your -contributions as part of the project. Head over to - to see your current agreements on file or -to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index b18a80a..ccf7e97 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -410,7 +410,7 @@ def _normalise( eps = jnp.finfo(norm.dtype).eps else: eps = jnp.astype(eps, norm.dtype) - breakdown = norm < eps + breakdown = norm < eps # pyright: ignore safe_norm = jnp.where(breakdown, jnp.inf, norm) with jax.numpy_dtype_promotion("standard"): x_normalised = (x**ω / safe_norm).ω diff --git a/pyproject.toml b/pyproject.toml index da5394e..aa6f053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,15 +31,27 @@ urls = {repository = "https://github.com/google/lineax"} version = "0.0.8" [project.optional-dependencies] +dev = [ + "pre-commit" +] docs = [ - "hippogriffe==0.2.0", + "hippogriffe==0.2.2", + "griffe==1.7.3", "mkdocs==1.6.1", "mkdocs-include-exclude-files==0.1.0", - "mkdocs-ipynb==0.1.0", + "mkdocs-ipynb==0.1.1", "mkdocs-material==9.6.7", - "mkdocstrings[python]==0.28.3", + "mkdocstrings==0.28.3", + "mkdocstrings-python==1.16.8", "pymdown-extensions==10.14.3" ] +tests = [ + "beartype", + "equinox", + "pytest", + "pytest-xdist", + "jaxlib" +] [tool.hatch.build] include = ["lineax/*"] @@ -53,10 +65,6 @@ addopts = "--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeC [tool.ruff] extend-include = ["*.ipynb"] -fixable = ["I001", "F401", "UP"] -ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] -ignore-init-module-imports = true -select = ["E", "F", "I001", "UP"] src = [] [tool.ruff.isort] @@ -64,3 +72,13 @@ combine-as-imports = true extra-standard-library = ["typing_extensions"] lines-after-imports = 2 order-by-type = false + +[tool.ruff.lint] +fixable = ["I001", "F401", "UP"] +ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] +select = ["E", "F", "I001", "UP"] + +[tool.ruff.lint.flake8-import-conventions.extend-aliases] +"collections" = "co" +"functools" = "ft" +"itertools" = "it" diff --git a/tests/requirements.txt b/tests/requirements.txt deleted file mode 100644 index 61f8088..0000000 --- a/tests/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -beartype -equinox -pytest -pytest-xdist -jaxlib From 71d1296160e8a8927dd7cb487c5d9e215d320442 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 14 Nov 2025 09:01:25 +0100 Subject: [PATCH 13/33] Decrease default value to prevent overflow in 32-bit. --- lineax/_solver/lsmr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 7496e77..524a7b2 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -178,7 +178,7 @@ def beta_zero(beta, u): # variables for estimation of ||A|| and cond(A) normA2=alpha**2, maxrbar=0.0, - minrbar=1e100, + minrbar=1.0e38, condA=1.0, # variables for use in stopping rules istop=0, From 60d654acad7edc6fa7041f1307c1090621190681 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sat, 15 Nov 2025 18:32:22 +0100 Subject: [PATCH 14/33] use maximum value of runtime dtype --- lineax/_solver/lsmr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 524a7b2..ad01a66 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -178,7 +178,7 @@ def beta_zero(beta, u): # variables for estimation of ||A|| and cond(A) normA2=alpha**2, maxrbar=0.0, - minrbar=1.0e38, + minrbar=jnp.finfo(dtype).max, condA=1.0, # variables for use in stopping rules istop=0, From 3c8df62a94c1bab49659a6a0dc53c1ea04a377aa Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Tue, 25 Nov 2025 21:53:14 +0100 Subject: [PATCH 15/33] limit condition numbers for sensitive, iterative solvers --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 4462300..54cd5f6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,7 +57,7 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): - if isinstance(solver, lx.Normal): + if isinstance(solver, lx.NormalCG | lx.LSMR | lx.GMRES): cond_cutoff = math.sqrt(1000) else: cond_cutoff = 1000 From fadd714a3bb05d361f029989996db2f942554bb6 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Tue, 25 Nov 2025 23:22:32 +0100 Subject: [PATCH 16/33] cond_cutoff -> 900 for GMRES, LSMR --- tests/helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 54cd5f6..ff406d4 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,8 +57,10 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): - if isinstance(solver, lx.NormalCG | lx.LSMR | lx.GMRES): + if isinstance(solver, lx.NormalCG): cond_cutoff = math.sqrt(1000) + elif isinstance(solver, lx.GMRES | lx.LSMR): + cond_cutoff = 900 else: cond_cutoff = 1000 return tuple( From 895fa23cbbcb811fea512920cbb85bb162b0eba5 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Wed, 26 Nov 2025 01:31:58 +0100 Subject: [PATCH 17/33] decrease condition number cutoff for LSMR --- tests/helpers.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index ff406d4..9e125d4 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -59,8 +59,17 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): if isinstance(solver, lx.NormalCG): cond_cutoff = math.sqrt(1000) - elif isinstance(solver, lx.GMRES | lx.LSMR): + # Comment (johannahaffner): Some of our iterative solvers do struggle with high + # condition numbers (as they are expected to do). The way this plays out can depend + # on the JAX version and on the platform the code runs - with a given seed and some + # new and optimised XLA kernel fusion, a test may fail. + # The cutoffs below are a little empirical, I've chosen them to be as high as + # possible to ensure that we are still testing a real scenario, without generating + # operators that the iterative solvers aren't built for. + elif isinstance(solver, lx.GMRES): cond_cutoff = 900 + elif isinstance(solver, lx.LSMR): + cond_cutoff = 800 else: cond_cutoff = 1000 return tuple( From daed0a430b7f85f3d655d2c4bf3c468252d7194c Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Wed, 26 Nov 2025 23:23:18 +0100 Subject: [PATCH 18/33] implement safeguard for int overflow in 32 bit --- lineax/_solver/lsmr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index ad01a66..d48b885 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -112,7 +112,11 @@ def compute( # number of singular values min_dim = min([m, n]) if self.max_steps is None: - max_steps = min_dim * 10 # for consistency with other iterative solvers + int_dtype = jnp.dtype(f"int{operator.in_structure().dtype.itemsize * 8}") + if min_dim > (jnp.iinfo(int_dtype).max / 10): + max_steps = jnp.iinfo(int_dtype).max + else: + max_steps = min_dim * 10 # for consistency with other iterative solvers else: max_steps = self.max_steps From 20ef0688fb4f85c28227611723a64de7748afed2 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sat, 29 Nov 2025 18:55:25 +0100 Subject: [PATCH 19/33] add extra safeguard for complex dtypes --- lineax/_solver/lsmr.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index d48b885..8d0db79 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -108,11 +108,24 @@ def compute( and self.rtol == 0 ) + dtype = jnp.result_type( + *jtu.tree_leaves(vector), + *jtu.tree_leaves(x), + *jtu.tree_leaves(operator.in_structure()), + ) + m, n = operator.out_size(), operator.in_size() # number of singular values min_dim = min([m, n]) if self.max_steps is None: - int_dtype = jnp.dtype(f"int{operator.in_structure().dtype.itemsize * 8}") + # Set max_steps based on the minimum dimension + avoid numerical overflows + # https://github.com/patrick-kidger/lineax/issues/175 + # https://github.com/patrick-kidger/lineax/issues/177 + if jnp.issubdtype(dtype, jnp.complexfloating): + real_dtype = jnp.finfo(dtype).dtype + else: + real_dtype = dtype + int_dtype = jnp.dtype(f"int{real_dtype.itemsize * 8}") if min_dim > (jnp.iinfo(int_dtype).max / 10): max_steps = jnp.iinfo(int_dtype).max else: @@ -123,12 +136,6 @@ def compute( if x is None: x = jtu.tree_map(jnp.zeros_like, operator.in_structure()) - dtype = jnp.result_type( - *jtu.tree_leaves(vector), - *jtu.tree_leaves(x), - *jtu.tree_leaves(operator.in_structure()), - ) - b = vector u = (ω(b) - ω(operator.mv(x))).ω normb = self.norm(b) From 08ceabec6e8413584ee0cfbfd4ba0930ded330ef Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sat, 13 Dec 2025 12:37:18 +0100 Subject: [PATCH 20/33] fix typo --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 9e125d4..aa99ee1 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,7 +57,7 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): - if isinstance(solver, lx.NormalCG): + if isinstance(solver, lx.Normal): cond_cutoff = math.sqrt(1000) # Comment (johannahaffner): Some of our iterative solvers do struggle with high # condition numbers (as they are expected to do). The way this plays out can depend From 9891748387e12bfa9607549fe1240818619581e9 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sat, 13 Dec 2025 14:51:15 +0100 Subject: [PATCH 21/33] reset condition number check --- tests/helpers.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index aa99ee1..4462300 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -59,17 +59,6 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype, i): def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): if isinstance(solver, lx.Normal): cond_cutoff = math.sqrt(1000) - # Comment (johannahaffner): Some of our iterative solvers do struggle with high - # condition numbers (as they are expected to do). The way this plays out can depend - # on the JAX version and on the platform the code runs - with a given seed and some - # new and optimised XLA kernel fusion, a test may fail. - # The cutoffs below are a little empirical, I've chosen them to be as high as - # possible to ensure that we are still testing a real scenario, without generating - # operators that the iterative solvers aren't built for. - elif isinstance(solver, lx.GMRES): - cond_cutoff = 900 - elif isinstance(solver, lx.LSMR): - cond_cutoff = 800 else: cond_cutoff = 1000 return tuple( From d29720947022ec761d2a8609de061950d6ec6e8b Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sat, 13 Dec 2025 15:51:59 +0100 Subject: [PATCH 22/33] simplify dtype conversion: reuse function from our misc module --- lineax/_solver/lsmr.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 8d0db79..8a58f39 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -42,6 +42,7 @@ from equinox.internal import ω from jaxtyping import Array, PyTree +from .._misc import complex_to_real_dtype from .._norm import two_norm from .._operator import AbstractLinearOperator, conj from .._solution import RESULTS @@ -121,11 +122,7 @@ def compute( # Set max_steps based on the minimum dimension + avoid numerical overflows # https://github.com/patrick-kidger/lineax/issues/175 # https://github.com/patrick-kidger/lineax/issues/177 - if jnp.issubdtype(dtype, jnp.complexfloating): - real_dtype = jnp.finfo(dtype).dtype - else: - real_dtype = dtype - int_dtype = jnp.dtype(f"int{real_dtype.itemsize * 8}") + int_dtype = jnp.dtype(f"int{complex_to_real_dtype(dtype).itemsize * 8}") if min_dim > (jnp.iinfo(int_dtype).max / 10): max_steps = jnp.iinfo(int_dtype).max else: From 9c1b9a9f3e2a3ae55235a996fec357444d8a7d95 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 22 Dec 2025 22:21:49 +0100 Subject: [PATCH 23/33] Moved bool states to Static. There seem to be some spurious downstream failures in Diffrax with JAX 0.8.2 otherwise. Probably JAX has started promoting these to tracers on some unusual codepath. --- lineax/_operator.py | 6 ++++-- lineax/_solver/cg.py | 13 +++++++++---- lineax/_solver/cholesky.py | 14 ++++++++++---- lineax/_solver/lu.py | 18 ++++++++++++++---- lineax/_solver/normal.py | 20 ++++++++++++++------ lineax/_solver/qr.py | 12 +++++++++--- lineax/_solver/triangular.py | 18 +++++++++++++----- pyproject.toml | 14 +++++++------- tests/__main__.py | 2 ++ 9 files changed, 82 insertions(+), 35 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 7d83693..a83eeff 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -255,7 +255,7 @@ def __init__( raise ValueError( "`MatrixLinearOperator(matrix=...)` should be 2-dimensional." ) - if not jnp.issubdtype(matrix, jnp.inexact): + if not jnp.issubdtype(matrix.dtype, jnp.inexact): matrix = matrix.astype(jnp.float32) self.matrix = matrix self.tags = _frozenset(tags) @@ -397,7 +397,9 @@ def sub_get_structure(leaf): raise ValueError( "`pytree` and `output_structure` are not consistent" ) - return jax.ShapeDtypeStruct(shape=shape[ndim:], dtype=jnp.dtype(leaf)) + return jax.ShapeDtypeStruct( + shape=shape[ndim:], dtype=jnp.result_type(leaf) + ) return _Leaf(jtu.tree_map(sub_get_structure, subpytree)) diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index 2541586..159aaf5 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -38,7 +38,7 @@ from .normal import Normal -_CGState: TypeAlias = tuple[AbstractLinearOperator, bool] +_CGState: TypeAlias = tuple[AbstractLinearOperator, eqxi.Static] # TODO(kidger): this is pretty slow to compile. @@ -99,7 +99,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): if is_nsd: operator = -operator operator = linearise(operator) - return operator, is_nsd + return operator, eqxi.Static(is_nsd) # This differs from jax.scipy.sparse.linalg.cg in: # 1. Every few steps we calculate the residual directly, rather than by cheaply @@ -114,6 +114,7 @@ def compute( self, state: _CGState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: operator, is_nsd = state + is_nsd = is_nsd.value preconditioner, y0 = preconditioner_and_y0(operator, vector, options) if not is_positive_semidefinite(preconditioner): raise ValueError("The preconditioner must be positive definite.") @@ -224,7 +225,9 @@ def cheap_r(): stats = {"num_steps": num_steps, "max_steps": self.max_steps} return solution, result, stats - def transpose(self, state: _CGState, options: dict[str, Any]): + def transpose( + self, state: _CGState, options: dict[str, Any] + ) -> tuple[_CGState, dict[str, Any]]: transpose_options = {} if "preconditioner" in options: transpose_options["preconditioner"] = options["preconditioner"].transpose() @@ -232,7 +235,9 @@ def transpose(self, state: _CGState, options: dict[str, Any]): transpose_state = psd_op.transpose(), is_nsd return transpose_state, transpose_options - def conj(self, state: _CGState, options: dict[str, Any]): + def conj( + self, state: _CGState, options: dict[str, Any] + ) -> tuple[_CGState, dict[str, Any]]: conj_options = {} if "preconditioner" in options: conj_options["preconditioner"] = conj(options["preconditioner"]) diff --git a/lineax/_solver/cholesky.py b/lineax/_solver/cholesky.py index 58d3d29..852ab70 100644 --- a/lineax/_solver/cholesky.py +++ b/lineax/_solver/cholesky.py @@ -14,6 +14,7 @@ from typing import Any, TypeAlias +import equinox.internal as eqxi import jax.flatten_util as jfu import jax.scipy as jsp from jaxtyping import Array, PyTree @@ -27,7 +28,7 @@ from .._solve import AbstractLinearSolver -_CholeskyState: TypeAlias = tuple[Array, bool] +_CholeskyState: TypeAlias = tuple[Array, eqxi.Static] class Cholesky(AbstractLinearSolver[_CholeskyState]): @@ -59,12 +60,13 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): factor, lower = jsp.linalg.cho_factor(matrix) # Fix upper triangular for simplicity. assert lower is False - return factor, is_nsd + return factor, eqxi.Static(is_nsd) def compute( self, state: _CholeskyState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: factor, is_nsd = state + is_nsd = is_nsd.value del options # Cholesky => PSD => symmetric => (in_structure == out_structure) => # we don't need to use packed structures. @@ -75,12 +77,16 @@ def compute( solution = unflatten(solution) return solution, RESULTS.successful, {} - def transpose(self, state: _CholeskyState, options: dict[str, Any]): + def transpose( + self, state: _CholeskyState, options: dict[str, Any] + ) -> tuple[_CholeskyState, dict[str, Any]]: # Matrix is self-adjoint factor, is_nsd = state return (factor.conj(), is_nsd), options - def conj(self, state: _CholeskyState, options: dict[str, Any]): + def conj( + self, state: _CholeskyState, options: dict[str, Any] + ) -> tuple[_CholeskyState, dict[str, Any]]: # Matrix is self-adjoint factor, is_nsd = state return (factor.conj(), is_nsd), options diff --git a/lineax/_solver/lu.py b/lineax/_solver/lu.py index d118991..7283600 100644 --- a/lineax/_solver/lu.py +++ b/lineax/_solver/lu.py @@ -14,6 +14,7 @@ from typing import Any, TypeAlias +import equinox.internal as eqxi import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, PyTree @@ -30,7 +31,7 @@ ) -_LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, bool] +_LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, eqxi.Static] class LU(AbstractLinearSolver[_LUState]): @@ -50,13 +51,14 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): lu = operator.as_matrix(), jnp.arange(operator.in_size(), dtype=jnp.int32) else: lu = jsp.linalg.lu_factor(operator.as_matrix()) - return lu, packed_structures, False + return lu, packed_structures, eqxi.Static(False) def compute( self, state: _LUState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: del options lu_and_piv, packed_structures, transpose = state + transpose = transpose.value trans = 1 if transpose else 0 vector = ravel_vector(vector, packed_structures) solution = jsp.linalg.lu_solve(lu_and_piv, vector, trans=trans) @@ -70,7 +72,11 @@ def transpose( ): lu_and_piv, packed_structures, transpose = state transposed_packed_structures = transpose_packed_structures(packed_structures) - transpose_state = lu_and_piv, transposed_packed_structures, not transpose + transpose_state = ( + lu_and_piv, + transposed_packed_structures, + eqxi.Static(not transpose.value), + ) transpose_options = {} return transpose_state, transpose_options @@ -80,7 +86,11 @@ def conj( options: dict[str, Any], ): (lu, piv), packed_structures, transpose = state - conj_state = (lu.conj(), piv), packed_structures, not transpose + conj_state = ( + (lu.conj(), piv), + packed_structures, + eqxi.Static(not transpose.value), + ) conj_options = {} return conj_state, conj_options diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py index cacd837..a2d4266 100644 --- a/lineax/_solver/normal.py +++ b/lineax/_solver/normal.py @@ -15,6 +15,7 @@ from copy import copy from typing import Any, TypeVar +import equinox.internal as eqxi from jaxtyping import Array, PyTree from .._operator import ( @@ -52,7 +53,7 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool): class Normal( AbstractLinearSolver[ - tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]] + tuple[_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]] ] ): """Wrapper for an inner solver of positive (semi)definite systems. The @@ -112,15 +113,18 @@ def init(self, operator, options): inner_options = normal_preconditioner_and_y0(options, tall) inner_state = self.inner_solver.init(inner_operator, inner_options) operator_conj_transpose = conj(operator.transpose()) - return inner_state, tall, operator_conj_transpose, inner_options + return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options def compute( self, - state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + state: tuple[ + _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] + ], vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: inner_state, tall, operator_conj_transpose, inner_options = state + tall = tall.value del state, options if tall: vector = operator_conj_transpose.mv(vector) @@ -133,7 +137,9 @@ def compute( def transpose( self, - state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + state: tuple[ + _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] + ], options: dict[str, Any], ): inner_state, tall, operator_conj_transpose, inner_options = state @@ -142,7 +148,7 @@ def transpose( ) state_transpose = ( inner_state_conj, - not tall, + eqxi.Static(not tall.value), operator_conj_transpose.transpose(), inner_options, ) @@ -150,7 +156,9 @@ def transpose( def conj( self, - state: tuple[_InnerSolverState, bool, AbstractLinearOperator, dict[str, Any]], + state: tuple[ + _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] + ], options: dict[str, Any], ): inner_state, tall, operator_conj_transpose, inner_options = state diff --git a/lineax/_solver/qr.py b/lineax/_solver/qr.py index 2551e95..4caa784 100644 --- a/lineax/_solver/qr.py +++ b/lineax/_solver/qr.py @@ -14,6 +14,7 @@ from typing import Any, TypeAlias +import equinox.internal as eqxi import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, PyTree @@ -29,7 +30,7 @@ ) -_QRState: TypeAlias = tuple[tuple[Array, Array], bool, PackedStructures] +_QRState: TypeAlias = tuple[tuple[Array, Array], eqxi.Static, PackedStructures] class QR(AbstractLinearSolver): @@ -59,7 +60,7 @@ def init(self, operator, options): matrix = matrix.T qr = jnp.linalg.qr(matrix, mode="reduced") # pyright: ignore packed_structures = pack_structures(operator) - return qr, transpose, packed_structures + return qr, eqxi.Static(transpose), packed_structures def compute( self, @@ -68,6 +69,7 @@ def compute( options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: (q, r), transpose, packed_structures = state + transpose = transpose.value del state, options vector = ravel_vector(vector, packed_structures) if transpose: @@ -86,7 +88,11 @@ def compute( def transpose(self, state: _QRState, options: dict[str, Any]): (q, r), transpose, structures = state transposed_packed_structures = transpose_packed_structures(structures) - transpose_state = (q, r), not transpose, transposed_packed_structures + transpose_state = ( + (q, r), + eqxi.Static(not transpose.value), + transposed_packed_structures, + ) transpose_options = {} return transpose_state, transpose_options diff --git a/lineax/_solver/triangular.py b/lineax/_solver/triangular.py index f926295..304f21a 100644 --- a/lineax/_solver/triangular.py +++ b/lineax/_solver/triangular.py @@ -14,6 +14,7 @@ from typing import Any, TypeAlias +import equinox.internal as eqxi import jax.scipy as jsp from jaxtyping import Array, PyTree @@ -34,7 +35,9 @@ ) -_TriangularState: TypeAlias = tuple[Array, bool, bool, PackedStructures, bool] +_TriangularState: TypeAlias = tuple[ + Array, eqxi.Static, eqxi.Static, PackedStructures, eqxi.Static +] class Triangular(AbstractLinearSolver[_TriangularState]): @@ -56,16 +59,19 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): ) return ( operator.as_matrix(), - is_lower_triangular(operator), - has_unit_diagonal(operator), + eqxi.Static(is_lower_triangular(operator)), + eqxi.Static(has_unit_diagonal(operator)), pack_structures(operator), - False, # transposed + eqxi.Static(False), # transposed ) def compute( self, state: _TriangularState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: matrix, lower, unit_diagonal, packed_structures, transpose = state + lower = lower.value + unit_diagonal = unit_diagonal.value + transpose = transpose.value del state, options vector = ravel_vector(vector, packed_structures) if transpose: @@ -79,6 +85,7 @@ def compute( return solution, RESULTS.successful, {} def transpose(self, state: _TriangularState, options: dict[str, Any]): + del options matrix, lower, unit_diagonal, packed_structures, transpose = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_state = ( @@ -86,12 +93,13 @@ def transpose(self, state: _TriangularState, options: dict[str, Any]): lower, unit_diagonal, transposed_packed_structures, - not transpose, + eqxi.Static(not transpose.value), ) transpose_options = {} return transpose_state, transpose_options def conj(self, state: _TriangularState, options: dict[str, Any]): + del options matrix, lower, unit_diagonal, packed_structures, transpose = state conj_state = ( matrix.conj(), diff --git a/pyproject.toml b/pyproject.toml index aa6f053..389a2ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,18 +67,18 @@ addopts = "--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeC extend-include = ["*.ipynb"] src = [] -[tool.ruff.isort] -combine-as-imports = true -extra-standard-library = ["typing_extensions"] -lines-after-imports = 2 -order-by-type = false - [tool.ruff.lint] fixable = ["I001", "F401", "UP"] -ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] +ignore = ["E402", "E721", "E731", "E741", "F722"] select = ["E", "F", "I001", "UP"] [tool.ruff.lint.flake8-import-conventions.extend-aliases] "collections" = "co" "functools" = "ft" "itertools" = "it" + +[tool.ruff.lint.isort] +combine-as-imports = true +extra-standard-library = ["typing_extensions"] +lines-after-imports = 2 +order-by-type = false diff --git a/tests/__main__.py b/tests/__main__.py index 5edf5e2..7c5ccb2 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -26,4 +26,6 @@ if file.is_file() and file.name.startswith("test"): out = subprocess.run(f"pytest {file}", shell=True).returncode running_out = max(running_out, out) + if out != 0: + break sys.exit(running_out) From c11bd7d1484e750ee60b93c6257519c9c756221c Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Tue, 27 Jan 2026 14:22:19 +0000 Subject: [PATCH 24/33] simplify jac=bwd behaviour --- lineax/_operator.py | 4 +--- tests/helpers.py | 30 +++++++++++++++++++++++++++++- tests/test_adjoint.py | 5 +++++ tests/test_operator.py | 15 ++++++++++++++- tests/test_well_posed.py | 5 +++++ 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 274b25f..be162d2 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1379,9 +1379,7 @@ def _(operator): if operator.jac == "fwd" or operator.jac is None: diag_as_pytree = operator.mv(unravel(basis)) elif operator.jac == "bwd": - fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) - _, vjp_fun = jax.vjp(fn, operator.x) - diag_as_pytree = vjp_fun(unravel(basis)) + diag_as_pytree = operator.T.mv(unravel(basis)) else: raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") diff --git a/tests/helpers.py b/tests/helpers.py index c902a66..d8914c8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -199,13 +199,41 @@ def make_jac_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2 + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 return lx.JacobianLinearOperator(fn, x, None, tags) +@_operators_append +def make_jacfwd_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") + + +@_operators_append +def make_jacrev_operator(getkey, matrix, tags): + out_size, in_size = matrix.shape + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) + diff = matrix - jac + fn = lambda x, _: a + (b + diff) @ x + c @ x**2 + return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") + + @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 75fbbc1..5827efc 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey): tags = () in_size = 5 out_size = 3 + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), diff --git a/tests/test_operator.py b/tests/test_operator.py index cec1f36..b0aaf97 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,6 +23,7 @@ from .helpers import ( make_identity_operator, + make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype): else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) @@ -160,9 +165,17 @@ def test_materialise_large(dtype, getkey): def test_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) matrix_diag = jnp.diag(matrix) - operators = _setup(getkey, jnp.diag(matrix_diag)) + # test we properly extract diagonal from a dense matrix when not tagged + operators = _setup(getkey, matrix) for operator in operators: assert jnp.allclose(lx.diagonal(operator), matrix_diag) + # test we properly extract diagonal from diagonal matrix when tagged + operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag) + for operator in operators: + if isinstance(operator, lx.IdentityLinearOperator): + assert jnp.allclose(lx.diagonal(operator), jnp.ones(3)) + else: + assert jnp.allclose(lx.diagonal(operator), matrix_diag) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index e377ef4..c3c0358 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -20,6 +20,7 @@ from .helpers import ( construct_matrix, + make_jacrev_operator, ops, params, solvers, @@ -31,6 +32,10 @@ @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): + if make_operator is make_jacrev_operator and dtype is jnp.complex128: + pytest.skip( + 'JacobianLinearOperator does not support complex dtypes when jac="bwd"' + ) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: From 1c58307cc582643da61fdba20fbb80b1dcdb0758 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Tue, 27 Jan 2026 16:32:15 +0000 Subject: [PATCH 25/33] add PyTreeLinearOperator optimisations for mv and diagonal --- lineax/_operator.py | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index a468be2..8919731 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -325,6 +325,14 @@ def __init__(self, value): self.value = value +def _leaf_from_keypath(pytree: PyTree, keypath) -> Array: + """Extract the leaf from a pytree at the given keypath.""" + for path, leaf in jtu.tree_leaves_with_path(pytree): + if path == keypath: + return leaf + raise ValueError(f"Leaf not found at keypath {keypath}") + + # The `{input,output}_structure`s have to be static because otherwise abstract # evaluation rules will promote them to ShapedArrays. class PyTreeLinearOperator(AbstractLinearOperator): @@ -420,11 +428,22 @@ def mv(self, vector): # vector has structure [tree(in), leaf(in)] # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] - # return has struture [tree(out), leaf(out)] - def matmul(_, matrix): - return _tree_matmul(matrix, vector) + # return has structure [tree(out), leaf(out)] + if diagonal_tag in self.tags: + # Efficient path: diagonal mv is just element-wise multiplication + def diag_mv(keypath, struct, subpytree): + block = _leaf_from_keypath(subpytree, keypath) + vec_leaf = _leaf_from_keypath(vector, keypath) + diag = jnp.diag(block.reshape(struct.size, struct.size)) + return (diag * vec_leaf.reshape(-1)).reshape(struct.shape) + + return jtu.tree_map_with_path(diag_mv, self.out_structure(), self.pytree) + else: + + def matmul(_, matrix): + return _tree_matmul(matrix, vector) - return jtu.tree_map(matmul, self.out_structure(), self.pytree) + return jtu.tree_map(matmul, self.out_structure(), self.pytree) def as_matrix(self): with jax.numpy_dtype_promotion("standard"): @@ -1369,11 +1388,26 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) -@diagonal.register(PyTreeLinearOperator) def _(operator): return jnp.diag(operator.as_matrix()) +@diagonal.register(PyTreeLinearOperator) +def _(operator): + if is_diagonal(operator): + # in_structure == out_structure guaranteed for diagonal operators + def extract_diag(keypath, struct, subpytree): + block = _leaf_from_keypath(subpytree, keypath) + return jnp.diag(block.reshape(struct.size, struct.size)) + + diags = jtu.tree_map_with_path( + extract_diag, operator.out_structure(), operator.pytree + ) + return jnp.concatenate(jtu.tree_leaves(diags)) + else: + return jnp.diag(operator.as_matrix()) + + @diagonal.register(JacobianLinearOperator) def _(operator): if is_diagonal(operator): From caeb4e8e79a7c15cbede8f004cd71b81fc0e9eeb Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Tue, 27 Jan 2026 17:10:44 +0000 Subject: [PATCH 26/33] add type annotation to _leaf_from_keypath --- lineax/_operator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 8919731..a113496 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -325,7 +325,7 @@ def __init__(self, value): self.value = value -def _leaf_from_keypath(pytree: PyTree, keypath) -> Array: +def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array: """Extract the leaf from a pytree at the given keypath.""" for path, leaf in jtu.tree_leaves_with_path(pytree): if path == keypath: @@ -1420,7 +1420,12 @@ def _(operator): if operator.jac == "fwd" or operator.jac is None: diag_as_pytree = operator.mv(unravel(basis)) elif operator.jac == "bwd": - diag_as_pytree = operator.T.mv(unravel(basis)) + # Don't use operator.T.mv here: if the operator is symmetric, + # operator.T returns self, and self.mv with jac="bwd" computes + # the full Jacobian with jacrev. Direct VJP is more efficient. + fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) + _, vjp_fun = jax.vjp(fn, operator.x) + (diag_as_pytree,) = vjp_fun(unravel(basis)) else: raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") From 0e23db9de83604315ffe478a24f3d861986274f4 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 30 Jan 2026 00:33:51 +0000 Subject: [PATCH 27/33] add efficient Matrix mv path, and Jacobian/Function materialise path --- lineax/_operator.py | 123 +++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 59 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index a113496..da7509a 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -261,6 +261,8 @@ def __init__( self.tags = _frozenset(tags) def mv(self, vector): + if diagonal_tag in self.tags: + return jnp.diagonal(self.matrix) * vector return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST) def as_matrix(self): @@ -1217,6 +1219,21 @@ def _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoR raise NotImplementedError(msg) +def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree[Array]: + """Construct a PyTree of ones matching the given structure. + + For a diagonal linear operator, applying it to this basis yields the diagonal. + + Callers should wrap this in `jax.ensure_compile_time_eval()` to ensure the + basis is stored in the jaxpr rather than being reallocated at runtime. + """ + + def make_ones(struct): + return jnp.ones(struct.shape, struct.dtype) + + return jtu.tree_map(make_ones, structure) + + # linearise @@ -1333,35 +1350,58 @@ def _(operator): @materialise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) - jac, aux = jacobian( - fn, - operator.in_size(), - operator.out_size(), - holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)), - has_aux=True, - jac=operator.jac, - )(operator.x) - out = PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) - return AuxLinearOperator(out, aux) + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + + if operator.jac == "bwd": + ((_, aux), vjp_fn) = jax.vjp(fn, operator.x) + if aux is None: + aux_ct = None + else: + aux_ct = jtu.tree_map(jnp.zeros_like, aux) + (diag_as_pytree,) = vjp_fn((basis, aux_ct)) + else: # "fwd" or None + (_, aux), (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,)) + + out = DiagonalLinearOperator(diag_as_pytree) + return AuxLinearOperator(out, aux) + else: + jac, aux = jacobian( + fn, + operator.in_size(), + operator.out_size(), + holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)), + has_aux=True, + jac=operator.jac, + )(operator.x) + out = PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) + return AuxLinearOperator(out, aux) @materialise.register(FunctionLinearOperator) def _(operator): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - eye = jnp.eye(flat.size, dtype=flat.dtype) - jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + diag_as_pytree = operator.mv(basis) + return DiagonalLinearOperator(diag_as_pytree) + else: + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + eye = jnp.eye(flat.size, dtype=flat.dtype) + jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) - def batch_unravel(x): - assert x.ndim > 0 - unravel_ = unravel - for _ in range(x.ndim - 1): - unravel_ = jax.vmap(unravel_) - return unravel_(x) + def batch_unravel(x): + assert x.ndim > 0 + unravel_ = unravel + for _ in range(x.ndim - 1): + unravel_ = jax.vmap(unravel_) + return unravel_(x) - jac = jtu.tree_map(batch_unravel, jac) - return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) + jac = jtu.tree_map(batch_unravel, jac) + return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) # diagonal @@ -1409,44 +1449,9 @@ def extract_diag(keypath, struct, subpytree): @diagonal.register(JacobianLinearOperator) -def _(operator): - if is_diagonal(operator): - with jax.ensure_compile_time_eval(): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - basis = jnp.ones(flat.size, dtype=flat.dtype) - - if operator.jac == "fwd" or operator.jac is None: - diag_as_pytree = operator.mv(unravel(basis)) - elif operator.jac == "bwd": - # Don't use operator.T.mv here: if the operator is symmetric, - # operator.T returns self, and self.mv with jac="bwd" computes - # the full Jacobian with jacrev. Direct VJP is more efficient. - fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) - _, vjp_fun = jax.vjp(fn, operator.x) - (diag_as_pytree,) = vjp_fun(unravel(basis)) - else: - raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") - - return jfu.ravel_pytree(diag_as_pytree)[0] - else: - return jnp.diag(operator.as_matrix()) - - @diagonal.register(FunctionLinearOperator) def _(operator): - if is_diagonal(operator): - with jax.ensure_compile_time_eval(): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - basis = jnp.ones(flat.size, dtype=flat.dtype) - - diag_as_pytree = operator.fn(unravel(basis)) - return jfu.ravel_pytree(diag_as_pytree)[0] - else: - return jnp.diag(operator.as_matrix()) + return diagonal(materialise(operator)) @diagonal.register(DiagonalLinearOperator) From 49077a65e294e5f5246b8ec9601a60de7642a0e8 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 10:59:46 +0000 Subject: [PATCH 28/33] add try sparse materialise helper (#3) * Add sparse materialisation helper and efficient diagonal paths This PR introduces _try_sparse_materialise helper and optimizes diagonal operator handling throughout lineax. Key changes: - Add _try_sparse_materialise() that converts diagonal-tagged operators to DiagonalLinearOperator, preserving pytree structure via unravel - Add efficient diagonal() for JLO/FLO using single JVP/VJP with ones basis - Add efficient diagonal() for Composed: diag(A @ B) = diag(A) * diag(B) - Simplify mv() for MLO, PTLO, Add, Composed to use _try_sparse_materialise - Apply early sparse materialisation in materialise() registrations Aux handling: - Fix bug: linearise/materialise now preserve aux on AuxLinearOperator - Preserve aux from first operator in Composed (output comes from op1) - Inner aux in Add children silently stripped (unclear semantics - may warrant guards in future) --------- Co-authored-by: jpbrodrick89 --- lineax/_operator.py | 177 +++++++++++++++++++++++++++++++------------- 1 file changed, 124 insertions(+), 53 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index da7509a..6cce0c0 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -261,8 +261,9 @@ def __init__( self.tags = _frozenset(tags) def mv(self, vector): - if diagonal_tag in self.tags: - return jnp.diagonal(self.matrix) * vector + sparse = _try_sparse_materialise(self) + if sparse is not self: + return sparse.mv(vector) return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST) def as_matrix(self): @@ -431,21 +432,14 @@ def mv(self, vector): # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] # return has structure [tree(out), leaf(out)] - if diagonal_tag in self.tags: - # Efficient path: diagonal mv is just element-wise multiplication - def diag_mv(keypath, struct, subpytree): - block = _leaf_from_keypath(subpytree, keypath) - vec_leaf = _leaf_from_keypath(vector, keypath) - diag = jnp.diag(block.reshape(struct.size, struct.size)) - return (diag * vec_leaf.reshape(-1)).reshape(struct.shape) - - return jtu.tree_map_with_path(diag_mv, self.out_structure(), self.pytree) - else: + sparse = _try_sparse_materialise(self) + if sparse is not self: + return sparse.mv(vector) - def matmul(_, matrix): - return _tree_matmul(matrix, vector) + def matmul(_, matrix): + return _tree_matmul(matrix, vector) - return jtu.tree_map(matmul, self.out_structure(), self.pytree) + return jtu.tree_map(matmul, self.out_structure(), self.pytree) def as_matrix(self): with jax.numpy_dtype_promotion("standard"): @@ -1017,6 +1011,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): + sparse = _try_sparse_materialise(self) + if sparse is not self: + return sparse.mv(vector) mv1 = self.operator1.mv(vector) mv2 = self.operator2.mv(vector) return (mv1**ω + mv2**ω).ω @@ -1151,6 +1148,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): + sparse = _try_sparse_materialise(self) + if sparse is not self: + return sparse.mv(vector) return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): @@ -1223,9 +1223,6 @@ def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree """Construct a PyTree of ones matching the given structure. For a diagonal linear operator, applying it to this basis yields the diagonal. - - Callers should wrap this in `jax.ensure_compile_time_eval()` to ensure the - basis is stored in the jaxpr rather than being reallocated at runtime. """ def make_ones(struct): @@ -1338,8 +1335,30 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: _default_not_implemented("materialise", operator) +def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: + """Try to materialise to a sparse operator. + + Returns a DiagonalLinearOperator if the operator is tagged as diagonal, + otherwise returns the original operator unchanged. The resulting operator + preserves the input/output structure of the original operator. + + Note: This function should not be called on AuxLinearOperator directly - + use the materialise registration for AuxLinearOperator instead. + """ + if is_diagonal(operator): + diag_flat = diagonal(operator) + _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + diag_pytree = unravel(diag_flat) + return DiagonalLinearOperator(diag_pytree) + return operator + + @materialise.register(MatrixLinearOperator) @materialise.register(PyTreeLinearOperator) +def _(operator): + return _try_sparse_materialise(operator) + + @materialise.register(IdentityLinearOperator) @materialise.register(DiagonalLinearOperator) @materialise.register(TridiagonalLinearOperator) @@ -1353,7 +1372,6 @@ def _(operator): if is_diagonal(operator): with jax.ensure_compile_time_eval(): basis = _construct_diagonal_basis(operator.in_structure()) - if operator.jac == "bwd": ((_, aux), vjp_fn) = jax.vjp(fn, operator.x) if aux is None: @@ -1363,45 +1381,40 @@ def _(operator): (diag_as_pytree,) = vjp_fn((basis, aux_ct)) else: # "fwd" or None (_, aux), (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,)) - out = DiagonalLinearOperator(diag_as_pytree) return AuxLinearOperator(out, aux) - else: - jac, aux = jacobian( - fn, - operator.in_size(), - operator.out_size(), - holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)), - has_aux=True, - jac=operator.jac, - )(operator.x) - out = PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) - return AuxLinearOperator(out, aux) + jac, aux = jacobian( + fn, + operator.in_size(), + operator.out_size(), + holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)), + has_aux=True, + jac=operator.jac, + )(operator.x) + out = PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) + return AuxLinearOperator(out, aux) @materialise.register(FunctionLinearOperator) def _(operator): - if is_diagonal(operator): - with jax.ensure_compile_time_eval(): - basis = _construct_diagonal_basis(operator.in_structure()) - diag_as_pytree = operator.mv(basis) - return DiagonalLinearOperator(diag_as_pytree) - else: - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) - ) - eye = jnp.eye(flat.size, dtype=flat.dtype) - jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) + out = _try_sparse_materialise(operator) + if out is not operator: + return out + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + ) + eye = jnp.eye(flat.size, dtype=flat.dtype) + jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) - def batch_unravel(x): - assert x.ndim > 0 - unravel_ = unravel - for _ in range(x.ndim - 1): - unravel_ = jax.vmap(unravel_) - return unravel_(x) + def batch_unravel(x): + assert x.ndim > 0 + unravel_ = unravel + for _ in range(x.ndim - 1): + unravel_ = jax.vmap(unravel_) + return unravel_(x) - jac = jtu.tree_map(batch_unravel, jac) - return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) + jac = jtu.tree_map(batch_unravel, jac) + return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) # diagonal @@ -1449,8 +1462,29 @@ def extract_diag(keypath, struct, subpytree): @diagonal.register(JacobianLinearOperator) +def _(operator): + if is_diagonal(operator): + fn = _NoAuxIn(operator.fn, operator.args) + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + if operator.jac == "bwd": + (_, vjp_fn) = jax.vjp(fn, operator.x) + (diag_as_pytree,) = vjp_fn((basis, None)) + else: # "fwd" or None + _, (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,)) + diag, _ = jfu.ravel_pytree(diag_as_pytree) + return diag + return diagonal(materialise(operator)) + + @diagonal.register(FunctionLinearOperator) def _(operator): + if is_diagonal(operator): + with jax.ensure_compile_time_eval(): + basis = _construct_diagonal_basis(operator.in_structure()) + diag_as_pytree = operator.mv(basis) + diag, _ = jfu.ravel_pytree(diag_as_pytree) + return diag return diagonal(materialise(operator)) @@ -1913,9 +1947,31 @@ def _(operator, transform=transform): def _(operator, transform=transform): return transform(operator.operator) / operator.scalar - @transform.register(AuxLinearOperator) # pyright: ignore - def _(operator, transform=transform): - return transform(operator.operator) + +# diagonal strips aux (returns array, not operator) +@diagonal.register(AuxLinearOperator) +def _(operator): + return diagonal(operator.operator) + + +# linearise and materialise preserve aux +@linearise.register(AuxLinearOperator) +def _(operator): + return AuxLinearOperator(linearise(operator.operator), operator.aux) + + +@materialise.register(AuxLinearOperator) +def _(operator): + return AuxLinearOperator(materialise(operator.operator), operator.aux) + + +# Override AddLinearOperator materialise to try sparse materialisation early +@materialise.register(AddLinearOperator) +def _(operator): + out = _try_sparse_materialise(operator) + if out is not operator: + return out + return materialise(operator.operator1) + materialise(operator.operator2) @linearise.register(TangentLinearOperator) @@ -1984,16 +2040,31 @@ def _(operator): @linearise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(linearise(inner_composed), aux) return linearise(operator.operator1) @ linearise(operator.operator2) @materialise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(materialise(inner_composed), aux) + out = _try_sparse_materialise(operator) + if out is not operator: + return out return materialise(operator.operator1) @ materialise(operator.operator2) @diagonal.register(ComposedLinearOperator) def _(operator): + if is_diagonal(operator): + return diagonal(operator.operator1) * diagonal(operator.operator2) return jnp.diag(operator.as_matrix()) From 045f6ae4563062c09aa8f041fcce582f27ec4816 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:27:29 +0000 Subject: [PATCH 29/33] unify JLO and FLO diagonal registrations --- lineax/_operator.py | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index b1d9095..a0ffad9 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1370,8 +1370,8 @@ def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearO otherwise returns the original operator unchanged. The resulting operator preserves the input/output structure of the original operator. - Note: This function should not be called on AuxLinearOperator directly - - use the materialise registration for AuxLinearOperator instead. + Note: This function silently strips aux and as such should not be called + on AuxLinearOperator directly. """ if is_diagonal(operator): diag_flat = diagonal(operator) @@ -1476,7 +1476,7 @@ def _(operator): @diagonal.register(PyTreeLinearOperator) def _(operator): if is_diagonal(operator): - # in_structure == out_structure guaranteed for diagonal operators + def extract_diag(keypath, struct, subpytree): block = _leaf_from_keypath(subpytree, keypath) return jnp.diag(block.reshape(struct.size, struct.size)) @@ -1490,21 +1490,6 @@ def extract_diag(keypath, struct, subpytree): @diagonal.register(JacobianLinearOperator) -def _(operator): - if is_diagonal(operator): - fn = _NoAuxIn(operator.fn, operator.args) - with jax.ensure_compile_time_eval(): - basis = _construct_diagonal_basis(operator.in_structure()) - if operator.jac == "bwd": - (_, vjp_fn) = jax.vjp(fn, operator.x) - (diag_as_pytree,) = vjp_fn((basis, None)) - else: # "fwd" or None - _, (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,)) - diag, _ = jfu.ravel_pytree(diag_as_pytree) - return diag - return diagonal(materialise(operator)) - - @diagonal.register(FunctionLinearOperator) def _(operator): if is_diagonal(operator): @@ -1983,17 +1968,13 @@ def _(operator): # linearise and materialise preserve aux -@linearise.register(AuxLinearOperator) -def _(operator): - return AuxLinearOperator(linearise(operator.operator), operator.aux) - +for transform in (linearise, materialise): -@materialise.register(AuxLinearOperator) -def _(operator): - return AuxLinearOperator(materialise(operator.operator), operator.aux) + @transform.register(AuxLinearOperator) + def _(operator, transform=transform): + return AuxLinearOperator(transform(operator.operator), operator.aux) -# Override AddLinearOperator materialise to try sparse materialisation early @materialise.register(AddLinearOperator) def _(operator): out = _try_sparse_materialise(operator) From 72e0229b827c1163bf0ed05a7f5d1ffbccfe4732 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:35:21 +0000 Subject: [PATCH 30/33] rename sparse variable -> maybe_sparse_op --- lineax/_operator.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index a0ffad9..ce6ae4d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -262,9 +262,9 @@ def __init__( self.tags = _frozenset(tags) def mv(self, vector): - sparse = _try_sparse_materialise(self) - if sparse is not self: - return sparse.mv(vector) + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST) def as_matrix(self): @@ -433,9 +433,9 @@ def mv(self, vector): # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] # return has structure [tree(out), leaf(out)] - sparse = _try_sparse_materialise(self) - if sparse is not self: - return sparse.mv(vector) + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) def matmul(_, matrix): return _tree_matmul(matrix, vector) @@ -1026,9 +1026,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): - sparse = _try_sparse_materialise(self) - if sparse is not self: - return sparse.mv(vector) + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) mv1 = self.operator1.mv(vector) mv2 = self.operator2.mv(vector) return (mv1**ω + mv2**ω).ω @@ -1163,9 +1163,9 @@ def __check_init__(self): raise ValueError("Incompatible linear operator structures") def mv(self, vector): - sparse = _try_sparse_materialise(self) - if sparse is not self: - return sparse.mv(vector) + maybe_sparse_op = _try_sparse_materialise(self) + if maybe_sparse_op is not self: + return maybe_sparse_op.mv(vector) return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): From 8fee6e44520df4f0e4cd3b4671158532a7749b86 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:38:12 +0000 Subject: [PATCH 31/33] move and shorten construct diagonal basis --- lineax/_operator.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index ce6ae4d..e4468b2 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1234,18 +1234,6 @@ def _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoR raise NotImplementedError(msg) -def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree[Array]: - """Construct a PyTree of ones matching the given structure. - - For a diagonal linear operator, applying it to this basis yields the diagonal. - """ - - def make_ones(struct): - return jnp.ones(struct.shape, struct.dtype) - - return jtu.tree_map(make_ones, structure) - - # linearise @@ -1381,6 +1369,11 @@ def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearO return operator +def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree[Array]: + """Construct a PyTree of ones matching the given structure.""" + return jtu.tree_map(lambda s: jnp.ones(s.shape, s.dtype), structure) + + @materialise.register(MatrixLinearOperator) @materialise.register(PyTreeLinearOperator) def _(operator): From 604a2b47d27c93b33da70d88c12f1accd0d3a136 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:43:55 +0000 Subject: [PATCH 32/33] revert unecessary **2.0 -> more efficient **2 --- tests/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 130f2cd..4259b85 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -202,7 +202,7 @@ def make_jac_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 @@ -216,7 +216,7 @@ def make_jacfwd_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 @@ -235,7 +235,7 @@ def make_jacrev_operator(getkey, matrix, tags): a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) - fn_tmp = lambda x, _: a + b @ x + c @ x**2.0 + fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac From 6212763d3da1d719eebdf912832aa4948f0c5f86 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:47:17 +0000 Subject: [PATCH 33/33] comments on why we can't use try sparse materialise with JLO --- lineax/_operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index e4468b2..05dfba0 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1359,7 +1359,7 @@ def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearO preserves the input/output structure of the original operator. Note: This function silently strips aux and as such should not be called - on AuxLinearOperator directly. + on AuxLinearOperator or JacobianLinearOperatory directly. """ if is_diagonal(operator): diag_flat = diagonal(operator) @@ -1390,6 +1390,7 @@ def _(operator): @materialise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) + # can't use try_sparse_materialise as strips aux if is_diagonal(operator): with jax.ensure_compile_time_eval(): basis = _construct_diagonal_basis(operator.in_structure())