Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/constrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ The following operators are available.
Projections always have two arguments: the input to be projected and the
parameters of the convex set.

Note that a retraction is also provided, that allows to retrieve
an arbitrary point lying in the intersection of convex sets.

.. autosummary::
:toctree: _autosummary

jaxopt.projection.alternating_projections

Mirror descent
--------------

Expand Down
61 changes: 61 additions & 0 deletions jaxopt/_src/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,80 @@
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Tuple

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from jaxopt._src.fixed_point_iteration import FixedPointIteration
from jaxopt._src.bisection import Bisection
from jaxopt._src.eq_qp import EqualityConstrainedQP
from jaxopt._src.lbfgs import LBFGS
from jaxopt._src.osqp import OSQP, BoxOSQP
from jaxopt._src import tree_util


def alternating_projections(initial_guess: Any,
projections: List,
hyperparams: List,
**fixed_point_params) -> Any:
"""Alternating projections algorithm.

This algorithm returns a point in the intersection of convex sets
by projecting onto each set in turn.

If the sets are not convex, or if their intersection is empty,
this algorithm may not converge.

If the sets are convex and their intersection is non empty,
the algorithm converges to a point `p*` in the intersection of the sets.
However this point `p*` is not necessarily the closest to the initial guess,
i.e alternating_projections is not a valid projection itself.

If the inittial guess lies in the intersection of the sets, then
the algorithm converges to this point. Hence this algorithm is a retraction.
If the initial guess lies outside the intersection, and if the intersection
contains more than one point, then the algorithm converges to an arbitrary
point in the intersection.

Implicit differentiation will measure the sensitivity of `p*`
to perturbations in the `hyperparams`, but not to perturbations
in the initial guess.

Args:
projections: a sequence of projections, each of which is a function that
with signature ``x, hyperparams -> x``.
hyperparams: a list of hyperparameters for each projection, each being a
pytree.
**fixed_point_params: parameters for the fixed point solver.
Returns:
A Pytree lying in the intersection of the sets.

References:
Escalante, R. and Raydan, M., 2011. Alternating projection methods.
Society for Industrial and Applied Mathematics.
"""
assert len(projections) == len(hyperparams)

def composed_projections(x, hyperparams):
for proj, hparam in zip(projections, hyperparams):
x = proj(x, hparam)
return x

if 'maxiter' not in fixed_point_params:
fixed_point_params["maxiter"] = 100
if 'tol' not in fixed_point_params:
fixed_point_params["tol"] = 1e-5

# look for a fixed point of this operator
solver = FixedPointIteration(fixed_point_fun=composed_projections,
**fixed_point_params)
fixed_point = solver.run(initial_guess, hyperparams).params
return fixed_point


def projection_non_negative(x: Any, hyperparams=None) -> Any:
r"""Projection onto the non-negative orthant:

Expand Down
1 change: 1 addition & 0 deletions jaxopt/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.projection import alternating_projections
from jaxopt._src.projection import projection_non_negative
from jaxopt._src.projection import projection_box
from jaxopt._src.projection import projection_hypercube
Expand Down
26 changes: 26 additions & 0 deletions tests/projection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,32 @@ def test_projection_birkhoff(self):
solution1 = projection.projection_birkhoff(doubly_stochastic_matrix)
self.assertArraysAllClose(doubly_stochastic_matrix, solution1)

def test_alternating_projections(self):
# x1 + x2 = 1
x = jnp.array([-2.0, 1.0, 3.0])
a = jnp.array([ 1.0, 1.0, 0.])
b = jnp.array(1.0)

# l2 ball of radius 1.5
radius = jnp.array(1.5)

def retract_on_disk_intercept(b):
# The intersection of a ball with an hyperplane is a disk.
retract_on_disk = [projection.projection_l2_ball,
projection.projection_hyperplane]
hyper_params = [radius, (a, b)]
in_disk = projection.alternating_projections(x, retract_on_disk, hyper_params)

return in_disk

in_disk = retract_on_disk_intercept(b)
atol = 1e-5
self.assertLessEqual(jnp.linalg.norm(in_disk), radius + atol)
self.assertArraysAllClose(jnp.dot(a, in_disk), jnp.array(b), atol=atol)

# test that there is no error.
unused_jac = jax.jacrev(retract_on_disk_intercept)(b)

def test_projection_sparse_simplex(self):
def top_k(x, k):
"""Preserve the top-k entries of the vector x and put -inf values elsewhere.
Expand Down