Conversation
|
Thanks for opening this! I wonder if this should be a thing we check every step, or just at initialisation? I'm guessing that checking every step may be moderately expensive. If we just do it at initialisation then I think that would also be something that we can just do for every solver. WDYT?
On this topic, this choice is essentially just to help out authors of custom solvers. A non-successful result is always intended to indicate a failure case, and this way it's not possible to forget to check it. |
|
I would say it's definitely needed on initialization, but also on every (n-th) iteration to catch divergence and stop early. Not sure about how expensive the check is, it should be linear in the number of parameters. A cheaper but less explicit alternative could be hiding the divergence check in y_diff_norm = norm((ω(y_diff).call(jnp.abs) / y_scale**ω).ω)
f_diff_norm = norm((ω(f_diff).call(jnp.abs) / f_scale**ω).ω)
diverged = jnp.invert(jnp.isfinite(y_diff_norm) & jnp.isfinite(f_diff_norm))and returning A downside of this could be that if Update: added a link to an implementation of |
Issue:
Please correct me if this is intentional, but I noticed that the termination/continuation in
_iteratemight have a bug.I ran into an issue where non-finite values in the initial parameters lead to an optimization running until
max_steps.Currently solvers initialize
result=RESULTS.successfuland change it if there is a failure of some kind (e.g. parameter divergence).I would assume that in those failure cases optimization would stop. However, with continuing if
jnp.invert(terminate) | (result != RESULTS.successful), it keeps running untilmax_stepsis reached.Minimal example reproducing the problem:
Interestingly, with JAXopt I didn't run into this because their continuation criterion was false if
state.errorwas NaN -- given by something liketwo_norm(inf - inf)in these cases.Changes:
_iterateto continue ifjnp.invert(terminate | (result != RESULTS.successful)), so stop if the solver says so or not successful._iterateinstead (see todo).newton_chordto only stop after 2 iterations are done and its test.Additional comment:
While I think the current solution works, for me it is unintuitive that solvers have a
terminatemethod, but that does not actually fully determine if they stop or not -- as_iteratelooks at bothterminateandresultsreturned by the solver.My ideal solution would be that
_iteratestops only based on the value ofterminateand the solver has all control over that. But I think that would require updating each solver'sterminatemethod, so I wanted to run this by you before doing that in case I'm wrong here.