From abe175e59dad07bb9220d0c015f73f07d77c5673 Mon Sep 17 00:00:00 2001 From: Luke Joseph Lozenski Date: Fri, 30 Jan 2026 10:46:24 +0100 Subject: [PATCH 1/2] Added Adaptive LADMM and unit test --- .../cil/optimisation/algorithms/ADMM.py | 47 +++++++++++++++---- .../Python/test/test_algorithm_convergence.py | 41 +++++++++++++++- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py index 318597d8a6..4b0a2bdba2 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py +++ b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py @@ -82,23 +82,26 @@ class LADMM(Algorithm): def __init__(self, f=None, g=None, operator=None, \ - tau = None, sigma = 1., + tau = None, sigma = 0.9, rho = 1., mode = 'constant', initial = None, **kwargs): """Initialisation of the algorithm.""" super(LADMM, self).__init__(**kwargs) - self.set_up(f = f, g = g, operator = operator, tau = tau,\ + self.set_up(f = f, g = g, operator = operator, rho = rho, mode = mode, tau = tau,\ sigma = sigma, initial=initial) - def set_up(self, f, g, operator, tau = None, sigma=1., initial=None): + def set_up(self, f, g, operator, rho, mode, tau = None, sigma=1., initial=None): """Set up of the algorithm.""" log.info("%s setting up", self.__class__.__name__) self.f = f self.g = g self.operator = operator + + self.rho = rho + self.mode = mode self.tau = tau self.sigma = sigma @@ -125,26 +128,54 @@ def set_up(self, f, g, operator, tau = None, sigma=1., initial=None): def update(self): """Performs a single iteration of the LADMM algorithm""" - self.tmp_dir += self.u + + self.tmp_dir += self.u/self.rho self.tmp_dir -= self.z self.operator.adjoint(self.tmp_dir, out = self.tmp_adj) + if self.mode == 'adaptive': + self.x0 = self.x.copy() self.x.sapyb(1, self.tmp_adj, -(self.tau/self.sigma), out=self.x) # apply proximal of f - tmp = self.f.proximal(self.x, self.tau) + tmp = self.f.proximal(self.x, self.tau/self.rho) self.operator.direct(tmp, out=self.tmp_dir) # store the result in x + self.x.fill(tmp) del tmp - self.u += self.tmp_dir + self.u += self.rho*self.tmp_dir # apply proximal of g - self.g.proximal(self.u, self.sigma, out = self.z) + if self.mode == 'adaptive': + self.z0 = self.z.copy() + self.g.proximal(self.u/self.rho, self.sigma/self.rho, out = self.z) # update - self.u -= self.z + self.u -= self.rho*self.z + if self.mode == 'adaptive': + if self.iteration %5 == 1: + num2 = 3*((self.tmp_dir - self.z).norm())**2 + + num2 += 2*(1/self.sigma - 1)*((self.z - 2*self.z0+self.zneg).norm())**2 + num2 += (1/self.tau)*((self.x - 2*self.x0+self.xneg).norm())**2 + num2 -= (self.operator.direct(self.x - 2*self.x0+self.xneg).norm())**2 + + denom2 = 2*(1/self.sigma - 1)*((self.z - self.z0).norm())**2 + denom2 += (1/self.tau)*((self.x - self.x0).norm())**2 + denom2 -= (self.operator.direct(self.x - self.x0).norm())**2 + + if num2 > 0 and denom2 > 0: + self.rho *= (num2/denom2)**(1/2) + elif num2 == 0 and denom2 > 0: + self.rho /= 10 + elif num2 > 0 and denom2 == 0: + self.rho *= 10 + + self.xneg = self.x0.copy() + self.zneg = self.z0.copy() + def update_objective(self): """Update the objective function value""" diff --git a/Wrappers/Python/test/test_algorithm_convergence.py b/Wrappers/Python/test/test_algorithm_convergence.py index 2fdd88717e..771c945008 100644 --- a/Wrappers/Python/test/test_algorithm_convergence.py +++ b/Wrappers/Python/test/test_algorithm_convergence.py @@ -1,5 +1,5 @@ -from cil.optimisation.algorithms import SPDHG, PDHG, LSQR, FISTA, APGD, GD, PD3O +from cil.optimisation.algorithms import SPDHG, PDHG, LSQR, FISTA, APGD, GD, PD3O, LADMM from cil.optimisation.functions import L2NormSquared, IndicatorBox, BlockFunction, ZeroFunction, KullbackLeibler, OperatorCompositionFunction, LeastSquares, TotalVariation, MixedL21Norm from cil.optimisation.operators import BlockOperator, IdentityOperator, MatrixOperator, GradientOperator from cil.optimisation.utilities import Sampler, BarzilaiBorweinStepSizeRule @@ -164,6 +164,45 @@ def test_FISTA_Denoising(self): rmse = (fista.get_output() - data).norm() / data.as_array().size self.assertLess(rmse, 4.2e-4) + def test_Adaptive_LADMM(self): + data = dataexample.SHAPES.get() + ig = data.geometry + ag = ig + # Create Noisy data with Gaussian noise + snr = 0.1 + std = snr*(data**2).mean()**(1/2) + noisy_data = applynoise.gaussian(data, var = std) + + alpha = 1 + K = alpha*GradientOperator(ig) + G = MixedL21Norm() + F = L2NormSquared(b = noisy_data) + num_iters = 100 + + admm_0 = LADMM(f = F, g = G, operator = K, rho = 1e0, mode = 'adaptive') + admm_0.run(num_iters) + admm_1 = LADMM(f = F, g = G, operator = K, rho = 1e1, mode = 'adaptive') + admm_1.run(num_iters) + admm_n1 = LADMM(f = F, g = G, operator = K, rho = 1e-1, mode = 'adaptive') + admm_n1.run(num_iters) + admm_2 = LADMM(f = F, g = G, operator = K, rho = 1e2, mode = 'adaptive') + admm_2.run(num_iters) + admm_n2 = LADMM(f = F, g = G, operator = K, rho = 1e-2, mode = 'adaptive') + admm_n2.run(num_iters) + + + rmse_0 = (admm_0.get_output() - data).norm() / data.norm() + rmse_1 = (admm_1.get_output() - data).norm() / data.norm() + rmse_n1 = (admm_n1.get_output() - data).norm() / data.norm() + rmse_2 = (admm_2.get_output() - data).norm() / data.norm() + rmse_n2 = (admm_n2.get_output() - data).norm() / data.norm() + + self.assertLess(rmse_0, 0.16) + self.assertLess(rmse_1, 0.16) + self.assertLess(rmse_n1, 0.16) + self.assertLess(rmse_2, 0.16) + self.assertLess(rmse_n2, 0.16) + def test_APGD(self): ig = ImageGeometry(41, 43, 47) initial = ig.allocate(0) From d6ac387f56fd44b99be1a87a284af4b9408aad34 Mon Sep 17 00:00:00 2001 From: Luke Joseph Lozenski Date: Fri, 30 Jan 2026 14:00:23 +0100 Subject: [PATCH 2/2] Fixed small math error --- Wrappers/Python/cil/optimisation/algorithms/ADMM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py index 4b0a2bdba2..fd5f7cb83a 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/ADMM.py +++ b/Wrappers/Python/cil/optimisation/algorithms/ADMM.py @@ -162,7 +162,7 @@ def update(self): num2 += (1/self.tau)*((self.x - 2*self.x0+self.xneg).norm())**2 num2 -= (self.operator.direct(self.x - 2*self.x0+self.xneg).norm())**2 - denom2 = 2*(1/self.sigma - 1)*((self.z - self.z0).norm())**2 + denom2 = (2/self.sigma - 1)*((self.z - self.z0).norm())**2 denom2 += (1/self.tau)*((self.x - self.x0).norm())**2 denom2 -= (self.operator.direct(self.x - self.x0).norm())**2