diff --git a/lineax/_operator.py b/lineax/_operator.py index b430a2f..150d90d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -86,14 +86,18 @@ class AbstractLinearOperator(eqx.Module): """ def __check_init__(self): - if is_symmetric(self): + if ( + is_symmetric(self) + or is_positive_semidefinite(self) + or is_negative_semidefinite(self) + ): # In particular, we check that dtypes match. in_structure = self.in_structure() out_structure = self.out_structure() # `is` check to handle the possibility of a tracer. if eqx.tree_equal(in_structure, out_structure) is not True: raise ValueError( - "Symmetric matrices must have matching input and output " + "Symmetric/Hermitian matrices must have matching input and output " f"structures. Got input structure {in_structure} and output " f"structure {out_structure}." ) @@ -1259,7 +1263,7 @@ def _(operator): _, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True) if is_symmetric(operator): # For symmetric: J = J.T, so vjp directly gives J @ v - lin = _Unwrap(vjp_fn()) + lin = _Unwrap(vjp_fn) else: # Transpose the VJP to get J @ v from J.T @ v lin = _Unwrap( @@ -1510,20 +1514,27 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool: _default_not_implemented("is_symmetric", operator) +def _has_real_dtype(operator) -> bool: + """Check if all dtypes in an operator's structure are real (not complex).""" + leaves = jtu.tree_leaves(operator.in_structure()) + return all(jnp.issubdtype(leaf.dtype, jnp.floating) for leaf in leaves) + + @is_symmetric.register(MatrixLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @is_symmetric.register(FunctionLinearOperator) def _(operator): - return any( - tag in operator.tags - for tag in ( - symmetric_tag, - positive_semidefinite_tag, - negative_semidefinite_tag, - diagonal_tag, - ) - ) + # Symmetric (A = A^T) if explicitly tagged symmetric or diagonal + if symmetric_tag in operator.tags or diagonal_tag in operator.tags: + return True + # PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian + if ( + positive_semidefinite_tag in operator.tags + or negative_semidefinite_tag in operator.tags + ): + return _has_real_dtype(operator) + return False @is_symmetric.register(IdentityLinearOperator) @@ -1771,7 +1782,7 @@ def _(operator): @is_positive_semidefinite.register(IdentityLinearOperator) def _(operator): - return True + return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True @is_positive_semidefinite.register(DiagonalLinearOperator) @@ -1964,6 +1975,8 @@ def _(operator): is_lower_triangular, is_upper_triangular, is_tridiagonal, + is_positive_semidefinite, + is_negative_semidefinite, ): @check.register(TangentLinearOperator) @@ -1999,21 +2012,6 @@ def _(operator): return False -for check in (is_positive_semidefinite, is_negative_semidefinite): - - @check.register(TangentLinearOperator) - def _(operator): - # Should be unreachable: TangentLinearOperator is used for a narrow set of - # operations only (mv; transpose) inside the JVP rule linear_solve_p. - raise NotImplementedError( - "Please open a GitHub issue: https://github.com/google/lineax" - ) - - @check.register(AuxLinearOperator) - def _(operator, check=check): - return check(operator.operator) - - class _ScalarSign(enum.Enum): positive = enum.auto() negative = enum.auto() diff --git a/tests/__main__.py b/tests/__main__.py index 7c5ccb2..5edf5e2 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -26,6 +26,4 @@ if file.is_file() and file.name.startswith("test"): out = subprocess.run(f"pytest {file}", shell=True).returncode running_out = max(running_out, out) - if out != 0: - break sys.exit(running_out)