Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2ac305b
Merge pull request #1 from patrick-kidger/main
jpbrodrick89 May 14, 2025
30e7736
Streamline diagonalisation by avoiding computation/instantiation of e…
jpbrodrick89 Jun 16, 2025
e8f474a
changed to using mv instead of jvp
jpbrodrick89 Jun 16, 2025
0e7703c
initialise test_diagonal operators with diagonal matrix
jpbrodrick89 Jun 16, 2025
5b80289
Merge branch 'patrick-kidger:main' into main
jpbrodrick89 Jun 16, 2025
3b7d805
remove redundant original diagonal overload for FunctionLinearOperator
jpbrodrick89 Jun 16, 2025
637e345
Merge branch 'main' into jpb/diagfromjac
jpbrodrick89 Jun 16, 2025
97f9e7f
unravel instead of map
jpbrodrick89 Jun 19, 2025
d48ba15
continue to extract diagonal if no diagonal tag
jpbrodrick89 Oct 24, 2025
40e2ad5
rmeove print statement in test_operator
jpbrodrick89 Oct 24, 2025
8e91b44
Replace allow_dependent_rows and allow_dependent_columns with assume_…
adconner Dec 5, 2025
a6bacf7
pyright fixes
patrick-kidger Dec 5, 2025
8d96e7c
Implement Normal, a solver applying an inner solver to the normal equ…
adconner Dec 5, 2025
71db7e0
Tidy up results. (#170)
patrick-kidger Dec 5, 2025
9d83182
Update infra (#182)
patrick-kidger Dec 5, 2025
71d1296
Decrease default value to prevent overflow in 32-bit.
Nov 14, 2025
60d654a
use maximum value of runtime dtype
Nov 15, 2025
3c8df62
limit condition numbers for sensitive, iterative solvers
Nov 25, 2025
fadd714
cond_cutoff -> 900 for GMRES, LSMR
Nov 25, 2025
895fa23
decrease condition number cutoff for LSMR
Nov 26, 2025
daed0a4
implement safeguard for int overflow in 32 bit
Nov 26, 2025
20ef068
add extra safeguard for complex dtypes
Nov 29, 2025
08ceabe
fix typo
Dec 13, 2025
9891748
reset condition number check
Dec 13, 2025
d297209
simplify dtype conversion: reuse function from our misc module
Dec 13, 2025
dddc513
Merge pull request #176 from johannahaffner/lsmr-init-overflow
johannahaffner Dec 13, 2025
9c1b9a9
Moved bool states to Static.
patrick-kidger Dec 22, 2025
c11bd7d
simplify jac=bwd behaviour
jpbrodrick89 Jan 27, 2026
b11788f
Merge branch 'dev' into jpb/diagfromjac
jpbrodrick89 Jan 27, 2026
1c58307
add PyTreeLinearOperator optimisations for mv and diagonal
jpbrodrick89 Jan 27, 2026
caeb4e8
add type annotation to _leaf_from_keypath
jpbrodrick89 Jan 27, 2026
5f9da15
Merge branch 'main' into jpb/diagfromjac
jpbrodrick89 Jan 28, 2026
0e23db9
add efficient Matrix mv path, and Jacobian/Function materialise path
jpbrodrick89 Jan 30, 2026
49077a6
add try sparse materialise helper (#3)
jpbrodrick89 Feb 2, 2026
c396ca2
Merge branch 'main' into jpb/diagfromjac
jpbrodrick89 Feb 2, 2026
045f6ae
unify JLO and FLO diagonal registrations
jpbrodrick89 Feb 2, 2026
72e0229
rename sparse variable -> maybe_sparse_op
jpbrodrick89 Feb 2, 2026
8fee6e4
move and shorten construct diagonal basis
jpbrodrick89 Feb 2, 2026
604a2b4
revert unecessary **2.0 -> more efficient **2
jpbrodrick89 Feb 2, 2026
6212763
comments on why we can't use try sparse materialise with JLO
jpbrodrick89 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 128 additions & 4 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def __init__(
self.tags = _frozenset(tags)

def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST)

def as_matrix(self):
Expand Down Expand Up @@ -326,6 +329,14 @@ def __init__(self, value):
self.value = value


def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array:
"""Extract the leaf from a pytree at the given keypath."""
for path, leaf in jtu.tree_leaves_with_path(pytree):
if path == keypath:
return leaf
raise ValueError(f"Leaf not found at keypath {keypath}")


# The `{input,output}_structure`s have to be static because otherwise abstract
# evaluation rules will promote them to ShapedArrays.
class PyTreeLinearOperator(AbstractLinearOperator):
Expand Down Expand Up @@ -421,7 +432,11 @@ def mv(self, vector):
# vector has structure [tree(in), leaf(in)]
# self.out_structure() has structure [tree(out)]
# self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]
# return has struture [tree(out), leaf(out)]
# return has structure [tree(out), leaf(out)]
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)

def matmul(_, matrix):
return _tree_matmul(matrix, vector)

Expand Down Expand Up @@ -1011,6 +1026,9 @@ def __check_init__(self):
raise ValueError("Incompatible linear operator structures")

def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
mv1 = self.operator1.mv(vector)
mv2 = self.operator2.mv(vector)
return (mv1**ω + mv2**ω).ω
Expand Down Expand Up @@ -1145,6 +1163,9 @@ def __check_init__(self):
raise ValueError("Incompatible linear operator structures")

def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return self.operator1.mv(self.operator2.mv(vector))

def as_matrix(self):
Expand Down Expand Up @@ -1330,8 +1351,35 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
_default_not_implemented("materialise", operator)


def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Try to materialise to a sparse operator.

Returns a DiagonalLinearOperator if the operator is tagged as diagonal,
otherwise returns the original operator unchanged. The resulting operator
preserves the input/output structure of the original operator.

Note: This function silently strips aux and as such should not be called
on AuxLinearOperator or JacobianLinearOperatory directly.
"""
if is_diagonal(operator):
diag_flat = diagonal(operator)
_, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
diag_pytree = unravel(diag_flat)
return DiagonalLinearOperator(diag_pytree)
return operator


def _construct_diagonal_basis(structure: PyTree[jax.ShapeDtypeStruct]) -> PyTree[Array]:
"""Construct a PyTree of ones matching the given structure."""
return jtu.tree_map(lambda s: jnp.ones(s.shape, s.dtype), structure)


@materialise.register(MatrixLinearOperator)
@materialise.register(PyTreeLinearOperator)
def _(operator):
return _try_sparse_materialise(operator)


@materialise.register(IdentityLinearOperator)
@materialise.register(DiagonalLinearOperator)
@materialise.register(TridiagonalLinearOperator)
Expand All @@ -1342,6 +1390,21 @@ def _(operator):
@materialise.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxIn(operator.fn, operator.args)
# can't use try_sparse_materialise as strips aux
if is_diagonal(operator):
with jax.ensure_compile_time_eval():
basis = _construct_diagonal_basis(operator.in_structure())
if operator.jac == "bwd":
((_, aux), vjp_fn) = jax.vjp(fn, operator.x)
if aux is None:
aux_ct = None
else:
aux_ct = jtu.tree_map(jnp.zeros_like, aux)
(diag_as_pytree,) = vjp_fn((basis, aux_ct))
else: # "fwd" or None
(_, aux), (diag_as_pytree, _) = jax.jvp(fn, (operator.x,), (basis,))
out = DiagonalLinearOperator(diag_as_pytree)
return AuxLinearOperator(out, aux)
jac, aux = jacobian(
fn,
operator.in_size(),
Expand All @@ -1356,6 +1419,9 @@ def _(operator):

@materialise.register(FunctionLinearOperator)
def _(operator):
out = _try_sparse_materialise(operator)
if out is not operator:
return out
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)
Expand Down Expand Up @@ -1397,11 +1463,36 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:


@diagonal.register(MatrixLinearOperator)
def _(operator):
return jnp.diag(operator.as_matrix())


@diagonal.register(PyTreeLinearOperator)
def _(operator):
if is_diagonal(operator):

def extract_diag(keypath, struct, subpytree):
block = _leaf_from_keypath(subpytree, keypath)
return jnp.diag(block.reshape(struct.size, struct.size))

diags = jtu.tree_map_with_path(
extract_diag, operator.out_structure(), operator.pytree
)
return jnp.concatenate(jtu.tree_leaves(diags))
else:
return jnp.diag(operator.as_matrix())


@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
def _(operator):
return jnp.diag(operator.as_matrix())
if is_diagonal(operator):
with jax.ensure_compile_time_eval():
basis = _construct_diagonal_basis(operator.in_structure())
diag_as_pytree = operator.mv(basis)
diag, _ = jfu.ravel_pytree(diag_as_pytree)
return diag
return diagonal(materialise(operator))


@diagonal.register(DiagonalLinearOperator)
Expand Down Expand Up @@ -1863,9 +1954,27 @@ def _(operator, transform=transform):
def _(operator, transform=transform):
return transform(operator.operator) / operator.scalar

@transform.register(AuxLinearOperator) # pyright: ignore

# diagonal strips aux (returns array, not operator)
@diagonal.register(AuxLinearOperator)
def _(operator):
return diagonal(operator.operator)


# linearise and materialise preserve aux
for transform in (linearise, materialise):

@transform.register(AuxLinearOperator)
def _(operator, transform=transform):
return transform(operator.operator)
return AuxLinearOperator(transform(operator.operator), operator.aux)


@materialise.register(AddLinearOperator)
def _(operator):
out = _try_sparse_materialise(operator)
if out is not operator:
return out
return materialise(operator.operator1) + materialise(operator.operator2)


@linearise.register(TangentLinearOperator)
Expand Down Expand Up @@ -1934,16 +2043,31 @@ def _(operator):

@linearise.register(ComposedLinearOperator)
def _(operator):
# If the first operator has aux, preserve it on the result
if isinstance(operator.operator1, AuxLinearOperator):
aux = operator.operator1.aux
inner_composed = operator.operator1.operator @ operator.operator2
return AuxLinearOperator(linearise(inner_composed), aux)
return linearise(operator.operator1) @ linearise(operator.operator2)


@materialise.register(ComposedLinearOperator)
def _(operator):
# If the first operator has aux, preserve it on the result
if isinstance(operator.operator1, AuxLinearOperator):
aux = operator.operator1.aux
inner_composed = operator.operator1.operator @ operator.operator2
return AuxLinearOperator(materialise(inner_composed), aux)
out = _try_sparse_materialise(operator)
if out is not operator:
return out
return materialise(operator.operator1) @ materialise(operator.operator2)


@diagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator):
return diagonal(operator.operator1) * diagonal(operator.operator2)
return jnp.diag(operator.as_matrix())


Expand Down
4 changes: 2 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def make_jacfwd_operator(getkey, matrix, tags):
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
Expand All @@ -235,7 +235,7 @@ def make_jacrev_operator(getkey, matrix, tags):
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac

Expand Down
8 changes: 8 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,17 @@ def test_materialise_large(dtype, getkey):
def test_diagonal(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
matrix_diag = jnp.diag(matrix)
# test we properly extract diagonal from a dense matrix when not tagged
operators = _setup(getkey, matrix)
for operator in operators:
assert jnp.allclose(lx.diagonal(operator), matrix_diag)
# test we properly extract diagonal from diagonal matrix when tagged
operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag)
for operator in operators:
if isinstance(operator, lx.IdentityLinearOperator):
assert jnp.allclose(lx.diagonal(operator), jnp.ones(3))
else:
assert jnp.allclose(lx.diagonal(operator), matrix_diag)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
Expand Down
Loading