diff --git a/lineax/_misc.py b/lineax/_misc.py index 3cc5117..4deb170 100644 --- a/lineax/_misc.py +++ b/lineax/_misc.py @@ -19,6 +19,7 @@ import jax.core import jax.numpy as jnp import jax.tree_util as jtu +from jax import ShapeDtypeStruct from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore @@ -110,3 +111,44 @@ def structure_equal(x, y) -> bool: x = strip_weak_dtype(jax.eval_shape(lambda: x)) y = strip_weak_dtype(jax.eval_shape(lambda: y)) return eqx.tree_equal(x, y) is True + + +def is_complex_structure(structure): + with jax.numpy_dtype_promotion("standard"): + return jnp.isdtype( + jnp.result_type(*(jax.tree.flatten(structure)[0])), + "complex floating", + ) + + +def complex_to_real_structure(in_structure): + return jtu.tree_map( + lambda x: ShapeDtypeStruct( + tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype) + ) + if jnp.isdtype(x.dtype, "complex floating") + else x, + in_structure, + ) + + +def complex_to_real_tree(x, in_structure): + with jax.numpy_dtype_promotion("standard"): + return jtu.tree_map( + lambda x, struct: jnp.stack([x.real, x.imag], axis=-1) + if jnp.isdtype(struct.dtype, "complex floating") + else x, + x, + in_structure, + ) + + +def real_to_complex_tree(x, in_structure): + with jax.numpy_dtype_promotion("standard"): + return jtu.tree_map( + lambda x, struct: x[..., 0] + 1.0j * x[..., 1] + if jnp.isdtype(struct.dtype, "complex floating") + else x, + x, + in_structure, + ) diff --git a/lineax/_operator.py b/lineax/_operator.py index 53d6ccc..d93cbbb 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -39,10 +39,13 @@ from ._custom_types import sentinel from ._misc import ( + complex_to_real_structure, default_floating_dtype, inexact_asarray, + is_complex_structure, jacobian, NoneAux, + real_to_complex_tree, strip_weak_dtype, ) from ._tags import ( @@ -1322,11 +1325,26 @@ def _(operator): @materialise.register(FunctionLinearOperator) def _(operator): + if is_complex_structure(operator.in_structure()) and not is_complex_structure( + operator.out_structure() + ): + # We'll use R^2->R representation for C->R function. + in_structure = complex_to_real_structure(operator.in_structure()) + + map_to_original = lambda x: real_to_complex_tree( + x, + operator.in_structure(), + ) + else: + map_to_original = lambda x: x + in_structure = operator.in_structure() flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + eqx.filter_eval_shape(jfu.ravel_pytree, in_structure) ) + fn = lambda x: operator.fn(map_to_original(unravel(x))) eye = jnp.eye(flat.size, dtype=flat.dtype) - jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) + + jac = jax.vmap(fn, out_axes=-1)(eye) def batch_unravel(x): assert x.ndim > 0 diff --git a/lineax/_solver/misc.py b/lineax/_solver/misc.py index b7e1a09..8bc982a 100644 --- a/lineax/_solver/misc.py +++ b/lineax/_solver/misc.py @@ -22,7 +22,14 @@ import numpy as np from jaxtyping import Array, PyTree, Shaped -from .._misc import strip_weak_dtype, structure_equal +from .._misc import ( + complex_to_real_structure, + complex_to_real_tree, + is_complex_structure, + real_to_complex_tree, + strip_weak_dtype, + structure_equal, +) from .._operator import ( AbstractLinearOperator, IdentityLinearOperator, @@ -81,11 +88,14 @@ def ravel_vector( pytree: PyTree[Array], packed_structures: PackedStructures ) -> Shaped[Array, " size"]: leaves, treedef = packed_structures.value - out_structure, _ = jtu.tree_unflatten(treedef, leaves) + out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) # `is` in case `tree_equal` returns a Tracer. if not structure_equal(pytree, out_structure): raise ValueError("pytree does not match out_structure") # not using `ravel_pytree` as that doesn't come with guarantees about order + + if is_complex_structure(out_structure) and not is_complex_structure(in_structure): + pytree = complex_to_real_tree(pytree, out_structure) leaves = jtu.tree_leaves(pytree) dtype = jnp.result_type(*leaves) return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves]) @@ -95,15 +105,24 @@ def unravel_solution( solution: Shaped[Array, " size"], packed_structures: PackedStructures ) -> PyTree[Array]: leaves, treedef = packed_structures.value - _, in_structure = jtu.tree_unflatten(treedef, leaves) - leaves, treedef = jtu.tree_flatten(in_structure) + out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) + complex_real = is_complex_structure(in_structure) and not is_complex_structure( + out_structure + ) + if complex_real: + leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure)) + else: + leaves, treedef = jtu.tree_flatten(in_structure) sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) split = jnp.split(solution, sizes) assert len(split) == len(leaves) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)] - return jtu.tree_unflatten(treedef, shaped) + if complex_real: + return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure) + else: + return jtu.tree_unflatten(treedef, shaped) def transpose_packed_structures( diff --git a/tests/helpers.py b/tests/helpers.py index bb2d396..e3dd25e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -254,6 +254,13 @@ def make_composed_operator(getkey, matrix, tags): return lx.TaggedLinearOperator(operator1 @ operator2, tags) +def make_real_function_operator(getkey, matrix, tags): + fn = lambda x: (matrix @ x).real + _, in_size = matrix.shape + in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype) + return lx.FunctionLinearOperator(fn, in_struct, tags) + + # Slightly sketchy approach to finite differences, in that this is pulled out of # Numerical Recipes. # I also don't know of a handling of the JVP case off the top of my head -- although diff --git a/tests/test_jvp.py b/tests/test_jvp.py index 592e530..8478574 100644 --- a/tests/test_jvp.py +++ b/tests/test_jvp.py @@ -15,10 +15,12 @@ import functools as ft import equinox as eqx +import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest +from lineax._misc import complex_to_real_dtype from .helpers import ( construct_matrix, @@ -27,6 +29,7 @@ has_tag, make_jac_operator, make_matrix_operator, + make_real_function_operator, solvers_tags_pseudoinverse, tree_allclose, ) @@ -118,3 +121,123 @@ def test_jvp( assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3) assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3) assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3) + + +@pytest.mark.parametrize( + "solver, tags, pseudoinverse", + [stp for stp in solvers_tags_pseudoinverse if stp[-1]], +) # only pseudoinverse +@pytest.mark.parametrize("use_state", (True, False)) +@pytest.mark.parametrize( + "make_matrix", + ( + construct_matrix, + construct_singular_matrix, + ), +) +def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix): + t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None + + matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=jnp.complex128) + + out_size, in_size = matrix.shape + out_dtype = complex_to_real_dtype(matrix.dtype) + vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + + if has_tag(tags, lx.unit_diagonal_tag): + # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} + # still satisfies the tag itself. + # This is the exception. + t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) + + make_op = ft.partial(make_real_function_operator, getkey) + operator, t_operator = eqx.filter_jvp(make_op, (matrix, tags), (t_matrix, t_tags)) + + if use_state: + state = solver.init(operator, options={}) + linear_solve = ft.partial(lx.linear_solve, state=state) + else: + linear_solve = lx.linear_solve + + solve_vec_only = lambda v: linear_solve(operator, v, solver).value + vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) + + solve_op_only = lambda op: linear_solve(op, vec, solver).value + solve_op_vec = lambda op, v: linear_solve(op, v, solver).value + + op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) + op_vec_out, t_op_vec_out = eqx.filter_jvp( + solve_op_vec, + (operator, vec), + (t_operator, t_vec), + ) + (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( + lambda op: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec + ), # pyright: ignore + (matrix,), + (t_matrix,), + ) + (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( + lambda op, v: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v + ), + (matrix, vec), + (t_matrix, t_vec), # pyright: ignore + ) + + # Work around JAX issue #14868. + if jnp.any(jnp.isnan(t_expected_op_out)): + _, (t_expected_op_out, *_) = finite_difference_jvp( + lambda op: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec + ), # pyright: ignore + (matrix,), + (t_matrix,), + ) + if jnp.any(jnp.isnan(t_expected_op_vec_out)): + _, (t_expected_op_vec_out, *_) = finite_difference_jvp( + lambda op, v: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v + ), + (matrix, vec), + (t_matrix, t_vec), # pyright: ignore + ) + + real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1) + pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore + expected_vec_out = pinv_matrix @ vec + with jax.numpy_dtype_promotion("standard"): + expected_complex_vec_out = ( + expected_vec_out[:in_size] + 1.0j * expected_vec_out[in_size:] + ) + expected_complex_op_out = ( + expected_op_out[:in_size] + 1.0j * expected_op_out[in_size:] + ) + expected_complex_op_vec_out = ( + expected_op_vec_out[:in_size] + 1.0j * expected_op_vec_out[in_size:] + ) + + assert tree_allclose(vec_out, expected_complex_vec_out) + assert tree_allclose(op_out, expected_complex_op_out) + assert tree_allclose(op_vec_out, expected_complex_op_vec_out) + + t_expected_vec_out = pinv_matrix @ t_vec + + with jax.numpy_dtype_promotion("standard"): + t_expected_complex_vec_out = ( + t_expected_vec_out[:in_size] + 1.0j * t_expected_vec_out[in_size:] + ) + t_expected_complex_op_out = ( + t_expected_op_out[:in_size] + 1.0j * t_expected_op_out[in_size:] + ) + + t_expected_complex_op_vec_out = ( + t_expected_op_vec_out[:in_size] + 1.0j * t_expected_op_vec_out[in_size:] + ) + assert tree_allclose( + matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3 + ) + assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3) + assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3) diff --git a/tests/test_operator.py b/tests/test_operator.py index 2a10135..aa8140d 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -20,6 +20,7 @@ import jax.random as jr import lineax as lx import pytest +from lineax._misc import complex_to_real_dtype from .helpers import ( make_diagonal_operator, @@ -321,6 +322,28 @@ def test_materialise_function_linear_operator(dtype, getkey): assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct +def test_materialise_function_real_linear_operator(getkey): + dtype = jnp.complex128 + x = ( + jr.normal(getkey(), (5, 9), dtype=dtype), + jr.normal(getkey(), (3,), dtype=dtype), + ) + input_structure = jax.eval_shape(lambda: x) + fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]).real, (1, 2))} + output_structure = jax.eval_shape(fn, input_structure) + operator = lx.FunctionLinearOperator(fn, input_structure) + materialised_operator = lx.materialise(operator) + assert materialised_operator.out_structure() == output_structure + assert isinstance(materialised_operator, lx.PyTreeLinearOperator) + expected_struct = { + "a": ( + jax.ShapeDtypeStruct((1, 2, 5, 9, 2), complex_to_real_dtype(dtype)), + jax.ShapeDtypeStruct((1, 2, 3, 2), complex_to_real_dtype(dtype)), + ) + } + assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_pytree_transpose(dtype, getkey): out_struct = jax.eval_shape(