-
-
Notifications
You must be signed in to change notification settings - Fork 170
Description
Hi!
This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function
In code, building on your snippet from #181
def L(y, theta): # some loss
...
def dLdy(t, y, args): # vector field for gradient system
return -jax.grad(L)(y, args)
def solve_y(y0, theta):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(vector_field),
y0=y0,
args=theta,
...
)
return sol.ys
def dLdtheta(self, y, theta):
return grad(L, argnums=(1))(y, theta)
Given these, I could just solve for y and then take the gradient wrt theta, like so
y_sol = solve_y(y0, theta):
theta_grad = dLdtheta(y_sol, theta)
However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y
def equilibrated_L(y0, theta): # equilibrated loss
y_sol = solve_y(y0, theta)
...
return L
def dLdtheta(self, y, theta):
return grad(equilibrated_L, argnums=(1))(y, theta)
theta_grad = dLdtheta(y, theta)
But using this approach I get a # TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function..
Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?
Thanks!