Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions chainercv/utils/image/read_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import division

import pathlib

import chainer
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -87,6 +89,24 @@ def _read_image_pil(file, dtype, color, alpha):
return img.transpose((2, 0, 1))


def _determine_backend():
if chainer.config.cv_read_image_backend is None:
if _cv2_available:
return 'cv2'
else:
return 'PIL'
elif chainer.config.cv_read_image_backend == 'cv2':
if not _cv2_available:
raise ValueError('cv2 is not installed even though '
'chainer.config.cv_read_image_backend == \'cv2\'')
return 'cv2'
elif chainer.config.cv_read_image_backend == 'PIL':
return 'PIL'
else:
raise ValueError('chainer.config.cv_read_image_backend should be '
'either "cv2" or "PIL".')


def read_image(file, dtype=np.float32, color=True, alpha=None):
"""Read an image from a file.

Expand All @@ -101,7 +121,7 @@ def read_image(file, dtype=np.float32, color=True, alpha=None):
and "PIL" is used when "cv2" is not installed.

Args:
file (string or file-like object): A path of image file or
file (string or pathlib.Path or file-like object): A path of image file or
a file-like object of image.
dtype: The type of array. The default value is :obj:`~numpy.float32`.
color (bool): This option determines the number of channels.
Expand All @@ -121,18 +141,12 @@ def read_image(file, dtype=np.float32, color=True, alpha=None):
Returns:
~numpy.ndarray: An image.
"""
if chainer.config.cv_read_image_backend is None:
if _cv2_available:
return _read_image_cv2(file, dtype, color, alpha)
else:
return _read_image_pil(file, dtype, color, alpha)
elif chainer.config.cv_read_image_backend == 'cv2':
if not _cv2_available:
raise ValueError('cv2 is not installed even though '
'chainer.config.cv_read_image_backend == \'cv2\'')
backend = _determine_backend()
assert backend in ['cv2', 'PIL'], 'Encountered unexpected behavior'
if backend == 'cv2':
# opencv does not support pathlib.Path
if isinstance(file, pathlib.Path):
file = str(file)
return _read_image_cv2(file, dtype, color, alpha)
elif chainer.config.cv_read_image_backend == 'PIL':
elif backend == 'PIL':
return _read_image_pil(file, dtype, color, alpha)
else:
raise ValueError('chainer.config.cv_read_image_backend should be '
'either "cv2" or "PIL".')
14 changes: 9 additions & 5 deletions tests/utils_tests/image_tests/test_read_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pathlib

import numpy as np
import tempfile
import unittest
Expand Down Expand Up @@ -28,7 +30,7 @@ def _write_rgba_image(rgba, file, format):

def _create_parameters():
params = testing.product({
'file_obj': [False, True],
'file_type': ['str', 'pathlib.Path', 'obj'],
'size': [(48, 32)],
'dtype': [np.float32, np.uint8, bool]})
no_color_params = testing.product({
Expand Down Expand Up @@ -56,7 +58,7 @@ def _create_parameters():
class TestReadImage(unittest.TestCase):

def setUp(self):
if self.file_obj:
if self.file_type == 'obj':
self.f = tempfile.TemporaryFile()
self.file = self.f
format = self.format
Expand All @@ -67,6 +69,8 @@ def setUp(self):
suffix = '.' + self.format
self.f = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
self.file = self.f.name
if self.file_type == 'pathlib.Path':
self.file = pathlib.Path(self.file)
format = None

if self.alpha is None:
Expand All @@ -82,7 +86,7 @@ def setUp(self):
0, 255, size=(4,) + self.size, dtype=np.uint8)
_write_rgba_image(self.img, self.file, format=format)

if self.file_obj:
if self.file_type == 'obj':
self.file.seek(0)

def test_read_image_as_color(self):
Expand Down Expand Up @@ -134,7 +138,7 @@ def test_read_image_raise_error_with_cv2(self):
class TestReadImageDifferentBackends(unittest.TestCase):

def setUp(self):
if self.file_obj:
if self.file_type == 'obj':
self.f = tempfile.TemporaryFile()
self.file = self.f
format = self.format
Expand All @@ -160,7 +164,7 @@ def setUp(self):
0, 255, size=(4,) + self.size, dtype=np.uint8)
_write_rgba_image(self.img, self.file, format=format)

if self.file_obj:
if self.file_type == 'obj':
self.file.seek(0)

@unittest.skipUnless(_cv2_available, 'cv2 is not installed')
Expand Down