Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/BM3DFunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
from cil.optimisation.functions import Function
import warnings

try:
from bm3d import bm3d, BM3DStages
_HAS_BM3D = True
ALL_STAGES = BM3DStages.ALL_STAGES
except ImportError:
bm3d = None
BM3DStages = None
ALL_STAGES = None
_HAS_BM3D = False

warnings.warn(
"Optional dependency 'bm3d' is not installed. Install via `pip install bm3d.",
RuntimeWarning,
stacklevel=2,
)


class BM3DFunction(Function):

r"""
Plug-and-Play (PnP) BM3D prior.

This class is meant to be used in proximal-gradient
schemes (PnP-ISTA / PnP-FISTA), where the proximal operator is replaced
by a BM3D denoiser:
\[
\operatorname{prox}_{\tau g}(x) \approx D_\sigma(x),
\]
with ``sigma`` interpreted as the assumed noise standard deviation in the image domain.

Notes
-----
* The function value ``g(x)`` is not defined for PnP; therefore
``__call__`` returns ``0.0``.
* Optionally enforces non-negativity by projecting the denoised output
onto ``\{x \ge 0\}``.

Parameters
----------
sigma : float
BM3D noise standard deviation (same units as the image). Must be > 0.

profile : str, default="np"
BM3D profile passed to ``bm3d`` (speed/quality trade-off). Available
profiles are ``('np', 'refilter', 'vn', 'vn_old', 'high', 'deb')

stage_arg : BM3DStages or np.ndarray, default=BM3DStages.ALL_STAGES
Controls which BM3D stage(s) are executed, or provides a pilot image:
- ``BM3DStages.ALL_STAGES``: hard-thresholding + Wiener filtering.
- ``BM3DStages.HARD_THRESHOLDING``: hard-thresholding only.
- ``np.ndarray``: a pilot estimate of the noise-free image (used by BM3D).

positivity : bool, default=True
If ``True``, clip the denoised image to be non-negative.

Note
----------
Reference: Dabov, K. and Foi, A. and Katkovnik, V. and Egiazarian, K., 2007. Image Denoising by Sparse 3-D Transform-Domain Collaborative Filtering. IEEE Transactions on Image Processing. http://dx.doi.org/10.1109/TIP.2007.901238.

"""


def __init__(self, sigma, profile="np", stage_arg=ALL_STAGES,
positivity=True):

self.sigma = sigma
if self.sigma<=0:
raise ValueError("Need a positive value for sigma")
self.profile = profile
self.stage_arg = stage_arg
self.positivity = positivity
self._warned_call = False

super(BM3DFunction, self).__init__(L=None)

def __call__(self, x):
if not self._warned_call:
warnings.warn(
"BM3DFunction does not define objective value; returning 0.0.",
RuntimeWarning,
stacklevel=2,
)
self._warned_call = True
return 0.0


def _denoise(self, znp: np.ndarray) -> np.ndarray:
z = np.asarray(znp, dtype=np.float32)
return bm3d(z, sigma_psd=self.sigma, profile=self.profile,
stage_arg=self.stage_arg).astype(np.float32)

def proximal(self, x, tau=1., out=None):

## TODO asarray for SIRF?
z = x.array.astype(np.float32, copy=False)
den_bm3d_np = self._denoise(z)


## TODO maybe we need a more general constraint?
if self.positivity:
np.maximum(den_bm3d_np, 0.0, out=den_bm3d_np)

if out is None:
out = x * 0.0
out.fill(den_bm3d_np)

return out
1 change: 1 addition & 0 deletions Wrappers/Python/cil/optimisation/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
from .SVRGFunction import SVRGFunction, LSVRGFunction
from .SAGFunction import SAGFunction, SAGAFunction
from .AbsFunction import FunctionOfAbs
from .BM3DFunction import BM3DFunction

35 changes: 35 additions & 0 deletions Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity, FunctionOfAbs

from cil.optimisation.functions import BlockFunction
from cil.utilities import dataexample, noise

import numpy
import scipy.special
Expand All @@ -54,6 +55,13 @@
import numba
from numbers import Number

try:
from bm3d import bm3d, BM3DStages
from cil.optimisation.functions import BM3DFunction
_HAS_BM3D = True
except Exception:
_HAS_BM3D = False

initialise_tests()

if has_ccpi_regularisation:
Expand All @@ -66,6 +74,7 @@
from cil.optimisation.functions.MixedL21Norm import _proximal_step_numba, _proximal_step_numpy



class TestFunction(CCPiTestClass):

def test_Function(self):
Expand Down Expand Up @@ -2226,3 +2235,29 @@ def test_convex_conjugate_not_implemented(self):
self.assertEqual(self.abs_function.convex_conjugate(self.data_real32), 0.)


class TestBM3D(unittest.TestCase):

def setUp(self):
pass

@unittest.skipUnless(_HAS_BM3D, "Optional dependency 'bm3d'.")
def test_sigma_positive(self):
with self.assertRaises(ValueError):
BM3DFunction(sigma=0.0)
with self.assertRaises(ValueError):
BM3DFunction(sigma=-1.0)

@unittest.skipUnless(_HAS_BM3D, "Optional dependency 'bm3d'.")
def test_proximal_(self):

data = dataexample.SHAPES.get()

G = BM3DFunction(sigma=0.1, positivity=False)
G_prox = G.proximal(data, tau=1.0)
G_denoise = G._denoise(data.array)

np.testing.assert_array_almost_equal(G_denoise, G_prox.array, decimal=4)




2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ gpu = [
[dependency-groups]
test = [
#"ccpi-regulariser=24.0.1", # [not osx] # missing from PyPI
"cvxpy",
"cvxpy", "bm3d",
"matplotlib-base>=3.3",
"packaging",
"scikit-image",
Expand Down
1 change: 1 addition & 0 deletions recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ test:

commands:
- pip install unittest-parametrize
- pip install bm3d
- python -m unittest discover -v -s Wrappers/Python/test

{% set ipp_version = '2021.12' %}
Expand Down
1 change: 1 addition & 0 deletions scripts/requirements-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ dependencies:
- pip
- pip:
- unittest-parametrize
- bm3d
Loading