From bcec7cc1b3ac7d28793a51ea547e097ada4f113a Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Tue, 12 Dec 2017 00:29:13 +0900 Subject: [PATCH 1/9] Add FLIC dataset --- chainercv/datasets/__init__.py | 2 + chainercv/datasets/flic/__init__.py | 0 .../datasets/flic/flic_keypoint_dataset.py | 128 ++++++++++++++++++ chainercv/datasets/flic/flic_utils.py | 68 ++++++++++ 4 files changed, 198 insertions(+) create mode 100644 chainercv/datasets/flic/__init__.py create mode 100644 chainercv/datasets/flic/flic_keypoint_dataset.py create mode 100644 chainercv/datasets/flic/flic_utils.py diff --git a/chainercv/datasets/__init__.py b/chainercv/datasets/__init__.py index fc0a0a271f..32d35ae741 100644 --- a/chainercv/datasets/__init__.py +++ b/chainercv/datasets/__init__.py @@ -15,6 +15,8 @@ from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA +from chainercv.datasets.flic.flic_keypoint_dataset import FLICKeypointDataset # NOQA +from chainercv.datasets.flic.flic_utils import flic_joint_label_names # NOQA from chainercv.datasets.online_products.online_products_dataset import online_products_super_label_names # NOQA from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA from chainercv.datasets.transform_dataset import TransformDataset # NOQA diff --git a/chainercv/datasets/flic/__init__.py b/chainercv/datasets/flic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chainercv/datasets/flic/flic_keypoint_dataset.py b/chainercv/datasets/flic/flic_keypoint_dataset.py new file mode 100644 index 0000000000..c63521581a --- /dev/null +++ b/chainercv/datasets/flic/flic_keypoint_dataset.py @@ -0,0 +1,128 @@ +import collections +import glob +import os + +import numpy as np + +import chainer +from chainercv import utils +from chainercv.datasets.flic import flic_utils + +try: + from scipy.io import loadmat + _scipy_available = True +except (ImportError, TypeError): + _scipy_available = False + + +class FLICKeypointDataset(chainer.dataset.DatasetMixin): + + """`Frames Labaled in Cinema (FLIC)`_ dataset with annotated keypoints. + + .. _`Frames Labaled in Cinema (FLIC)`: + https://bensapp.github.io/flic-dataset.html + + An index corresponds to each image. + + When queried by an index, this dataset returns the corresponding + :obj:`img, keypoint`, which is a tuple of an image and keypoints + that indicates visible keypoints in the image. + The data type of the two elements are :obj:`float32, float32`. + + The keypoints are packed into a two dimensional array of shape + :math:`(K, 2)`, where :math:`K` is the number of keypoints. + Note that :math:`K=29` in FLIC dataset. Also note that not all + keypoints are visible in an image. When a keypoint is not visible, + the values stored for that keypoint are :obj:`~numpy.nan`. The second axis + corresponds to the :math:`y` and :math:`x` coordinates of the + keypoints in the image. + + The torso bounding box is a one-dimensional array of shape :math:`(4,)`. + The elements of the bounding box corresponds to + :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, where the four attributes are + coordinates of the top left and the bottom right vertices. + This information can optionally be retrieved from the dataset + by setting :obj:`return_torsobox = True`. + + Args: + data_dir (string): Path to the root of the training data. If this is + :obj:`auto`, this class will automatically download data for you + under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`. + split ({'train', 'test'}): Select from dataset splits used in + the FLIC dataset. + return_torsobox (bool): If :obj:`True`, this returns a bounding box + around the torso. The default value is :obj:`False`. + + """ + + def __init__(self, data_dir='auto', split='train', return_torsobox=False): + super(FLICKeypointDataset, self).__init__() + if split not in ['train', 'test']: + raise ValueError( + '\'split\' argment should be eighter \'train\' or \'test\'.') + + if not _scipy_available: + raise ImportError( + 'scipy is needed to extract labales from the .mat file.' + 'Please install scipy:\n\n' + '\t$pip install scipy\n\n') + + if data_dir == 'auto': + data_dir = flic_utils.get_flic() + + img_paths = {os.path.basename(fn): fn for fn in glob.glob( + os.path.join(data_dir, 'FLIC-full', 'images', '*.jpg'))} + + label_keys = [ + 'poselet_hit_idx', + 'moviename', + 'coords', + 'filepath', + 'imgdims', + 'currframe', + 'torsobox', + 'istrain', + 'istest', + 'isbad', + 'isunchecked', + ] + labels = loadmat(os.path.join(data_dir, 'FLIC-full', 'examples.mat')) + + self.img_paths = [] + self.keypoints = [] + self.torsoboxes = [] + self.return_torsobox = return_torsobox + + for label in labels['examples'][0]: + label = {label_keys[i]: val for i, val in enumerate(label)} + if int(label['isbad']) == 1 or int(label['isunchecked']) == 1: + continue + if ((split == 'train' and int(label['istrain']) == 0) + or (split == 'test' and int(label['istest']) == 0)): + continue + + self.img_paths.append(img_paths[label['filepath'][0]]) + self.keypoints.append(label['coords'].T[:, ::-1]) + if return_torsobox: + self.torsoboxes.append(label['torsobox'][0, [1, 0, 3, 2]]) + + def get_example(self, i): + """Returns the i-th example. + + Args: + i (int): The index of the example. + + Returns: + tuple of an image and keypoints. + The image is in CHW format and its color channel is ordered in + RGB. + If :obj:`return_torsobox = True`, + a bounding box is appended to the returned value. + + """ + img = utils.read_image(self.img_paths[i]) + keypoint = np.array(self.keypoints[i], dtype=np.float32) + if self.return_torsobox: + return img, keypoint, self.torsoboxes[i] + else: + return img, keypoint diff --git a/chainercv/datasets/flic/flic_utils.py b/chainercv/datasets/flic/flic_utils.py new file mode 100644 index 0000000000..8d62edd5dc --- /dev/null +++ b/chainercv/datasets/flic/flic_utils.py @@ -0,0 +1,68 @@ +import os +import shutil + +from chainer.dataset import download +from chainercv import utils + +root = 'pfnet/chainercv/flic' + +urls = [ + 'http://vision.grasp.upenn.edu/video/FLIC-full.zip', + + 'http://cims.nyu.edu/~tompson/data/tr_plus_indices.mat', + +] + +flic_joint_label_names = [ + 'lsho', + 'lelb', + 'lwri', + 'rsho', + 'relb', + 'rwri', + 'lhip', + 'lkne', + 'lank', + 'rhip', + 'rkne', + 'rank', + 'leye', + 'reye', + 'lear', + 'rear', + 'nose', + 'msho', + 'mhip', + 'mear', + 'mtorso', + 'mluarm', + 'mruarm', + 'mllarm', + 'mrlarm', + 'mluleg', + 'mruleg', + 'mllleg', + 'mrlleg' +] + + +def get_flic(): + data_root = download.get_dataset_directory(root) + + if not os.path.exists(os.path.join(data_root, + 'FLIC-full')): + download_file_path = utils.cached_download(urls[0]) + ext = os.path.splitext(urls[0])[1] + utils.extractall(download_file_path, + data_root, + ext) + + if not os.path.exists(os.path.join(data_root, + 'tr_plus_indices.mat')): + download_file_path = utils.cached_download(urls[1]) + shutil.copy(download_file_path, + + os.path.join(data_root, + 'tr_plus_indices.mat')) + + return data_root From c72a0ac600aefe899148ae8c3076f54f96c0241f Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Tue, 12 Dec 2017 00:33:14 +0900 Subject: [PATCH 2/9] Add __len__ --- chainercv/datasets/flic/flic_keypoint_dataset.py | 3 +++ chainercv/datasets/flic/flic_utils.py | 8 -------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/chainercv/datasets/flic/flic_keypoint_dataset.py b/chainercv/datasets/flic/flic_keypoint_dataset.py index c63521581a..0a0ff59599 100644 --- a/chainercv/datasets/flic/flic_keypoint_dataset.py +++ b/chainercv/datasets/flic/flic_keypoint_dataset.py @@ -106,6 +106,9 @@ def __init__(self, data_dir='auto', split='train', return_torsobox=False): if return_torsobox: self.torsoboxes.append(label['torsobox'][0, [1, 0, 3, 2]]) + def __len__(self): + return len(self.img_paths) + def get_example(self, i): """Returns the i-th example. diff --git a/chainercv/datasets/flic/flic_utils.py b/chainercv/datasets/flic/flic_utils.py index 8d62edd5dc..fbce1c50c0 100644 --- a/chainercv/datasets/flic/flic_utils.py +++ b/chainercv/datasets/flic/flic_utils.py @@ -57,12 +57,4 @@ def get_flic(): data_root, ext) - if not os.path.exists(os.path.join(data_root, - 'tr_plus_indices.mat')): - download_file_path = utils.cached_download(urls[1]) - shutil.copy(download_file_path, - - os.path.join(data_root, - 'tr_plus_indices.mat')) - return data_root From 0a9d46bccdf66a81b8e9786194f064e767595aa5 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Tue, 12 Dec 2017 00:44:20 +0900 Subject: [PATCH 3/9] Add skip_bad and skip_unchecked arguments --- chainercv/datasets/flic/flic_keypoint_dataset.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/chainercv/datasets/flic/flic_keypoint_dataset.py b/chainercv/datasets/flic/flic_keypoint_dataset.py index 0a0ff59599..ba9018dd89 100644 --- a/chainercv/datasets/flic/flic_keypoint_dataset.py +++ b/chainercv/datasets/flic/flic_keypoint_dataset.py @@ -52,10 +52,15 @@ class FLICKeypointDataset(chainer.dataset.DatasetMixin): the FLIC dataset. return_torsobox (bool): If :obj:`True`, this returns a bounding box around the torso. The default value is :obj:`False`. + skip_bad (bool): If :obj:`True`, the data which have :obj:`isbad = 1` + will be ignored. The default is :obj:`True`. + skip_unchecked (bool): If :obj:`True`, the data which have + :obj:`isunchecked = 1` will be ignored. The default is :obj:`True`. """ - def __init__(self, data_dir='auto', split='train', return_torsobox=False): + def __init__(self, data_dir='auto', split='train', return_torsobox=False, + skip_bad=True, skip_unchecked=True): super(FLICKeypointDataset, self).__init__() if split not in ['train', 'test']: raise ValueError( @@ -95,7 +100,9 @@ def __init__(self, data_dir='auto', split='train', return_torsobox=False): for label in labels['examples'][0]: label = {label_keys[i]: val for i, val in enumerate(label)} - if int(label['isbad']) == 1 or int(label['isunchecked']) == 1: + if skip_bad and int(label['isbad']) == 1: + continue + if skip_unchecked and int(label['isunchecked']) == 1: continue if ((split == 'train' and int(label['istrain']) == 0) or (split == 'test' and int(label['istest']) == 0)): From 6921c7290cea5676b262b92f70640be1c28d67fa Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Tue, 12 Dec 2017 22:39:31 +0900 Subject: [PATCH 4/9] Update flic_utils.py --- chainercv/datasets/flic/flic_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chainercv/datasets/flic/flic_utils.py b/chainercv/datasets/flic/flic_utils.py index fbce1c50c0..0c863066f1 100644 --- a/chainercv/datasets/flic/flic_utils.py +++ b/chainercv/datasets/flic/flic_utils.py @@ -1,5 +1,4 @@ import os -import shutil from chainer.dataset import download from chainercv import utils From aa6d40d6125714aa0e06744f941aaa7d4e6008cf Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Wed, 13 Dec 2017 11:32:15 +0900 Subject: [PATCH 5/9] Fix styles --- .../datasets/flic/flic_keypoint_dataset.py | 9 +-- chainercv/datasets/imagenet/__init__.py | 0 .../imagenet/imagenet_label_dataset.py | 73 +++++++++++++++++++ chainercv/datasets/imagenet/imagenet_utils.py | 22 ++++++ chainercv/datasets/lsp/__init__.py | 0 .../datasets/lsp/lsp_keypoint_dataset.py | 0 6 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 chainercv/datasets/imagenet/__init__.py create mode 100644 chainercv/datasets/imagenet/imagenet_label_dataset.py create mode 100644 chainercv/datasets/imagenet/imagenet_utils.py create mode 100644 chainercv/datasets/lsp/__init__.py create mode 100644 chainercv/datasets/lsp/lsp_keypoint_dataset.py diff --git a/chainercv/datasets/flic/flic_keypoint_dataset.py b/chainercv/datasets/flic/flic_keypoint_dataset.py index ba9018dd89..f2be12a531 100644 --- a/chainercv/datasets/flic/flic_keypoint_dataset.py +++ b/chainercv/datasets/flic/flic_keypoint_dataset.py @@ -1,12 +1,11 @@ -import collections import glob import os import numpy as np import chainer -from chainercv import utils from chainercv.datasets.flic import flic_utils +from chainercv import utils try: from scipy.io import loadmat @@ -93,9 +92,9 @@ def __init__(self, data_dir='auto', split='train', return_torsobox=False, ] labels = loadmat(os.path.join(data_dir, 'FLIC-full', 'examples.mat')) - self.img_paths = [] - self.keypoints = [] - self.torsoboxes = [] + self.img_paths = list() + self.keypoints = list() + self.torsoboxes = list() self.return_torsobox = return_torsobox for label in labels['examples'][0]: diff --git a/chainercv/datasets/imagenet/__init__.py b/chainercv/datasets/imagenet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chainercv/datasets/imagenet/imagenet_label_dataset.py b/chainercv/datasets/imagenet/imagenet_label_dataset.py new file mode 100644 index 0000000000..bbc75ffb06 --- /dev/null +++ b/chainercv/datasets/imagenet/imagenet_label_dataset.py @@ -0,0 +1,73 @@ +import numpy as np +import os + +import chainer +from chainercv import utils +from chainercv.datasets.imagenet import imagenet_utils + + +class ImageNetLabelDataset(chainer.dataset.DatasetMixin): + + """`ImageNet`_ dataset with annotated class labels. + + .. _`ImageNet`: + http://image-net.org/challenges/LSVRC/2015/download-images-3j16.php + + When queried by an index, this dataset returns a corresponding + :obj:`img, label`, a tuple of an image and class id. + The image is in RGB and CHW format. + The class id is between 0 and 999. + + Args: + data_dir (string): Path to the root of the training data. If this is + :obj:`auto`, this class will automatically download data for you + under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/imagenet`. + + """ + + def __init__(self, data_dir='auto'): + super(ImageNetLabelDataset, self).__init__() + if data_dir == 'auto': + + + image_class_labels_file = os.path.join( + self.data_dir, 'image_class_labels.txt') + labels = [int(d_label.split()[1]) - 1 for + d_label in open(image_class_labels_file)] + self._labels = np.array(labels, dtype=np.int32) + + def get_example(self, i): + """Returns the i-th example. + + Args: + i (int): The index of the example. + + Returns: + tuple of an image and its label. + The image is in CHW format and its color channel is ordered in + RGB. + If :obj:`return_bb = True`, + a bounding box is appended to the returned value. + If :obj:`return_mask = True`, + a probability map is appended to the returned value. + + """ + img = utils.read_image( + os.path.join(self.data_dir, 'images', self.paths[i]), + color=True) + label = self._labels[i] + + if not self.return_prob_map: + if self.return_bb: + return img, label, self.bbs[i] + else: + return img, label + + prob_map = utils.read_image(self.prob_map_paths[i], + dtype=np.uint8, color=False) + prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1] + prob_map = prob_map[0] # (1, H, W) --> (H, W) + if self.return_bb: + return img, label, self.bbs[i], prob_map + else: + return img, label, prob_map diff --git a/chainercv/datasets/imagenet/imagenet_utils.py b/chainercv/datasets/imagenet/imagenet_utils.py new file mode 100644 index 0000000000..cdd54ed44a --- /dev/null +++ b/chainercv/datasets/imagenet/imagenet_utils.py @@ -0,0 +1,22 @@ +from chainer.dataset import download + +from chainercv import utils + + +urls = [ + 'cls_loc': 'http://image-net.org/image/ILSVRC2015/' + 'ILSVRC2015_CLS-LOC.tar.gz', + 'det': 'http://image-net.org/image/ILSVRC2015/ILSVRC2015_DET.tar.gz', + 'det_test': 'http://image-net.org/image/ILSVRC2015/' + 'ILSVRC2015_DET_test.tar.gz', + 'det_test_new': 'http://image-net.org/image/ILSVRC2015/' + 'ILSVRC2015_DET_test_new.tar.gz' + 'info': 'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz' +] + +def get_imagenet(): + download_file_path = utils.cached_download(url) + ext = os.path.splitext(url)[1] + utils.extractall(download_file_path, data_root, ext) + return base_path + diff --git a/chainercv/datasets/lsp/__init__.py b/chainercv/datasets/lsp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chainercv/datasets/lsp/lsp_keypoint_dataset.py b/chainercv/datasets/lsp/lsp_keypoint_dataset.py new file mode 100644 index 0000000000..e69de29bb2 From 708cc1d18fa3fde99979c40b5204d34ce54172c1 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Wed, 13 Dec 2017 12:09:13 +0900 Subject: [PATCH 6/9] Remove unrelated files --- chainercv/datasets/imagenet/__init__.py | 0 .../imagenet/imagenet_label_dataset.py | 73 ------------------- chainercv/datasets/imagenet/imagenet_utils.py | 22 ------ 3 files changed, 95 deletions(-) delete mode 100644 chainercv/datasets/imagenet/__init__.py delete mode 100644 chainercv/datasets/imagenet/imagenet_label_dataset.py delete mode 100644 chainercv/datasets/imagenet/imagenet_utils.py diff --git a/chainercv/datasets/imagenet/__init__.py b/chainercv/datasets/imagenet/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/chainercv/datasets/imagenet/imagenet_label_dataset.py b/chainercv/datasets/imagenet/imagenet_label_dataset.py deleted file mode 100644 index bbc75ffb06..0000000000 --- a/chainercv/datasets/imagenet/imagenet_label_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -import os - -import chainer -from chainercv import utils -from chainercv.datasets.imagenet import imagenet_utils - - -class ImageNetLabelDataset(chainer.dataset.DatasetMixin): - - """`ImageNet`_ dataset with annotated class labels. - - .. _`ImageNet`: - http://image-net.org/challenges/LSVRC/2015/download-images-3j16.php - - When queried by an index, this dataset returns a corresponding - :obj:`img, label`, a tuple of an image and class id. - The image is in RGB and CHW format. - The class id is between 0 and 999. - - Args: - data_dir (string): Path to the root of the training data. If this is - :obj:`auto`, this class will automatically download data for you - under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/imagenet`. - - """ - - def __init__(self, data_dir='auto'): - super(ImageNetLabelDataset, self).__init__() - if data_dir == 'auto': - - - image_class_labels_file = os.path.join( - self.data_dir, 'image_class_labels.txt') - labels = [int(d_label.split()[1]) - 1 for - d_label in open(image_class_labels_file)] - self._labels = np.array(labels, dtype=np.int32) - - def get_example(self, i): - """Returns the i-th example. - - Args: - i (int): The index of the example. - - Returns: - tuple of an image and its label. - The image is in CHW format and its color channel is ordered in - RGB. - If :obj:`return_bb = True`, - a bounding box is appended to the returned value. - If :obj:`return_mask = True`, - a probability map is appended to the returned value. - - """ - img = utils.read_image( - os.path.join(self.data_dir, 'images', self.paths[i]), - color=True) - label = self._labels[i] - - if not self.return_prob_map: - if self.return_bb: - return img, label, self.bbs[i] - else: - return img, label - - prob_map = utils.read_image(self.prob_map_paths[i], - dtype=np.uint8, color=False) - prob_map = prob_map.astype(np.float32) / 255 # [0, 255] -> [0, 1] - prob_map = prob_map[0] # (1, H, W) --> (H, W) - if self.return_bb: - return img, label, self.bbs[i], prob_map - else: - return img, label, prob_map diff --git a/chainercv/datasets/imagenet/imagenet_utils.py b/chainercv/datasets/imagenet/imagenet_utils.py deleted file mode 100644 index cdd54ed44a..0000000000 --- a/chainercv/datasets/imagenet/imagenet_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -from chainer.dataset import download - -from chainercv import utils - - -urls = [ - 'cls_loc': 'http://image-net.org/image/ILSVRC2015/' - 'ILSVRC2015_CLS-LOC.tar.gz', - 'det': 'http://image-net.org/image/ILSVRC2015/ILSVRC2015_DET.tar.gz', - 'det_test': 'http://image-net.org/image/ILSVRC2015/' - 'ILSVRC2015_DET_test.tar.gz', - 'det_test_new': 'http://image-net.org/image/ILSVRC2015/' - 'ILSVRC2015_DET_test_new.tar.gz' - 'info': 'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz' -] - -def get_imagenet(): - download_file_path = utils.cached_download(url) - ext = os.path.splitext(url)[1] - utils.extractall(download_file_path, data_root, ext) - return base_path - From 562fafd92130775211a0b672f61780fa87c22c64 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 28 Dec 2017 00:41:35 +0900 Subject: [PATCH 7/9] Remove lsp files --- chainercv/datasets/lsp/__init__.py | 0 chainercv/datasets/lsp/lsp_keypoint_dataset.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chainercv/datasets/lsp/__init__.py delete mode 100644 chainercv/datasets/lsp/lsp_keypoint_dataset.py diff --git a/chainercv/datasets/lsp/__init__.py b/chainercv/datasets/lsp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/chainercv/datasets/lsp/lsp_keypoint_dataset.py b/chainercv/datasets/lsp/lsp_keypoint_dataset.py deleted file mode 100644 index e69de29bb2..0000000000 From ce8dbcd72e5b9cd65b4ed03b20275f139626f4f0 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 28 Dec 2017 00:50:35 +0900 Subject: [PATCH 8/9] Reflect reviews --- chainercv/datasets/__init__.py | 2 +- .../datasets/flic/flic_keypoint_dataset.py | 27 ++++++++++--------- chainercv/datasets/flic/flic_utils.py | 24 ++++++----------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/chainercv/datasets/__init__.py b/chainercv/datasets/__init__.py index 32d35ae741..1193802bb4 100644 --- a/chainercv/datasets/__init__.py +++ b/chainercv/datasets/__init__.py @@ -16,7 +16,7 @@ from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA from chainercv.datasets.flic.flic_keypoint_dataset import FLICKeypointDataset # NOQA -from chainercv.datasets.flic.flic_utils import flic_joint_label_names # NOQA +from chainercv.datasets.flic.flic_utils import flic_joint_names # NOQA from chainercv.datasets.online_products.online_products_dataset import online_products_super_label_names # NOQA from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA from chainercv.datasets.transform_dataset import TransformDataset # NOQA diff --git a/chainercv/datasets/flic/flic_keypoint_dataset.py b/chainercv/datasets/flic/flic_keypoint_dataset.py index f2be12a531..2d41e3f0e8 100644 --- a/chainercv/datasets/flic/flic_keypoint_dataset.py +++ b/chainercv/datasets/flic/flic_keypoint_dataset.py @@ -46,20 +46,21 @@ class FLICKeypointDataset(chainer.dataset.DatasetMixin): Args: data_dir (string): Path to the root of the training data. If this is :obj:`auto`, this class will automatically download data for you - under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`. + under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/flic/FLIC-full`. split ({'train', 'test'}): Select from dataset splits used in the FLIC dataset. return_torsobox (bool): If :obj:`True`, this returns a bounding box around the torso. The default value is :obj:`False`. - skip_bad (bool): If :obj:`True`, the data which have :obj:`isbad = 1` - will be ignored. The default is :obj:`True`. - skip_unchecked (bool): If :obj:`True`, the data which have - :obj:`isunchecked = 1` will be ignored. The default is :obj:`True`. + use_bad (bool): If :obj:`False`, the data which have :obj:`isbad = 1` + will be ignored. The default is :obj:`False`. + use_unchecked (bool): If :obj:`False`, the data which have + :obj:`isunchecked = 1` will be ignored. The default is + :obj:`False`. """ def __init__(self, data_dir='auto', split='train', return_torsobox=False, - skip_bad=True, skip_unchecked=True): + use_bad=False, use_unchecked=False): super(FLICKeypointDataset, self).__init__() if split not in ['train', 'test']: raise ValueError( @@ -75,9 +76,9 @@ def __init__(self, data_dir='auto', split='train', return_torsobox=False, data_dir = flic_utils.get_flic() img_paths = {os.path.basename(fn): fn for fn in glob.glob( - os.path.join(data_dir, 'FLIC-full', 'images', '*.jpg'))} + os.path.join(data_dir, 'images', '*.jpg'))} - label_keys = [ + label_annos = [ 'poselet_hit_idx', 'moviename', 'coords', @@ -90,18 +91,18 @@ def __init__(self, data_dir='auto', split='train', return_torsobox=False, 'isbad', 'isunchecked', ] - labels = loadmat(os.path.join(data_dir, 'FLIC-full', 'examples.mat')) + annos = loadmat(os.path.join(data_dir, 'examples.mat')) self.img_paths = list() self.keypoints = list() self.torsoboxes = list() self.return_torsobox = return_torsobox - for label in labels['examples'][0]: - label = {label_keys[i]: val for i, val in enumerate(label)} - if skip_bad and int(label['isbad']) == 1: + for label in annos['examples'][0]: + label = {label_annos[i]: val for i, val in enumerate(label)} + if not use_bad and int(label['isbad']) == 1: continue - if skip_unchecked and int(label['isunchecked']) == 1: + if not use_unchecked and int(label['isunchecked']) == 1: continue if ((split == 'train' and int(label['istrain']) == 0) or (split == 'test' and int(label['istest']) == 0)): diff --git a/chainercv/datasets/flic/flic_utils.py b/chainercv/datasets/flic/flic_utils.py index 0c863066f1..cdd389cd41 100644 --- a/chainercv/datasets/flic/flic_utils.py +++ b/chainercv/datasets/flic/flic_utils.py @@ -5,14 +5,9 @@ root = 'pfnet/chainercv/flic' -urls = [ - 'http://vision.grasp.upenn.edu/video/FLIC-full.zip', +url = 'http://vision.grasp.upenn.edu/video/FLIC-full.zip' - 'http://cims.nyu.edu/~tompson/data/tr_plus_indices.mat', - -] - -flic_joint_label_names = [ +flic_joint_names = [ 'lsho', 'lelb', 'lwri', @@ -47,13 +42,10 @@ def get_flic(): data_root = download.get_dataset_directory(root) + dataset_dir = os.path.join(data_root, 'FLIC-full') + if not os.path.exists(dataset_dir): + download_file_path = utils.cached_download(url) + ext = os.path.splitext(url)[1] + utils.extractall(download_file_path, data_root, ext) - if not os.path.exists(os.path.join(data_root, - 'FLIC-full')): - download_file_path = utils.cached_download(urls[0]) - ext = os.path.splitext(urls[0])[1] - utils.extractall(download_file_path, - data_root, - ext) - - return data_root + return dataset_dir From b1e7afd104da399c06afca09b39ff7aa15483917 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Thu, 28 Dec 2017 00:52:02 +0900 Subject: [PATCH 9/9] Add to reference --- docs/source/reference/datasets.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/reference/datasets.rst b/docs/source/reference/datasets.rst index 853d77bc0e..d82a1a28ae 100644 --- a/docs/source/reference/datasets.rst +++ b/docs/source/reference/datasets.rst @@ -54,6 +54,12 @@ CUBKeypointDataset ~~~~~~~~~~~~~~~~~~~ .. autoclass:: CUBKeypointDataset +FLIC +---- + +FLICKeypointDataset +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FLICKeypointDataset OnlineProducts --------------