Skip to content

fix linearise for JacobianLinearOperator with jac=bwd and use linear_transpose in mv#191

Merged
patrick-kidger merged 7 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/jacbwd
Jan 31, 2026
Merged

fix linearise for JacobianLinearOperator with jac=bwd and use linear_transpose in mv#191
patrick-kidger merged 7 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/jacbwd

Conversation

@jpbrodrick89
Copy link
Contributor

It turns out lx.linearise fails for JacobianLinearOperator with jac="bwd" if there is custom vjp which prevents forward-mode autodiff. This was missed in the test because just CALLING lx.linearise is fine—we can create a jaxpr representing the failure, but an error is only raised on EVALUATION. I updated the tests so they would fail and corrected linearise to use jax.linear_transpose (it turns out you can use these even with a custom vjp!).

I also utilised linear_transpose in the mv so we're not computing jacrev each time, however for most use cases a user should ALWAYS want to call lx.linearise on such a JacobianLinearOperator to avoid recomputing the primal (unlike using jvp with fwd mode, linear_transpose has no memory advantages so there is no reason to not cache it for reuse unless you know you're only going to need a single mv).

Users who used custom vjp's that are nonlinear or affine in their cotangents will now get jax errors, this is expected and such custom vjp's are fundamentally incorrect. I had to correct one of the tests that used an affine custom vjp to use a linear one.

These improvements should simplify the coloring method PR's I have open for JacobianLinearOperator as operator.T.mv will now work with custom vjp's if symmetric (previously this failed and we had to right out the backward mode more verbosely).

@jpbrodrick89
Copy link
Contributor Author

I'm not sure the best way to handle test_tangent_as_matrix failure for make_jacrev_operator. Intuitively, we should just skip it as jvp shouldn't be expected to work when we have a custom jvp, but it seems the actual reason is more subtle: the test tries to compute a jvp of operator construction (in the case of make_jac_operate this function ALSO constructs a function with a custom_vjp depending on matrix that is used to construct the operator) from it. The test actually complains about variables being closed over. The problem essentially is that custom_vjp depends on a closed over variable that cannot ever be known until trace when the value of matrix is provided. I'm not completely clear about the purpose of the test but my gut feel is that this is still an intuitive limitation of custom vjp's and we should feel comfortable skipping it. Let me know if you think we need to replace it with anything or create the custom vjp at module level?

Note that if you call jax.vjp instead of a jax.jvp you get a much more helpful tailor-made error message:

  jax._src.interpreters.ad.CustomVJPException: Detected differentiation of a
  custom_vjp function with respect to a closed-over value. That isn't supported
  because the custom VJP rule only specifies how to differentiate the custom_vjp
  function with respect to explicit input parameters. Try passing the closed-over
  value into the custom_vjp function as an argument, and adapting the custom_vjp
  fwd and bwd rules.

@jpbrodrick89
Copy link
Contributor Author

Also I think it would be cleaner if we merged this before #164 please then I can update #164 accordingly. Understand if you disagree though.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This looks really good to me.

Indeed, very happy to rebase the other PR on top of this one.

@patrick-kidger patrick-kidger merged commit c718ff8 into patrick-kidger:main Jan 31, 2026
1 check failed
@patrick-kidger
Copy link
Owner

Awesome, merged! 🎉

And on test_tangent_as_matrix, yup I think skipping this should be fine. As you note it's not really defined (I think regardless of whether the custom_vjp actually occurs at the module level).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants