Skip to content

Including user-defined Jacobian #17

@Justin-Tan

Description

@Justin-Tan

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions