Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chainercv.transforms.bbox.resize_bbox import resize_bbox # NOQA
from chainercv.transforms.bbox.translate_bbox import translate_bbox # NOQA
from chainercv.transforms.image.center_crop import center_crop # NOQA
from chainercv.transforms.image.color_jitter import color_jitter # NOQA
from chainercv.transforms.image.flip import flip # NOQA
from chainercv.transforms.image.pca_lighting import pca_lighting # NOQA
from chainercv.transforms.image.random_crop import random_crop # NOQA
Expand Down
97 changes: 97 additions & 0 deletions chainercv/transforms/image/color_jitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import random


def _grayscale(img):
out = np.zeros_like(img)
out[:] = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
return out


def _blend(img_a, img_b, alpha):
return alpha * img_a + (1 - alpha) * img_b


def _brightness(img, var):
alpha = 1 + np.random.uniform(-var, var)
return _blend(img, np.zeros_like(img), alpha), alpha


def _contrast(img, var):
gray = _grayscale(img)
gray.fill(gray[0].mean())

alpha = 1 + np.random.uniform(-var, var)
return _blend(img, gray, alpha), alpha


def _saturation(img, var):
gray = _grayscale(img)

alpha = 1 + np.random.uniform(-var, var)
return _blend(img, gray, alpha), alpha


def color_jitter(img, brightness_var=0.4, contrast_var=0.4,
saturation_var=0.4, return_param=False):
"""Data augmentation on brightness, contrast and saturation.

Args:
img (~numpy.ndarray): An image array to be augmented. This is in
CHW and RGB format.
brightness_var (float): Alpha for brightness is sampled from
:obj:`unif(-brightness_var, brightness_var)`. The default
value is 0.4.
contrast_var (float): Alpha for contrast is sampled from
:obj:`unif(-contrast_var, contrast_var)`. The default
value is 0.4.
saturation_var (float): Alpha for contrast is sampled from
:obj:`unif(-saturation_var, saturation_var)`. The default
value is 0.4.
return_param (bool): Returns parameters if :obj:`True`.

Returns:
~numpy.ndarray or (~numpy.ndarray, dict):

If :obj:`return_param = False`,
returns an color jittered image.

If :obj:`return_param = True`, returns a tuple of an array and a
dictionary :obj:`param`.
:obj:`param` is a dictionary of intermediate parameters whose
contents are listed below with key, value-type and the description
of the value.

* **order** (*list of strings*): List containing three strings: \
:obj:`'brightness'`, :obj:`'contrast'` and :obj:`'saturation'`. \
They are ordered according to the order in which the data \
augmentation functions are applied.
* **brightness_alpha** (*float*): Alpha used for brightness \
data augmentation.
* **contrast_alpha** (*float*): Alpha used for contrast \
data augmentation.
* **saturation_alpha** (*float*): Alpha used for saturation \
data augmentation.

"""
funcs = list()
if brightness_var > 0:
funcs.append(('brightness', lambda x: _brightness(x, brightness_var)))
if contrast_var > 0:
funcs.append(('contrast', lambda x: _contrast(x, contrast_var)))
if saturation_var > 0:
funcs.append(('saturation', lambda x: _saturation(x, saturation_var)))
random.shuffle(funcs)

params = {'order': [key for key, val in funcs],
'brightness_alpha': 1,
'contrast_alpha': 1,
'saturation_alpha': 1}
for key, func in funcs:
img, alpha = func(img)
params[key + '_alpha'] = alpha
img = np.minimum(np.maximum(img, 0), 255)
if return_param:
return img, params
else:
return img
4 changes: 4 additions & 0 deletions docs/source/reference/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ center_crop
~~~~~~~~~~~
.. autofunction:: center_crop

color_jitter
~~~~~~~~~~~~
.. autofunction:: color_jitter

flip
~~~~
.. autofunction:: flip
Expand Down
36 changes: 36 additions & 0 deletions tests/transforms_tests/image_tests/test_color_jitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

import numpy as np

from chainer import testing
from chainercv.transforms import color_jitter


class TestColorJitter(unittest.TestCase):

def test_color_jitter_run_data_augmentation(self):
img = 255 * np.random.uniform(size=(3, 48, 32)).astype(np.float32)

out, param = color_jitter(img, return_param=True)
self.assertEqual(out.shape, (3, 48, 32))
self.assertEqual(out.dtype, img.dtype)
self.assertLessEqual(np.max(img), 255)
self.assertGreaterEqual(np.min(img), 0)

self.assertEqual(
sorted(param['order']), ['brightness', 'contrast', 'saturation'])
self.assertIsInstance(param['brightness_alpha'], float)
self.assertIsInstance(param['contrast_alpha'], float)
self.assertIsInstance(param['saturation_alpha'], float)

def test_color_jitter_no_data_augmentation(self):
img = 255 * np.random.uniform(size=(3, 48, 32)).astype(np.float32)

out, param = color_jitter(img, 0, 0, 0, return_param=True)
np.testing.assert_equal(out, img)
self.assertEqual(param['brightness_alpha'], 1)
self.assertEqual(param['contrast_alpha'], 1)
self.assertEqual(param['saturation_alpha'], 1)


testing.run_module(__name__, __file__)