diff --git a/.gitignore b/.gitignore index 27c0588f..daf54d8c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ site/ .all_objects.cache .pymon .idea/ +.venv/ diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 1ee6ea3e..42073a10 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -4,6 +4,7 @@ AbstractAdjoint as AbstractAdjoint, BacksolveAdjoint as BacksolveAdjoint, DirectAdjoint as DirectAdjoint, + ForwardMode as ForwardMode, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, ) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 4ff2dd2c..46338ea2 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): !!! info Note that this cannot be forward-mode autodifferentiated. (E.g. using - `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need. + `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode + and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need + only forward-mode autodifferentiation. ??? cite "References" @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint): So unless you need forward-mode autodifferentiation then [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. + If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is + more efficient. """ def loop( @@ -852,3 +856,40 @@ def loop( ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats + + +class ForwardMode(AbstractAdjoint): + """Supports forward-mode automatic differentiation through a differential equation + solve. This works by propagating the derivatives during the forward-pass - that is, + during the ODE solve, instead of solving the adjoint equations afterwards. + (So this is really a different way of quantifying the sensitivity of the output to + the input, even if its interface is that of an adjoint for convenience.) + + This is useful when we have many more outputs than inputs to a function - for + instance during parameter inference for ODE models with least-squares solvers such + as `optimistix.Levenberg-Marquardt`, that operate on the residuals. + """ + + def loop( + self, + *, + solver, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del throw, passed_solver_state, passed_controller_state + inner_while_loop = eqx.Partial(_inner_loop, kind="lax") + outer_while_loop = eqx.Partial(_outer_loop, kind="lax") + # Support forward-mode autodiff. + # TODO: remove this hack once we can JVP through custom_vjps. + if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none) + final_state = self._loop( + solver=solver, + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, + **kwargs, + ) + return final_state diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 992d57ab..5d4ede63 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax selection: members: false +::: diffrax.ForwardMode + selection: + members: false + --- ::: diffrax.adjoint_rms_seminorm diff --git a/docs/requirements.txt b/docs/requirements.txt index b033b3e3..ed743cfd 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -14,4 +14,4 @@ mkdocs-autorefs==1.0.1 mkdocs-material-extensions==1.3.1 # Install latest version of our dependencies -jax[cpu] +jax[cpu] \ No newline at end of file diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 30ad7bc8..12e6ee27 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -74,6 +74,7 @@ def _run_inexact(inexact, saveat, adjoint): return _run(eqx.combine(inexact, static), saveat, adjoint) _run_grad = eqx.filter_jit(jax.grad(_run_inexact)) + _run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact)) _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact) @@ -83,6 +84,11 @@ def _run_vmap_grad(twice_inexact, saveat, adjoint): f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None)) return f(twice_inexact, saveat, adjoint) + @eqx.filter_jit + def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint): + f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None)) + return f(twice_inexact, saveat, adjoint) + # @eqx.filter_jit # def _run_vmap_finite_diff(twice_inexact, saveat, adjoint): # @jax.vmap @@ -102,6 +108,17 @@ def _run_impl(twice_inexact): return _run_impl(twice_inexact) + @eqx.filter_jit + def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint): + @jax.jacfwd + def _run_impl(twice_inexact): + f = jax.vmap(_run_inexact, in_axes=(0, None, None)) + out = f(twice_inexact, saveat, adjoint) + assert out.shape == (2,) + return jnp.sum(out) + + return _run_impl(twice_inexact) + # Yep, test that they're not implemented. We can remove these checks if we ever # do implement them. # Until that day comes, it's worth checking that things don't silently break. @@ -136,10 +153,16 @@ def _convert_float0(x): inexact, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint()) + forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardMode()) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) + # Test support for integer inputs (jax.grad(..., allow_int=True)). There + # is no corresponding option for jax.jacfwd or jax.linearize, and a + # workaround (jvp with custom "unit pytrees" for mixed array and + # non-array inputs?) is not implemented and tested here. direct_grads = _run_grad_int( y0__args__term, saveat, diffrax.DirectAdjoint() ) @@ -166,9 +189,13 @@ def _convert_float0(x): backsolve_grads = _run_vmap_grad( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_vmap_fwd_grad( + twice_inexact, saveat, diffrax.ForwardMode() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) direct_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.DirectAdjoint() @@ -179,9 +206,13 @@ def _convert_float0(x): backsolve_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_fwd_grad_vmap( + twice_inexact, saveat, diffrax.ForwardMode() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) def test_adjoint_seminorm():