diff --git a/Wrappers/Python/cil/optimisation/functions/BM3DFunction.py b/Wrappers/Python/cil/optimisation/functions/BM3DFunction.py new file mode 100644 index 0000000000..5a9dcb99b4 --- /dev/null +++ b/Wrappers/Python/cil/optimisation/functions/BM3DFunction.py @@ -0,0 +1,111 @@ +import numpy as np +from cil.optimisation.functions import Function +import warnings + +try: + from bm3d import bm3d, BM3DStages + _HAS_BM3D = True + ALL_STAGES = BM3DStages.ALL_STAGES +except ImportError: + bm3d = None + BM3DStages = None + ALL_STAGES = None + _HAS_BM3D = False + + warnings.warn( + "Optional dependency 'bm3d' is not installed. Install via `pip install bm3d.", + RuntimeWarning, + stacklevel=2, + ) + + +class BM3DFunction(Function): + + r""" + Plug-and-Play (PnP) BM3D prior. + + This class is meant to be used in proximal-gradient + schemes (PnP-ISTA / PnP-FISTA), where the proximal operator is replaced + by a BM3D denoiser: + \[ + \operatorname{prox}_{\tau g}(x) \approx D_\sigma(x), + \] + with ``sigma`` interpreted as the assumed noise standard deviation in the image domain. + + Notes + ----- + * The function value ``g(x)`` is not defined for PnP; therefore + ``__call__`` returns ``0.0``. + * Optionally enforces non-negativity by projecting the denoised output + onto ``\{x \ge 0\}``. + + Parameters + ---------- + sigma : float + BM3D noise standard deviation (same units as the image). Must be > 0. + + profile : str, default="np" + BM3D profile passed to ``bm3d`` (speed/quality trade-off). Available + profiles are ``('np', 'refilter', 'vn', 'vn_old', 'high', 'deb') + + stage_arg : BM3DStages or np.ndarray, default=BM3DStages.ALL_STAGES + Controls which BM3D stage(s) are executed, or provides a pilot image: + - ``BM3DStages.ALL_STAGES``: hard-thresholding + Wiener filtering. + - ``BM3DStages.HARD_THRESHOLDING``: hard-thresholding only. + - ``np.ndarray``: a pilot estimate of the noise-free image (used by BM3D). + + positivity : bool, default=True + If ``True``, clip the denoised image to be non-negative. + + Note + ---------- + Reference: Dabov, K. and Foi, A. and Katkovnik, V. and Egiazarian, K., 2007. Image Denoising by Sparse 3-D Transform-Domain Collaborative Filtering. IEEE Transactions on Image Processing. http://dx.doi.org/10.1109/TIP.2007.901238. + + """ + + + def __init__(self, sigma, profile="np", stage_arg=ALL_STAGES, + positivity=True): + + self.sigma = sigma + if self.sigma<=0: + raise ValueError("Need a positive value for sigma") + self.profile = profile + self.stage_arg = stage_arg + self.positivity = positivity + self._warned_call = False + + super(BM3DFunction, self).__init__(L=None) + + def __call__(self, x): + if not self._warned_call: + warnings.warn( + "BM3DFunction does not define objective value; returning 0.0.", + RuntimeWarning, + stacklevel=2, + ) + self._warned_call = True + return 0.0 + + + def _denoise(self, znp: np.ndarray) -> np.ndarray: + z = np.asarray(znp, dtype=np.float32) + return bm3d(z, sigma_psd=self.sigma, profile=self.profile, + stage_arg=self.stage_arg).astype(np.float32) + + def proximal(self, x, tau=1., out=None): + + ## TODO asarray for SIRF? + z = x.array.astype(np.float32, copy=False) + den_bm3d_np = self._denoise(z) + + + ## TODO maybe we need a more general constraint? + if self.positivity: + np.maximum(den_bm3d_np, 0.0, out=den_bm3d_np) + + if out is None: + out = x * 0.0 + out.fill(den_bm3d_np) + + return out \ No newline at end of file diff --git a/Wrappers/Python/cil/optimisation/functions/__init__.py b/Wrappers/Python/cil/optimisation/functions/__init__.py index 762c4d29ff..b78d366177 100644 --- a/Wrappers/Python/cil/optimisation/functions/__init__.py +++ b/Wrappers/Python/cil/optimisation/functions/__init__.py @@ -40,4 +40,5 @@ from .SVRGFunction import SVRGFunction, LSVRGFunction from .SAGFunction import SAGFunction, SAGAFunction from .AbsFunction import FunctionOfAbs +from .BM3DFunction import BM3DFunction diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 61e65c03b9..09534bddd8 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -35,6 +35,7 @@ WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity, FunctionOfAbs from cil.optimisation.functions import BlockFunction +from cil.utilities import dataexample, noise import numpy import scipy.special @@ -54,6 +55,13 @@ import numba from numbers import Number +try: + from bm3d import bm3d, BM3DStages + from cil.optimisation.functions import BM3DFunction + _HAS_BM3D = True +except Exception: + _HAS_BM3D = False + initialise_tests() if has_ccpi_regularisation: @@ -66,6 +74,7 @@ from cil.optimisation.functions.MixedL21Norm import _proximal_step_numba, _proximal_step_numpy + class TestFunction(CCPiTestClass): def test_Function(self): @@ -2226,3 +2235,29 @@ def test_convex_conjugate_not_implemented(self): self.assertEqual(self.abs_function.convex_conjugate(self.data_real32), 0.) +class TestBM3D(unittest.TestCase): + + def setUp(self): + pass + + @unittest.skipUnless(_HAS_BM3D, "Optional dependency 'bm3d'.") + def test_sigma_positive(self): + with self.assertRaises(ValueError): + BM3DFunction(sigma=0.0) + with self.assertRaises(ValueError): + BM3DFunction(sigma=-1.0) + + @unittest.skipUnless(_HAS_BM3D, "Optional dependency 'bm3d'.") + def test_proximal_(self): + + data = dataexample.SHAPES.get() + + G = BM3DFunction(sigma=0.1, positivity=False) + G_prox = G.proximal(data, tau=1.0) + G_denoise = G._denoise(data.array) + + np.testing.assert_array_almost_equal(G_denoise, G_prox.array, decimal=4) + + + + diff --git a/pyproject.toml b/pyproject.toml index 499858c6d8..aefeeb282a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ gpu = [ [dependency-groups] test = [ #"ccpi-regulariser=24.0.1", # [not osx] # missing from PyPI - "cvxpy", + "cvxpy", "bm3d", "matplotlib-base>=3.3", "packaging", "scikit-image", diff --git a/recipe/meta.yaml b/recipe/meta.yaml index ebecb34f75..7ec51ed2b0 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -30,6 +30,7 @@ test: commands: - pip install unittest-parametrize + - pip install bm3d - python -m unittest discover -v -s Wrappers/Python/test {% set ipp_version = '2021.12' %} diff --git a/scripts/requirements-test.yml b/scripts/requirements-test.yml index 496c19b3ad..c2aa3a3d9f 100644 --- a/scripts/requirements-test.yml +++ b/scripts/requirements-test.yml @@ -48,3 +48,4 @@ dependencies: - pip - pip: - unittest-parametrize + - bm3d