From 4bb4f110db510657e810c12f9f4e9c4a4ec5b548 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:29:49 +0200 Subject: [PATCH 1/6] added the computed initialization of the jac approx in Broyden --- jaxopt/_src/broyden.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index dc7fe79b..de911960 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -80,14 +80,13 @@ def inv_jacobian_product(pytree: Any, Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`. c_history: pytree with the same structure as `pytree`. Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`. - gamma: scalar to use for the initial inverse jacobian approximation, + gamma: pytree with scalars to use for the initial inverse jacobian approximation, i.e., `gamma * I`. start: starting index in the circular buffer. """ fun = partial(inv_jacobian_product_leaf, - gamma=gamma, start=start) - return tree_map(fun, pytree, d_history, c_history) + return tree_map(fun, pytree, d_history, c_history, gamma) def inv_jacobian_rproduct(pytree: Any, d_history: Any, @@ -115,7 +114,7 @@ class BroydenState(NamedTuple): error: float d_history: Any c_history: Any - gamma: jnp.ndarray + gamma: Any aux: Optional[Any] = None failed_linesearch: bool = False @@ -205,6 +204,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 + compute_gamma: bool = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -244,24 +244,44 @@ def init_state(self, iter_num=init_params.state.iter_num, stepsize=init_params.state.stepsize, ) + # XXX: not computing the jacobian init approx + # when starting from an OptStep object init_params = init_params.params dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) else: dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) + if self.compute_gamma: + # we use scipy's formula: + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569 + # self.alpha = 0.5*max(norm(x0), 1) / normf0 + normf0 = tree_map(jnp.linalg.norm, value) + normx0 = tree_map(jnp.linalg.norm, init_params) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.max(x, 1), normx0) + def safe_divide_by_zero(x, y): + # a classical division of x by x + # when y == 0 then return 1 + return jnp.where(y == 0, 1, x / y) + gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) + return gamma + else: + gamma = self.gamma + # repeat gamma as a pytre of the shape of init_params + gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), c_history=init_history(init_params, self.history_size), - gamma=jnp.asarray(self.gamma, dtype=dtype), + gamma=gamma, iter_num=jnp.asarray(0), stepsize=jnp.asarray(self.max_stepsize, dtype=dtype), ) - value, aux = self._value_with_aux(init_params, *args, **kwargs) return BroydenState(value=value, error=jnp.asarray(jnp.inf), **state_kwargs, aux=aux, failed_linesearch=jnp.asarray(False), - num_fun_eval=jnp.array(1, base.NUM_EVAL_DTYPE), + num_fun_eval=jnp.array(1, base.NUM_EVAL_DTYPE), num_linesearch_iter=jnp.array(0, base.NUM_EVAL_DTYPE) ) From c6b20e698dd10cfa2c0ba2b78cd6beae84e09058 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:33:41 +0200 Subject: [PATCH 2/6] corrected right product gamma --- jaxopt/_src/broyden.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index de911960..c9877451 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -93,7 +93,8 @@ def inv_jacobian_rproduct(pytree: Any, c_history: Any, gamma: float = 1.0, start: int = 0): - return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start) + gamma_conj = tree_map(jnp.conjugate, gamma) + return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start) def init_history(pytree, history_size): From e2d4ec687f8d32b9f7e8f7af39a17584243d5dd0 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:45:45 +0200 Subject: [PATCH 3/6] few corrections to gamma computation --- jaxopt/_src/broyden.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index c9877451..dbaaa597 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -205,7 +205,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 - compute_gamma: bool = False + compute_gamma: bool = True implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -259,13 +259,12 @@ def init_state(self, # self.alpha = 0.5*max(norm(x0), 1) / normf0 normf0 = tree_map(jnp.linalg.norm, value) normx0 = tree_map(jnp.linalg.norm, init_params) - clipped_normx0 = tree_map(lambda x: 0.5 * jnp.max(x, 1), normx0) + clipped_normx0 = tree_map(lambda x: - 0.5 * jnp.maximum(x, 1), normx0) def safe_divide_by_zero(x, y): # a classical division of x by x # when y == 0 then return 1 return jnp.where(y == 0, 1, x / y) gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) - return gamma else: gamma = self.gamma # repeat gamma as a pytre of the shape of init_params From 44646a307c2a8b18bf538146821a2b99eb8ca7e7 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:46:06 +0200 Subject: [PATCH 4/6] increased tolerance for a broyden test --- tests/broyden_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 60503a55..5e24b6bd 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -133,7 +133,7 @@ def test_affine_contractive_mapping(self): b = jax.random.uniform(subkey, shape=(n,)) def g(x, M, b): return M @ x + b - x - tol = 1e-6 + tol = 5e-6 fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) From 23529d5dfd11e7d25ea60e8c85e5b5a7d89a9743 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 16:40:32 +0100 Subject: [PATCH 5/6] typos correction + test correction => needed more iterations --- jaxopt/_src/broyden.py | 6 +++--- tests/broyden_test.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index c6d42258..53267b21 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -259,15 +259,15 @@ def init_state(self, # self.alpha = 0.5*max(norm(x0), 1) / normf0 normf0 = tree_map(jnp.linalg.norm, value) normx0 = tree_map(jnp.linalg.norm, init_params) - clipped_normx0 = tree_map(lambda x: - 0.5 * jnp.maximum(x, 1), normx0) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0) def safe_divide_by_zero(x, y): - # a classical division of x by x + # a classical division of x by y # when y == 0 then return 1 return jnp.where(y == 0, 1, x / y) gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) else: gamma = self.gamma - # repeat gamma as a pytre of the shape of init_params + # repeat gamma as a pytree of the shape of init_params gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 5e24b6bd..ac42bd5e 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -133,8 +133,8 @@ def test_affine_contractive_mapping(self): b = jax.random.uniform(subkey, shape=(n,)) def g(x, M, b): return M @ x + b - x - tol = 5e-6 - fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) + tol = 1e-6 + fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) self.assertLess(state.error, tol) From 6b3672239e441b451941c1f71835d957c4e6a699 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 16:54:50 +0100 Subject: [PATCH 6/6] removed gamma setting in broyden test, since smart init is available --- tests/broyden_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index ac42bd5e..e14196a3 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0) + sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol)