diff --git a/chainercv/utils/image/read_image.py b/chainercv/utils/image/read_image.py index b9c643742b..b60ab06430 100644 --- a/chainercv/utils/image/read_image.py +++ b/chainercv/utils/image/read_image.py @@ -1,5 +1,7 @@ from __future__ import division +import pathlib + import chainer import numpy as np from PIL import Image @@ -87,6 +89,24 @@ def _read_image_pil(file, dtype, color, alpha): return img.transpose((2, 0, 1)) +def _determine_backend(): + if chainer.config.cv_read_image_backend is None: + if _cv2_available: + return 'cv2' + else: + return 'PIL' + elif chainer.config.cv_read_image_backend == 'cv2': + if not _cv2_available: + raise ValueError('cv2 is not installed even though ' + 'chainer.config.cv_read_image_backend == \'cv2\'') + return 'cv2' + elif chainer.config.cv_read_image_backend == 'PIL': + return 'PIL' + else: + raise ValueError('chainer.config.cv_read_image_backend should be ' + 'either "cv2" or "PIL".') + + def read_image(file, dtype=np.float32, color=True, alpha=None): """Read an image from a file. @@ -101,7 +121,7 @@ def read_image(file, dtype=np.float32, color=True, alpha=None): and "PIL" is used when "cv2" is not installed. Args: - file (string or file-like object): A path of image file or + file (string or pathlib.Path or file-like object): A path of image file or a file-like object of image. dtype: The type of array. The default value is :obj:`~numpy.float32`. color (bool): This option determines the number of channels. @@ -121,18 +141,12 @@ def read_image(file, dtype=np.float32, color=True, alpha=None): Returns: ~numpy.ndarray: An image. """ - if chainer.config.cv_read_image_backend is None: - if _cv2_available: - return _read_image_cv2(file, dtype, color, alpha) - else: - return _read_image_pil(file, dtype, color, alpha) - elif chainer.config.cv_read_image_backend == 'cv2': - if not _cv2_available: - raise ValueError('cv2 is not installed even though ' - 'chainer.config.cv_read_image_backend == \'cv2\'') + backend = _determine_backend() + assert backend in ['cv2', 'PIL'], 'Encountered unexpected behavior' + if backend == 'cv2': + # opencv does not support pathlib.Path + if isinstance(file, pathlib.Path): + file = str(file) return _read_image_cv2(file, dtype, color, alpha) - elif chainer.config.cv_read_image_backend == 'PIL': + elif backend == 'PIL': return _read_image_pil(file, dtype, color, alpha) - else: - raise ValueError('chainer.config.cv_read_image_backend should be ' - 'either "cv2" or "PIL".') diff --git a/tests/utils_tests/image_tests/test_read_image.py b/tests/utils_tests/image_tests/test_read_image.py index c1e7352a85..4bdaa82aa1 100644 --- a/tests/utils_tests/image_tests/test_read_image.py +++ b/tests/utils_tests/image_tests/test_read_image.py @@ -1,3 +1,5 @@ +import pathlib + import numpy as np import tempfile import unittest @@ -28,7 +30,7 @@ def _write_rgba_image(rgba, file, format): def _create_parameters(): params = testing.product({ - 'file_obj': [False, True], + 'file_type': ['str', 'pathlib.Path', 'obj'], 'size': [(48, 32)], 'dtype': [np.float32, np.uint8, bool]}) no_color_params = testing.product({ @@ -56,7 +58,7 @@ def _create_parameters(): class TestReadImage(unittest.TestCase): def setUp(self): - if self.file_obj: + if self.file_type == 'obj': self.f = tempfile.TemporaryFile() self.file = self.f format = self.format @@ -67,6 +69,8 @@ def setUp(self): suffix = '.' + self.format self.f = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) self.file = self.f.name + if self.file_type == 'pathlib.Path': + self.file = pathlib.Path(self.file) format = None if self.alpha is None: @@ -82,7 +86,7 @@ def setUp(self): 0, 255, size=(4,) + self.size, dtype=np.uint8) _write_rgba_image(self.img, self.file, format=format) - if self.file_obj: + if self.file_type == 'obj': self.file.seek(0) def test_read_image_as_color(self): @@ -134,7 +138,7 @@ def test_read_image_raise_error_with_cv2(self): class TestReadImageDifferentBackends(unittest.TestCase): def setUp(self): - if self.file_obj: + if self.file_type == 'obj': self.f = tempfile.TemporaryFile() self.file = self.f format = self.format @@ -160,7 +164,7 @@ def setUp(self): 0, 255, size=(4,) + self.size, dtype=np.uint8) _write_rgba_image(self.img, self.file, format=format) - if self.file_obj: + if self.file_type == 'obj': self.file.seek(0) @unittest.skipUnless(_cv2_available, 'cv2 is not installed')