non-singular Laplace BIE solver, improved FFT interpolation for singular integral, and new magnetic field API#1360
non-singular Laplace BIE solver, improved FFT interpolation for singular integral, and new magnetic field API#1360
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1360 +/- ##
==========================================
+ Coverage 95.75% 95.77% +0.01%
==========================================
Files 102 104 +2
Lines 28344 28880 +536
==========================================
+ Hits 27142 27661 +519
- Misses 1202 1219 +17
🚀 New features to boost your workflow:
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -0.55 +/- 2.55 | -4.66e-03 +/- 2.17e-02 | 8.47e-01 +/- 1.7e-02 | 8.52e-01 +/- 1.4e-02 |
test_build_transform_fft_highres | -1.35 +/- 1.91 | -1.52e-02 +/- 2.14e-02 | 1.11e+00 +/- 1.6e-02 | 1.12e+00 +/- 1.4e-02 |
test_equilibrium_init_lowres | -5.67 +/- 4.64 | -3.44e-01 +/- 2.81e-01 | 5.71e+00 +/- 1.7e-01 | 6.06e+00 +/- 2.3e-01 |
test_objective_compile_atf | -7.31 +/- 5.11 | -5.96e-01 +/- 4.16e-01 | 7.55e+00 +/- 1.7e-01 | 8.15e+00 +/- 3.8e-01 |
test_objective_compute_atf | -1.94 +/- 15.07 | -4.42e-05 +/- 3.45e-04 | 2.24e-03 +/- 1.4e-04 | 2.29e-03 +/- 3.1e-04 |
test_objective_jac_atf | -0.16 +/- 3.17 | -2.80e-03 +/- 5.54e-02 | 1.75e+00 +/- 3.6e-02 | 1.75e+00 +/- 4.2e-02 |
test_perturb_1 | -1.93 +/- 2.90 | -2.98e-01 +/- 4.48e-01 | 1.52e+01 +/- 3.4e-01 | 1.54e+01 +/- 2.9e-01 |
test_proximal_jac_atf | +0.50 +/- 1.93 | +2.71e-02 +/- 1.06e-01 | 5.50e+00 +/- 5.8e-02 | 5.47e+00 +/- 8.8e-02 |
test_proximal_freeb_compute | -3.13 +/- 2.34 | -5.09e-03 +/- 3.80e-03 | 1.57e-01 +/- 2.3e-03 | 1.62e-01 +/- 3.0e-03 |
test_solve_fixed_iter | +0.31 +/- 3.24 | +8.94e-02 +/- 9.23e-01 | 2.86e+01 +/- 8.1e-01 | 2.85e+01 +/- 4.4e-01 |
test_objective_compute_ripple | -0.38 +/- 2.81 | -8.09e-04 +/- 6.01e-03 | 2.13e-01 +/- 4.6e-03 | 2.14e-01 +/- 3.9e-03 |
test_objective_grad_ripple | -0.94 +/- 5.22 | -8.55e-03 +/- 4.77e-02 | 9.05e-01 +/- 9.2e-03 | 9.14e-01 +/- 4.7e-02 |
test_build_transform_fft_lowres | -0.87 +/- 2.19 | -6.48e-03 +/- 1.63e-02 | 7.39e-01 +/- 1.1e-02 | 7.46e-01 +/- 1.2e-02 |
test_equilibrium_init_medres | +3.08 +/- 4.09 | +1.88e-01 +/- 2.49e-01 | 6.28e+00 +/- 2.4e-01 | 6.09e+00 +/- 6.6e-02 |
test_equilibrium_init_highres | +1.57 +/- 4.83 | +1.08e-01 +/- 3.34e-01 | 7.03e+00 +/- 3.2e-01 | 6.92e+00 +/- 1.1e-01 |
test_objective_compile_dshape_current | -1.20 +/- 7.31 | -4.96e-02 +/- 3.02e-01 | 4.08e+00 +/- 1.3e-01 | 4.13e+00 +/- 2.7e-01 |
test_objective_compute_dshape_current | +2.46 +/- 12.68 | +1.77e-05 +/- 9.14e-05 | 7.39e-04 +/- 7.9e-05 | 7.21e-04 +/- 4.5e-05 |
test_objective_jac_dshape_current | +2.08 +/- 16.42 | +5.50e-04 +/- 4.34e-03 | 2.70e-02 +/- 2.8e-03 | 2.64e-02 +/- 3.3e-03 |
test_perturb_2 | +1.96 +/- 3.40 | +3.58e-01 +/- 6.22e-01 | 1.86e+01 +/- 5.5e-01 | 1.83e+01 +/- 3.0e-01 |
test_proximal_jac_atf_with_eq_update | +0.12 +/- 1.47 | +1.56e-02 +/- 1.95e-01 | 1.33e+01 +/- 1.3e-01 | 1.32e+01 +/- 1.5e-01 |
+test_proximal_freeb_jac | -14.83 +/- 2.48 | -7.27e-01 +/- 1.22e-01 | 4.18e+00 +/- 5.9e-02 | 4.90e+00 +/- 1.1e-01 |
test_solve_fixed_iter_compiled | -1.58 +/- 2.36 | -1.27e-01 +/- 1.89e-01 | 7.89e+00 +/- 1.4e-01 | 8.02e+00 +/- 1.3e-01 |
test_LinearConstraintProjection_build | -3.21 +/- 2.39 | -2.77e-01 +/- 2.07e-01 | 8.35e+00 +/- 1.6e-01 | 8.62e+00 +/- 1.3e-01 |
test_objective_compute_ripple_bounce1d | -1.07 +/- 4.62 | -3.21e-03 +/- 1.39e-02 | 2.97e-01 +/- 1.0e-02 | 3.00e-01 +/- 9.6e-03 |
test_objective_grad_ripple_bounce1d | -2.21 +/- 4.58 | -2.11e-02 +/- 4.38e-02 | 9.35e-01 +/- 2.8e-02 | 9.56e-01 +/- 3.3e-02 |Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account. |
8d9bbef to
45405b6
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
and fix an unused kernel
|
@dpanici try using jaxopt's anderson acceleration |
| _, err, i, _, _ = state | ||
| return (i < m) & (~is_converged(err, tol)) | ||
|
|
||
| ## run for m steps first with del2 then do anderson |
There was a problem hiding this comment.
Caution with some of these changes
you must use iteration methods that guarentee convergence for this problem; no heuristics. Atkinson's del2 method is a heuristic accelerator. When I tested it alone, it did not converge for our integral system as can be seen by choosing that option and running test_convergence_run_fixed_point, and generating the plots with test_convergence_plot_fixed_point.
- Standard fixed point contraction guarantees convergence.
- Anderson method is guaranteed to converge, which is why I recommended it.
- (If you know the eigenspectrum, a fully iterative Chebyshev method works, but eigenspectrum of this operator is open mathematics problem as I discussed in the shared paper. So don't use chebyshev iteration).
Improvements that can be made
There are two main improvements that can be done in the fixed point method. The first is reducing number of iteration. We should see with Anderson because the function we iterate f calls singular integrals which are are so expensive, so doing some extra work with Anderson to reduce calls to f should be worthwile.
The second is that when we autodiff through this, the function _tangent_solve is computing the inverse on the right hand side of the implicit function theorem.

That inversion is likely why optimization was slow/used memory with this implementation, even though computing the objective was fast. When we do normal root solve like with map_pest_coordinates, the Jacobian is 1x1 so its inverse is just the reciprocal. In this case the Jacobian is mxm where m is the number of output points the singular integral is evaluated m~10^4. We were using standard inversion so that's grows as m^3 so the cost was 10^12!
The way to avoid this is to also use fixed_point iterative to compute that inverse as discussed here. Note this fixed point iteration to compute our vector Jacobian product doesn't involve any calls to singular integrals, so it should be super cheap. You can use Anderson here too. The jaxopt anderson method already has this built in with the methods implicit_diff=True and implicit_diff_solve=None. The latter refers to how that matrix is inverted. I did not check whether the default is fixed point for the latter. https://jaxopt.github.io/stable/_autosummary/jaxopt.AndersonAcceleration.html#jaxopt.AndersonAcceleration
There was a problem hiding this comment.
That might also be why the optimizer was stalling? Accumulating error from inverting such a large matrix directly . I still don't trust jax to do large scale computations due to the open bugs i reported with batching so that might also have been an issue.
There was a problem hiding this comment.
That comment is inaccurate actually, my bad, if you look at the code I am using "simple" for the first few iterations. I did not have "del2" used either for the iterations for any of the prior work on this PR or on the discretization patch one either
Thank you for checking this though, as these are good points. My next thing to try is jaxopt's version as it incorporates both of these improvements, and is likely a better impleemntation than my attempt at anderson here
There was a problem hiding this comment.
It was not straightforward to get JAXOPT working, as they don't seem to like us closing over all the optimizable params in our iteration function I am also not so familiar at this. Maybe it is doable, more likely we should attempt something like the linked JAX docs kaya posted, but also that example is not using custom_root so it needs some changing around to work with our setup, I think.
Also, while the fixed point iteration for the jacobian does not need any singuar integrals, wouldn't we also still require one of the jacobians of f, either w.r.t a or x, which would still need many singular integral evaluations? like the "A" and "B" things in the linked jax docs which are themselves derivatives of the function f
There was a problem hiding this comment.
Could it be as simple as a change like this (just replacing _tangent_solve with the below in the custom_root call of fixed_point):
def _tangent_solve_fixed_pt(g, y):
# map f(u) = y + u @ A
# where A is jax.jacfwd(fxn) wrt x
# g here is fxn(x)-x
# jac.jacfwd(g) = A-I
# so jac.jacfwd(g)+I = A
# then we can use fixed_pt to find
# fixed pt of f(u=y+u@A)
# and return that?
A = jax.jacfwd(g)(y)
A = A + jnp.eye(A.shape[0], A.shape[1])
def f(u):
return y + u @ A
u0 = y
p, _, _, _, _ = _fixed_point(
f,
u0,
tol=1e-8,
maxiter=100,
method="anderson",
is_converged=_is_converged,
m=4,
beta=0.25,
)
return p
There was a problem hiding this comment.
the above does not seem to solve either the speed or the non-convergence in vacuum, have not looked into the memory use.
On this PR it seems to change behavior of the optimization, but on the discretization patch PR it does not seem to impact the final solution much. Still need to understand why
…on the grad(phi) part of test fails despite Phi error going to machine precision
All the solvers are thoroughly tested. See the shared paper to explain this PR.
This is the only open source JAX AD compatible BIE solver.
The robustness of the free surface optimization is outside scope of this pull request. See #1894 and #1208 for more details. There are plenty of tests that illustrate correctness, and tuning optimization hyper parameter defaults which is problem dependent anyway should not delay merging this PR.
This PR includes my interpax pull requests which significantly improve free surface optimization (both old and new) and anything that involves singular integrals.
The quadrature in master has been observed to have unfortunately slow convergence for some test cases. It has been shown the improved quadrature in the linked paper will converge faster especially for the non singular kernel of this system. It can be added in a different PR.