fix linearise for JacobianLinearOperator with jac=bwd and use linear_transpose in mv#191
Conversation
|
I'm not sure the best way to handle Note that if you call jax.vjp instead of a jax.jvp you get a much more helpful tailor-made error message: |
patrick-kidger
left a comment
There was a problem hiding this comment.
Nice! This looks really good to me.
Indeed, very happy to rebase the other PR on top of this one.
|
Awesome, merged! 🎉 And on |
It turns out
lx.linearisefails for JacobianLinearOperator withjac="bwd"if there is custom vjp which prevents forward-mode autodiff. This was missed in the test because just CALLINGlx.lineariseis fine—we can create a jaxpr representing the failure, but an error is only raised on EVALUATION. I updated the tests so they would fail and corrected linearise to use jax.linear_transpose (it turns out you can use these even with a custom vjp!).I also utilised linear_transpose in the mv so we're not computing
jacreveach time, however for most use cases a user should ALWAYS want to calllx.lineariseon such a JacobianLinearOperator to avoid recomputing the primal (unlike using jvp with fwd mode, linear_transpose has no memory advantages so there is no reason to not cache it for reuse unless you know you're only going to need a singlemv).Users who used custom vjp's that are nonlinear or affine in their cotangents will now get jax errors, this is expected and such custom vjp's are fundamentally incorrect. I had to correct one of the tests that used an affine custom vjp to use a linear one.
These improvements should simplify the coloring method PR's I have open for JacobianLinearOperator as
operator.T.mvwill now work with custom vjp's if symmetric (previously this failed and we had to right out the backward mode more verbosely).