Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ site/
.all_objects.cache
.pymon
.idea/
.venv/
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AbstractAdjoint as AbstractAdjoint,
BacksolveAdjoint as BacksolveAdjoint,
DirectAdjoint as DirectAdjoint,
ForwardMode as ForwardMode,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
)
Expand Down
43 changes: 42 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
!!! info

Note that this cannot be forward-mode autodifferentiated. (E.g. using
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need.
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode
and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need
only forward-mode autodifferentiation.

??? cite "References"

Expand Down Expand Up @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint):

So unless you need forward-mode autodifferentiation then
[`diffrax.RecursiveCheckpointAdjoint`][] should be preferred.
If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is
more efficient.
"""

def loop(
Expand Down Expand Up @@ -852,3 +856,40 @@ def loop(
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats


class ForwardMode(AbstractAdjoint):
"""Supports forward-mode automatic differentiation through a differential equation
solve. This works by propagating the derivatives during the forward-pass - that is,
during the ODE solve, instead of solving the adjoint equations afterwards.
(So this is really a different way of quantifying the sensitivity of the output to
the input, even if its interface is that of an adjoint for convenience.)

This is useful when we have many more outputs than inputs to a function - for
instance during parameter inference for ODE models with least-squares solvers such
as `optimistix.Levenberg-Marquardt`, that operate on the residuals.
"""

def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = eqx.Partial(_inner_loop, kind="lax")
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
4 changes: 4 additions & 0 deletions docs/api/adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax
selection:
members: false

::: diffrax.ForwardMode
selection:
members: false

---

::: diffrax.adjoint_rms_seminorm
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ mkdocs-autorefs==1.0.1
mkdocs-material-extensions==1.3.1

# Install latest version of our dependencies
jax[cpu]
jax[cpu]
31 changes: 31 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _run_inexact(inexact, saveat, adjoint):
return _run(eqx.combine(inexact, static), saveat, adjoint)

_run_grad = eqx.filter_jit(jax.grad(_run_inexact))
_run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact))
_run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True))

twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact)
Expand All @@ -83,6 +84,11 @@ def _run_vmap_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

@eqx.filter_jit
def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

# @eqx.filter_jit
# def _run_vmap_finite_diff(twice_inexact, saveat, adjoint):
# @jax.vmap
Expand All @@ -102,6 +108,17 @@ def _run_impl(twice_inexact):

return _run_impl(twice_inexact)

@eqx.filter_jit
def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint):
@jax.jacfwd
def _run_impl(twice_inexact):
f = jax.vmap(_run_inexact, in_axes=(0, None, None))
out = f(twice_inexact, saveat, adjoint)
assert out.shape == (2,)
return jnp.sum(out)

return _run_impl(twice_inexact)

# Yep, test that they're not implemented. We can remove these checks if we ever
# do implement them.
# Until that day comes, it's worth checking that things don't silently break.
Expand Down Expand Up @@ -136,10 +153,16 @@ def _convert_float0(x):
inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
)
backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint())
forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardMode())
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

# Test support for integer inputs (jax.grad(..., allow_int=True)). There
# is no corresponding option for jax.jacfwd or jax.linearize, and a
# workaround (jvp with custom "unit pytrees" for mixed array and
# non-array inputs?) is not implemented and tested here.
direct_grads = _run_grad_int(
y0__args__term, saveat, diffrax.DirectAdjoint()
)
Expand All @@ -166,9 +189,13 @@ def _convert_float0(x):
backsolve_grads = _run_vmap_grad(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_vmap_fwd_grad(
twice_inexact, saveat, diffrax.ForwardMode()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

direct_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.DirectAdjoint()
Expand All @@ -179,9 +206,13 @@ def _convert_float0(x):
backsolve_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_fwd_grad_vmap(
twice_inexact, saveat, diffrax.ForwardMode()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)


def test_adjoint_seminorm():
Expand Down
Loading