Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lineax/_solver/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jaxtyping import Array, PyTree

from .._norm import max_norm, tree_dot
from .._operator import AbstractLinearOperator, conj
from .._operator import AbstractLinearOperator, conj, linearise
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import preconditioner_and_y0
Expand Down Expand Up @@ -73,7 +73,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
"`BiCGstab(..., normal=False)` may only be used for linear solves with "
"square matrices."
)
return operator
return linearise(operator)

def compute(
self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any]
Expand Down
8 changes: 2 additions & 6 deletions lineax/_solver/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

from .._misc import structure_equal
from .._norm import max_norm, two_norm
from .._operator import (
AbstractLinearOperator,
conj,
MatrixLinearOperator,
)
from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator
from .._solution import RESULTS
from .._solve import AbstractLinearSolver, linear_solve
from .misc import preconditioner_and_y0
Expand Down Expand Up @@ -86,7 +82,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
"`GMRES(..., normal=False)` may only be used for linear solves with "
"square matrices."
)
return operator
return linearise(operator)

#
# This differs from `jax.scipy.sparse.linalg.gmres` in a few ways:
Expand Down
4 changes: 2 additions & 2 deletions lineax/_solver/lsmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from .._misc import complex_to_real_dtype
from .._norm import two_norm
from .._operator import AbstractLinearOperator, conj
from .._operator import AbstractLinearOperator, conj, linearise
from .._solution import RESULTS
from .._solve import AbstractLinearSolver

Expand Down Expand Up @@ -89,7 +89,7 @@ def __check_init__(self):
)

def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
return operator
return linearise(operator)

def compute(
self,
Expand Down
7 changes: 2 additions & 5 deletions lineax/_solver/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,15 @@
from jaxtyping import Array, PyTree, Shaped

from .._misc import strip_weak_dtype, structure_equal
from .._operator import (
AbstractLinearOperator,
IdentityLinearOperator,
)
from .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise


def preconditioner_and_y0(
operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any]
):
structure = operator.in_structure()
try:
preconditioner = options["preconditioner"]
preconditioner = linearise(options["preconditioner"])
except KeyError:
preconditioner = IdentityLinearOperator(structure)
else:
Expand Down
19 changes: 10 additions & 9 deletions lineax/_solver/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
import equinox.internal as eqxi
from jaxtyping import Array, PyTree

from .._operator import (
conj,
TaggedLinearOperator,
)
from .._operator import conj, linearise, TaggedLinearOperator
from .._solution import RESULTS
from .._solve import AbstractLinearOperator, AbstractLinearSolver
from .._tags import positive_semidefinite_tag
Expand All @@ -36,6 +33,7 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
inner_options = copy(options)
del options
if preconditioner is not None:
preconditioner = linearise(preconditioner)
if tall:
inner_options["preconditioner"] = TaggedLinearOperator(
preconditioner @ conj(preconditioner.transpose()),
Expand All @@ -46,8 +44,8 @@ def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
conj(preconditioner.transpose()) @ preconditioner,
positive_semidefinite_tag,
)
if preconditioner is not None and y0 is not None and not tall:
inner_options["y0"] = conj(preconditioner.transpose()).mv(y0)
if y0 is not None:
inner_options["y0"] = conj(preconditioner.transpose()).mv(y0)
return inner_options


Expand Down Expand Up @@ -105,14 +103,17 @@ class Normal(

def init(self, operator, options):
tall = operator.out_size() >= operator.in_size()
# we are apply repeated mv's when constructing normal matrix
# these cannot be parallelised so more efficient to linearise first
lin_op = linearise(operator)
if tall:
inner_operator = conj(operator.transpose()) @ operator
inner_operator = conj(lin_op.transpose()) @ lin_op
else:
inner_operator = operator @ conj(operator.transpose())
inner_operator = lin_op @ conj(lin_op.transpose())
inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag)
inner_options = normal_preconditioner_and_y0(options, tall)
inner_state = self.inner_solver.init(inner_operator, inner_options)
operator_conj_transpose = conj(operator.transpose())
operator_conj_transpose = conj(lin_op.transpose())
return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options

def compute(
Expand Down
Loading