diff --git a/chainercv/transforms/__init__.py b/chainercv/transforms/__init__.py index 045b3be405..8b9b6b93dc 100644 --- a/chainercv/transforms/__init__.py +++ b/chainercv/transforms/__init__.py @@ -3,6 +3,7 @@ from chainercv.transforms.bbox.resize_bbox import resize_bbox # NOQA from chainercv.transforms.bbox.translate_bbox import translate_bbox # NOQA from chainercv.transforms.image.center_crop import center_crop # NOQA +from chainercv.transforms.image.color_jitter import color_jitter # NOQA from chainercv.transforms.image.flip import flip # NOQA from chainercv.transforms.image.pca_lighting import pca_lighting # NOQA from chainercv.transforms.image.random_crop import random_crop # NOQA diff --git a/chainercv/transforms/image/color_jitter.py b/chainercv/transforms/image/color_jitter.py new file mode 100644 index 0000000000..a06bee1db8 --- /dev/null +++ b/chainercv/transforms/image/color_jitter.py @@ -0,0 +1,97 @@ +import numpy as np +import random + + +def _grayscale(img): + out = np.zeros_like(img) + out[:] = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2] + return out + + +def _blend(img_a, img_b, alpha): + return alpha * img_a + (1 - alpha) * img_b + + +def _brightness(img, var): + alpha = 1 + np.random.uniform(-var, var) + return _blend(img, np.zeros_like(img), alpha), alpha + + +def _contrast(img, var): + gray = _grayscale(img) + gray.fill(gray[0].mean()) + + alpha = 1 + np.random.uniform(-var, var) + return _blend(img, gray, alpha), alpha + + +def _saturation(img, var): + gray = _grayscale(img) + + alpha = 1 + np.random.uniform(-var, var) + return _blend(img, gray, alpha), alpha + + +def color_jitter(img, brightness_var=0.4, contrast_var=0.4, + saturation_var=0.4, return_param=False): + """Data augmentation on brightness, contrast and saturation. + + Args: + img (~numpy.ndarray): An image array to be augmented. This is in + CHW and RGB format. + brightness_var (float): Alpha for brightness is sampled from + :obj:`unif(-brightness_var, brightness_var)`. The default + value is 0.4. + contrast_var (float): Alpha for contrast is sampled from + :obj:`unif(-contrast_var, contrast_var)`. The default + value is 0.4. + saturation_var (float): Alpha for contrast is sampled from + :obj:`unif(-saturation_var, saturation_var)`. The default + value is 0.4. + return_param (bool): Returns parameters if :obj:`True`. + + Returns: + ~numpy.ndarray or (~numpy.ndarray, dict): + + If :obj:`return_param = False`, + returns an color jittered image. + + If :obj:`return_param = True`, returns a tuple of an array and a + dictionary :obj:`param`. + :obj:`param` is a dictionary of intermediate parameters whose + contents are listed below with key, value-type and the description + of the value. + + * **order** (*list of strings*): List containing three strings: \ + :obj:`'brightness'`, :obj:`'contrast'` and :obj:`'saturation'`. \ + They are ordered according to the order in which the data \ + augmentation functions are applied. + * **brightness_alpha** (*float*): Alpha used for brightness \ + data augmentation. + * **contrast_alpha** (*float*): Alpha used for contrast \ + data augmentation. + * **saturation_alpha** (*float*): Alpha used for saturation \ + data augmentation. + + """ + funcs = list() + if brightness_var > 0: + funcs.append(('brightness', lambda x: _brightness(x, brightness_var))) + if contrast_var > 0: + funcs.append(('contrast', lambda x: _contrast(x, contrast_var))) + if saturation_var > 0: + funcs.append(('saturation', lambda x: _saturation(x, saturation_var))) + random.shuffle(funcs) + + params = {'order': [key for key, val in funcs], + 'brightness_alpha': 1, + 'contrast_alpha': 1, + 'saturation_alpha': 1} + for key, func in funcs: + img, alpha = func(img) + params[key + '_alpha'] = alpha + img = np.minimum(np.maximum(img, 0), 255) + if return_param: + return img, params + else: + return img diff --git a/docs/source/reference/transforms.rst b/docs/source/reference/transforms.rst index 6e9c818b85..7b9bffeb92 100644 --- a/docs/source/reference/transforms.rst +++ b/docs/source/reference/transforms.rst @@ -11,6 +11,10 @@ center_crop ~~~~~~~~~~~ .. autofunction:: center_crop +color_jitter +~~~~~~~~~~~~ +.. autofunction:: color_jitter + flip ~~~~ .. autofunction:: flip diff --git a/tests/transforms_tests/image_tests/test_color_jitter.py b/tests/transforms_tests/image_tests/test_color_jitter.py new file mode 100644 index 0000000000..af72f9198a --- /dev/null +++ b/tests/transforms_tests/image_tests/test_color_jitter.py @@ -0,0 +1,36 @@ +import unittest + +import numpy as np + +from chainer import testing +from chainercv.transforms import color_jitter + + +class TestColorJitter(unittest.TestCase): + + def test_color_jitter_run_data_augmentation(self): + img = 255 * np.random.uniform(size=(3, 48, 32)).astype(np.float32) + + out, param = color_jitter(img, return_param=True) + self.assertEqual(out.shape, (3, 48, 32)) + self.assertEqual(out.dtype, img.dtype) + self.assertLessEqual(np.max(img), 255) + self.assertGreaterEqual(np.min(img), 0) + + self.assertEqual( + sorted(param['order']), ['brightness', 'contrast', 'saturation']) + self.assertIsInstance(param['brightness_alpha'], float) + self.assertIsInstance(param['contrast_alpha'], float) + self.assertIsInstance(param['saturation_alpha'], float) + + def test_color_jitter_no_data_augmentation(self): + img = 255 * np.random.uniform(size=(3, 48, 32)).astype(np.float32) + + out, param = color_jitter(img, 0, 0, 0, return_param=True) + np.testing.assert_equal(out, img) + self.assertEqual(param['brightness_alpha'], 1) + self.assertEqual(param['contrast_alpha'], 1) + self.assertEqual(param['saturation_alpha'], 1) + + +testing.run_module(__name__, __file__)