Skip to content
34 changes: 27 additions & 7 deletions jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ def inv_jacobian_product(pytree: Any,
Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`.
c_history: pytree with the same structure as `pytree`.
Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`.
gamma: scalar to use for the initial inverse jacobian approximation,
gamma: pytree with scalars to use for the initial inverse jacobian approximation,
i.e., `gamma * I`.
start: starting index in the circular buffer.
"""
fun = partial(inv_jacobian_product_leaf,
gamma=gamma,
start=start)
return tree_map(fun, pytree, d_history, c_history)
return tree_map(fun, pytree, d_history, c_history, gamma)

def inv_jacobian_rproduct(pytree: Any,
d_history: Any,
c_history: Any,
gamma: float = 1.0,
start: int = 0):
return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start)
gamma_conj = tree_map(jnp.conjugate, gamma)
return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start)


def init_history(pytree, history_size):
Expand All @@ -115,7 +115,7 @@ class BroydenState(NamedTuple):
error: float
d_history: Any
c_history: Any
gamma: jnp.ndarray
gamma: Any
aux: Optional[Any] = None
failed_linesearch: bool = False

Expand Down Expand Up @@ -205,6 +205,7 @@ class Broyden(base.IterativeSolver):

history_size: int = None
gamma: float = 1.0
compute_gamma: bool = True

implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
Expand Down Expand Up @@ -244,18 +245,37 @@ def init_state(self,
iter_num=init_params.state.iter_num,
stepsize=init_params.state.stepsize,
)
# XXX: not computing the jacobian init approx
# when starting from an OptStep object
init_params = init_params.params
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
else:
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
if self.compute_gamma:
# we use scipy's formula:
# https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569
# self.alpha = 0.5*max(norm(x0), 1) / normf0
normf0 = tree_map(jnp.linalg.norm, value)
normx0 = tree_map(jnp.linalg.norm, init_params)
clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0)
def safe_divide_by_zero(x, y):
# a classical division of x by y
# when y == 0 then return 1
return jnp.where(y == 0, 1, x / y)
gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0)
else:
gamma = self.gamma
# repeat gamma as a pytree of the shape of init_params
gamma = tree_map(lambda x: jnp.array(gamma), init_params)
state_kwargs = dict(
d_history=init_history(init_params, self.history_size),
c_history=init_history(init_params, self.history_size),
gamma=jnp.asarray(self.gamma, dtype=dtype),
gamma=gamma,
iter_num=jnp.asarray(0),
stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
return BroydenState(value=value,
error=jnp.asarray(jnp.inf),
**state_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions tests/broyden_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11
return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1]
x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]])
tol = 1e-6
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0)
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0)
self.assertLess(state.error, tol)
g_sol_norm = tree_l2_norm(g(sol))
self.assertLess(g_sol_norm, tol)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_affine_contractive_mapping(self):
def g(x, M, b):
return M @ x + b - x
tol = 1e-6
fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1)
fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True)
x0 = jnp.zeros_like(b)
sol, state = fp.run(x0, M, b)
self.assertLess(state.error, tol)
Expand Down