Skip to content

Custom JVP/VJP implementation #111

@quattro

Description

@quattro

Hi! Thanks for developing an excellent library.

My group has been using it to perform numerical integration for some inferential tasks. Currently, we have been deriving gradients by hand, and running multiple passes to compute an objective as well as expected gradients which has worked well, but fails to make use of JAX's most powerful asset: autodiff.

Naively running autodiff for us results in exorbitant memory costs (even with as few as N = 100 samples) which precludes us from making use of vmap and other tasks to deal with multiple observations.

FWD autodiff would be great to handle reducing memory issues, and I was curious if you have plans for custom JVP implementation at some point soon? I realize with the dynamic integration points this may be challenging for some of the more advanced integrators but was curious.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions