Streamline tridiagonalisation of JacobianLinearOperator/FunctionLinearOperator using coloring method#165
Conversation
merge main into fork
…rOperator using coloring methods
|
I addressed the inefficiencies compared to We might want to consider adding attribution to sparsejac somewhere under their MIT license. I will update the original PR comment. |
…l: skipped 80 tests but only 4 failed)
|
CI passes on my machine with 3.13 and on GH with 3.10. The only failure on GH with 3.12 is with GMRES I get this error: |
|
Nice, I like this in the same way as #164. I think the same comments I have there apply here, otherwise this essentially LGTM. I'm happy to assume the presence of a vmap rule where necessary. I thnk if need be then the custom primitive can be provided a vmap rule that just does a for loop at that level, rather than attempting to hoist it out here. |
|
Hi, just wanted to ask what the status is for this PR? We got a request for a new release, and I want to compile a To-Do list for what we want to have in before we do a release. |
|
See #164 for current blockers there that apply here too, thanks! |
Thank you :) I'll familiarise myself with the state of the discussion. |

BREAKING CHANGE: Tridiagonals are no longer "extracted" from operators but rely on the promise of the
tridiagonaltag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.Preface
Very similar in spirit to #164 and partially addresses the problem discussed in #149. This turned out to be a lot simpler than I anticipated and thought it would be helpful to compare and contrast the two implementations and performance impacts. In this case, the performance impact is significant and consistent for array sizes > 75 even when the JVP batching rule is efficient. However, this implementation does not support primitives with no batching rule is available at all. As I am just
vmapping over three vectors perhaps I could write an explicitforloop but probably best to prioritise efficiency here over flexibility here (or maybe we can do both if we can query whether the rule exists within a jit function?).Some of the variable names and overall spirit of the method is inspired direct from sparsejac so we should probably add attribution (under MIT license) somewhere.
Changes made
The single-dispatch
tridiagonalfunction has been re-written forJacobianLinearOperator/FunctionLinearOperatorbased on the observation made by coloring methods that all elements of a tridiagonal matrix can be disentagled from just three carefully constructed vector pre/post-multiplications. I have gone with the simple case where each vector consists of ones in every third position and zeros otherwise. This means the full Jacobian need never be calculated.Performance Impact
Applying this to the problem discussed in #149 where we attempt to use
optimistixto provide an implicit solve of a diffusion equation and compare this with straightforwardlineax.linear_solves with various operators (withEQX_ON_ERROR=nan):Code
Comparison for various sizes of

rhs(usingEQX_ON_ERROR=nanandjax_enable_x64with all other environment variables at their defaults) :We see that this PR will lead to improvements in Newton solve (which uses
FunctionLinearOperatorunder the hood) andlinear_solvewithJacobianLinearOperatorfor array sizes beyond 75.JacobianLinearOperatoris now as efficient asTridiagonalLinearOperator(at least with a linear function). A performance gain >15x is observed for array sizes of 1E3 and of greater 400x for array sizes of 1E4. The performance hit for smaller array sizes is probably again due to threading which I haven't controlled here (happy to look into it). As discussed in #149 the factor of 2 performance hit of the Newton solve is likely due to the Caucy termination test which requires two steps.I have not done testing for custom primitives yet, let me know whether I should add this.
Testing done
test_tridiagonaltest intest_operators.pymake_jacfwd_operatorandmake_jacrev_operatorinhelpers.pyto ensure this continues to work whenjac=True