diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index bd54222d..53267b21 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -80,21 +80,21 @@ def inv_jacobian_product(pytree: Any, Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`. c_history: pytree with the same structure as `pytree`. Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`. - gamma: scalar to use for the initial inverse jacobian approximation, + gamma: pytree with scalars to use for the initial inverse jacobian approximation, i.e., `gamma * I`. start: starting index in the circular buffer. """ fun = partial(inv_jacobian_product_leaf, - gamma=gamma, start=start) - return tree_map(fun, pytree, d_history, c_history) + return tree_map(fun, pytree, d_history, c_history, gamma) def inv_jacobian_rproduct(pytree: Any, d_history: Any, c_history: Any, gamma: float = 1.0, start: int = 0): - return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start) + gamma_conj = tree_map(jnp.conjugate, gamma) + return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start) def init_history(pytree, history_size): @@ -115,7 +115,7 @@ class BroydenState(NamedTuple): error: float d_history: Any c_history: Any - gamma: jnp.ndarray + gamma: Any aux: Optional[Any] = None failed_linesearch: bool = False @@ -205,6 +205,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 + compute_gamma: bool = True implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -244,18 +245,37 @@ def init_state(self, iter_num=init_params.state.iter_num, stepsize=init_params.state.stepsize, ) + # XXX: not computing the jacobian init approx + # when starting from an OptStep object init_params = init_params.params dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) else: dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) + if self.compute_gamma: + # we use scipy's formula: + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569 + # self.alpha = 0.5*max(norm(x0), 1) / normf0 + normf0 = tree_map(jnp.linalg.norm, value) + normx0 = tree_map(jnp.linalg.norm, init_params) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0) + def safe_divide_by_zero(x, y): + # a classical division of x by y + # when y == 0 then return 1 + return jnp.where(y == 0, 1, x / y) + gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) + else: + gamma = self.gamma + # repeat gamma as a pytree of the shape of init_params + gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), c_history=init_history(init_params, self.history_size), - gamma=jnp.asarray(self.gamma, dtype=dtype), + gamma=gamma, iter_num=jnp.asarray(0), stepsize=jnp.asarray(self.max_stepsize, dtype=dtype), ) - value, aux = self._value_with_aux(init_params, *args, **kwargs) return BroydenState(value=value, error=jnp.asarray(jnp.inf), **state_kwargs, diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 60503a55..e14196a3 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0) + sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol) @@ -134,7 +134,7 @@ def test_affine_contractive_mapping(self): def g(x, M, b): return M @ x + b - x tol = 1e-6 - fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) + fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) self.assertLess(state.error, tol)