Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,18 @@ class AbstractLinearOperator(eqx.Module):
"""

def __check_init__(self):
if is_symmetric(self):
if (
is_symmetric(self)
or is_positive_semidefinite(self)
or is_negative_semidefinite(self)
):
# In particular, we check that dtypes match.
in_structure = self.in_structure()
out_structure = self.out_structure()
# `is` check to handle the possibility of a tracer.
if eqx.tree_equal(in_structure, out_structure) is not True:
raise ValueError(
"Symmetric matrices must have matching input and output "
"Symmetric/Hermitian matrices must have matching input and output "
f"structures. Got input structure {in_structure} and output "
f"structure {out_structure}."
)
Expand Down Expand Up @@ -1259,7 +1263,7 @@ def _(operator):
_, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True)
if is_symmetric(operator):
# For symmetric: J = J.T, so vjp directly gives J @ v
lin = _Unwrap(vjp_fn())
lin = _Unwrap(vjp_fn)
else:
# Transpose the VJP to get J @ v from J.T @ v
lin = _Unwrap(
Expand Down Expand Up @@ -1510,20 +1514,27 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool:
_default_not_implemented("is_symmetric", operator)


def _has_real_dtype(operator) -> bool:
"""Check if all dtypes in an operator's structure are real (not complex)."""
leaves = jtu.tree_leaves(operator.in_structure())
return all(jnp.issubdtype(leaf.dtype, jnp.floating) for leaf in leaves)


@is_symmetric.register(MatrixLinearOperator)
@is_symmetric.register(PyTreeLinearOperator)
@is_symmetric.register(JacobianLinearOperator)
@is_symmetric.register(FunctionLinearOperator)
def _(operator):
return any(
tag in operator.tags
for tag in (
symmetric_tag,
positive_semidefinite_tag,
negative_semidefinite_tag,
diagonal_tag,
)
)
# Symmetric (A = A^T) if explicitly tagged symmetric or diagonal
if symmetric_tag in operator.tags or diagonal_tag in operator.tags:
return True
# PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian
if (
positive_semidefinite_tag in operator.tags
or negative_semidefinite_tag in operator.tags
):
return _has_real_dtype(operator)
return False


@is_symmetric.register(IdentityLinearOperator)
Expand Down Expand Up @@ -1771,7 +1782,7 @@ def _(operator):

@is_positive_semidefinite.register(IdentityLinearOperator)
def _(operator):
return True
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True


@is_positive_semidefinite.register(DiagonalLinearOperator)
Expand Down Expand Up @@ -1964,6 +1975,8 @@ def _(operator):
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
is_positive_semidefinite,
is_negative_semidefinite,
):

@check.register(TangentLinearOperator)
Expand Down Expand Up @@ -1999,21 +2012,6 @@ def _(operator):
return False


for check in (is_positive_semidefinite, is_negative_semidefinite):

@check.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)

@check.register(AuxLinearOperator)
def _(operator, check=check):
return check(operator.operator)


class _ScalarSign(enum.Enum):
positive = enum.auto()
negative = enum.auto()
Expand Down
2 changes: 0 additions & 2 deletions tests/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,4 @@
if file.is_file() and file.name.startswith("test"):
out = subprocess.run(f"pytest {file}", shell=True).returncode
running_out = max(running_out, out)
if out != 0:
break
sys.exit(running_out)
Loading