diff --git a/lineax/_operator.py b/lineax/_operator.py index a83eeff..0bada90 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +import enum import functools as ft import math import warnings @@ -1942,14 +1943,35 @@ def _(operator): def _(operator, check=check): return check(operator.primal) + @check.register(AuxLinearOperator) + def _(operator, check=check): + return check(operator.operator) + + +# Scaling/negating preserves these structural properties +for check in ( + is_symmetric, + is_diagonal, + is_lower_triangular, + is_upper_triangular, + is_tridiagonal, +): + @check.register(MulLinearOperator) @check.register(NegLinearOperator) @check.register(DivLinearOperator) - @check.register(AuxLinearOperator) def _(operator, check=check): return check(operator.operator) +# has_unit_diagonal is NOT preserved by scaling or negation +@has_unit_diagonal.register(MulLinearOperator) +@has_unit_diagonal.register(NegLinearOperator) +@has_unit_diagonal.register(DivLinearOperator) +def _(operator): + return False + + for check in (is_positive_semidefinite, is_negative_semidefinite): @check.register(TangentLinearOperator) @@ -1960,20 +1982,91 @@ def _(operator): "Please open a GitHub issue: https://github.com/google/lineax" ) - @check.register(MulLinearOperator) - @check.register(DivLinearOperator) - def _(operator): - return False # play it safe, no way to tell. - - @check.register(NegLinearOperator) - def _(operator, check=check): - return not check(operator.operator) - @check.register(AuxLinearOperator) def _(operator, check=check): return check(operator.operator) +class _ScalarSign(enum.Enum): + positive = enum.auto() + negative = enum.auto() + zero = enum.auto() + unknown = enum.auto() + + +def _scalar_sign(scalar) -> _ScalarSign: + """Returns the sign of a scalar, or unknown for JAX tracers.""" + if isinstance(scalar, (int, float, np.ndarray, np.generic)): + scalar = float(scalar) + if scalar > 0: + return _ScalarSign.positive + elif scalar < 0: + return _ScalarSign.negative + else: + return _ScalarSign.zero + else: + return _ScalarSign.unknown + + +# PSD/NSD for MulLinearOperator: depends on sign of scalar +# Zero scalar gives zero matrix which is both PSD and NSD +@is_positive_semidefinite.register(MulLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign is _ScalarSign.positive: + return is_positive_semidefinite(operator.operator) + elif sign is _ScalarSign.negative: + return is_negative_semidefinite(operator.operator) + elif sign is _ScalarSign.zero: + return True # zero matrix is PSD + return False + + +@is_negative_semidefinite.register(MulLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign is _ScalarSign.positive: + return is_negative_semidefinite(operator.operator) + elif sign is _ScalarSign.negative: + return is_positive_semidefinite(operator.operator) + elif sign is _ScalarSign.zero: + return True # zero matrix is NSD + return False + + +# PSD/NSD for DivLinearOperator: depends on sign of scalar +# Zero scalar is division by zero - return False (conservative) +@is_positive_semidefinite.register(DivLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign is _ScalarSign.positive: + return is_positive_semidefinite(operator.operator) + elif sign is _ScalarSign.negative: + return is_negative_semidefinite(operator.operator) + return False + + +@is_negative_semidefinite.register(DivLinearOperator) +def _(operator): + sign = _scalar_sign(operator.scalar) + if sign is _ScalarSign.positive: + return is_negative_semidefinite(operator.operator) + elif sign is _ScalarSign.negative: + return is_positive_semidefinite(operator.operator) + return False + + +# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD +@is_positive_semidefinite.register(NegLinearOperator) +def _(operator): + return is_negative_semidefinite(operator.operator) + + +@is_negative_semidefinite.register(NegLinearOperator) +def _(operator): + return is_positive_semidefinite(operator.operator) + + for check, tag in ( (is_symmetric, symmetric_tag), (is_diagonal, diagonal_tag), @@ -2010,14 +2103,11 @@ def _(operator): return False +# These properties ARE preserved under composition for check in ( - is_symmetric, is_diagonal, is_lower_triangular, is_upper_triangular, - is_positive_semidefinite, - is_negative_semidefinite, - is_tridiagonal, ): @check.register(ComposedLinearOperator) @@ -2025,6 +2115,30 @@ def _(operator, check=check): return check(operator.operator1) and check(operator.operator2) +# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute. +@is_symmetric.register(ComposedLinearOperator) +def _(operator): + return is_diagonal(operator.operator1) and is_diagonal(operator.operator2) + + +# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but +# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal +@is_tridiagonal.register(ComposedLinearOperator) +def _(operator): + if is_diagonal(operator.operator1): + return is_tridiagonal(operator.operator2) + if is_diagonal(operator.operator2): + return is_tridiagonal(operator.operator1) + return False + + +# PSD/NSD: not preserved under composition in general. +@is_positive_semidefinite.register(ComposedLinearOperator) +@is_negative_semidefinite.register(ComposedLinearOperator) +def _(operator): + return False + + @has_unit_diagonal.register(ComposedLinearOperator) def _(operator): a = is_diagonal(operator)