Skip to content

Conversation

@johannahaffner
Copy link
Contributor

@johannahaffner johannahaffner commented Dec 8, 2024

Hi Patrick,

do you remember the forward-mode adjoint you wrote for me all those many months ago?

I've been wanting to make this useful to others for a long time, here it is! I have some questions:

  1. I did not manage to write a test that mimics the behaviour of jax.grad(..., allow_int=True). jacfwd and jax.linearize do not support integer inputs, writing some custom thing that computes a JVP with respect to "unit pytrees" with mixed types and non-arrays is... a pain.
    This seems like an exceedingly rare use case given the difficulties, but maybe we should document somewhere that this option is not, in fact, available. (Or leave it be, since diffrax does not differ from jax in this respect.)
  2. I wrote a two-liner documentation thing - and frankly, I don't think I understand what magic is actually happening well enough to explain more!
  3. Anything else? Does forward-mode automatic differentiation have edge cases that we should test explicitly?

To serve documentation, I needed to make the version of mkdocs-autorefs explicit (to the second-to-last release or so), so I added that too. This also came up for optimistix a while ago.

Lastly, this is rebased on dev as of now - but I can imagine that your weekend was quite busy following Friday's release, so if you'd like to postpone this, I'm not in a rush! I will mark this as a draft, due to the questions above.

@johannahaffner johannahaffner marked this pull request as draft December 8, 2024 22:06
@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev December 9, 2024 09:04
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, I did not mean to close this :'D That happens automatically when the target branch (here dev) gets merged and removed.

If you can reopen this against main then I'd be very happy to take this PR! I agree we should really make this available to the world at large.

AbstractAdjoint as AbstractAdjoint,
BacksolveAdjoint as BacksolveAdjoint,
DirectAdjoint as DirectAdjoint,
ForwardAdjoint as ForwardAdjoint,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we really shouldn't call this 'adjoint' as that word refers specifically to reverse-mode autodifferentiation. Maybe ForwardMode instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha see I should really dig deeper into the theory of this!

Should we then rename the section in the documentation? Maybe "Differentiating through a solve"? ForwardMode also inherits from AbstractAdjoint.

The easiest thing to do would be to simply add a comment to the documentation of ForwardMode, that it is not really an adjoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants