From eb69964a0f6f24aba8efce6ee6434ff252e9e576 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Tue, 3 Feb 2026 09:57:28 +0000 Subject: [PATCH] apply linearise in init for Normal and iterative solvers --- lineax/_solver/bicgstab.py | 4 ++-- lineax/_solver/gmres.py | 8 ++------ lineax/_solver/lsmr.py | 4 ++-- lineax/_solver/misc.py | 7 ++----- lineax/_solver/normal.py | 19 ++++++++++--------- 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 7208745..31cf241 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -23,7 +23,7 @@ from jaxtyping import Array, PyTree from .._norm import max_norm, tree_dot -from .._operator import AbstractLinearOperator, conj +from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import preconditioner_and_y0 @@ -73,7 +73,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): "`BiCGstab(..., normal=False)` may only be used for linear solves with " "square matrices." ) - return operator + return linearise(operator) def compute( self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any] diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index ccf7e97..d5911a0 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -26,11 +26,7 @@ from .._misc import structure_equal from .._norm import max_norm, two_norm -from .._operator import ( - AbstractLinearOperator, - conj, - MatrixLinearOperator, -) +from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator from .._solution import RESULTS from .._solve import AbstractLinearSolver, linear_solve from .misc import preconditioner_and_y0 @@ -86,7 +82,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): "`GMRES(..., normal=False)` may only be used for linear solves with " "square matrices." ) - return operator + return linearise(operator) # # This differs from `jax.scipy.sparse.linalg.gmres` in a few ways: diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 8a58f39..94303e7 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -44,7 +44,7 @@ from .._misc import complex_to_real_dtype from .._norm import two_norm -from .._operator import AbstractLinearOperator, conj +from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS from .._solve import AbstractLinearSolver @@ -89,7 +89,7 @@ def __check_init__(self): ) def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): - return operator + return linearise(operator) def compute( self, diff --git a/lineax/_solver/misc.py b/lineax/_solver/misc.py index 55b64b6..d82744a 100644 --- a/lineax/_solver/misc.py +++ b/lineax/_solver/misc.py @@ -24,10 +24,7 @@ from jaxtyping import Array, PyTree, Shaped from .._misc import strip_weak_dtype, structure_equal -from .._operator import ( - AbstractLinearOperator, - IdentityLinearOperator, -) +from .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise def preconditioner_and_y0( @@ -35,7 +32,7 @@ def preconditioner_and_y0( ): structure = operator.in_structure() try: - preconditioner = options["preconditioner"] + preconditioner = linearise(options["preconditioner"]) except KeyError: preconditioner = IdentityLinearOperator(structure) else: diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py index a2d4266..6d7306f 100644 --- a/lineax/_solver/normal.py +++ b/lineax/_solver/normal.py @@ -18,10 +18,7 @@ import equinox.internal as eqxi from jaxtyping import Array, PyTree -from .._operator import ( - conj, - TaggedLinearOperator, -) +from .._operator import conj, linearise, TaggedLinearOperator from .._solution import RESULTS from .._solve import AbstractLinearOperator, AbstractLinearSolver from .._tags import positive_semidefinite_tag @@ -36,6 +33,7 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool): inner_options = copy(options) del options if preconditioner is not None: + preconditioner = linearise(preconditioner) if tall: inner_options["preconditioner"] = TaggedLinearOperator( preconditioner @ conj(preconditioner.transpose()), @@ -46,8 +44,8 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool): conj(preconditioner.transpose()) @ preconditioner, positive_semidefinite_tag, ) - if preconditioner is not None and y0 is not None and not tall: - inner_options["y0"] = conj(preconditioner.transpose()).mv(y0) + if y0 is not None: + inner_options["y0"] = conj(preconditioner.transpose()).mv(y0) return inner_options @@ -105,14 +103,17 @@ class Normal( def init(self, operator, options): tall = operator.out_size() >= operator.in_size() + # we are apply repeated mv's when constructing normal matrix + # these cannot be parallelised so more efficient to linearise first + lin_op = linearise(operator) if tall: - inner_operator = conj(operator.transpose()) @ operator + inner_operator = conj(lin_op.transpose()) @ lin_op else: - inner_operator = operator @ conj(operator.transpose()) + inner_operator = lin_op @ conj(lin_op.transpose()) inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag) inner_options = normal_preconditioner_and_y0(options, tall) inner_state = self.inner_solver.init(inner_operator, inner_options) - operator_conj_transpose = conj(operator.transpose()) + operator_conj_transpose = conj(lin_op.transpose()) return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options def compute(