Skip to content
8 changes: 4 additions & 4 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BacktrackingLineSearchState(NamedTuple):
params: Any
value: float
grad: Any # either initial or final for armijo or glodstein
value_init: float
value_init: float
grad_init: Any
error: float
done: bool
Expand Down Expand Up @@ -260,11 +260,11 @@ def update(

if self.condition in ["armijo", "goldstein"]:
# If we are done for the armijo or the goldstein conditions,
# we compute the final gradient (we had not computed it before since
# we compute the final gradient (we had not computed it before since
# these conditions did not require it)
new_grad = cond(done | failed,
self._compute_final_grad,
lambda *_: grad,
lambda *_: grad,
new_params, fun_args, fun_kwargs,
jit=self.jit)
maybe_additional_eval = jnp.asarray(done | failed, dtype=base.NUM_EVAL_DTYPE)
Expand All @@ -284,7 +284,7 @@ def update(
num_grad_eval=num_grad_eval)

return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

def _compute_final_grad(self, params, fun_args, fun_kwargs):
return self._grad_with_aux(params, *fun_args, **fun_kwargs)[0]

Expand Down
138 changes: 89 additions & 49 deletions jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Limited-memory Broyden method"""

import warnings

from functools import partial

from typing import Any
Expand All @@ -28,7 +30,7 @@
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
from jaxopt._src.linesearch_util import _setup_linesearch, _init_stepsize
from jaxopt.tree_util import tree_map
from jaxopt.tree_util import tree_vdot
from jaxopt.tree_util import tree_add_scalar_mul
Expand Down Expand Up @@ -80,21 +82,21 @@ def inv_jacobian_product(pytree: Any,
Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`.
c_history: pytree with the same structure as `pytree`.
Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`.
gamma: scalar to use for the initial inverse jacobian approximation,
gamma: pytree with scalars to use for the initial inverse jacobian approximation,
i.e., `gamma * I`.
start: starting index in the circular buffer.
"""
fun = partial(inv_jacobian_product_leaf,
gamma=gamma,
start=start)
return tree_map(fun, pytree, d_history, c_history)
return tree_map(fun, pytree, d_history, c_history, gamma)

def inv_jacobian_rproduct(pytree: Any,
d_history: Any,
c_history: Any,
gamma: float = 1.0,
start: int = 0):
return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start)
gamma_conj = tree_map(jnp.conjugate, gamma)
return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start)


def init_history(pytree, history_size):
Expand All @@ -115,7 +117,7 @@ class BroydenState(NamedTuple):
error: float
d_history: Any
c_history: Any
gamma: jnp.ndarray
gamma: Any
aux: Optional[Any] = None
failed_linesearch: bool = False

Expand Down Expand Up @@ -194,17 +196,19 @@ class Broyden(base.IterativeSolver):

stepsize: Union[float, Callable] = 0.0
linesearch: str = "backtracking"
linesearch_init: str = "increase"
stop_if_linesearch_fails: bool = False
condition: str = "wolfe"
condition: Any = None # deprecated in v0.8
maxls: int = 15
decrease_factor: float = 0.8
decrease_factor: Any = None # deprecated in v0.8
increase_factor: float = 1.5
max_stepsize: float = 1.0
# FIXME: should depend on whether float32 or float64 is used.
min_stepsize: float = 1e-6

history_size: int = None
gamma: float = 1.0
compute_gamma: bool = True

implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
Expand Down Expand Up @@ -244,18 +248,37 @@ def init_state(self,
iter_num=init_params.state.iter_num,
stepsize=init_params.state.stepsize,
)
# XXX: not computing the jacobian init approx
# when starting from an OptStep object
init_params = init_params.params
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
else:
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
if self.compute_gamma:
# we use scipy's formula:
# https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569
# self.alpha = 0.5*max(norm(x0), 1) / normf0
normf0 = tree_map(jnp.linalg.norm, value)
normx0 = tree_map(jnp.linalg.norm, init_params)
clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0)
def safe_divide_by_zero(x, y):
# a classical division of x by y
# when y == 0 then return 1
return jnp.where(y == 0, 1, x / y)
gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0)
else:
gamma = self.gamma
# repeat gamma as a pytree of the shape of init_params
gamma = tree_map(lambda x: jnp.array(gamma), init_params)
state_kwargs = dict(
d_history=init_history(init_params, self.history_size),
c_history=init_history(init_params, self.history_size),
gamma=jnp.asarray(self.gamma, dtype=dtype),
gamma=gamma,
iter_num=jnp.asarray(0),
stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
return BroydenState(value=value,
error=jnp.asarray(jnp.inf),
**state_kwargs,
Expand Down Expand Up @@ -308,45 +331,29 @@ def update(self,

use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0
if use_linesearch:
if self.linesearch == "backtracking":
# we need to build the function used for the line search
# which is going to be the squared norm of the original function
# as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278
# we then need to check if the gradient can be obtained with jax
# and if not we can build it in the same fashion as scipy
# https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285
def ls_fun_with_aux(params, *args, **kwargs):
f, aux = self._value_with_aux(params, *args, **kwargs)
norm_squared = tree_l2_norm(f, squared=True)
return norm_squared, (f, aux)
# here we need a check if the function is not smooth
ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True)
ls = BacktrackingLineSearch(fun=ls_fun_with_aux_and_grad,
value_and_grad=True,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
max_stepsize=self.max_stepsize,
condition=self.condition,
jit=self.jit,
unroll=self.unroll,
has_aux=True,
tol=1e-2)
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
self.max_stepsize,
# Else, we increase a bit the previous one.
state.stepsize * self.increase_factor)
new_stepsize, ls_state = ls.run(init_stepsize,
params, value, None,
descent_direction,
fun_args=args, fun_kwargs=kwargs)
new_value, new_aux = ls_state.aux
new_params = ls_state.params
new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num
new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval
failed_linesearch = ls_state.failed
else:
raise ValueError("Invalid name in 'linesearch' option.")
init_stepsize = _init_stepsize(
self.linesearch_init,
self.max_stepsize,
self.min_stepsize,
self.increase_factor,
state.stepsize,
)
new_stepsize, ls_state = self.run_ls(
init_stepsize,
params,
value=tree_l2_norm(value),
# in the case of Broyden, it's the value that's actually the equivalent
# of the gradient in the optimization case.
grad=value,
descent_direction=descent_direction,
fun_args=args,
fun_kwargs=kwargs,
)
new_value, new_aux = ls_state.aux
new_params = ls_state.params
new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num
new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval
failed_linesearch = ls_state.failed
else:
# without line search
if isinstance(self.stepsize, Callable):
Expand Down Expand Up @@ -409,3 +416,36 @@ def __post_init__(self):

if self.history_size is None:
self.history_size = self.maxiter

# we need to build the function used for the line search
# which is going to be the squared norm of the original function
# as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278
# we then need to check if thtree_l2_norme gradient can be obtained with jax
# and if not we can build it in the same fashion as scipy
# https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285
def ls_fun_with_aux(params, *args, **kwargs):
f, aux = self._value_with_aux(params, *args, **kwargs)
norm_squared = tree_l2_norm(f, squared=True)
return norm_squared, (f, aux)
# here we need a check if the function is not smooth
ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True)
self.linesearch_solver = _setup_linesearch(
linesearch=self.linesearch,
fun=ls_fun_with_aux_and_grad,
value_and_grad=True,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
jit=self.jit,
unroll=self.unroll,
verbose=self.verbose,
)
self.run_ls = self.linesearch_solver.run

# FIXME: to remove in future releases
if self.condition is not None:
warnings.warn("Argument condition is deprecated", DeprecationWarning)
if self.decrease_factor is not None:
warnings.warn(
"Argument decrease_factor is deprecated", DeprecationWarning
)
4 changes: 2 additions & 2 deletions tests/broyden_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11
return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1]
x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]])
tol = 1e-6
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0)
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, stop_if_linesearch_fails=True).run(x0)
self.assertLess(state.error, tol)
g_sol_norm = tree_l2_norm(g(sol))
self.assertLess(g_sol_norm, tol)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_affine_contractive_mapping(self):
def g(x, M, b):
return M @ x + b - x
tol = 1e-6
fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1)
fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True)
x0 = jnp.zeros_like(b)
sol, state = fp.run(x0, M, b)
self.assertLess(state.error, tol)
Expand Down