From 0d51444254452c2dec6b42626fd1d15572761336 Mon Sep 17 00:00:00 2001 From: aidancrilly Date: Tue, 28 May 2024 14:13:13 +0100 Subject: [PATCH 1/4] First pass at implementing Woodbury --- lineax/__init__.py | 2 + lineax/_operator.py | 143 +++++++++++++++++++++++++++++++++++++ lineax/_solve.py | 9 ++- lineax/_solver/Woodbury.py | 138 +++++++++++++++++++++++++++++++++++ lineax/_solver/__init__.py | 1 + tests/test_operator.py | 27 +++++++ 6 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 lineax/_solver/Woodbury.py diff --git a/lineax/__init__.py b/lineax/__init__.py index f7c5ee9..1e16dd3 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -34,6 +34,7 @@ is_symmetric as is_symmetric, is_tridiagonal as is_tridiagonal, is_upper_triangular as is_upper_triangular, + is_Woodbury as is_Woodbury, JacobianLinearOperator as JacobianLinearOperator, linearise as linearise, materialise as materialise, @@ -45,6 +46,7 @@ TangentLinearOperator as TangentLinearOperator, tridiagonal as tridiagonal, TridiagonalLinearOperator as TridiagonalLinearOperator, + WoodburyLinearOperator as WoodburyLinearOperator, ) from ._solution import RESULTS as RESULTS, Solution as Solution from ._solve import ( diff --git a/lineax/_operator.py b/lineax/_operator.py index ad1e4f5..5a2d59d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -906,6 +906,78 @@ def out_structure(self): return self.operator.out_structure() +class WoodburyLinearOperator(AbstractLinearOperator, strict=True): + """As [`lineax.MatrixLinearOperator`][], but for specifically a matrix + with A + U C V structure, such that the Woodbury identity can be used""" + + A: AbstractLinearOperator + C: Inexact[Array, " k k"] + U: Inexact[Array, " n k"] + V: Inexact[Array, " k n"] + UCV: Array + tags: frozenset[object] = eqx.field(static=True) + + def __init__( + self, + A: AbstractLinearOperator, + C: Inexact[Array, " k k"], + U: Inexact[Array, " n k"], + V: Inexact[Array, " k n"], + tags: Union[object, frozenset[object]] = (), + ): + """**Arguments:** + + Matrix of form A + U C V, such that the inverse can be computed + using Woodbury matrix identity + + - `A`: Linear operator, in/out shape (n,n) + - `C`: A rank-two JAX array. Shape (k,k) + - `U`: A rank-two JAX array. Shape (n,k) + - `V`: A rank-two JAX array. Shape (k,n) + + """ + self.A = A + self.C = inexact_asarray(C) + self.U = inexact_asarray(U) + self.V = inexact_asarray(V) + (N, M) = self.A.in_structure(), self.A.out_structure() + if N != M: + raise ValueError(f"expecting square operator for A, got {N} by {M}") + (K, L) = self.C.shape + if K != L: + raise ValueError(f"expecting square operator for C, got {K} by {L}") + N = N.shape[0] + if self.U.shape != (N, K): + raise ValueError("U does not have consistent shape with A and C") + if self.V.shape != (K, N): + raise ValueError("V does not have consistent shape with A and C") + self.UCV = self.U @ (self.C @ self.V) + self.tags = _frozenset(tags) + + def mv(self, vector): + Ax = self.A.mv(vector) + UCVx = self.UCV @ vector + return Ax + UCVx + + def as_matrix(self): + matrix = self.A.as_matrix() + self.UCV + return matrix + + def transpose(self): + return WoodburyLinearOperator( + self.A.transpose(), + jnp.transpose(self.C), + jnp.transpose(self.V), + jnp.transpose(self.U), + ) + + def in_structure(self): + return self.A.in_structure() + + def out_structure(self): + return self.A.out_structure() + + # # All operators below here are private to lineax. # @@ -1207,6 +1279,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator: @linearise.register(IdentityLinearOperator) @linearise.register(DiagonalLinearOperator) @linearise.register(TridiagonalLinearOperator) +@linearise.register(WoodburyLinearOperator) def _(operator): return operator @@ -1283,6 +1356,7 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: @materialise.register(IdentityLinearOperator) @materialise.register(DiagonalLinearOperator) @materialise.register(TridiagonalLinearOperator) +@materialise.register(WoodburyLinearOperator) def _(operator): return operator @@ -1343,6 +1417,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) +@diagonal.register(WoodburyLinearOperator) @diagonal.register(PyTreeLinearOperator) @diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) @@ -1397,6 +1472,7 @@ def tridiagonal( @tridiagonal.register(MatrixLinearOperator) +@tridiagonal.register(WoodburyLinearOperator) @tridiagonal.register(PyTreeLinearOperator) @tridiagonal.register(JacobianLinearOperator) @tridiagonal.register(FunctionLinearOperator) @@ -1429,6 +1505,33 @@ def _(operator): return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal +@ft.singledispatch +def woodbury( + operator: AbstractLinearOperator, +) -> tuple[ + AbstractLinearOperator, + Shaped[Array, " k k"], + Shaped[Array, " n k"], + Shaped[Array, " k n"], +]: + """Extracts the A, C, U, V Woodbury structure, from a linear + operator. Returns one linear operators and three matrices. + **Arguments:** + - `operator`: a linear operator. + **Returns:** + A 4-tuple, consisting of + - A which is a linear operator + - C, U and V which are matrices + For all but the Woodbury operator this extraction is not possible + """ + _default_not_implemented("woodbury", operator) + + +@woodbury.register(WoodburyLinearOperator) +def _(operator): + return operator.A, operator.C, operator.U, operator.V + + # is_symmetric @@ -1451,6 +1554,7 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool: @is_symmetric.register(MatrixLinearOperator) +@is_symmetric.register(WoodburyLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @is_symmetric.register(FunctionLinearOperator) @@ -1503,6 +1607,7 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool: @is_diagonal.register(MatrixLinearOperator) +@is_diagonal.register(WoodburyLinearOperator) @is_diagonal.register(PyTreeLinearOperator) @is_diagonal.register(JacobianLinearOperator) @is_diagonal.register(FunctionLinearOperator) @@ -1543,6 +1648,7 @@ def is_tridiagonal(operator: AbstractLinearOperator) -> bool: @is_tridiagonal.register(MatrixLinearOperator) +@is_tridiagonal.register(WoodburyLinearOperator) @is_tridiagonal.register(PyTreeLinearOperator) @is_tridiagonal.register(JacobianLinearOperator) @is_tridiagonal.register(FunctionLinearOperator) @@ -1557,6 +1663,36 @@ def _(operator): return True +@ft.singledispatch +def is_Woodbury(operator: AbstractLinearOperator) -> bool: + """Returns whether an operator is marked as Woodbury. + See [the documentation on linear operator tags](../api/tags.md) for more + information. + **Arguments:** + - `operator`: a linear operator. + **Returns:** + Either `True` or `False.` + """ + _default_not_implemented("is_Woodbury", operator) + + +@is_Woodbury.register(WoodburyLinearOperator) +def _(operator): + return True + + +@is_Woodbury.register(MatrixLinearOperator) +@is_Woodbury.register(PyTreeLinearOperator) +@is_Woodbury.register(JacobianLinearOperator) +@is_Woodbury.register(FunctionLinearOperator) +@is_Woodbury.register(IdentityLinearOperator) +@is_Woodbury.register(DiagonalLinearOperator) +@is_Woodbury.register(TridiagonalLinearOperator) +@is_Woodbury.register(TaggedLinearOperator) # TODO : check this +def _(operator): + return False + + # has_unit_diagonal @@ -1579,6 +1715,7 @@ def has_unit_diagonal(operator: AbstractLinearOperator) -> bool: @has_unit_diagonal.register(MatrixLinearOperator) +@has_unit_diagonal.register(WoodburyLinearOperator) @has_unit_diagonal.register(PyTreeLinearOperator) @has_unit_diagonal.register(JacobianLinearOperator) @has_unit_diagonal.register(FunctionLinearOperator) @@ -1620,6 +1757,7 @@ def is_lower_triangular(operator: AbstractLinearOperator) -> bool: @is_lower_triangular.register(MatrixLinearOperator) +@is_lower_triangular.register(WoodburyLinearOperator) @is_lower_triangular.register(PyTreeLinearOperator) @is_lower_triangular.register(JacobianLinearOperator) @is_lower_triangular.register(FunctionLinearOperator) @@ -1660,6 +1798,7 @@ def is_upper_triangular(operator: AbstractLinearOperator) -> bool: @is_upper_triangular.register(MatrixLinearOperator) +@is_upper_triangular.register(WoodburyLinearOperator) @is_upper_triangular.register(PyTreeLinearOperator) @is_upper_triangular.register(JacobianLinearOperator) @is_upper_triangular.register(FunctionLinearOperator) @@ -1700,6 +1839,7 @@ def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool: @is_positive_semidefinite.register(MatrixLinearOperator) +@is_positive_semidefinite.register(WoodburyLinearOperator) @is_positive_semidefinite.register(PyTreeLinearOperator) @is_positive_semidefinite.register(JacobianLinearOperator) @is_positive_semidefinite.register(FunctionLinearOperator) @@ -1741,6 +1881,7 @@ def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool: @is_negative_semidefinite.register(MatrixLinearOperator) +@is_negative_semidefinite.register(WoodburyLinearOperator) @is_negative_semidefinite.register(PyTreeLinearOperator) @is_negative_semidefinite.register(JacobianLinearOperator) @is_negative_semidefinite.register(FunctionLinearOperator) @@ -1902,6 +2043,7 @@ def _(operator): is_lower_triangular, is_upper_triangular, is_tridiagonal, + is_Woodbury, ): @check.register(TangentLinearOperator) @@ -1964,6 +2106,7 @@ def _(operator, check=check, tag=tag): is_positive_semidefinite, is_negative_semidefinite, is_tridiagonal, + is_Woodbury, ): @check.register(AddLinearOperator) diff --git a/lineax/_solve.py b/lineax/_solve.py index 36fad0b..82fe750 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -40,6 +40,7 @@ is_positive_semidefinite, is_tridiagonal, is_upper_triangular, + is_Woodbury, linearise, TangentLinearOperator, ) @@ -268,7 +269,7 @@ def _linear_solve_transpose(inputs, cts_out): _assert_defined, (operator, state, options, solver), is_leaf=_is_undefined ) cts_solution = jtu.tree_map( - ft.partial(eqxi.materialise_zeros, allow_struct=True), + ft.partial(eqxi.materialise_zeros, allow_struct=True), # pyright: ignore operator.in_structure(), cts_solution, ) @@ -498,6 +499,7 @@ def conj( _cholesky_token = eqxi.str2jax("cholesky_token") _lu_token = eqxi.str2jax("lu_token") _svd_token = eqxi.str2jax("svd_token") +_woodbury_token = eqxi.str2jax("woodbury_token") # Ugly delayed import because we have the dependency chain @@ -518,6 +520,7 @@ def _lookup(token) -> AbstractLinearSolver: _cholesky_token: _solver.Cholesky(), # pyright: ignore _lu_token: _solver.LU(), # pyright: ignore _svd_token: _solver.SVD(), # pyright: ignore + _woodbury_token: _solver.Woodbury(), # pyright: ignore } return _lookup_dict[token] @@ -535,6 +538,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. + - If the matrix has structure A + U C V, then use [`lineax.Woodbury`][]. - Else use [`lineax.LU`][]. This is a good choice if you want to be certain that an error is raised for @@ -554,6 +558,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. + - If the matrix has structure A + U C V, then use [`lineax.Woodbury`][]. - Else, use [`lineax.LU`][]. This is a good choice if your primary concern is computational efficiency. It will @@ -582,6 +587,8 @@ def _select_solver(self, operator: AbstractLinearOperator): operator ): token = _cholesky_token + elif is_Woodbury(operator): + token = _woodbury_token else: token = _lu_token elif self.well_posed is False: diff --git a/lineax/_solver/Woodbury.py b/lineax/_solver/Woodbury.py new file mode 100644 index 0000000..c6e818d --- /dev/null +++ b/lineax/_solver/Woodbury.py @@ -0,0 +1,138 @@ +from typing import Any +from typing_extensions import TypeAlias + +import jax +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from .._operator import ( + AbstractLinearOperator, + is_Woodbury, + MatrixLinearOperator, + woodbury, +) +from .._solution import RESULTS +from .._solve import AbstractLinearSolver, AutoLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + transpose_packed_structures, + unravel_solution, +) + + +_Woodbury_State: TypeAlias = tuple[ + tuple[Array, Array, Array], + tuple[AbstractLinearSolver, Any, AbstractLinearSolver, Any], + PackedStructures, +] + + +def compute_pushthrough(A_solver, A_state, C, U, V): + # Push through ( C^-1 + V A^-1 U) y = x + vmapped_solve = jax.vmap( + lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1 + ) + pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U) + pushthrough_op = MatrixLinearOperator(pushthrough_mat) + solver = AutoLinearSolver(well_posed=True).select_solver(pushthrough_op) + state = solver.init(pushthrough_op, {}) + return solver, state + + +class Woodbury(AbstractLinearSolver[_Woodbury_State]): + """Solving system using Woodbury matrix identity""" + + def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): + del options + if not is_Woodbury(operator): + raise ValueError( + "`Woodbury` may only be used for linear solves with A + U C V structure" + ) + else: + A, C, U, V = woodbury(operator) + if A.in_size() != A.out_size(): + raise ValueError("""A must be square""") + # Find correct solvers and init for A + A_solver = AutoLinearSolver(well_posed=True).select_solver(A) + A_state = A_solver.init(A, {}) + # Compute pushthrough operator + pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + return ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + pack_structures(A), + ) + + def compute( + self, + state: _Woodbury_State, + vector, + options, + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + A_packed_structures, + ) = state + del state, options + vector = ravel_vector(vector, A_packed_structures) + + # Solution to A x = b + # [0] selects the solution vector + x_1 = A_solver.compute(A_state, vector, {})[0] + # Push through U ( C^-1 + V A^-1 U)^-1 V (A^-1 b) + # [0] selects the solution vector + x_pushthrough = U @ pt_solver.compute(pt_state, V @ x_1, {})[0] + # A^-1 on result of push through + # [0] selects the solution vector + x_2 = A_solver.compute(A_state, x_pushthrough, {})[0] + # See https://en.wikipedia.org/wiki/Woodbury_matrix_identity + solution = x_1 - x_2 + + solution = unravel_solution(solution, A_packed_structures) + return solution, RESULTS.successful, {} + + def transpose(self, state: _Woodbury_State, options: dict[str, Any]): + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + A_packed_structures, + ) = state + transposed_packed_structures = transpose_packed_structures(A_packed_structures) + C = jnp.transpose(C) + U = jnp.transpose(V) + V = jnp.transpose(U) + A_state, _ = A_solver.transpose(A_state, {}) + pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + transpose_state = ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + transposed_packed_structures, + ) + return transpose_state, options + + def conj(self, state: _Woodbury_State, options: dict[str, Any]): + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + packed_structures, + ) = state + C = jnp.conj(C) + U = jnp.conj(U) + V = jnp.conj(V) + A_state, _ = A_solver.conj(A_state, {}) + pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + conj_state = ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + packed_structures, + ) + return conj_state, options + + def allow_dependent_columns(self, operator): + return False + + def allow_dependent_rows(self, operator): + return False diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index 425fc40..bc631f8 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -22,3 +22,4 @@ from .svd import SVD as SVD from .triangular import Triangular as Triangular from .tridiagonal import Tridiagonal as Tridiagonal +from .Woodbury import Woodbury as Woodbury diff --git a/tests/test_operator.py b/tests/test_operator.py index e0e1e7a..a775104 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -141,6 +141,33 @@ def test_diagonal(dtype, getkey): assert jnp.allclose(lx.diagonal(operator), matrix_diag) +@pytest.mark.parametrize("dtype", (jnp.float64,)) +def test_Woodbury(dtype, getkey): + tol = 1e-4 + N = 20 + k = 2 + + A = jr.normal(getkey(), (N, N), dtype=dtype) + C = jr.normal(getkey(), (k, k), dtype=dtype) + U = jr.normal(getkey(), (N, k), dtype=dtype) + V = jr.normal(getkey(), (k, N), dtype=dtype) + + WB = lx.WoodburyLinearOperator(lx.MatrixLinearOperator(A), C, U, V) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_symmetric(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) From 60fb24f2620a41e7d7b1ec6706327e0d2fbb9278 Mon Sep 17 00:00:00 2001 From: aidancrilly Date: Tue, 28 May 2024 15:06:33 +0100 Subject: [PATCH 2/4] More inclusive test of Woodbury --- tests/test_operator.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_operator.py b/tests/test_operator.py index a775104..284007e 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -152,6 +152,7 @@ def test_Woodbury(dtype, getkey): U = jr.normal(getkey(), (N, k), dtype=dtype) V = jr.normal(getkey(), (k, N), dtype=dtype) + # Full matrix A WB = lx.WoodburyLinearOperator(lx.MatrixLinearOperator(A), C, U, V) full_matrix = WB.as_matrix() @@ -167,6 +168,43 @@ def test_Woodbury(dtype, getkey): assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + # Tridiagonal matrix A + diagonal = jnp.diagonal(A, offset=0) + upper_diagonal = jnp.diagonal(A, offset=1) + lower_diagonal = jnp.diagonal(A, offset=-1) + WB = lx.WoodburyLinearOperator( + lx.TridiagonalLinearOperator(diagonal, lower_diagonal, upper_diagonal), C, U, V + ) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + + # Diagonal matrix A + WB = lx.WoodburyLinearOperator(lx.DiagonalLinearOperator(diagonal), C, U, V) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_symmetric(dtype, getkey): From 319fb2b3b93c50f466758b6d3ad725ab2bacc6df Mon Sep 17 00:00:00 2001 From: aidancrilly Date: Fri, 31 May 2024 19:35:53 +0100 Subject: [PATCH 3/4] Review response - easier corrections --- lineax/__init__.py | 1 - lineax/_operator.py | 82 +++++--------------------------------- lineax/_solve.py | 4 +- lineax/_solver/Woodbury.py | 35 +++++++++------- lineax/_solver/__init__.py | 2 +- 5 files changed, 34 insertions(+), 90 deletions(-) diff --git a/lineax/__init__.py b/lineax/__init__.py index 1e16dd3..b804142 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -34,7 +34,6 @@ is_symmetric as is_symmetric, is_tridiagonal as is_tridiagonal, is_upper_triangular as is_upper_triangular, - is_Woodbury as is_Woodbury, JacobianLinearOperator as JacobianLinearOperator, linearise as linearise, materialise as materialise, diff --git a/lineax/_operator.py b/lineax/_operator.py index 5a2d59d..585063d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -914,8 +914,6 @@ class WoodburyLinearOperator(AbstractLinearOperator, strict=True): C: Inexact[Array, " k k"] U: Inexact[Array, " n k"] V: Inexact[Array, " k n"] - UCV: Array - tags: frozenset[object] = eqx.field(static=True) def __init__( self, @@ -923,7 +921,6 @@ def __init__( C: Inexact[Array, " k k"], U: Inexact[Array, " n k"], V: Inexact[Array, " k n"], - tags: Union[object, frozenset[object]] = (), ): """**Arguments:** @@ -941,7 +938,7 @@ def __init__( self.U = inexact_asarray(U) self.V = inexact_asarray(V) (N, M) = self.A.in_structure(), self.A.out_structure() - if N != M: + if not eqx.tree_equal(N, M): raise ValueError(f"expecting square operator for A, got {N} by {M}") (K, L) = self.C.shape if K != L: @@ -951,16 +948,14 @@ def __init__( raise ValueError("U does not have consistent shape with A and C") if self.V.shape != (K, N): raise ValueError("V does not have consistent shape with A and C") - self.UCV = self.U @ (self.C @ self.V) - self.tags = _frozenset(tags) def mv(self, vector): Ax = self.A.mv(vector) - UCVx = self.UCV @ vector + UCVx = self.U @ (self.C @ (self.V @ vector)) return Ax + UCVx def as_matrix(self): - matrix = self.A.as_matrix() + self.UCV + matrix = self.A.as_matrix() + self.U @ (self.C @ self.V) return matrix def transpose(self): @@ -1505,33 +1500,6 @@ def _(operator): return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal -@ft.singledispatch -def woodbury( - operator: AbstractLinearOperator, -) -> tuple[ - AbstractLinearOperator, - Shaped[Array, " k k"], - Shaped[Array, " n k"], - Shaped[Array, " k n"], -]: - """Extracts the A, C, U, V Woodbury structure, from a linear - operator. Returns one linear operators and three matrices. - **Arguments:** - - `operator`: a linear operator. - **Returns:** - A 4-tuple, consisting of - - A which is a linear operator - - C, U and V which are matrices - For all but the Woodbury operator this extraction is not possible - """ - _default_not_implemented("woodbury", operator) - - -@woodbury.register(WoodburyLinearOperator) -def _(operator): - return operator.A, operator.C, operator.U, operator.V - - # is_symmetric @@ -1554,7 +1522,6 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool: @is_symmetric.register(MatrixLinearOperator) -@is_symmetric.register(WoodburyLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @is_symmetric.register(FunctionLinearOperator) @@ -1581,6 +1548,7 @@ def _(operator): @is_symmetric.register(TridiagonalLinearOperator) +@is_symmetric.register(WoodburyLinearOperator) def _(operator): return False @@ -1607,7 +1575,6 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool: @is_diagonal.register(MatrixLinearOperator) -@is_diagonal.register(WoodburyLinearOperator) @is_diagonal.register(PyTreeLinearOperator) @is_diagonal.register(JacobianLinearOperator) @is_diagonal.register(FunctionLinearOperator) @@ -1621,6 +1588,7 @@ def _(operator): return True +@is_diagonal.register(WoodburyLinearOperator) @is_diagonal.register(TridiagonalLinearOperator) def _(operator): return False @@ -1648,7 +1616,6 @@ def is_tridiagonal(operator: AbstractLinearOperator) -> bool: @is_tridiagonal.register(MatrixLinearOperator) -@is_tridiagonal.register(WoodburyLinearOperator) @is_tridiagonal.register(PyTreeLinearOperator) @is_tridiagonal.register(JacobianLinearOperator) @is_tridiagonal.register(FunctionLinearOperator) @@ -1663,32 +1630,7 @@ def _(operator): return True -@ft.singledispatch -def is_Woodbury(operator: AbstractLinearOperator) -> bool: - """Returns whether an operator is marked as Woodbury. - See [the documentation on linear operator tags](../api/tags.md) for more - information. - **Arguments:** - - `operator`: a linear operator. - **Returns:** - Either `True` or `False.` - """ - _default_not_implemented("is_Woodbury", operator) - - -@is_Woodbury.register(WoodburyLinearOperator) -def _(operator): - return True - - -@is_Woodbury.register(MatrixLinearOperator) -@is_Woodbury.register(PyTreeLinearOperator) -@is_Woodbury.register(JacobianLinearOperator) -@is_Woodbury.register(FunctionLinearOperator) -@is_Woodbury.register(IdentityLinearOperator) -@is_Woodbury.register(DiagonalLinearOperator) -@is_Woodbury.register(TridiagonalLinearOperator) -@is_Woodbury.register(TaggedLinearOperator) # TODO : check this +@is_tridiagonal.register(WoodburyLinearOperator) def _(operator): return False @@ -1715,7 +1657,6 @@ def has_unit_diagonal(operator: AbstractLinearOperator) -> bool: @has_unit_diagonal.register(MatrixLinearOperator) -@has_unit_diagonal.register(WoodburyLinearOperator) @has_unit_diagonal.register(PyTreeLinearOperator) @has_unit_diagonal.register(JacobianLinearOperator) @has_unit_diagonal.register(FunctionLinearOperator) @@ -1728,6 +1669,7 @@ def _(operator): return True +@has_unit_diagonal.register(WoodburyLinearOperator) @has_unit_diagonal.register(DiagonalLinearOperator) @has_unit_diagonal.register(TridiagonalLinearOperator) def _(operator): @@ -1757,7 +1699,6 @@ def is_lower_triangular(operator: AbstractLinearOperator) -> bool: @is_lower_triangular.register(MatrixLinearOperator) -@is_lower_triangular.register(WoodburyLinearOperator) @is_lower_triangular.register(PyTreeLinearOperator) @is_lower_triangular.register(JacobianLinearOperator) @is_lower_triangular.register(FunctionLinearOperator) @@ -1771,6 +1712,7 @@ def _(operator): return True +@is_lower_triangular.register(WoodburyLinearOperator) @is_lower_triangular.register(TridiagonalLinearOperator) def _(operator): return False @@ -1798,7 +1740,6 @@ def is_upper_triangular(operator: AbstractLinearOperator) -> bool: @is_upper_triangular.register(MatrixLinearOperator) -@is_upper_triangular.register(WoodburyLinearOperator) @is_upper_triangular.register(PyTreeLinearOperator) @is_upper_triangular.register(JacobianLinearOperator) @is_upper_triangular.register(FunctionLinearOperator) @@ -1812,6 +1753,7 @@ def _(operator): return True +@is_upper_triangular.register(WoodburyLinearOperator) @is_upper_triangular.register(TridiagonalLinearOperator) def _(operator): return False @@ -1839,7 +1781,6 @@ def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool: @is_positive_semidefinite.register(MatrixLinearOperator) -@is_positive_semidefinite.register(WoodburyLinearOperator) @is_positive_semidefinite.register(PyTreeLinearOperator) @is_positive_semidefinite.register(JacobianLinearOperator) @is_positive_semidefinite.register(FunctionLinearOperator) @@ -1852,6 +1793,7 @@ def _(operator): return True +@is_positive_semidefinite.register(WoodburyLinearOperator) @is_positive_semidefinite.register(DiagonalLinearOperator) @is_positive_semidefinite.register(TridiagonalLinearOperator) def _(operator): @@ -1881,7 +1823,6 @@ def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool: @is_negative_semidefinite.register(MatrixLinearOperator) -@is_negative_semidefinite.register(WoodburyLinearOperator) @is_negative_semidefinite.register(PyTreeLinearOperator) @is_negative_semidefinite.register(JacobianLinearOperator) @is_negative_semidefinite.register(FunctionLinearOperator) @@ -1894,6 +1835,7 @@ def _(operator): return False +@is_negative_semidefinite.register(WoodburyLinearOperator) @is_negative_semidefinite.register(DiagonalLinearOperator) @is_negative_semidefinite.register(TridiagonalLinearOperator) def _(operator): @@ -2043,7 +1985,6 @@ def _(operator): is_lower_triangular, is_upper_triangular, is_tridiagonal, - is_Woodbury, ): @check.register(TangentLinearOperator) @@ -2106,7 +2047,6 @@ def _(operator, check=check, tag=tag): is_positive_semidefinite, is_negative_semidefinite, is_tridiagonal, - is_Woodbury, ): @check.register(AddLinearOperator) diff --git a/lineax/_solve.py b/lineax/_solve.py index 82fe750..f2fa197 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -40,9 +40,9 @@ is_positive_semidefinite, is_tridiagonal, is_upper_triangular, - is_Woodbury, linearise, TangentLinearOperator, + WoodburyLinearOperator, ) from ._solution import RESULTS, Solution @@ -587,7 +587,7 @@ def _select_solver(self, operator: AbstractLinearOperator): operator ): token = _cholesky_token - elif is_Woodbury(operator): + elif isinstance(operator, WoodburyLinearOperator): token = _woodbury_token else: token = _lu_token diff --git a/lineax/_solver/Woodbury.py b/lineax/_solver/Woodbury.py index c6e818d..e23c26c 100644 --- a/lineax/_solver/Woodbury.py +++ b/lineax/_solver/Woodbury.py @@ -7,9 +7,8 @@ from .._operator import ( AbstractLinearOperator, - is_Woodbury, MatrixLinearOperator, - woodbury, + WoodburyLinearOperator, ) from .._solution import RESULTS from .._solve import AbstractLinearSolver, AutoLinearSolver @@ -22,14 +21,16 @@ ) -_Woodbury_State: TypeAlias = tuple[ +_WoodburyState: TypeAlias = tuple[ tuple[Array, Array, Array], tuple[AbstractLinearSolver, Any, AbstractLinearSolver, Any], PackedStructures, ] -def compute_pushthrough(A_solver, A_state, C, U, V): +def _compute_pushthrough( + A_solver: AbstractLinearSolver, A_state: Any, C: Array, U: Array, V: Array +) -> tuple[AbstractLinearSolver, Any]: # Push through ( C^-1 + V A^-1 U) y = x vmapped_solve = jax.vmap( lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1 @@ -41,24 +42,28 @@ def compute_pushthrough(A_solver, A_state, C, U, V): return solver, state -class Woodbury(AbstractLinearSolver[_Woodbury_State]): +class Woodbury(AbstractLinearSolver[_WoodburyState]): """Solving system using Woodbury matrix identity""" - def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): + def init( + self, + operator: AbstractLinearOperator, + options: dict[str, Any], + A_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True), + ): del options - if not is_Woodbury(operator): + if not isinstance(operator, WoodburyLinearOperator): raise ValueError( "`Woodbury` may only be used for linear solves with A + U C V structure" ) else: - A, C, U, V = woodbury(operator) + A, C, U, V = operator.A, operator.C, operator.U, operator.V # pyright: ignore if A.in_size() != A.out_size(): raise ValueError("""A must be square""") # Find correct solvers and init for A - A_solver = AutoLinearSolver(well_posed=True).select_solver(A) A_state = A_solver.init(A, {}) # Compute pushthrough operator - pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) return ( (C, U, V), (A_solver, A_state, pt_solver, pt_state), @@ -67,7 +72,7 @@ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): def compute( self, - state: _Woodbury_State, + state: _WoodburyState, vector, options, ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: @@ -94,7 +99,7 @@ def compute( solution = unravel_solution(solution, A_packed_structures) return solution, RESULTS.successful, {} - def transpose(self, state: _Woodbury_State, options: dict[str, Any]): + def transpose(self, state: _WoodburyState, options: dict[str, Any]): ( (C, U, V), (A_solver, A_state, pt_solver, pt_state), @@ -105,7 +110,7 @@ def transpose(self, state: _Woodbury_State, options: dict[str, Any]): U = jnp.transpose(V) V = jnp.transpose(U) A_state, _ = A_solver.transpose(A_state, {}) - pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) transpose_state = ( (C, U, V), (A_solver, A_state, pt_solver, pt_state), @@ -113,7 +118,7 @@ def transpose(self, state: _Woodbury_State, options: dict[str, Any]): ) return transpose_state, options - def conj(self, state: _Woodbury_State, options: dict[str, Any]): + def conj(self, state: _WoodburyState, options: dict[str, Any]): ( (C, U, V), (A_solver, A_state, pt_solver, pt_state), @@ -123,7 +128,7 @@ def conj(self, state: _Woodbury_State, options: dict[str, Any]): U = jnp.conj(U) V = jnp.conj(V) A_state, _ = A_solver.conj(A_state, {}) - pt_solver, pt_state = compute_pushthrough(A_solver, A_state, C, U, V) + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) conj_state = ( (C, U, V), (A_solver, A_state, pt_solver, pt_state), diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index bc631f8..53d400e 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -22,4 +22,4 @@ from .svd import SVD as SVD from .triangular import Triangular as Triangular from .tridiagonal import Tridiagonal as Tridiagonal -from .Woodbury import Woodbury as Woodbury +from .woodbury import Woodbury as Woodbury From b55f8268a76a80bfabb5b153bba769624ff283b6 Mon Sep 17 00:00:00 2001 From: aidancrilly Date: Fri, 31 May 2024 19:38:54 +0100 Subject: [PATCH 4/4] Woodbury.py -> woodbury.py --- lineax/_solver/{Woodbury.py => woodbury.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lineax/_solver/{Woodbury.py => woodbury.py} (100%) diff --git a/lineax/_solver/Woodbury.py b/lineax/_solver/woodbury.py similarity index 100% rename from lineax/_solver/Woodbury.py rename to lineax/_solver/woodbury.py