diff --git a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py index 318597d8a..fd5f7cb83 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py +++ b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py @@ -82,23 +82,26 @@ class LADMM(Algorithm): def __init__(self, f=None, g=None, operator=None, \ - tau = None, sigma = 1., + tau = None, sigma = 0.9, rho = 1., mode = 'constant', initial = None, **kwargs): """Initialisation of the algorithm.""" super(LADMM, self).__init__(**kwargs) - self.set_up(f = f, g = g, operator = operator, tau = tau,\ + self.set_up(f = f, g = g, operator = operator, rho = rho, mode = mode, tau = tau,\ sigma = sigma, initial=initial) - def set_up(self, f, g, operator, tau = None, sigma=1., initial=None): + def set_up(self, f, g, operator, rho, mode, tau = None, sigma=1., initial=None): """Set up of the algorithm.""" log.info("%s setting up", self.__class__.__name__) self.f = f self.g = g self.operator = operator + + self.rho = rho + self.mode = mode self.tau = tau self.sigma = sigma @@ -125,26 +128,54 @@ def set_up(self, f, g, operator, tau = None, sigma=1., initial=None): def update(self): """Performs a single iteration of the LADMM algorithm""" - self.tmp_dir += self.u + + self.tmp_dir += self.u/self.rho self.tmp_dir -= self.z self.operator.adjoint(self.tmp_dir, out = self.tmp_adj) + if self.mode == 'adaptive': + self.x0 = self.x.copy() self.x.sapyb(1, self.tmp_adj, -(self.tau/self.sigma), out=self.x) # apply proximal of f - tmp = self.f.proximal(self.x, self.tau) + tmp = self.f.proximal(self.x, self.tau/self.rho) self.operator.direct(tmp, out=self.tmp_dir) # store the result in x + self.x.fill(tmp) del tmp - self.u += self.tmp_dir + self.u += self.rho*self.tmp_dir # apply proximal of g - self.g.proximal(self.u, self.sigma, out = self.z) + if self.mode == 'adaptive': + self.z0 = self.z.copy() + self.g.proximal(self.u/self.rho, self.sigma/self.rho, out = self.z) # update - self.u -= self.z + self.u -= self.rho*self.z + if self.mode == 'adaptive': + if self.iteration %5 == 1: + num2 = 3*((self.tmp_dir - self.z).norm())**2 + + num2 += 2*(1/self.sigma - 1)*((self.z - 2*self.z0+self.zneg).norm())**2 + num2 += (1/self.tau)*((self.x - 2*self.x0+self.xneg).norm())**2 + num2 -= (self.operator.direct(self.x - 2*self.x0+self.xneg).norm())**2 + + denom2 = (2/self.sigma - 1)*((self.z - self.z0).norm())**2 + denom2 += (1/self.tau)*((self.x - self.x0).norm())**2 + denom2 -= (self.operator.direct(self.x - self.x0).norm())**2 + + if num2 > 0 and denom2 > 0: + self.rho *= (num2/denom2)**(1/2) + elif num2 == 0 and denom2 > 0: + self.rho /= 10 + elif num2 > 0 and denom2 == 0: + self.rho *= 10 + + self.xneg = self.x0.copy() + self.zneg = self.z0.copy() + def update_objective(self): """Update the objective function value""" diff --git a/Wrappers/Python/test/test_algorithm_convergence.py b/Wrappers/Python/test/test_algorithm_convergence.py index 2fdd88717..771c94500 100644 --- a/Wrappers/Python/test/test_algorithm_convergence.py +++ b/Wrappers/Python/test/test_algorithm_convergence.py @@ -1,5 +1,5 @@ -from cil.optimisation.algorithms import SPDHG, PDHG, LSQR, FISTA, APGD, GD, PD3O +from cil.optimisation.algorithms import SPDHG, PDHG, LSQR, FISTA, APGD, GD, PD3O, LADMM from cil.optimisation.functions import L2NormSquared, IndicatorBox, BlockFunction, ZeroFunction, KullbackLeibler, OperatorCompositionFunction, LeastSquares, TotalVariation, MixedL21Norm from cil.optimisation.operators import BlockOperator, IdentityOperator, MatrixOperator, GradientOperator from cil.optimisation.utilities import Sampler, BarzilaiBorweinStepSizeRule @@ -164,6 +164,45 @@ def test_FISTA_Denoising(self): rmse = (fista.get_output() - data).norm() / data.as_array().size self.assertLess(rmse, 4.2e-4) + def test_Adaptive_LADMM(self): + data = dataexample.SHAPES.get() + ig = data.geometry + ag = ig + # Create Noisy data with Gaussian noise + snr = 0.1 + std = snr*(data**2).mean()**(1/2) + noisy_data = applynoise.gaussian(data, var = std) + + alpha = 1 + K = alpha*GradientOperator(ig) + G = MixedL21Norm() + F = L2NormSquared(b = noisy_data) + num_iters = 100 + + admm_0 = LADMM(f = F, g = G, operator = K, rho = 1e0, mode = 'adaptive') + admm_0.run(num_iters) + admm_1 = LADMM(f = F, g = G, operator = K, rho = 1e1, mode = 'adaptive') + admm_1.run(num_iters) + admm_n1 = LADMM(f = F, g = G, operator = K, rho = 1e-1, mode = 'adaptive') + admm_n1.run(num_iters) + admm_2 = LADMM(f = F, g = G, operator = K, rho = 1e2, mode = 'adaptive') + admm_2.run(num_iters) + admm_n2 = LADMM(f = F, g = G, operator = K, rho = 1e-2, mode = 'adaptive') + admm_n2.run(num_iters) + + + rmse_0 = (admm_0.get_output() - data).norm() / data.norm() + rmse_1 = (admm_1.get_output() - data).norm() / data.norm() + rmse_n1 = (admm_n1.get_output() - data).norm() / data.norm() + rmse_2 = (admm_2.get_output() - data).norm() / data.norm() + rmse_n2 = (admm_n2.get_output() - data).norm() / data.norm() + + self.assertLess(rmse_0, 0.16) + self.assertLess(rmse_1, 0.16) + self.assertLess(rmse_n1, 0.16) + self.assertLess(rmse_2, 0.16) + self.assertLess(rmse_n2, 0.16) + def test_APGD(self): ig = ImageGeometry(41, 43, 47) initial = ig.allocate(0)