From fc8fcd8a7ba778d2b9d50549093a8ba6a3dbd059 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 20 Dec 2024 12:21:29 +0000 Subject: [PATCH 1/7] add .venv --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 27c0588f..daf54d8c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ site/ .all_objects.cache .pymon .idea/ +.venv/ From 2b52b7cc22e2f99d0e207e3104ac096310415fca Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sun, 8 Dec 2024 22:49:30 +0100 Subject: [PATCH 2/7] add code, tests and documentation for ForwardAdjoint --- diffrax/__init__.py | 1 + diffrax/_adjoint.py | 33 ++++++++++++++++++++++++++++ docs/api/adjoints.md | 4 ++++ test/test_adjoint.py | 51 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 1ee6ea3e..3679cf85 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -4,6 +4,7 @@ AbstractAdjoint as AbstractAdjoint, BacksolveAdjoint as BacksolveAdjoint, DirectAdjoint as DirectAdjoint, + ForwardAdjoint as ForwardAdjoint, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, ) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 4ff2dd2c..6b580d14 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -852,3 +852,36 @@ def loop( ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats + + +class ForwardAdjoint(AbstractAdjoint): + """Differentiate through a differential equation solve during the forward pass. + + This is a useful adjoint to use whenever we have many more outputs than inputs to a + function - for instance during parameter inference for ODE models with least-squares + based solvers such as `optimistix.Levenberg-Marquardt`. + """ + + def loop( + self, + *, + solver, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del throw, passed_solver_state, passed_controller_state + inner_while_loop = eqx.Partial(_inner_loop, kind="lax") + outer_while_loop = eqx.Partial(_outer_loop, kind="lax") + # Support forward-mode autodiff. + # TODO: remove this hack once we can JVP through custom_vjps. + if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none) + final_state = self._loop( + solver=solver, + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, + **kwargs, + ) + return final_state diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 992d57ab..7a76d019 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax selection: members: false +::: diffrax.ForwardAdjoint + selection: + members: false + --- ::: diffrax.adjoint_rms_seminorm diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 30ad7bc8..24ee96d1 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -74,8 +74,29 @@ def _run_inexact(inexact, saveat, adjoint): return _run(eqx.combine(inexact, static), saveat, adjoint) _run_grad = eqx.filter_jit(jax.grad(_run_inexact)) + _run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact)) _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) + # # TODO(jhaffner): jax.linearize can not handle integer input either (like jacfwd), + # # so we would have to use jvp directly and construct unit "vector" pytree(s) for + # # each leaf. This seems like an exceedingly rare use case, so maybe skip? + # @eqx.filter_jit + # def _run_fwd_grad_int(y0__args__term, saveat, adjoint): + # # We want a "fwd_grad" that works with integer inputs. A way to get this is to + # # linearise the function and then transpose. Here we filter out numeric types + # # because we want to catch floats and ints, not just arrays. + + # def is_numeric(x): + # return isinstance(x, (int, float, jax.numpy.ndarray)) + # dynamic, static = eqx.partition(y0__args__term, is_numeric) + + # def differentiable_run(dynamic): + # return _run(eqx.combine(dynamic, static), saveat, adjoint) + + # _, lin_fn = jax.linearize(differentiable_run, dynamic) + # grad = jax.linear_transpose(lin_fn, dynamic)(1.0) + # return grad + twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact) @eqx.filter_jit @@ -83,6 +104,11 @@ def _run_vmap_grad(twice_inexact, saveat, adjoint): f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None)) return f(twice_inexact, saveat, adjoint) + @eqx.filter_jit + def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint): + f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None)) + return f(twice_inexact, saveat, adjoint) + # @eqx.filter_jit # def _run_vmap_finite_diff(twice_inexact, saveat, adjoint): # @jax.vmap @@ -102,6 +128,17 @@ def _run_impl(twice_inexact): return _run_impl(twice_inexact) + @eqx.filter_jit + def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint): + @jax.jacfwd + def _run_impl(twice_inexact): + f = jax.vmap(_run_inexact, in_axes=(0, None, None)) + out = f(twice_inexact, saveat, adjoint) + assert out.shape == (2,) + return jnp.sum(out) + + return _run_impl(twice_inexact) + # Yep, test that they're not implemented. We can remove these checks if we ever # do implement them. # Until that day comes, it's worth checking that things don't silently break. @@ -136,9 +173,11 @@ def _convert_float0(x): inexact, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint()) + forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardAdjoint()) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) direct_grads = _run_grad_int( y0__args__term, saveat, diffrax.DirectAdjoint() @@ -149,6 +188,10 @@ def _convert_float0(x): backsolve_grads = _run_grad_int( y0__args__term, saveat, diffrax.BacksolveAdjoint() ) + # forward_grads = _run_fwd_grad_int( + # y0__args__term, saveat, diffrax.ForwardAdjoint() + # ) + direct_grads = jtu.tree_map(_convert_float0, direct_grads) recursive_grads = jtu.tree_map(_convert_float0, recursive_grads) backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads) @@ -166,9 +209,13 @@ def _convert_float0(x): backsolve_grads = _run_vmap_grad( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_vmap_fwd_grad( + twice_inexact, saveat, diffrax.ForwardAdjoint() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) direct_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.DirectAdjoint() @@ -179,9 +226,13 @@ def _convert_float0(x): backsolve_grads = _run_grad_vmap( twice_inexact, saveat, diffrax.BacksolveAdjoint() ) + forward_grads = _run_fwd_grad_vmap( + twice_inexact, saveat, diffrax.ForwardAdjoint() + ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) + assert tree_allclose(direct_grads, forward_grads, atol=1e-5) def test_adjoint_seminorm(): From 69f602bda8af0e1932e1b03ba640e054ff3f0338 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Sun, 8 Dec 2024 22:50:11 +0100 Subject: [PATCH 3/7] make version of mkdocs-autorefs explicit (https://github.com/patrick-kidger/optimistix/pull/91, but for diffrax) --- docs/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b033b3e3..44c17201 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,6 +3,7 @@ mkdocs==1.3.0 # Main documentation generator. mkdocs-material==7.3.6 # Theme pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. +mkdocs-autorefs==1.0.1 # Automatically generate references to other pages. mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included @@ -14,4 +15,4 @@ mkdocs-autorefs==1.0.1 mkdocs-material-extensions==1.3.1 # Install latest version of our dependencies -jax[cpu] +jax[cpu] \ No newline at end of file From 59a969b10bc8273ddede8d01e04274ead8a03256 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Mon, 9 Dec 2024 11:37:49 +0100 Subject: [PATCH 4/7] rename, add documentation, explicate lack of test covarage for unit-input case. --- diffrax/__init__.py | 2 +- diffrax/_adjoint.py | 10 ++++++---- test/test_adjoint.py | 34 +++++++--------------------------- 3 files changed, 14 insertions(+), 32 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 3679cf85..42073a10 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -4,7 +4,7 @@ AbstractAdjoint as AbstractAdjoint, BacksolveAdjoint as BacksolveAdjoint, DirectAdjoint as DirectAdjoint, - ForwardAdjoint as ForwardAdjoint, + ForwardMode as ForwardMode, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, ) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 6b580d14..9c084ecc 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -854,12 +854,14 @@ def loop( return final_state, aux_stats -class ForwardAdjoint(AbstractAdjoint): +class ForwardMode(AbstractAdjoint): """Differentiate through a differential equation solve during the forward pass. + (So it is not really an adjoint - it is a different way of quantifying the + sensitivity of the output to the input.) - This is a useful adjoint to use whenever we have many more outputs than inputs to a - function - for instance during parameter inference for ODE models with least-squares - based solvers such as `optimistix.Levenberg-Marquardt`. + ForwardMode is useful when we have many more outputs than inputs to a function - for + instance during parameter inference for ODE models with least-squares solvers such + as `optimistix.Levenberg-Marquardt`, that operate on the residuals. """ def loop( diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 24ee96d1..12e6ee27 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -77,26 +77,6 @@ def _run_inexact(inexact, saveat, adjoint): _run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact)) _run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True)) - # # TODO(jhaffner): jax.linearize can not handle integer input either (like jacfwd), - # # so we would have to use jvp directly and construct unit "vector" pytree(s) for - # # each leaf. This seems like an exceedingly rare use case, so maybe skip? - # @eqx.filter_jit - # def _run_fwd_grad_int(y0__args__term, saveat, adjoint): - # # We want a "fwd_grad" that works with integer inputs. A way to get this is to - # # linearise the function and then transpose. Here we filter out numeric types - # # because we want to catch floats and ints, not just arrays. - - # def is_numeric(x): - # return isinstance(x, (int, float, jax.numpy.ndarray)) - # dynamic, static = eqx.partition(y0__args__term, is_numeric) - - # def differentiable_run(dynamic): - # return _run(eqx.combine(dynamic, static), saveat, adjoint) - - # _, lin_fn = jax.linearize(differentiable_run, dynamic) - # grad = jax.linear_transpose(lin_fn, dynamic)(1.0) - # return grad - twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact) @eqx.filter_jit @@ -173,12 +153,16 @@ def _convert_float0(x): inexact, saveat, diffrax.RecursiveCheckpointAdjoint() ) backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint()) - forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardAdjoint()) + forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardMode()) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5) assert tree_allclose(direct_grads, forward_grads, atol=1e-5) + # Test support for integer inputs (jax.grad(..., allow_int=True)). There + # is no corresponding option for jax.jacfwd or jax.linearize, and a + # workaround (jvp with custom "unit pytrees" for mixed array and + # non-array inputs?) is not implemented and tested here. direct_grads = _run_grad_int( y0__args__term, saveat, diffrax.DirectAdjoint() ) @@ -188,10 +172,6 @@ def _convert_float0(x): backsolve_grads = _run_grad_int( y0__args__term, saveat, diffrax.BacksolveAdjoint() ) - # forward_grads = _run_fwd_grad_int( - # y0__args__term, saveat, diffrax.ForwardAdjoint() - # ) - direct_grads = jtu.tree_map(_convert_float0, direct_grads) recursive_grads = jtu.tree_map(_convert_float0, recursive_grads) backsolve_grads = jtu.tree_map(_convert_float0, backsolve_grads) @@ -210,7 +190,7 @@ def _convert_float0(x): twice_inexact, saveat, diffrax.BacksolveAdjoint() ) forward_grads = _run_vmap_fwd_grad( - twice_inexact, saveat, diffrax.ForwardAdjoint() + twice_inexact, saveat, diffrax.ForwardMode() ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) @@ -227,7 +207,7 @@ def _convert_float0(x): twice_inexact, saveat, diffrax.BacksolveAdjoint() ) forward_grads = _run_fwd_grad_vmap( - twice_inexact, saveat, diffrax.ForwardAdjoint() + twice_inexact, saveat, diffrax.ForwardMode() ) assert tree_allclose(fd_grads, direct_grads[0]) assert tree_allclose(direct_grads, recursive_grads, atol=1e-5) From 5c6ee640a3bf2ea7889c607bc5a1a59f618d3a0b Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Mon, 9 Dec 2024 11:39:28 +0100 Subject: [PATCH 5/7] rename import of ForwardMode --- docs/api/adjoints.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 7a76d019..5d4ede63 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -44,7 +44,7 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax selection: members: false -::: diffrax.ForwardAdjoint +::: diffrax.ForwardMode selection: members: false From 08907c782f9006ffd9f4dad08faeec9a2948412d Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 20 Dec 2024 12:05:08 +0000 Subject: [PATCH 6/7] fix duplicate --- docs/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 44c17201..ed743cfd 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,7 +3,6 @@ mkdocs==1.3.0 # Main documentation generator. mkdocs-material==7.3.6 # Theme pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. -mkdocs-autorefs==1.0.1 # Automatically generate references to other pages. mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included From e45552209228e5cabed6732726caaf245df98101 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 20 Dec 2024 12:08:23 +0000 Subject: [PATCH 7/7] Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints --- diffrax/_adjoint.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 9c084ecc..46338ea2 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): !!! info Note that this cannot be forward-mode autodifferentiated. (E.g. using - `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need. + `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode + and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need + only forward-mode autodifferentiation. ??? cite "References" @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint): So unless you need forward-mode autodifferentiation then [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. + If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is + more efficient. """ def loop( @@ -855,11 +859,13 @@ def loop( class ForwardMode(AbstractAdjoint): - """Differentiate through a differential equation solve during the forward pass. - (So it is not really an adjoint - it is a different way of quantifying the - sensitivity of the output to the input.) + """Supports forward-mode automatic differentiation through a differential equation + solve. This works by propagating the derivatives during the forward-pass - that is, + during the ODE solve, instead of solving the adjoint equations afterwards. + (So this is really a different way of quantifying the sensitivity of the output to + the input, even if its interface is that of an adjoint for convenience.) - ForwardMode is useful when we have many more outputs than inputs to a function - for + This is useful when we have many more outputs than inputs to a function - for instance during parameter inference for ODE models with least-squares solvers such as `optimistix.Levenberg-Marquardt`, that operate on the residuals. """