Skip to content
This repository was archived by the owner on May 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion datasets/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# LICENSE file in the root directory of this source tree.

import numpy as np

from torch.utils.data import Subset
from datasets import base, cifar10, mnist, imagenet
from foundations.hparams import DatasetHparams
from platforms.platform import get_platform


registered_datasets = {'cifar10': cifar10, 'mnist': mnist, 'imagenet': imagenet}


Expand All @@ -31,6 +32,14 @@ def get(dataset_hparams: DatasetHparams, train: bool = True):
if train and dataset_hparams.random_labels_fraction is not None:
dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction)

if dataset_hparams.subset_start is not None or dataset_hparams.subset_stride != 1 or dataset_hparams.subset_end is not None:
if dataset_hparams.subsample_fraction is not None:
raise ValueError("Cannot have both subsample_fraction and subset_[start,end,stride]")
subset_start = 0 if dataset_hparams.subset_start is None else dataset_hparams.subset_start
subset_end = len(dataset) if dataset_hparams.subset_end is None else dataset_hparams.subset_end
subset_stride = 1 if dataset_hparams.subset_stride is None else dataset_hparams.subset_stride
dataset = Subset(dataset, np.arange(subset_start, subset_end, subset_stride))

if train and dataset_hparams.subsample_fraction is not None:
dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction)

Expand Down
25 changes: 25 additions & 0 deletions foundations/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def create_from_args(cls, args: argparse.Namespace, prefix: str = None) -> 'Hpar

return cls(**d)

@classmethod
def create_from_dict(cls, d: dict, prefix: str = None) -> 'Hparams':
for field in fields(cls):
if field.name.startswith('_'): continue

key = f'{field.name}' if prefix is None else f'{prefix}_{field.name}'

if key in d:
# Cast to the appropriate type
if field.type in [bool, float, int]:
d[key] = field.type(d[key])

# Nested hparams.
elif isinstance(field.type, type) and issubclass(field.type, Hparams):
d[key] = field.type.create_from_dict(d[key], prefix=key if prefix else '')

return cls(**d)


@property
def display(self):
nondefault_fields = [f for f in fields(self)
Expand Down Expand Up @@ -109,6 +128,9 @@ class DatasetHparams(Hparams):
do_not_augment: bool = False
transformation_seed: int = None
subsample_fraction: float = None
subset_start: int = None
subset_end: int = None
subset_stride: int = 1
random_labels_fraction: float = None
unsupervised_labels: str = None
blur_factor: int = None
Expand All @@ -121,6 +143,9 @@ class DatasetHparams(Hparams):
_transformation_seed: str = 'The random seed that controls dataset transformations like ' \
'random labels, subsampling, and unsupervised labels.'
_subsample_fraction: str = 'Subsample the training set, retaining the specified fraction: float in (0, 1]'
_subset_start: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)'
_subset_end: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)'
_subset_stride: str = 'Stride of subset indices (default 1)'
_random_labels_fraction: str = 'Apply random labels to a fraction of the training set: float in (0, 1]'
_unsupervised_labels: str = 'Replace the standard labels with alternative, unsupervised labels. Example: rotation'
_blur_factor: str = 'Blur the training set by downsampling and then upsampling by this multiple.'
Expand Down