From 35433a4bd32a4e3d3d10167f56283cfea5b12385 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 30 Jan 2026 01:16:32 +0000 Subject: [PATCH 1/3] fix derived tag check rules --- lineax/_operator.py | 133 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 119 insertions(+), 14 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index a83eeff..b223858 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1942,14 +1942,35 @@ def _(operator): def _(operator, check=check): return check(operator.primal) + @check.register(AuxLinearOperator) + def _(operator, check=check): + return check(operator.operator) + + +# Scaling/negating preserves these structural properties +for check in ( + is_symmetric, + is_diagonal, + is_lower_triangular, + is_upper_triangular, + is_tridiagonal, +): + @check.register(MulLinearOperator) @check.register(NegLinearOperator) @check.register(DivLinearOperator) - @check.register(AuxLinearOperator) def _(operator, check=check): return check(operator.operator) +# has_unit_diagonal is NOT preserved by scaling or negation +@has_unit_diagonal.register(MulLinearOperator) +@has_unit_diagonal.register(NegLinearOperator) +@has_unit_diagonal.register(DivLinearOperator) +def _(operator): + return False + + for check in (is_positive_semidefinite, is_negative_semidefinite): @check.register(TangentLinearOperator) @@ -1960,20 +1981,83 @@ def _(operator): "Please open a GitHub issue: https://github.com/google/lineax" ) - @check.register(MulLinearOperator) - @check.register(DivLinearOperator) - def _(operator): - return False # play it safe, no way to tell. - - @check.register(NegLinearOperator) - def _(operator, check=check): - return not check(operator.operator) - @check.register(AuxLinearOperator) def _(operator, check=check): return check(operator.operator) +def _scalar_sign(scalar) -> int | None: + """Returns 1 if positive, -1 if negative, 0 if zero, None if unknown (traced).""" + try: + if scalar > 0: + return 1 + elif scalar < 0: + return -1 + else: + return 0 + except Exception: + return None + + +# PSD/NSD for MulLinearOperator: depends on sign of scalar +# Zero scalar gives zero matrix which is both PSD and NSD +@is_positive_semidefinite.register(MulLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign == 1: + return is_positive_semidefinite(operator.operator) + elif sign == -1: + return is_negative_semidefinite(operator.operator) + elif sign == 0: + return True # zero matrix is PSD + return False + + +@is_negative_semidefinite.register(MulLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign == 1: + return is_negative_semidefinite(operator.operator) + elif sign == -1: + return is_positive_semidefinite(operator.operator) + elif sign == 0: + return True # zero matrix is NSD + return False + + +# PSD/NSD for DivLinearOperator: depends on sign of scalar +# Zero scalar is division by zero - return False (conservative) +@is_positive_semidefinite.register(DivLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign == 1: + return is_positive_semidefinite(operator.operator) + elif sign == -1: + return is_negative_semidefinite(operator.operator) + return False + + +@is_negative_semidefinite.register(DivLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign == 1: + return is_negative_semidefinite(operator.operator) + elif sign == -1: + return is_positive_semidefinite(operator.operator) + return False + + +# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD +@is_positive_semidefinite.register(NegLinearOperator) +def _(operator): + return is_negative_semidefinite(operator.operator) + + +@is_negative_semidefinite.register(NegLinearOperator) +def _(operator): + return is_positive_semidefinite(operator.operator) + + for check, tag in ( (is_symmetric, symmetric_tag), (is_diagonal, diagonal_tag), @@ -2010,14 +2094,11 @@ def _(operator): return False +# These properties ARE preserved under composition for check in ( - is_symmetric, is_diagonal, is_lower_triangular, is_upper_triangular, - is_positive_semidefinite, - is_negative_semidefinite, - is_tridiagonal, ): @check.register(ComposedLinearOperator) @@ -2025,6 +2106,30 @@ def _(operator, check=check): return check(operator.operator1) and check(operator.operator2) +# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute. +@is_symmetric.register(ComposedLinearOperator) +def _(operator): + return is_diagonal(operator.operator1) and is_diagonal(operator.operator2) + + +# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but +# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal +@is_tridiagonal.register(ComposedLinearOperator) +def _(operator): + if is_diagonal(operator.operator1): + return is_tridiagonal(operator.operator2) + if is_diagonal(operator.operator2): + return is_tridiagonal(operator.operator1) + return False + + +# PSD/NSD: not preserved under composition in general. +@is_positive_semidefinite.register(ComposedLinearOperator) +@is_negative_semidefinite.register(ComposedLinearOperator) +def _(operator): + return False + + @has_unit_diagonal.register(ComposedLinearOperator) def _(operator): a = is_diagonal(operator) From 376c6ee1dafc99ca1f3daf8d2b4e8a2d7eb3a6a8 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 30 Jan 2026 01:19:02 +0000 Subject: [PATCH 2/3] simpler docstring for _scalar_sign --- lineax/_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index b223858..4458220 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1987,7 +1987,7 @@ def _(operator, check=check): def _scalar_sign(scalar) -> int | None: - """Returns 1 if positive, -1 if negative, 0 if zero, None if unknown (traced).""" + """Returns scalar sign if known at trace time otherwise None.""" try: if scalar > 0: return 1 From 1b3568d168044ca74fbfcc747cc9c411a3cf2639 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 31 Jan 2026 01:18:57 +0000 Subject: [PATCH 3/3] incorporate review comments on scalar sign --- lineax/_operator.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 4458220..0bada90 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +import enum import functools as ft import math import warnings @@ -1986,17 +1987,25 @@ def _(operator, check=check): return check(operator.operator) -def _scalar_sign(scalar) -> int | None: - """Returns scalar sign if known at trace time otherwise None.""" - try: +class _ScalarSign(enum.Enum): + positive = enum.auto() + negative = enum.auto() + zero = enum.auto() + unknown = enum.auto() + + +def _scalar_sign(scalar) -> _ScalarSign: + """Returns the sign of a scalar, or unknown for JAX tracers.""" + if isinstance(scalar, (int, float, np.ndarray, np.generic)): + scalar = float(scalar) if scalar > 0: - return 1 + return _ScalarSign.positive elif scalar < 0: - return -1 + return _ScalarSign.negative else: - return 0 - except Exception: - return None + return _ScalarSign.zero + else: + return _ScalarSign.unknown # PSD/NSD for MulLinearOperator: depends on sign of scalar @@ -2004,11 +2013,11 @@ def _scalar_sign(scalar) -> int | None: @is_positive_semidefinite.register(MulLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) - if sign == 1: + if sign is _ScalarSign.positive: return is_positive_semidefinite(operator.operator) - elif sign == -1: + elif sign is _ScalarSign.negative: return is_negative_semidefinite(operator.operator) - elif sign == 0: + elif sign is _ScalarSign.zero: return True # zero matrix is PSD return False @@ -2016,11 +2025,11 @@ def _(operator): @is_negative_semidefinite.register(MulLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) - if sign == 1: + if sign is _ScalarSign.positive: return is_negative_semidefinite(operator.operator) - elif sign == -1: + elif sign is _ScalarSign.negative: return is_positive_semidefinite(operator.operator) - elif sign == 0: + elif sign is _ScalarSign.zero: return True # zero matrix is NSD return False @@ -2030,9 +2039,9 @@ def _(operator): @is_positive_semidefinite.register(DivLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) - if sign == 1: + if sign is _ScalarSign.positive: return is_positive_semidefinite(operator.operator) - elif sign == -1: + elif sign is _ScalarSign.negative: return is_negative_semidefinite(operator.operator) return False @@ -2040,9 +2049,9 @@ def _(operator): @is_negative_semidefinite.register(DivLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) - if sign == 1: + if sign is _ScalarSign.positive: return is_negative_semidefinite(operator.operator) - elif sign == -1: + elif sign is _ScalarSign.negative: return is_positive_semidefinite(operator.operator) return False