Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ site/
.all_objects.cache
.pymon
.idea/
.venv
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AbstractAdjoint as AbstractAdjoint,
BacksolveAdjoint as BacksolveAdjoint,
DirectAdjoint as DirectAdjoint,
ForwardAdjoint as ForwardAdjoint,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we really shouldn't call this 'adjoint' as that word refers specifically to reverse-mode autodifferentiation. Maybe ForwardMode instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha see I should really dig deeper into the theory of this!

Should we then rename the section in the documentation? Maybe "Differentiating through a solve"? ForwardMode also inherits from AbstractAdjoint.

The easiest thing to do would be to simply add a comment to the documentation of ForwardMode, that it is not really an adjoint.

ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
)
Expand Down
33 changes: 33 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,3 +852,36 @@ def loop(
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats


class ForwardAdjoint(AbstractAdjoint):
"""Differentiate through a differential equation solve during the forward pass.

This is a useful adjoint to use whenever we have many more outputs than inputs to a
function - for instance during parameter inference for ODE models with least-squares
based solvers such as `optimistix.Levenberg-Marquardt`.
"""

def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = eqx.Partial(_inner_loop, kind="lax")
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
4 changes: 4 additions & 0 deletions docs/api/adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax
selection:
members: false

::: diffrax.ForwardAdjoint
selection:
members: false

---

::: diffrax.adjoint_rms_seminorm
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mkdocs==1.3.0 # Main documentation generator.
mkdocs-material==7.3.6 # Theme
pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX.
mkdocstrings==0.17.0 # Autogenerate documentation from docstrings.
mkdocs-autorefs==1.0.1 # Automatically generate references to other pages.
mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages.
pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects
mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included
Expand All @@ -12,4 +13,4 @@ nbformat==5.4.0 # |
pygments==2.14.0

# Install latest version of our dependencies
jax[cpu]
jax[cpu]
51 changes: 51 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,41 @@ def _run_inexact(inexact, saveat, adjoint):
return _run(eqx.combine(inexact, static), saveat, adjoint)

_run_grad = eqx.filter_jit(jax.grad(_run_inexact))
_run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact))
_run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True))

# # TODO(jhaffner): jax.linearize can not handle integer input either (like jacfwd),
# # so we would have to use jvp directly and construct unit "vector" pytree(s) for
# # each leaf. This seems like an exceedingly rare use case, so maybe skip?
# @eqx.filter_jit
# def _run_fwd_grad_int(y0__args__term, saveat, adjoint):
# # We want a "fwd_grad" that works with integer inputs. A way to get this is to
# # linearise the function and then transpose. Here we filter out numeric types
# # because we want to catch floats and ints, not just arrays.

# def is_numeric(x):
# return isinstance(x, (int, float, jax.numpy.ndarray))
# dynamic, static = eqx.partition(y0__args__term, is_numeric)

# def differentiable_run(dynamic):
# return _run(eqx.combine(dynamic, static), saveat, adjoint)

# _, lin_fn = jax.linearize(differentiable_run, dynamic)
# grad = jax.linear_transpose(lin_fn, dynamic)(1.0)
# return grad

twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact)

@eqx.filter_jit
def _run_vmap_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

@eqx.filter_jit
def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

# @eqx.filter_jit
# def _run_vmap_finite_diff(twice_inexact, saveat, adjoint):
# @jax.vmap
Expand All @@ -102,6 +128,17 @@ def _run_impl(twice_inexact):

return _run_impl(twice_inexact)

@eqx.filter_jit
def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint):
@jax.jacfwd
def _run_impl(twice_inexact):
f = jax.vmap(_run_inexact, in_axes=(0, None, None))
out = f(twice_inexact, saveat, adjoint)
assert out.shape == (2,)
return jnp.sum(out)

return _run_impl(twice_inexact)

# Yep, test that they're not implemented. We can remove these checks if we ever
# do implement them.
# Until that day comes, it's worth checking that things don't silently break.
Expand Down Expand Up @@ -136,9 +173,11 @@ def _convert_float0(x):
inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
)
backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint())
forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardAdjoint())
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

direct_grads = _run_grad_int(
y0__args__term, saveat, diffrax.DirectAdjoint()
Expand All @@ -149,6 +188,10 @@ def _convert_float0(x):
backsolve_grads = _run_grad_int(
y0__args__term, saveat, diffrax.BacksolveAdjoint()
)
# forward_grads = _run_fwd_grad_int(
# y0__args__term, saveat, diffrax.ForwardAdjoint()
# )

direct_grads = jtu.tree_map(_convert_float0, direct_grads)
recursive_grads = jtu.tree_map(_convert_float0, recursive_grads)
backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads)
Expand All @@ -166,9 +209,13 @@ def _convert_float0(x):
backsolve_grads = _run_vmap_grad(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_vmap_fwd_grad(
twice_inexact, saveat, diffrax.ForwardAdjoint()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

direct_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.DirectAdjoint()
Expand All @@ -179,9 +226,13 @@ def _convert_float0(x):
backsolve_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_fwd_grad_vmap(
twice_inexact, saveat, diffrax.ForwardAdjoint()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)


def test_adjoint_seminorm():
Expand Down