diff --git a/Wrappers/Python/cil/optimisation/algorithms/ProxSkip.py b/Wrappers/Python/cil/optimisation/algorithms/ProxSkip.py new file mode 100644 index 000000000..43da3b14b --- /dev/null +++ b/Wrappers/Python/cil/optimisation/algorithms/ProxSkip.py @@ -0,0 +1,112 @@ +from cil.optimisation.algorithms import Algorithm +import numpy as np +import logging +from warnings import warn + + +class ProxSkip(Algorithm): + + + r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†" + + Parameters + ---------- + + initial : DataContainer + Initial point for the ProxSkip algorithm. + f : Function + A smooth function with Lipschitz continuous gradient. + g : Function + A convex function with a "simple" proximal. + prob : positive :obj:`float` + Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration. + step_size : positive :obj:`float` + Step size for the ProxSkip algorithm. It is equal to 1./L for strongly convex f and 2./L for convex f, where L is the Lipschitz constant for the gradient of f. + + """ + + + def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs): + """ Set up of the algorithm + """ + + super(ProxSkip, self).__init__(**kwargs) + + self.f = f # smooth function + self.g = g # proximable + self.step_size = step_size + self.prob = prob + self.rng = np.random.default_rng(seed=seed) + self.thetas = [] + + if self.prob<=0: + raise ValueError("Need a positive probability") + if self.prob==1: + raise warn("If p=1, ProxSkip is equivalent to ISTA/PGD. Please use ISTA/PGD to avoid computing updates of the control variate that is not used.") + + self.set_up(initial, f, g, step_size, prob, **kwargs) + + + def set_up(self, initial, f, g, step_size, prob, **kwargs): + + logging.info("{} setting up".format(self.__class__.__name__, )) + + ## TODO better to use different initials for x and h. + self.initial = initial + self.x = initial.copy() + self.xhat_new = initial.copy() + self.x_new = initial.copy() + self.ht = initial.copy() + + self.configured = True + + logging.info("{} configured".format(self.__class__.__name__, )) + + + def update(self): + r""" Performs a single iteration of the ProxSkip algorithm + """ + + self.f.gradient(self.x, out=self.xhat_new) + self.xhat_new -= self.ht + self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new) + + theta = self.rng.random() < self.prob + # convention: use proximal in the first iteration + if self.iteration==0: + theta = True + self.thetas.append(theta) + + if theta: + # Proximal step is used + self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new) + self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht) + else: + self.x_new.fill(self.xhat_new) + + def _update_previous_solution(self): + """ Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`. + """ + tmp = self.x_new + self.x = self.x_new + self.x = tmp + + def get_output(self): + " Returns the current solution. " + return self.x + + + def update_objective(self): + + """ Updates the objective + + .. math:: f(x) + g(x) + + """ + + fun_g = self.g(self.x) + fun_f = self.f(self.x) + p1 = fun_f + fun_g + self.loss.append( p1 ) + + diff --git a/Wrappers/Python/cil/optimisation/algorithms/__init__.py b/Wrappers/Python/cil/optimisation/algorithms/__init__.py index 0ca481bea..9237b2b91 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/__init__.py +++ b/Wrappers/Python/cil/optimisation/algorithms/__init__.py @@ -22,10 +22,11 @@ from .GD import GD from .FISTA import FISTA from .FISTA import ISTA +from .ProxSkip import ProxSkip from .FISTA import ISTA as PGD from .APGD import APGD from .PDHG import PDHG from .ADMM import LADMM from .SPDHG import SPDHG from .PD3O import PD3O -from .LSQR import LSQR \ No newline at end of file +from .LSQR import LSQR diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index 7908debf6..34d24660d 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -38,7 +38,7 @@ from cil.optimisation.functions import MixedL21Norm, BlockFunction, L1Norm, KullbackLeibler, IndicatorBox, LeastSquares, ZeroFunction, L2NormSquared, OperatorCompositionFunction, TotalVariation, SGFunction, SVRGFunction, SAGAFunction, SAGFunction, LSVRGFunction, ScaledFunction -from cil.optimisation.algorithms import Algorithm, GD, CGLS, SIRT, FISTA, ISTA, SPDHG, PDHG, LADMM, PD3O, PGD, APGD , LSQR +from cil.optimisation.algorithms import Algorithm, GD, CGLS, SIRT, FISTA, ISTA, SPDHG, PDHG, LADMM, PD3O, PGD, APGD , LSQR, ProxSkip from scipy.optimize import minimize, rosen @@ -340,8 +340,7 @@ def test_provable_convergence(self): with self.assertRaises(NotImplementedError): alg.is_provably_convergent() - - + class TestFISTA(CCPiTestClass): @@ -533,6 +532,102 @@ def get_step_size(self, algorithm): self.assertEqual(alg.step_size, 0.99/2) self.assertEqual(alg.step_size, 0.99/2) + +class TestProxSkip(CCPiTestClass): + + def setUp(self): + + np.random.seed(10) + n = 50 + m = 500 + + A = np.random.uniform(0, 1, (m, n)).astype('float32') + b = (A.dot(np.random.randn(n)) + 0.1 * + np.random.randn(m)).astype('float32') + + self.Aop = MatrixOperator(A) + self.bop = VectorData(b) + + self.f = LeastSquares(self.Aop, b=self.bop, c=0.5) + self.g = 0.5 * L1Norm() + self.step_size = 1.99/self.f.L + + self.ig = self.Aop.domain + + self.initial = self.ig.allocate() + + def tearDown(self): + pass + + def test_signature(self): + + # check required arguments (initial, f, g, step size, and prob) + with np.testing.assert_raises(TypeError): + proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size) + + # test neg prob + with np.testing.assert_raises(ValueError): + proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=-0.1) + + # zero prob + with np.testing.assert_raises(ValueError): + proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.) + + def test_coin_flip(self): + + seed = 10 + num_it = 100 + prob = 0.2 + + proxskip1 = ProxSkip(initial=self.initial, f=self.f, g=self.g, + step_size=self.step_size, prob=prob, seed=seed) + proxskip1.run(num_it, verbose=0) + + rng = np.random.default_rng(seed) + + thetas1 = [] + for k in range(num_it): + tmp = rng.random() < prob + theta = True if k == 0 else tmp + thetas1.append(theta) + + assert np.array_equal(proxskip1.thetas, thetas1) + + + def test_seeds(self): + + # same seeds + proxskip1 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.1, seed=10) + proxskip1.run(100, verbose=0) + + proxskip2 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.1, seed=10) + proxskip2.run(100, verbose=0) + + np.testing.assert_allclose(proxskip2.thetas, proxskip1.thetas) + + # different seeds + proxskip2 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, + prob=0.1, seed=20) + proxskip2.run(100, verbose=0) + + assert not np.array_equal(proxskip2.thetas, proxskip1.thetas) + + + def test_ista_vs_proxskip(self): + + prox = ProxSkip(initial=self.initial, f=self.f, + g=self.g, step_size = self.step_size, prob = 0.1) + prox.run(2000, verbose=0) + + ista = ISTA(initial=self.initial, f=self.f, + g=self.g, step_size = self.step_size) + ista.run(1000, verbose=0) + + np.testing.assert_allclose(ista.objective[-1], prox.objective[-1], atol=1e-3) + np.testing.assert_allclose( + prox.solution.array, ista.solution.array, atol=1e-3) + + class testISTA(CCPiTestClass): def setUp(self):