Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 128 additions & 14 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import abc
import enum
import functools as ft
import math
import warnings
Expand Down Expand Up @@ -1942,14 +1943,35 @@ def _(operator):
def _(operator, check=check):
return check(operator.primal)

@check.register(AuxLinearOperator)
def _(operator, check=check):
return check(operator.operator)


# Scaling/negating preserves these structural properties
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
):

@check.register(MulLinearOperator)
@check.register(NegLinearOperator)
@check.register(DivLinearOperator)
@check.register(AuxLinearOperator)
def _(operator, check=check):
return check(operator.operator)


# has_unit_diagonal is NOT preserved by scaling or negation
@has_unit_diagonal.register(MulLinearOperator)
@has_unit_diagonal.register(NegLinearOperator)
@has_unit_diagonal.register(DivLinearOperator)
def _(operator):
return False


for check in (is_positive_semidefinite, is_negative_semidefinite):

@check.register(TangentLinearOperator)
Expand All @@ -1960,20 +1982,91 @@ def _(operator):
"Please open a GitHub issue: https://github.com/google/lineax"
)

@check.register(MulLinearOperator)
@check.register(DivLinearOperator)
def _(operator):
return False # play it safe, no way to tell.

@check.register(NegLinearOperator)
def _(operator, check=check):
return not check(operator.operator)

@check.register(AuxLinearOperator)
def _(operator, check=check):
return check(operator.operator)


class _ScalarSign(enum.Enum):
positive = enum.auto()
negative = enum.auto()
zero = enum.auto()
unknown = enum.auto()


def _scalar_sign(scalar) -> _ScalarSign:
"""Returns the sign of a scalar, or unknown for JAX tracers."""
if isinstance(scalar, (int, float, np.ndarray, np.generic)):
scalar = float(scalar)
if scalar > 0:
return _ScalarSign.positive
elif scalar < 0:
return _ScalarSign.negative
else:
return _ScalarSign.zero
else:
return _ScalarSign.unknown


# PSD/NSD for MulLinearOperator: depends on sign of scalar
# Zero scalar gives zero matrix which is both PSD and NSD
@is_positive_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is PSD
return False


@is_negative_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is NSD
return False


# PSD/NSD for DivLinearOperator: depends on sign of scalar
# Zero scalar is division by zero - return False (conservative)
@is_positive_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
return False


@is_negative_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
return False


# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
@is_positive_semidefinite.register(NegLinearOperator)
def _(operator):
return is_negative_semidefinite(operator.operator)


@is_negative_semidefinite.register(NegLinearOperator)
def _(operator):
return is_positive_semidefinite(operator.operator)


for check, tag in (
(is_symmetric, symmetric_tag),
(is_diagonal, diagonal_tag),
Expand Down Expand Up @@ -2010,21 +2103,42 @@ def _(operator):
return False


# These properties ARE preserved under composition
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
):

@check.register(ComposedLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)


# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
@is_symmetric.register(ComposedLinearOperator)
def _(operator):
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)


# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
@is_tridiagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1):
return is_tridiagonal(operator.operator2)
if is_diagonal(operator.operator2):
return is_tridiagonal(operator.operator1)
return False


# PSD/NSD: not preserved under composition in general.
@is_positive_semidefinite.register(ComposedLinearOperator)
@is_negative_semidefinite.register(ComposedLinearOperator)
def _(operator):
return False


@has_unit_diagonal.register(ComposedLinearOperator)
def _(operator):
a = is_diagonal(operator)
Expand Down
Loading