-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Hi! Thanks for developing an excellent library.
My group has been using it to perform numerical integration for some inferential tasks. Currently, we have been deriving gradients by hand, and running multiple passes to compute an objective as well as expected gradients which has worked well, but fails to make use of JAX's most powerful asset: autodiff.
Naively running autodiff for us results in exorbitant memory costs (even with as few as N = 100 samples) which precludes us from making use of vmap and other tasks to deal with multiple observations.
FWD autodiff would be great to handle reducing memory issues, and I was curious if you have plans for custom JVP implementation at some point soon? I realize with the dynamic integration points this may be challenging for some of the more advanced integrators but was curious.