diff --git a/datasets/registry.py b/datasets/registry.py index 2dd8551d..16a2a6dc 100644 --- a/datasets/registry.py +++ b/datasets/registry.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import numpy as np - +from torch.utils.data import Subset from datasets import base, cifar10, mnist, imagenet from foundations.hparams import DatasetHparams from platforms.platform import get_platform + registered_datasets = {'cifar10': cifar10, 'mnist': mnist, 'imagenet': imagenet} @@ -31,6 +32,14 @@ def get(dataset_hparams: DatasetHparams, train: bool = True): if train and dataset_hparams.random_labels_fraction is not None: dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction) + if dataset_hparams.subset_start is not None or dataset_hparams.subset_stride != 1 or dataset_hparams.subset_end is not None: + if dataset_hparams.subsample_fraction is not None: + raise ValueError("Cannot have both subsample_fraction and subset_[start,end,stride]") + subset_start = 0 if dataset_hparams.subset_start is None else dataset_hparams.subset_start + subset_end = len(dataset) if dataset_hparams.subset_end is None else dataset_hparams.subset_end + subset_stride = 1 if dataset_hparams.subset_stride is None else dataset_hparams.subset_stride + dataset = Subset(dataset, np.arange(subset_start, subset_end, subset_stride)) + if train and dataset_hparams.subsample_fraction is not None: dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction) diff --git a/foundations/hparams.py b/foundations/hparams.py index e513c533..2d9d0d49 100644 --- a/foundations/hparams.py +++ b/foundations/hparams.py @@ -81,6 +81,25 @@ def create_from_args(cls, args: argparse.Namespace, prefix: str = None) -> 'Hpar return cls(**d) + @classmethod + def create_from_dict(cls, d: dict, prefix: str = None) -> 'Hparams': + for field in fields(cls): + if field.name.startswith('_'): continue + + key = f'{field.name}' if prefix is None else f'{prefix}_{field.name}' + + if key in d: + # Cast to the appropriate type + if field.type in [bool, float, int]: + d[key] = field.type(d[key]) + + # Nested hparams. + elif isinstance(field.type, type) and issubclass(field.type, Hparams): + d[key] = field.type.create_from_dict(d[key], prefix=key if prefix else '') + + return cls(**d) + + @property def display(self): nondefault_fields = [f for f in fields(self) @@ -109,6 +128,9 @@ class DatasetHparams(Hparams): do_not_augment: bool = False transformation_seed: int = None subsample_fraction: float = None + subset_start: int = None + subset_end: int = None + subset_stride: int = 1 random_labels_fraction: float = None unsupervised_labels: str = None blur_factor: int = None @@ -121,6 +143,9 @@ class DatasetHparams(Hparams): _transformation_seed: str = 'The random seed that controls dataset transformations like ' \ 'random labels, subsampling, and unsupervised labels.' _subsample_fraction: str = 'Subsample the training set, retaining the specified fraction: float in (0, 1]' + _subset_start: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)' + _subset_end: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)' + _subset_stride: str = 'Stride of subset indices (default 1)' _random_labels_fraction: str = 'Apply random labels to a fraction of the training set: float in (0, 1]' _unsupervised_labels: str = 'Replace the standard labels with alternative, unsupervised labels. Example: rotation' _blur_factor: str = 'Blur the training set by downsampling and then upsampling by this multiple.'