Skip to content

How to evaluate derivatives of diffrax-solved, equilibrated functions? #432

@francesco-innocenti

Description

@francesco-innocenti

Hi!

This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function $\mathcal{L}$ with respect to some variable $\theta$ at the gradient equilibrium of that loss with respect to some other variable $\partial \mathcal{L}/\partial{y} \approx 0$. Mathematically this would be something like

$\LARGE{\frac{\partial \mathcal{L}(y; \theta)}{\partial \theta}|_{\frac{\partial \mathcal{L}}{\partial y}\approx 0}}$

In code, building on your snippet from #181

def L(y, theta):  # some loss
    ... 

def dLdy(t, y, args):  # vector field for gradient system
    return -jax.grad(L)(y, args)

def solve_y(y0, theta):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        y0=y0,
        args=theta,
        ...
    )
    return sol.ys

def dLdtheta(self, y, theta):
    return grad(L, argnums=(1))(y, theta)

Given these, I could just solve for y and then take the gradient wrt theta, like so

y_sol = solve_y(y0, theta):
theta_grad = dLdtheta(y_sol, theta)

However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y

def equilibrated_L(y0, theta):  # equilibrated loss
    y_sol = solve_y(y0, theta)
    ...
    return L

def dLdtheta(self, y, theta):
    return grad(equilibrated_L, argnums=(1))(y, theta)

theta_grad = dLdtheta(y, theta)

But using this approach I get a # TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function..

Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions