diff --git a/README.md b/README.md index 6e3c2f7..92682a2 100644 --- a/README.md +++ b/README.md @@ -16,22 +16,29 @@ This repository provides a PyTorch implementation of our method from the [paper] ``` ## Installation + Checkout the repo and set up conda environment: + ```bash conda env create -f environment.yaml ``` Activate the new environment: + ```bash conda activate spoco ``` ## Training + This implementation uses `DistributedDataParallel` training. In order to restrict the number of GPUs used for training use `CUDA_VISIBLE_DEVICES`, e.g. `CUDA_VISIBLE_DEVICES=0 python spoco_train.py ...` will execute training on `GPU:0`. -### CVPPP dataset -We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In order to train with 10% of randomly selected objects, run: +### CVPPP dataset + +We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In +order to train with 10% of randomly selected objects, run: + ```bash python spoco_train.py \ --spoco \ @@ -54,6 +61,7 @@ python spoco_train.py \ ``` `CVPPP_ROOT_DIR` is assumed to have the following subdirectories: + ``` - train: - A1: @@ -73,11 +81,16 @@ python spoco_train.py \ - ... ``` -Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split manually using the `training` subdir. + +Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split +manually using the `training` subdir. ### Cityscapes Dataset -Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from the [Cityscapes website](https://www.cityscapes-dataset.com/downloads) + +Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from +the [Cityscapes website](https://www.cityscapes-dataset.com/downloads) and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the following structure: + ``` - gtFine: - train @@ -91,20 +104,27 @@ and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the fo ``` Create random samplings of each class using the [cityscapesampler.py](spoco/datasets/cityscapesampler.py) script: + ```bash python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR --class_names person rider car truck bus train motorcycle bicycle ``` -this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated directories, + +this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated +directories, e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/car/0.4` will contain random 40% of objects of class `car`. -One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles) collectively by simply: +One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles) +collectively by simply: + ```bash python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR ``` + this will randomly sample 10%, 20%, ..., 90% of **all** objects and save the results in dedicated directories, e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/all/0.4` will contain random 40% of all objects. In order to train with 40% of randomly selected objects of class `car`, run: + ```bash python spoco_train.py \ --spoco \ @@ -129,10 +149,13 @@ python spoco_train.py \ --log-after-iters 500 --max-num-iterations 90000 ``` -In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the command above. +In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the +command above. ## Prediction + Give a model trained on the CVPPP dataset, run the prediction using the following command: + ```bash python spoco_predict.py \ --spoco \ @@ -143,13 +166,17 @@ python spoco_predict.py \ --model-feature-maps 16 32 64 128 256 512 \ --output-dir OUTPUT_DIR ``` + Results will be saved in the given `OUTPUT_DIR` directory. For each test input image `plantXXX_rgb.png` the following 3 output files will be saved in the `OUTPUT_DIR`: -* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f` embedding network), `/embeddings2` (output from the `g` momentum contrast network) + +* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f` + embedding network), `/embeddings2` (output from the `g` momentum contrast network) * `plantXXX_rgb_predictions_1.png` - output from the `f` embedding network PCA-projected into the RGB-space * `plantXXX_rgb_predictions_2.png` - output from the `g` momentum contrast network PCA-projected into the RGB-space -And similarly for the Cityscapes dataset +And similarly for the Cityscapes dataset + ```bash python spoco_predict.py \ --spoco \ @@ -163,8 +190,11 @@ python spoco_predict.py \ ``` ## Clustering + To produce the final segmentation one needs to cluster the embeddings with and algorithm of choice. Supported -algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with HDBSCAN, run: +algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with +HDBSCAN, run: + ```bash python cluster_predictions.py \ --ds-name cvppp \ @@ -175,7 +205,9 @@ python cluster_predictions.py \ Where `PREDICTION_DIR` is the directory where h5 files containing network predictions are stored. Resulting segmentation will be saved as a separate dataset (named `segmentation`) inside each of the H5 prediction files. -In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation scores on the validation set: +In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation +scores on the validation set: + ```bash python cluster_predictions.py \ --ds-name cityscapes \ @@ -185,5 +217,63 @@ python cluster_predictions.py \ --things-class car \ --clustering msplus --delta-var 0.5 --delta-dist 2.0 ``` + Where `SEM_PREDICTION_DIR` is the directory containing the semantic segmentation predictions for your validation images. We used pre-trained DeepLabv3 model from [here](https://github.com/VainF/DeepLabV3Plus-Pytorch). + +## Training and inference on MitoEM dataset + +Download the MitoEM-R dataset from https://mitoem.grand-challenge.org and split the h5 file containing 500 slices +into training and validation sets: training file should be named `train.h5` and have 400 slices and validation file +should be named `val.h5` and contain 100 slices. + +Then create the random 1%, 5%, 10% samplings of instances using the [mitoemsampler.py](spoco/datasets/mitoemsampler.py) +script: + +```bash +python spoco/datasets/mitoemsampler.py --dataset_dir MITOEM_ROOT_DIR --instance_ratios 0.01 0.05, 0.1 +``` + +this will create the following additional datasets inside the `MITOEM_ROOT_DIR/train.h5`: + +``` +- label_0.01 +- label_0.05 +- label_0.1 +``` + +### Training on MitoEM + +In order to train with 1% of randomly selected instances, run: + +```bash +python spoco_train.py \ + --spoco \ + --ds-name mitoem \ + --ds-path MITOEM_ROOT_DIR \ + --patch-shape 512 512 \ + --stride-shape 512 512 \ + --instance-ratio 0.01 \ + --batch-size 16 \ + --model-name UNet2D \ + --model-in-channels 1 \ + --model-feature-maps 16 32 64 128 256 512 \ + --learning-rate 0.0002 \ + --weight-decay 0.00001 \ + --cos \ + --loss-delta-var 0.5 \ + --loss-delta-dist 2.0 \ + --loss-unlabeled-push 1.0 \ + --loss-instance-weight 1.0 \ + --loss-consistency-weight 1.0 \ + --kernel-threshold 0.5 \ + --checkpoint-dir CHECKPOINT_DIR \ + --log-after-iters 500 \ + --max-num-iterations 100000 +``` + +### Prediction on MitoEM + +The prediction scripts converts the embeddings to affinities using the formula defined in the paper (see eq. 12 in +Appendix 4). +TODO diff --git a/environment.yaml b/environment.yaml index a5aa0b4..1f14784 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,7 +6,6 @@ channels: - conda-forge dependencies: - - python 3.8 - tqdm - pytorch - torchvision @@ -17,4 +16,3 @@ dependencies: - scikit-learn - pyyaml - hdbscan - - pytest diff --git a/spoco/datasets/mitoemsampler.py b/spoco/datasets/mitoemsampler.py new file mode 100644 index 0000000..ab33220 --- /dev/null +++ b/spoco/datasets/mitoemsampler.py @@ -0,0 +1,64 @@ +import argparse +from pathlib import Path + +import h5py +import numpy as np + + +def mitoem_sample_instances(label, instance_ratio, random_state): + """ + Sample a fraction of ground truth objects from the label dataset. + + Args: + label: np.array, label dataset + instance_ratio: np.array, fraction of ground truth objects to sample + random_state: instance of np.random.RandomState + + Returns: + np.array, sampled label dataset + """ + label_img = np.copy(label) + unique_ids = np.unique(label)[1:] + random_state.shuffle(unique_ids) + # pick instance_ratio objects + num_objects = round(instance_ratio * len(unique_ids)) + assert num_objects > 0, 'No objects to sample' + print(f'Sampled {num_objects} out of {len(unique_ids)} objects. Instance ratio: {instance_ratio}') + # create a set of object ids left for training + sampled_ids = set(unique_ids[:num_objects]) + for id in unique_ids: + if id not in sampled_ids: + label_img[label_img == id] = 0 + return label_img + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', type=str, help='MitoEM dir containing train.h5 and val.h5 files', + required=True) + parser.add_argument('--instance_ratios', nargs="+", type=float, + help='fraction of ground truth objects to sample.', required=True) + + args = parser.parse_args() + + # load label dataset from the train.h5 file + train_file = Path(args.dataset_dir) / 'train.h5' + assert train_file.exists(), f'{train_file} does not exist' + + with h5py.File(train_file, 'r+') as f: + label = f['label'][:] + + for instance_ratio in args.instance_ratios: + assert 0.0 <= instance_ratio <= 1.0, 'Instance ratio must be in [0, 1]' + + ir = float(instance_ratio) + rs = np.random.RandomState(47) + print(f'Sampling {ir * 100}% of mitoEM instances') + + label_sampled = mitoem_sample_instances(label, ir, rs) + + dataset_name = f'label_{instance_ratio}' + if dataset_name in f: + del f[dataset_name] + # save the sampled label dataset + f.create_dataset(dataset_name, data=label_sampled, compression='gzip') diff --git a/spoco/datasets/utils.py b/spoco/datasets/utils.py index b0adb6c..5013598 100644 --- a/spoco/datasets/utils.py +++ b/spoco/datasets/utils.py @@ -1,10 +1,12 @@ import collections +from pathlib import Path import torch from torch.utils.data import DataLoader from spoco.datasets.cityscapes import CityscapesDataset from spoco.datasets.cvppp import CVPPP2017Dataset +from spoco.datasets.volumetric import VolumetricH5Dataset def create_train_val_loaders(args): @@ -25,6 +27,21 @@ def create_train_val_loaders(args): train_dataset = CityscapesDataset(args.ds_path, phase='train', class_name=args.things_class, spoco=args.spoco, instance_ratio=args.instance_ratio) val_dataset = CityscapesDataset(args.ds_path, phase='val', class_name=args.things_class, spoco=args.spoco) + elif args.ds_name == 'mitoem': + ds_path = Path(args.ds_path) + train_file = ds_path / 'train.h5' + val_file = ds_path / 'val.h5' + assert train_file.exists(), f'Training file {train_file} does not exist' + assert val_file.exists(), f'Validation file {val_file} does not exist' + assert len(args.patch_shape) == 3, 'Patch shape must be a 3D tuple' + assert len(args.stride_shape) == 3, 'Stride shape must be a 3D tuple' + assert args.patch_shape[0] == 1, 'Patch shape must have a depth of 1: only 2D patches are supported' + assert args.stride_shape[0] == 1, 'Stride shape must have a depth of 1: only 2D patches are supported' + train_dataset = VolumetricH5Dataset(train_file, phase='train', patch_shape=args.patch_shape, + stride_shape=args.stride_shape, spoco=args.spoco, + instance_ratio=args.instance_ratio) + val_dataset = VolumetricH5Dataset(val_file, phase='val', patch_shape=args.patch_shape, + stride_shape=args.stride_shape, spoco=args.spoco) else: raise RuntimeError(f'Unsupported dataset: {args.ds_name}') @@ -51,6 +68,16 @@ def create_test_loader(args): test_dataset = CVPPP2017Dataset(args.ds_path, phase='test', spoco=args.spoco) elif args.ds_name == 'cityscapes': test_dataset = CityscapesDataset(args.ds_path, phase='test', class_name=None, spoco=args.spoco) + elif args.ds_name == 'mitoem': + ds_path = Path(args.ds_path) + test_file = ds_path / 'val.h5' + assert test_file.exists(), f'Test file {test_file} does not exist' + assert len(args.patch_shape) == 3, 'Patch shape must be a 3D tuple' + assert len(args.stride_shape) == 3, 'Stride shape must be a 3D tuple' + assert args.patch_shape[0] == 1, 'Patch shape must have a depth of 1: only 2D patches are supported' + assert args.stride_shape[0] == 1, 'Stride shape must have a depth of 1: only 2D patches are supported' + test_dataset = VolumetricH5Dataset(test_file, phase='test', patch_shape=args.patch_shape, + stride_shape=args.stride_shape, spoco=args.spoco) else: raise RuntimeError(f'Unsupported dataset {args.ds_name}') diff --git a/spoco/datasets/volumetric.py b/spoco/datasets/volumetric.py new file mode 100644 index 0000000..9dc1340 --- /dev/null +++ b/spoco/datasets/volumetric.py @@ -0,0 +1,255 @@ +import random + +import h5py +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision.transforms import transforms + +from spoco.transforms import Relabel, GaussianBlurNp, Standardize, RandomFlip + +LABEL_TRANSFORM = transforms.Compose( + [ + RandomFlip(), + Relabel(run_cc=False), + transforms.ToTensor() + ] +) + +TEST_LABEL_TRANSFORM = transforms.Compose( + [ + Relabel(run_cc=False), + transforms.ToTensor() + ] +) + +EXTENDED_TRANSFORM = transforms.Compose( + [ + GaussianBlurNp(execution_probability=1.0), + ] +) + + +def _build_slices(patch_shape, stride_shape, raw_dataset, label_dataset): + slice_builder = FilterSliceBuilder(raw_dataset, label_dataset, patch_shape, stride_shape) + return slice_builder.raw_slices, slice_builder.label_slices + + +class VolumetricH5Dataset(Dataset): + """ + Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets + patch by patch with a given stride. + + Args: + file_path (str): path to H5 file containing raw data and label data + phase (str): 'train' for training, 'val' for validation, 'test' for testing + patch_shape (tuple[int, int, int]): shape of the patches to be extracted + stride_shape (tuple[int, int, int]): shape of the stride when extracting patches + raw_internal_path (str): H5 internal path to the raw dataset, default is 'raw' + label_internal_path (str): H5 internal path to the label dataset, default is 'label' + instance_ratio (float): ratio of instances to be used for self-training + spoco (bool): if True, the dataset is used for spoco training + """ + + def __init__(self, file_path, phase, patch_shape, stride_shape, + raw_internal_path='raw', label_internal_path='label', + instance_ratio=None, spoco=False): + assert phase in ['train', 'val', 'test'] + + self.phase = phase + self.file_path = file_path + self.raw_internal_path = raw_internal_path + self.label_internal_path = label_internal_path + self.instance_ratio = instance_ratio + self.spoco = spoco + + if phase == 'test': + with h5py.File(file_path, 'r') as f: + self.raw = f[raw_internal_path][:] + self.label = None + else: + with (h5py.File(file_path, 'r') as f): + self.raw = f[raw_internal_path][:] + if phase == 'train' and instance_ratio is not None: + label_internal_path = f'label_{instance_ratio}' + assert label_internal_path in f, (f"Label dataset {label_internal_path} " + f"with instance ratio {instance_ratio} not found") + self.label = f[label_internal_path][:] + # make sure the raw and label datasets have the same shape + assert self.raw.shape[-3:] == self.label.shape[-3:], "Raw and label datasets must have the same shape" + + self.raw_slices, self.label_slices = _build_slices(patch_shape, stride_shape, self.raw, self.label) + + with h5py.File(file_path, 'r') as f: + raw = f[raw_internal_path][:] + raw_mean, raw_std = raw.mean(), raw.std() + + self.train_raw_transform = transforms.Compose( + [ + RandomFlip(), + Standardize(mean=raw_mean, std=raw_std), + transforms.ToTensor() + ] + ) + + self.raw_transform = transforms.Compose( + [ + Standardize(mean=raw_mean, std=raw_std), + transforms.ToTensor() + ] + ) + + self.patch_count = len(self.raw_slices) + + def __getitem__(self, idx): + if idx >= len(self): + raise StopIteration + + raw_idx = self.raw_slices[idx] + + raw_img = self.raw[raw_idx][0] + if self.phase == 'train': + seed = np.random.randint(np.iinfo('int32').max) + random.seed(seed) + torch.manual_seed(seed) + raw_patch_transformed = self.train_raw_transform(raw_img) + random.seed(seed) + torch.manual_seed(seed) + label_idx = self.label_slices[idx] + label_img = self.label[label_idx][0] + label_patch_transformed = LABEL_TRANSFORM(label_img) + # remove channel dim + label_patch_transformed = label_patch_transformed[0] + if self.spoco: + raw_patch_transformed1 = EXTENDED_TRANSFORM(raw_patch_transformed) + return raw_patch_transformed, raw_patch_transformed1, label_patch_transformed + else: + return raw_patch_transformed, label_patch_transformed + elif self.phase == 'val': + raw_patch_transformed = self.raw_transform(raw_img) + label_idx = self.label_slices[idx] + label_img = self.label[label_idx][0] + label_patch_transformed = TEST_LABEL_TRANSFORM(label_img) + # remove channel dim + label_patch_transformed = label_patch_transformed[0] + if self.spoco: + return raw_patch_transformed, raw_patch_transformed, label_patch_transformed + return raw_patch_transformed, label_patch_transformed + else: + raw_patch = self.raw_transform(raw_img) + if self.spoco: + return raw_patch, raw_patch, raw_idx + return raw_patch, raw_idx + + def __len__(self): + return self.patch_count + + +class SliceBuilder: + """ + Builds the position of the patches in a given raw/label ndarray based on the patch and stride shape. + + Args: + raw_dataset (ndarray): raw data + label_dataset (ndarray): ground truth labels + patch_shape (tuple): the shape of the patch DxHxW + stride_shape (tuple): the shape of the stride DxHxW + """ + + def __init__(self, raw_dataset, label_dataset, patch_shape, stride_shape): + patch_shape = tuple(patch_shape) + stride_shape = tuple(stride_shape) + + self._raw_slices = self._build_slices(raw_dataset, patch_shape, stride_shape) + if label_dataset is None: + self._label_slices = None + else: + self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape) + assert len(self._raw_slices) == len(self._label_slices), 'Raw and label slices must have the same length' + + @property + def raw_slices(self): + return self._raw_slices + + @property + def label_slices(self): + return self._label_slices + + @staticmethod + def _build_slices(dataset, patch_shape, stride_shape): + """Iterates over a given n-dim dataset patch-by-patch with a given stride + and builds an array of slice positions. + + Returns: + list of slices, i.e. + [(slice, slice, slice, slice), ...] if len(shape) == 4 + [(slice, slice, slice), ...] if len(shape) == 3 + """ + slices = [] + if dataset.ndim == 4: + in_channels, i_z, i_y, i_x = dataset.shape + else: + i_z, i_y, i_x = dataset.shape + + k_z, k_y, k_x = patch_shape + s_z, s_y, s_x = stride_shape + z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z) + for z in z_steps: + y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y) + for y in y_steps: + x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x) + for x in x_steps: + slice_idx = ( + slice(z, z + k_z), + slice(y, y + k_y), + slice(x, x + k_x), + ) + if dataset.ndim == 4: + slice_idx = (slice(0, in_channels),) + slice_idx + slices.append(slice_idx) + return slices + + @staticmethod + def _gen_indices(i, k, s): + assert i >= k, 'Sample size has to be bigger than the patch size' + for j in range(0, i - k + 1, s): + yield j + if j + k < i: + yield i - k + + @staticmethod + def _check_patch_shape(patch_shape): + assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple' + assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64' + + +class FilterSliceBuilder(SliceBuilder): + """ + Filter patches containing more than `1 - threshold` of ignore_index label + """ + + def __init__(self, raw_dataset, label_dataset, patch_shape, stride_shape, ignore_index=None, + threshold=0.05, slack_acceptance=0.01): + super().__init__(raw_dataset, label_dataset, patch_shape, stride_shape) + if label_dataset is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + label_idx = raw_label_idx[1] + patch = label_dataset[label_idx] + if ignore_index is not None: + patch = np.copy(patch) + patch[patch == ignore_index] = 0 + non_ignore_counts = np.count_nonzero(patch != 0) + non_ignore_counts = non_ignore_counts / patch.size + return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) diff --git a/spoco/transforms.py b/spoco/transforms.py index 21c327a..84c4967 100644 --- a/spoco/transforms.py +++ b/spoco/transforms.py @@ -4,8 +4,10 @@ import torch import torchvision.transforms.functional as F from PIL import ImageFilter +from PIL import Image from scipy.ndimage import rotate, map_coordinates, gaussian_filter from skimage import measure +from skimage.filters import gaussian class RandomFlip: @@ -16,9 +18,7 @@ class RandomFlip: otherwise the models won't converge. """ - def __init__(self, random_state, axis_prob=0.5, channelwise=False, **kwargs): - assert random_state is not None, 'RandomState cannot be None' - self.random_state = random_state + def __init__(self, axis_prob=0.5, channelwise=False, **kwargs): self.axis_prob = axis_prob self.channelwise = channelwise @@ -30,7 +30,7 @@ def __call__(self, m): axes = range(m.ndim) for axis in axes: - if self.random_state.uniform() > self.axis_prob: + if random.random() > self.axis_prob: if self.channelwise: channels = [np.flip(m[c], axis) for c in range(m.shape[0])] m = np.stack(channels, axis=0) @@ -264,7 +264,7 @@ class Standardize: Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. """ - def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): + def __init__(self, mean=None, std=None, eps=1e-10, channelwise=False, **kwargs): if mean is not None or std is not None: assert mean is not None and std is not None self.mean = mean @@ -287,7 +287,8 @@ def __call__(self, m): mean = np.mean(m) std = np.std(m) - return (m - mean) / np.clip(std, a_min=self.eps, a_max=None) + result = (m - mean) / np.clip(std, a_min=self.eps, a_max=None) + return result.astype(np.float32) class PercentileNormalizer: @@ -457,3 +458,18 @@ def __call__(self, x): sigma = random.uniform(self.sigma[0], self.sigma[1]) x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) return x + + +class GaussianBlurNp: + """Applies multi-dimensional gaussian filter to the input numpy array.""" + + def __init__(self, sigma=[0.5, 2.0], execution_probability=1.0, **kwargs): + self.sigma = sigma + self.execution_probability = execution_probability + + def __call__(self, x): + if random.random() < self.execution_probability: + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = gaussian(x, sigma=sigma) + return x + return x diff --git a/spoco_train.py b/spoco_train.py index 4579e06..7904064 100644 --- a/spoco_train.py +++ b/spoco_train.py @@ -17,12 +17,15 @@ parser.add_argument('--ds-name', type=str, default='cvppp', choices=SUPPORTED_DATASETS, help=f'Name of the dataset from: {SUPPORTED_DATASETS}') parser.add_argument('--ds-path', type=str, required=True, help='Path to the dataset root directory') -parser.add_argument('--things-class', type=str, help='Cityscapes instance class. If None, train with all things classes', +parser.add_argument('--things-class', type=str, + help='Cityscapes instance class. If None, train with all things classes', default=None) parser.add_argument('--instance-ratio', type=float, default=None, help='ratio of ground truth instances that should be taken for training') parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--num-workers', type=int, default=4) +parser.add_argument('--patch-shape', type=int, nargs="+", help="Patch shape for training", default=[1, 512, 512]) +parser.add_argument('--stride-shape', type=int, nargs="+", help="Stride shape for training", default=[1, 512, 512]) # model config parser.add_argument('--model-name', type=str, default="UNet2D", help="UNet2D or UNet3D") @@ -58,8 +61,10 @@ parser.add_argument('--cos', action='store_true', default=False, help="Use cosine learning rate scheduler") # trainer config +parser.add_argument('--debug', action='store_true', help='Use single GPU instead of DDP', default=False) parser.add_argument('--spoco', action='store_true', default=False, help="Indicate SPOCO training with consistency loss") -parser.add_argument('--save-all-checkpoints', action='store_true', default=False, help="Save checkpoint after every epoch") +parser.add_argument('--save-all-checkpoints', action='store_true', default=False, + help="Save checkpoint after every epoch") parser.add_argument('--checkpoint-dir', type=str, required=True, help="Model and tensorboard logs directory") parser.add_argument('--log-after-iters', type=int, required=True, help="Number of iterations between tensorboard logging") @@ -123,8 +128,12 @@ def main(): torch.backends.cudnn.deterministic = True print('Using CuDNN deterministic setting. This may slow down the training!') - nprocs = torch.cuda.device_count() - mp.spawn(train, args=(args,), nprocs=nprocs) + if args.debug: + # debug on a single GPU + train(0, args) + else: + nprocs = torch.cuda.device_count() + mp.spawn(train, args=(args,), nprocs=nprocs) if __name__ == '__main__':