diff --git a/.gitignore b/.gitignore index 27c0588f..e81ca4c5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ site/ .all_objects.cache .pymon .idea/ +.venv diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 1ee6ea3e..3679cf85 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -4,6 +4,7 @@ AbstractAdjoint as AbstractAdjoint, BacksolveAdjoint as BacksolveAdjoint, DirectAdjoint as DirectAdjoint, + ForwardAdjoint as ForwardAdjoint, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, ) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 4ff2dd2c..6b580d14 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -852,3 +852,36 @@ def loop( ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats + + +class ForwardAdjoint(AbstractAdjoint): + """Differentiate through a differential equation solve during the forward pass. + + This is a useful adjoint to use whenever we have many more outputs than inputs to a + function - for instance during parameter inference for ODE models with least-squares + based solvers such as `optimistix.Levenberg-Marquardt`. + """ + + def loop( + self, + *, + solver, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del throw, passed_solver_state, passed_controller_state + inner_while_loop = eqx.Partial(_inner_loop, kind="lax") + outer_while_loop = eqx.Partial(_outer_loop, kind="lax") + # Support forward-mode autodiff. + # TODO: remove this hack once we can JVP through custom_vjps. + if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none) + final_state = self._loop( + solver=solver, + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, + **kwargs, + ) + return final_state diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 992d57ab..7a76d019 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax selection: members: false +::: diffrax.ForwardAdjoint + selection: + members: false + --- ::: diffrax.adjoint_rms_seminorm diff --git a/docs/requirements.txt b/docs/requirements.txt index e0fe8e61..0eafa922 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,6 +3,7 @@ mkdocs==1.3.0 # Main documentation generator. mkdocs-material==7.3.6 # Theme pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. +mkdocs-autorefs==1.0.1 # Automatically generate references to other pages. mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included @@ -12,4 +13,4 @@ nbformat==5.4.0 # | pygments==2.14.0 # Install latest version of our dependencies -jax[cpu] +jax[cpu] \ No newline at end of file diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 30ad7bc8..24ee96d1 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -74,8 +74,29 @@ def _run_inexact(inexact, saveat, adjoint): return _run(eqx.combine(inexact, static), saveat, adjoint) _run_grad = eqx.filter_jit(jax.grad(_run_inexact)) + _run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact)) _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) + # # TODO(jhaffner): jax.linearize can not handle integer input either (like jacfwd), + # # so we would have to use jvp directly and construct unit "vector" pytree(s) for + # # each leaf. This seems like an exceedingly rare use case, so maybe skip? + # @eqx.filter_jit + # def _run_fwd_grad_int(y0__args__term, saveat, adjoint): + # # We want a "fwd_grad" that works with integer inputs. A way to get this is to + # # linearise the function and then transpose. Here we filter out numeric types + # # because we want to catch floats and ints, not just arrays. + + # def is_numeric(x): + # return isinstance(x, (int, float, jax.numpy.ndarray)) + # dynamic, static = eqx.partition(y0__args__term, is_numeric) + + # def differentiable_run(dynamic): + # return _run(eqx.combine(dynamic, static), saveat, adjoint) + + # _, lin_fn = jax.linearize(differentiable_run, dynamic) + # grad = jax.linear_transpose(lin_fn, dynamic)(1.0) + # return grad + twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact) @eqx.filter_jit @@ -83,6 +104,11 @@ def _run_vmap_grad(twice_inexact, saveat, adjoint): f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None)) return f(twice_inexact, saveat, adjoint) + @eqx.filter_jit + def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint): + f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None)) + return f(twice_inexact, saveat, adjoint) + # @eqx.filter_jit # def _run_vmap_finite_diff(twice_inexact, saveat, adjoint): # @jax.vmap @@ -102,6 +128,17 @@ def _run_impl(twice_inexact): return _run_impl(twice_inexact) + @eqx.filter_jit + def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint): + @jax.jacfwd + def _run_impl(twice_inexact): + f = jax.vmap(_run_inexact, in_axes=(0, None, None)) + out = f(twice_inexact, saveat, adjoint) + assert out.shape == (2,) + return jnp.sum(out) + + return _run_impl(twice_inexact) + # Yep, test that they're not implemented. We can remove these checks if we ever # do implement them. # Until that day comes, it's worth checking that things don't silently break. @@ -136,9 +173,11 @@ def _convert_float0(x): inexact, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint()) + forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardAdjoint()) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) direct_grads = _run_grad_int( y0__args__term, saveat, diffrax.DirectAdjoint() @@ -149,6 +188,10 @@ def _convert_float0(x): backsolve_grads = _run_grad_int( y0__args__term, saveat, diffrax.BacksolveAdjoint() ) + # forward_grads = _run_fwd_grad_int( + # y0__args__term, saveat, diffrax.ForwardAdjoint() + # ) + direct_grads = jtu.tree_map(_convert_float0, direct_grads) recursive_grads = jtu.tree_map(_convert_float0, recursive_grads) backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads) @@ -166,9 +209,13 @@ def _convert_float0(x): backsolve_grads = _run_vmap_grad( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_vmap_fwd_grad( + twice_inexact, saveat, diffrax.ForwardAdjoint() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) direct_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.DirectAdjoint() @@ -179,9 +226,13 @@ def _convert_float0(x): backsolve_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_fwd_grad_vmap( + twice_inexact, saveat, diffrax.ForwardAdjoint() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) def test_adjoint_seminorm():