From fe6ce1da06b9a62eefb0b830503489543f84429a Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 2 Feb 2026 11:54:24 +0000 Subject: [PATCH] ensure materialise preserves aux --- lineax/_operator.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index b430a2f..c916427 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1863,9 +1863,19 @@ def _(operator, transform=transform): def _(operator, transform=transform): return transform(operator.operator) / operator.scalar - @transform.register(AuxLinearOperator) # pyright: ignore + +# diagonal strips aux (returns array, not operator) +@diagonal.register(AuxLinearOperator) +def _(operator): + return diagonal(operator.operator) + + +# linearise and materialise preserve aux +for transform in (linearise, materialise): + + @transform.register(AuxLinearOperator) def _(operator, transform=transform): - return transform(operator.operator) + return AuxLinearOperator(transform(operator.operator), operator.aux) @linearise.register(TangentLinearOperator) @@ -1934,11 +1944,21 @@ def _(operator): @linearise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(linearise(inner_composed), aux) return linearise(operator.operator1) @ linearise(operator.operator2) @materialise.register(ComposedLinearOperator) def _(operator): + # If the first operator has aux, preserve it on the result + if isinstance(operator.operator1, AuxLinearOperator): + aux = operator.operator1.aux + inner_composed = operator.operator1.operator @ operator.operator2 + return AuxLinearOperator(materialise(inner_composed), aux) return materialise(operator.operator1) @ materialise(operator.operator2)