diff --git a/chainercv/datasets/__init__.py b/chainercv/datasets/__init__.py index a810d4b2db..c4db3fedf3 100644 --- a/chainercv/datasets/__init__.py +++ b/chainercv/datasets/__init__.py @@ -19,6 +19,13 @@ from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA +from chainercv.datasets.kitti.kitti_bbox_dataset import KITTIBboxDataset # NOQA +from chainercv.datasets.kitti.kitti_utils import kitti_bbox_label_colors # NOQA +from chainercv.datasets.kitti.kitti_utils import kitti_bbox_label_names # NOQA +from chainercv.datasets.kitti.kitti_utils import kitti_date_lists # NOQA +from chainercv.datasets.kitti.kitti_utils import kitti_date_num_dicts # NOQA +from chainercv.datasets.kitti.kitti_utils import kitti_ignore_bbox_label_color # NOQA +from chainercv.datasets.kitti.parseTrackletXML import parseXML # NOQA from chainercv.datasets.mixup_soft_label_dataset import MixUpSoftLabelDataset # NOQA from chainercv.datasets.online_products.online_products_dataset import online_products_super_label_names # NOQA from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA diff --git a/chainercv/datasets/kitti/__init__.py b/chainercv/datasets/kitti/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chainercv/datasets/kitti/kitti_bbox_dataset.py b/chainercv/datasets/kitti/kitti_bbox_dataset.py new file mode 100644 index 0000000000..655d10f901 --- /dev/null +++ b/chainercv/datasets/kitti/kitti_bbox_dataset.py @@ -0,0 +1,219 @@ +import os +import warnings + +from pkg_resources import get_distribution +from pkg_resources import parse_version + +import numpy as np +try: + import pykitti + pykitti_version = get_distribution('pykitti').version + if parse_version(pykitti_version) >= parse_version('0.3.0'): + # pykitti>=0.3.0 + _available = True + else: + # pykitti<0.3.0 + warnings.warn('not support pykitti version : ' + pykitti_version) + _available = False +except ImportError: + _available = False + +from chainercv.chainer_experimental.datasets.sliceable import GetterDataset +from chainercv.datasets.kitti.kitti_utils import get_kitti_label +from chainercv.datasets.kitti.kitti_utils import get_kitti_nosync_data +from chainercv.datasets.kitti.kitti_utils import get_kitti_sync_data +from chainercv.datasets.kitti.kitti_utils import get_kitti_tracklets +from chainercv.datasets.kitti.kitti_utils import kitti_date_lists +from chainercv.datasets.kitti.kitti_utils import kitti_date_num_dicts + + +def _check_available(): + if not _available: + raise ValueError( + 'pykitti is not installed in your environment,' + 'so the dataset cannot be loaded.' + 'Please install pykitti to load dataset.\n\n' + '$ pip install pykitti>=0.3.0') + + +class KITTIBboxDataset(GetterDataset): + + """Bounding box dataset for `KITTI dataset`_. + + .. _`KITTI dataset`: http://www.cvlibs.net/datasets/kitti/raw_data.php + + Args: + data_dir (string): Path to the root of the training data. If this is + :obj:`auto`, this class will automatically download data for you + under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/kitti`. + date ({'2011_09_26', '2011_09_28', '2011_09_29', + '2011_09_30', '2011_10_03'}): + reference Calibration datas. + drive_num ({'0xxx'}): get datas drive No. + sync (bool): get timer sync/nosync data. + is_left (bool): left/right camera image use 2type. + tracklet (bool): 3d bblox data. date only 2011_09_26. + + This dataset returns the following data. + + .. csv-table:: + :header: name, shape, dtype, format + + :obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \ + "RGB, :math:`[0, 255]`" + :obj:`bbox` [#kitti_bbox_1]_, ":math:`(R, 4)`", :obj:`float32`, \ + ":math:`(y_{min}, x_{min}, y_{max}, x_{max})`" + :obj:`label`, scalar, :obj:`int32`, ":math:`[0, \#class - 1]`" + + .. [#kitti_bbox_1] If :obj:`tracklet = True`, \ + :obj:`bbox` and :obj:`label` contain crowded instances. + + When queried by an index, if :obj:`tracklet == True`, + this dataset returns a corresponding + :obj:`img, bbox, label`, a tuple of an image, bounding boxes, labels. + + Please see more detail in the Fig. 6 of the summary paper [#]_. + + .. [#] Andreas Geiger and Philip Lenz \ + and Christoph Stiller and Raquel Urtasun. \ + `Vision meets Robotics: The KITTI Dataset \ + `_. \ + Geiger2013IJRR. + + """ + + def __init__(self, data_dir='auto', date='', drive_num='', + sync=True, is_left=True, tracklet=False): + super(KITTIBboxDataset, self).__init__() + + _check_available() + + self.sync = sync + self.is_left = is_left + + if date not in kitti_date_lists: + raise ValueError('\'date\' argment must be one of the ' + + str(kitti_date_lists) + 'values.') + + # date(key map) + # if drive_num not in ['0001', '0002', ...]: + if drive_num not in kitti_date_num_dicts[date]: + raise ValueError('\'drive_num\' argment must be one of the ' + + str(kitti_date_num_dicts[date]) + 'values.') + + if date == '2011_09_26': + self.tracklet = tracklet + else: + self.tracklet = False + + if data_dir == 'auto': + if sync is True: + # download sync data + data_dir = get_kitti_sync_data( + os.path.join('pfnet', 'chainercv', 'kitti'), + date, drive_num, self.tracklet) + else: + # download nosync data + data_dir = get_kitti_nosync_data( + os.path.join('pfnet', 'chainercv', 'kitti'), + date, drive_num, self.tracklet) + + if not os.path.exists(data_dir) or not os.path.exists(data_dir): + raise ValueError( + 'kitti dataset does not exist at the expected location.' + 'Please download it from http://www.cvlibs.net/datasets/kitti/' + 'Then place directory at {}.' + .format(os.path.join(data_dir, date + '_drive_' + drive_num))) + + # use pykitti + self.dataset = pykitti.raw( + data_dir, date, drive_num, frames=None, imformat='cv2') + + # current camera calibration R/P settings. + if self.is_left is True: + # img02 + self.cur_rotation_matrix = self.dataset.calib.R_rect_20 + self.cur_position_matrix = self.dataset.calib.P_rect_20 + # pykitti>=0.3.0 + # get PIL Image + # convert from PIL.Image to numpy + dataArray = [] + for cam2 in self.dataset.cam2: + data = np.asarray(cam2) + # Convert RGB to BGR + if len(data.shape) > 2: + data = data[:, :, ::-1] + dataArray.append(data) + + self.imgs = dataArray + pass + else: + # img03 + self.cur_rotation_matrix = self.dataset.calib.R_rect_30 + self.cur_position_matrix = self.dataset.calib.P_rect_30 + # pykitti>=0.3.0 + # get PIL Image + # convert from PIL.Image to numpy + dataArray = [] + for cam2 in self.dataset.cam2: + data = np.asarray(cam2) + # Convert RGB to BGR + if len(data.shape) > 2: + data = data[:, :, ::-1] + dataArray.append(data) + + self.imgs = dataArray + pass + + # get object info(type/area/bbox/...) + if self.tracklet is True: + self.tracklets = get_kitti_tracklets(data_dir, date, drive_num) + else: + self.tracklets = None + + self.bboxes, self.labels = get_kitti_label( + self.tracklets, self.dataset.calib, + self.cur_rotation_matrix, self.cur_position_matrix, + self.__len__()) + + self.add_getter('img', self._get_image) + self.add_getter(['bbox', 'label'], self._get_annotations) + keys = ('img', 'bbox', 'label') + self.keys = keys + + def __len__(self): + return len(self.imgs) + + def _get_image(self, i): + img = self.imgs[i] + # convert data is utils.read_image function return values + if img.ndim == 2: + # reshape (H, W) -> (1, H, W) + return img[np.newaxis] + else: + # pykitti img data + # transpose (H, W, C) -> (C, H, W) + return img.transpose((2, 0, 1)) + + def _get_annotations(self, i): + bbox = self.bboxes + label = self.labels + + # convert list to ndarray + if len(bbox[i]) == 0: + # NG + # bbox[i] = [[0.0, 0.0, 0.0, 0.0]] + # Data Padding(Pass Bbox Test) + # bbox[i] = [[0.0, 0.0, 0.01, 0.01]] + np_bbox = np.zeros((0, 4), dtype=np.float32) + else: + np_bbox = np.array(bbox[i], dtype=np.float32) + + if len(label[i]) == 0: + # Data Padding(Pass Bbox Test) + # label[i] = [0] + np_label = np.zeros(0, dtype=np.int32) + else: + np_label = np.array(label[i], dtype=np.int32) + + return np_bbox, np_label diff --git a/chainercv/datasets/kitti/kitti_utils.py b/chainercv/datasets/kitti/kitti_utils.py new file mode 100644 index 0000000000..fd4670d869 --- /dev/null +++ b/chainercv/datasets/kitti/kitti_utils.py @@ -0,0 +1,285 @@ +import os +try: + # 3.x + from urllib.parse import urljoin +except ImportError: + # 2.7 + from urlparse import urljoin + +from chainer.dataset import download + +from chainercv.datasets.kitti import parseTrackletXML as xmlParser +from chainercv import utils + +import numpy as np + +# root = 'pfnet/chainercv/kitti' +url_base = 'https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/' + + +def get_kitti_sync_data(root, date, drive_num, tracklet): + data_root = download.get_dataset_directory(root) + + # data + folder = date + '_drive_' + drive_num + url_data = urljoin(url_base, folder + '/' + folder + '_sync.zip') + + # calibration + url_calib = url_base + date + '_calib.zip' + + download_file_path = utils.cached_download(url_data) + ext = os.path.splitext(url_data)[1] + utils.extractall(download_file_path, data_root, ext) + + download_file_path = utils.cached_download(url_calib) + ext = os.path.splitext(url_calib)[1] + utils.extractall(download_file_path, data_root, ext) + + if tracklet is True: + # tracklet + url_tracklet = \ + urljoin(url_base, folder + '/' + folder + '_tracklets.zip') + + download_file_path = utils.cached_download(url_tracklet) + ext = os.path.splitext(url_tracklet)[1] + utils.extractall(download_file_path, data_root, ext) + + return data_root + + +def get_kitti_nosync_data(root, date, drive_num, tracklet): + data_root = download.get_dataset_directory(root) + + # data + folder = date + '_drive_' + drive_num + url_data = urljoin(url_base, folder + '/' + folder + '_extract.zip') + + # calibration + url_calib = url_base + date + '_calib.zip' + + download_file_path = utils.cached_download(url_data) + ext = os.path.splitext(url_data)[1] + utils.extractall(download_file_path, data_root, ext) + + download_file_path = utils.cached_download(url_calib) + ext = os.path.splitext(url_calib)[1] + utils.extractall(download_file_path, data_root, ext) + + if tracklet is True: + # tracklet + url_tracklet = \ + urljoin(url_base, folder + '/' + folder + '_tracklets.zip') + + download_file_path = utils.cached_download(url_tracklet) + ext = os.path.splitext(url_tracklet)[1] + utils.extractall(download_file_path, data_root, ext) + + return data_root + + +def get_kitti_tracklets(data_root, date, drive_num): + # read calibration files + kitti_dir = os.path.join(data_root, date) + + # read tracklet + folder = date + '_drive_' + drive_num + '_sync' + + # read tracklets from file + tracklet_filepath = os.path.join(kitti_dir, folder, 'tracklet_labels.xml') + tracklets = xmlParser.parseXML(tracklet_filepath) + return tracklets + + +def get_kitti_label(tracklets, calib, + cur_rotation_matrix, cur_position_matrix, + framelength): + # set list + bboxes = [0] * framelength + labels = [0] * framelength + for idx in range(0, framelength): + bboxes[idx] = [] + labels[idx] = [] + + if tracklets is None: + return bboxes, labels + + # loop over tracklets + for iTracklet, tracklet in enumerate(tracklets): + # this part is inspired by kitti object development kit + # matlab code: computeBox3D + # h: height + # w: width + # lg : length + h, w, lg = tracklet.size + # in velodyne coordinates around zero point and without orientation yet + tracklet_box = np.array([ + [-lg/2, -lg/2, lg/2, lg/2, -lg/2, -lg/2, lg/2, lg/2], + [w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], + [0.0, 0.0, 0.0, 0.0, h, h, h, h]]) + + objtype_str = tracklet.objectType + + # loop over all data in tracklet + for translation, rotation, state, occlusion, truncation, \ + amtOcclusion, amtBorders, absoluteFrameNumber in tracklet: + + # determine if object is in the image; otherwise continue + if truncation not in (xmlParser.TRUNC_IN_IMAGE, + xmlParser.TRUNC_TRUNCATED): + continue + + # re-create 3D bounding box in velodyne coordinate system + # other rotations are 0 in all xml files I checked + yaw = rotation[2] + assert np.abs(rotation[:2]).sum( + ) == 0, 'object rotations other than yaw given!' + rot_mat = np.array([ + [np.cos(yaw), -np.sin(yaw), 0.0], + [np.sin(yaw), np.cos(yaw), 0.0], + [0.0, 0.0, 1.0]]) + cornerpos_in_velo = np.dot( + rot_mat, tracklet_box) + np.tile(translation, (8, 1)).T + + # calc yaw as seen from the camera + # (i.e. 0 degree = facing away from cam), + # as opposed to car-centered yaw + # (i.e. 0 degree = same orientation as car). + # makes quite a difference for objects in periphery! + # Result is in [0, 2pi] + x, y, z = translation + + # yawVisual = ( yaw - np.arctan2(y, x) ) % twoPi + # param = pykitti.utils.transform_from_rot_trans( + # rot_mat, translation) + + # projection to image? + # param3 = translation.reshape(3, 1) * calib.P_rect_20 + pt3d = np.vstack((cornerpos_in_velo[:, 0:8], np.ones(8))) + pt2d = project_velo_points_in_img( + pt3d, calib.T_cam2_velo, + cur_rotation_matrix, cur_position_matrix) + + xmin = min(pt2d[0, :]) + xmax = max(pt2d[0, :]) + ymin = min(pt2d[1, :]) + ymax = max(pt2d[1, :]) + if xmin < 0.0: + xmin = 0.0 + if ymin < 0.0: + ymin = 0.0 + if xmax < 0.0: + xmax = 0.0 + if ymax < 0.0: + ymax = 0.0 + + # img_size_x = img_size[0] + # img_size_y = img_size[1] + # image_shape = 375, 1242 + if xmin > 1242.0: + xmin = 1242.0 + if ymin > 375.0: + ymin = 375.0 + if xmax > 1242.0: + xmax = 1242.0 + if ymax > 375.0: + ymax = 375.0 + + param = np.array((ymin, xmin, ymax, xmax), dtype=np.float32) + bboxes[absoluteFrameNumber].append(param) + + # not search objtype_str? process + param2 = kitti_bbox_label_names.index(objtype_str) + labels[absoluteFrameNumber].append(param2) + + # end : for all frames in track + # end : for all tracks + + return bboxes, labels + + +def project_velo_points_in_img(pts3d, transform_cam_velo, + rotaion_matrix, position_matrix): + """Project 3D points into 2D imag e. Expects pts3d as a 4xN numpy array. + + Returns the 2D projection of the points that + are in front of the camera only an the corresponding 3D points. + """ + # 3D points in camera reference frame. + pts3d_cam = rotaion_matrix.dot(transform_cam_velo.dot(pts3d)) + + # Before projecting, keep only points with z > 0 + # (points that are in fronto of the camera). + idx = (pts3d_cam[2, :] >= 0) + pts2d_cam = position_matrix.dot(pts3d_cam[:, idx]) + + # return pts3d[:, idx], pts2d_cam / pts2d_cam[2,:] + return pts2d_cam / pts2d_cam[2, :] + + +# image_shape = 375, 1242 +# kitti_category_names = ( +# 'City', +# 'Residential', +# 'Road', +# 'Campus', +# 'Person', +# 'Calibration' +# ) + +kitti_bbox_label_names = ( + 'Car', + 'Van', + 'Truck', + 'Pedestrian', + 'Sitter', + 'Cyclist', + 'Tram', + 'Misc', +) + +kitti_bbox_label_colors = ( + (128, 128, 128), + (128, 0, 0), + (192, 192, 128), + (128, 64, 128), + (60, 40, 222), + (128, 128, 0), + (192, 128, 128), + (64, 64, 128), +) +kitti_ignore_bbox_label_color = (0, 0, 0) + +kitti_date_lists = ['2011_09_26', '2011_09_28', + '2011_09_29', '2011_09_30', '2011_10_03'] + +kitti_date_num_dicts = { + # calibration date_num 0119(not nse) + '2011_09_26': ['0001', '0002', '0005', '0009', '0011', '0013', '0014', + '0015', '0017', '0018', '0019', '0020', '0022', '0023', + '0027', '0028', '0029', '0032', '0035', '0036', '0039', + '0046', '0048', '0051', '0052', '0056', '0057', '0059', + '0060', '0061', '0064', '0070', '0079', '0084', '0086', + '0087', '0091', '0093', '0095', '0096', '0101', '0104', + '0106', '0113', '0117'], + # calibration date_num 0225(not nse) + '2011_09_28': ['0001', '0002', '0016', '0021', '0034', '0035', '0037', + '0038', '0039', '0043', '0045', '0047', '0053', '0054', + '0057', '0065', '0066', '0068', '0070', '0071', '0075', + '0077', '0078', '0080', '0082', '0086', '0087', '0089', + '0090', '0094', '0095', '0096', '0098', '0100', '0102', + '0103', '0104', '0106', '0108', '0110', '0113', '0117', + '0119', '0121', '0122', '0125', '0126', '0128', '0132', + '0134', '0135', '0136', '0138', '0141', '0143', '0145', + '0146', '0149', '0153', '0154', '0155', '0156', '0160', + '0161', '0162', '0165', '0166', '0167', '0168', '0171', + '0174', '0177', '0179', '0183', '0184', '0185', '0186', + '0187', '0191', '0192', '0195', '0198', '0199', '0201', + '0204', '0205', '0208', '0209', '0214', '0216', '0220', + '0222'], + # calibration date_num 0108(not nse) + '2011_09_29': ['0004', '0026', '0071'], + # calibration date_num 0072(not nse) + '2011_09_30': ['0016', '0018', '0020', '0027', '0028', '0033', '0034'], + # calibration date_num 0058(not nse) + '2011_10_03': ['0027', '0034', '0042', '0047'], +} diff --git a/chainercv/datasets/kitti/parseTrackletXML.py b/chainercv/datasets/kitti/parseTrackletXML.py new file mode 100644 index 0000000000..85fd4a5e1b --- /dev/null +++ b/chainercv/datasets/kitti/parseTrackletXML.py @@ -0,0 +1,340 @@ +"""Parse XML files containing tracklet info for kitti data base. + (http://cvlibs.net/datasets/kitti/raw_data.php) + + No guarantees that this code is correct, usage is at your own risk! + + created by Christian Herdtweck, + Max Planck Institute for Biological Cybernetics + (christian.herdtweck@tuebingen.mpg.de) + + requires numpy! + + example usage: + import parseTrackletXML as xmlParser + kittiDir = '/path/to/kitti/data' + drive = '2011_09_26_drive_0001' + xmlParser.example(kittiDir, drive) + or simply on command line: + python parseTrackletXML.py +""" + +# Version History: +# 4/7/12 Christian Herdtweck: +# seems to work with a few random test xml tracklet files; +# converts file contents to ElementTree and then to list of Tracklet objects; +# Tracklet objects have str and iter functions +# 5/7/12 ch: added constants for state, occlusion, +# truncation and added consistency checks +# 30/1/14 ch: create example function from example code + +import itertools +from warnings import warn +from xml.etree.ElementTree import ElementTree + +import numpy as np + +STATE_UNSET = 0 +STATE_INTERP = 1 +STATE_LABELED = 2 +stateFromText = {'0': STATE_UNSET, '1': STATE_INTERP, '2': STATE_LABELED} + +OCC_UNSET = 255 # -1 as uint8 +OCC_VISIBLE = 0 +OCC_PARTLY = 1 +OCC_FULLY = 2 +occFromText = {'-1': OCC_UNSET, '0': OCC_VISIBLE, + '1': OCC_PARTLY, '2': OCC_FULLY} + +TRUNC_UNSET = 255 # -1 as uint8, but in xml files the value '99' is used! +TRUNC_IN_IMAGE = 0 +TRUNC_TRUNCATED = 1 +TRUNC_OUT_IMAGE = 2 +TRUNC_BEHIND_IMAGE = 3 +truncFromText = {'99': TRUNC_UNSET, '0': TRUNC_IN_IMAGE, '1': TRUNC_TRUNCATED, + '2': TRUNC_OUT_IMAGE, '3': TRUNC_BEHIND_IMAGE} + + +class Tracklet(object): + r"""Representation an annotated object track. + + Tracklets are created in function parseXML + and can most conveniently used as follows: + + for trackletObj in parseXML(tracklet_filepath): + for translation, rotation, state, occlusion, \ + truncation, amt_occlusion, amt_borders, \ + absoluteFrameNumber in trackletObj: + ... your code here ... + #end: for all frames + #end: for all tracklets + + absoluteFrameNumber is in range [firstFrame, firstFrame+nFrames[ + amt_occlusion and amt_borders could be None + + You can of course also directly access the fields + objType (string), size (len-3 ndarray), firstFrame/nFrames (int), + trans/rots (nFrames x 3 float ndarrays), + states/truncs (len-nFrames uint8 ndarrays), + occs (nFrames x 2 uint8 ndarray), + and for some tracklets amt_occs (nFrames x 2 float ndarray) + and amt_borders (nFrames x 3 float ndarray). + The last two can be None if the xml file + did not include these fields in poses + """ + + objectType = None + size = None # len-3 float array: (height, width, length) + firstFrame = None + trans = None # n x 3 float array (x,y,z) + rots = None # n x 3 float array (x,y,z) + states = None # len-n uint8 array of states + occs = None # n x 2 uint8 array (occlusion, occlusion_kf) + truncs = None # len-n uint8 array of truncation + # None or (n x 2) float array (amt_occlusion, amt_occlusion_kf) + amt_occs = None + amt_borders = None # None (n x 3) float array (amt_border_l / _r / _kf) + nFrames = None + + def __init__(self): + """Create Tracklet with no info set.""" + self.size = np.nan*np.ones(3, dtype=float) + + def __str__(self): + """Return human-readable string representation of tracklet object. + + called implicitly in + # print(trackletObj) + or in + text = str(trackletObj) + """ + return '[Tracklet over {0} frames for {1}]'.format( + self.nFrames, self.objectType) + + def __iter__(self): + """Return an iterator object. + + that yields tuple of all the available data for each frame + + called whenever code iterates over a tracklet object, e.g. in + for translation, rotation, state, occlusion, truncation, + amt_occlusion, amt_borders, absoluteFrameNumber in trackletObj: + ...do something ... + or + trackDataIter = iter(trackletObj) + """ + if self.amt_occs is None: + # Python2/3 + return zip( + self.trans, self.rots, self.states, + self.occs, self.truncs, + itertools.repeat(None), itertools.repeat(None), + range(self.firstFrame, self.firstFrame+self.nFrames)) + # xrange(self.firstFrame, self.firstFrame+self.nFrames)) + # tmpAmtOccs = repeat(None) + # tmpAmtBorders = repeat(None) + # return zip(self.trans, self.rots, self.states, + # self.occs, self.truncs, + # tmp_amt_occs, tmp_amt_borders, + # range(self.firstFrame, self.firstFrame + self.nFrames)) + else: + # Python2/3 + return zip( + self.trans, self.rots, self.states, + self.occs, self.truncs, + self.amt_occs, self.amt_borders, + range(self.firstFrame, self.firstFrame + self.nFrames)) + # xrange(self.firstFrame, self.firstFrame+self.nFrames)) + # return zip(self.trans, self.rots, self.states, + # self.occs, self.truncs, + # self.amt_occs, self.amt_borders, + # range(self.firstFrame, self.firstFrame + self.nFrames)) +# end: class Tracklet + + +def parseXML(tracklet_filepath): + r"""Parse tracklet xml file and convert list of Tracklet objects. + + :param tracklet_filepath: name of a tracklet xml file + :returns: list of Tracklet objects read from xml file + """ + new_track_nframes_isnone_errorstr = \ + 'there are several pose lists for a single track!' + + # convert tracklet XML data to a tree structure + element_tree = ElementTree() + # print('parsing tracklet file', tracklet_filepath) + with open(tracklet_filepath) as f: + element_tree.parse(f) + + # now convert output to list of Tracklet objects + tracklets_element = element_tree.find('tracklets') + tracklets = [] + tracklet_idx = 0 + numeric_tracklets = None + for tracklet_element in tracklets_element: + # print('track:', tracklet_element.tag) + if tracklet_element.tag == 'count': + numeric_tracklets = int(tracklet_element.text) + # print('file contains', numeric_tracklets, 'tracklets') + elif tracklet_element.tag == 'item_version': + pass + elif tracklet_element.tag == 'item': + # print( + # 'tracklet {0} of {1}'.format(tracklet_idx, numeric_tracklets)) + # a tracklet + new_track = Tracklet() + is_finished = False + has_amt = False + frame_idx = None + for info in tracklet_element: + # print('trackInfo:', info.tag) + if is_finished: + raise ValueError('more info on element after finished!') + if info.tag == 'objectType': + new_track.objectType = info.text + elif info.tag == 'h': + new_track.size[0] = float(info.text) + elif info.tag == 'w': + new_track.size[1] = float(info.text) + elif info.tag == 'l': + new_track.size[2] = float(info.text) + elif info.tag == 'first_frame': + new_track.firstFrame = int(info.text) + elif info.tag == 'poses': + # this info is the possibly long list of poses + for pose in info: + # print('trackInfoPose:', pose.tag) + # this should come before the others + if pose.tag == 'count': + if new_track.nFrames is not None: + raise ValueError( + new_track_nframes_isnone_errorstr) + elif frame_idx is not None: + raise ValueError('?!') + new_track.nFrames = int(pose.text) + new_track.trans = np.nan * \ + np.ones((new_track.nFrames, 3), dtype=float) + new_track.rots = np.nan * \ + np.ones((new_track.nFrames, 3), dtype=float) + new_track.states = np.nan * \ + np.ones(new_track.nFrames, dtype='uint8') + new_track.occs = np.nan * \ + np.ones((new_track.nFrames, 2), dtype='uint8') + new_track.truncs = np.nan * \ + np.ones(new_track.nFrames, dtype='uint8') + new_track.amt_occs = np.nan * \ + np.ones((new_track.nFrames, 2), dtype=float) + new_track.amt_borders = np.nan * \ + np.ones((new_track.nFrames, 3), dtype=float) + frame_idx = 0 + elif pose.tag == 'item_version': + pass + elif pose.tag == 'item': + # pose in one frame + if frame_idx is None: + raise ValueError( + 'pose item came before number of poses!') + for poseInfo in pose: + # print('trackInfoPoseInfo:', poseInfo.tag) + if poseInfo.tag == 'tx': + new_track.trans[frame_idx, 0] = float( + poseInfo.text) + elif poseInfo.tag == 'ty': + new_track.trans[frame_idx, 1] = float( + poseInfo.text) + elif poseInfo.tag == 'tz': + new_track.trans[frame_idx, 2] = float( + poseInfo.text) + elif poseInfo.tag == 'rx': + new_track.rots[frame_idx, 0] = float( + poseInfo.text) + elif poseInfo.tag == 'ry': + new_track.rots[frame_idx, 1] = float( + poseInfo.text) + elif poseInfo.tag == 'rz': + new_track.rots[frame_idx, 2] = float( + poseInfo.text) + elif poseInfo.tag == 'state': + new_track.states[frame_idx] = \ + stateFromText[poseInfo.text] + elif poseInfo.tag == 'occlusion': + new_track.occs[frame_idx, 0] = \ + occFromText[poseInfo.text] + elif poseInfo.tag == 'occlusion_kf': + new_track.occs[frame_idx, 1] = \ + occFromText[poseInfo.text] + elif poseInfo.tag == 'truncation': + new_track.truncs[frame_idx] = \ + truncFromText[poseInfo.text] + elif poseInfo.tag == 'amt_occlusion': + new_track.amt_occs[frame_idx, 0] = \ + float(poseInfo.text) + has_amt = True + elif poseInfo.tag == 'amt_occlusion_kf': + new_track.amt_occs[frame_idx, 1] = \ + float(poseInfo.text) + has_amt = True + elif poseInfo.tag == 'amt_border_l': + new_track.amt_borders[frame_idx, 0] = \ + float(poseInfo.text) + has_amt = True + elif poseInfo.tag == 'amt_border_r': + new_track.amt_borders[frame_idx, 1] = \ + float(poseInfo.text) + has_amt = True + elif poseInfo.tag == 'amt_border_kf': + new_track.amt_borders[frame_idx, 2] = \ + float(poseInfo.text) + has_amt = True + else: + raise ValueError( + 'unexpected tag in poses item: {0}!' + .format(poseInfo.tag)) + frame_idx += 1 + else: + raise ValueError( + 'unexpected pose info: {0}!'.format(pose.tag)) + elif info.tag == 'finished': + is_finished = True + else: + raise ValueError( + 'unexpected tag in tracklets: {0}!'.format(info.tag)) + # end: for all fields in current tracklet + + # some final consistency checks on new tracklet + if not is_finished: + warn('tracklet {0} was not finished!'.format(tracklet_idx)) + if new_track.nFrames is None: + warn('tracklet {0} contains no information!'.format( + tracklet_idx)) + elif frame_idx != new_track.nFrames: + warn( + 'tracklet {0} is supposed to have {1} frames, \ + but perser found {1}!'.format(tracklet_idx, + new_track.nFrames, + frame_idx)) + if np.abs(new_track.rots[:, :2]).sum() > 1e-16: + warn('track contains rotation other than yaw!') + + # if amt_occs / amt_borders are not set, set them to None + if not has_amt: + new_track.amt_occs = None + new_track.amt_borders = None + + # add new tracklet to list + tracklets.append(new_track) + tracklet_idx += 1 + + else: + raise ValueError('unexpected tracklet info') + # end: for tracklet list items + + # print('loaded', tracklet_idx, 'tracklets') + + # final consistency check + if tracklet_idx != numeric_tracklets: + warn('according to xml information the file has {0} tracklets, \ + but parser found {1}!'.format(numeric_tracklets, tracklet_idx)) + + return tracklets +# end: function parseXML diff --git a/docs/source/reference/datasets.rst b/docs/source/reference/datasets.rst index 8c7e6e85ed..87f32c11d9 100644 --- a/docs/source/reference/datasets.rst +++ b/docs/source/reference/datasets.rst @@ -73,6 +73,15 @@ COCOInstanceSegmentationDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: COCOInstanceSegmentationDataset + +KITTI +----- + +KITTIDataset +~~~~~~~~~~~~ +.. autoclass:: KITTIDataset + + OnlineProducts -------------- diff --git a/tests/datasets_tests/kitti_tests/test_kitti_bbox_dataset.py b/tests/datasets_tests/kitti_tests/test_kitti_bbox_dataset.py new file mode 100644 index 0000000000..37bb7fe594 --- /dev/null +++ b/tests/datasets_tests/kitti_tests/test_kitti_bbox_dataset.py @@ -0,0 +1,153 @@ +import unittest + +from chainer import testing +from chainer.testing import attr + +from chainercv.datasets import kitti_bbox_label_names +from chainercv.datasets import KITTIBboxDataset +from chainercv.utils import assert_is_bbox_dataset + + +@testing.parameterize( + # category : City + { + 'date': '2011_09_26', + 'drive_num': '0001', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + { + 'date': '2011_09_26', + 'drive_num': '0001', + 'sync': False, + 'is_left': True, + 'tracklet': True + }, + # { + # 'date': '2011_09_26', + # 'drive_num': '0001', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, + { + 'date': '2011_09_26', + 'drive_num': '0001', + 'sync': True, + 'is_left': False, + 'tracklet': True + }, + # Test NG(not Tracklet data) + # { + # 'date': '2011_09_26', + # 'drive_num': '0001', + # 'sync': True, + # 'is_left': True, + # 'tracklet': False + # }, + # Test NG(Part of Framerate not Bbox/label data) + # { + # 'date': '2011_09_26', + # 'drive_num': '0009', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, + # Test NG(Part of Framerate not Bbox/label data) + # { + # 'date': '2011_09_26', + # 'drive_num': '0017', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, + { + 'date': '2011_09_26', + 'drive_num': '0056', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + { + 'date': '2011_09_26', + 'drive_num': '0057', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + # Test NG(not Tracklet data) + # { + # 'date': '2011_09_28', + # 'drive_num': '0001', + # 'sync': True, + # 'is_left': True, + # 'tracklet': False + # }, + # category : Residential + { + 'date': '2011_09_26', + 'drive_num': '0064', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + # category : Road + { + 'date': '2011_09_26', + 'drive_num': '0032', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + { + 'date': '2011_09_26', + 'drive_num': '0052', + 'sync': True, + 'is_left': True, + 'tracklet': True + }, + # Test NG(not Tracklet data) + # { + # 'date': '2011_10_03', + # 'drive_num': '0047', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, + # category : Campus + # Test NG(not Tracklet data) + # { + # 'date': '2011_09_28', + # 'drive_num': '0016', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, + # category : Person + # Test NG(not Tracklet data) + # { + # 'date': '2011_09_28', + # 'drive_num': '0053', + # 'sync': True, + # 'is_left': True, + # 'tracklet': True + # }, +) +class TestKITTIBboxDataset(unittest.TestCase): + + def setUp(self): + self.dataset = KITTIBboxDataset( + date=self.date, + drive_num=self.drive_num, + sync=self.sync, + is_left=self.is_left, + tracklet=self.tracklet) + + @attr.slow + def test_kitti_bbox_dataset(self): + assert_is_bbox_dataset( + self.dataset, len(kitti_bbox_label_names), n_example=10) + + +testing.run_module(__name__, __file__)