-
Notifications
You must be signed in to change notification settings - Fork 44
Open
Labels
questionUser queriesUser queries
Description
Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.
The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?
For reference, this is my MWE in case I am not doing things efficiently:
from jax import jit, jacfwd, vmap, random
import optimistix as optx
def fn(y, b):
return (y-b)**2
M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)
y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries