diff --git a/examples/gold_ball_ptycho_spawn.py b/examples/gold_ball_ptycho_spawn.py new file mode 100644 index 00000000..ae13c9e9 --- /dev/null +++ b/examples/gold_ball_ptycho_spawn.py @@ -0,0 +1,91 @@ +import cdtools +from matplotlib import pyplot as plt +import torch as t + +# If you're noticing that the multi-GPU job is hanging (especially with 100% +# GPU use across all participating devices), you might want to try disabling +# the environment variable NCCL_P2P_DISABLE. +import os +os.environ['NCCL_P2P_DISABLE'] = str(int(True)) + + +# The entire reconstruction script needs to be wrapped in a function +def reconstruct(rank, world_size): + + # In the multigpu setup, we need to explicitly define the rank and + # world_size. The master address and master port should also be + # defined if we don't specify the `init_method` parameter. + cdtools.tools.multigpu.setup(rank=rank, + world_size=world_size, + master_addr='localhost', + master_port='6666') + + filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi' + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) + + pad = 10 + dataset.pad(pad) + dataset.inspect() + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # We need to manually define the rank parameter for the model, or else + # all plots will be duplicated by the number of GPUs used. + model.rank = rank + + device = 'cuda' + model.to(device=device) + dataset.get_as(device=device) + + # Rank and world_size also needs to be explicitly defined here + recon = cdtools.reconstructors.AdamReconstructor(model, + dataset, + rank=rank, + world_size=world_size) + + with model.save_on_exception( + 'example_reconstructions/gold_balls_earlyexit.h5', dataset): + + for loss in recon.optimize(20, lr=0.005, batch_size=50): + if rank == 0: + print(model.report()) + if model.epoch % 10 == 0: + model.inspect(dataset) + + for loss in recon.optimize(50, lr=0.002, batch_size=100, + schedule=True): + if rank == 0: + print(model.report()) + + if model.epoch % 10 == 0: + model.inspect(dataset) + + for loss in recon.optimize(100, lr=0.001, batch_size=100, + schedule=True): + if rank == 0: + print(model.report()) + if model.epoch % 10 == 0: + model.inspect(dataset) + + cdtools.tools.multigpu.cleanup() + + model.tidy_probes() + model.save_to_h5('example_reconstructions/gold_balls.h5', dataset) + model.inspect(dataset) + model.compare(dataset) + plt.show() + + +if __name__ == '__main__': + # Specify the number of GPUs we want to use, then spawn the multi-GPU job + ngpus = 2 + t.multiprocessing.spawn(reconstruct, args=(ngpus,), nprocs=ngpus) diff --git a/examples/gold_ball_ptycho_speedtest.py b/examples/gold_ball_ptycho_speedtest.py new file mode 100644 index 00000000..7f7a4248 --- /dev/null +++ b/examples/gold_ball_ptycho_speedtest.py @@ -0,0 +1,95 @@ +import cdtools +import torch as t + +# If you're noticing that the multi-GPU job is hanging (especially with 100% +# GPU use across all participating devices), you might want to try disabling +# the environment variable NCCL_P2P_DISABLE. +import os +os.environ['NCCL_P2P_DISABLE'] = str(int(True)) + + +# For running a speed test, we need to add an additional `conn` parameter +# that the speed test uses to send loss-versus-time curves to the +# speed test function. +def reconstruct(rank, world_size, conn): + + # We define the setup in the same manner as we did in the spawn example + # (speed test uses spawn) + cdtools.tools.multigpu.setup(rank=rank, + world_size=world_size, + master_addr='localhost', + master_port='6666') + + filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi' + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) + + pad = 10 + dataset.pad(pad) + dataset.inspect() + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # Unless you're plotting data within the reconstruct function, setting + # the rank parameter is not necessary. + # model.rank = rank + + device = 'cuda' + model.to(device=device) + dataset.get_as(device=device) + + # Rank and world_size also needs to be explicitly defined here + recon = cdtools.reconstructors.AdamReconstructor(model, + dataset, + rank=rank, + world_size=world_size) + + with model.save_on_exception( + 'example_reconstructions/gold_balls_earlyexit.h5', dataset): + + # It is recommended to comment out/remove all plotting-related methods + # for the speed test. + for loss in recon.optimize(20, lr=0.005, batch_size=50): + if rank == 0: + print(model.report()) + + for loss in recon.optimize(50, lr=0.002, batch_size=100): + if rank == 0: + print(model.report()) + + for loss in recon.optimize(100, lr=0.001, batch_size=100, + schedule=True): + if rank == 0: + print(model.report()) + + # We now use the conn parameter to send the loss-versus-time data to the + # main process running the speed test. + conn.send((model.loss_times, model.loss_history)) + + # And, as always, we need to clean up at the end. + cdtools.tools.multigpu.cleanup() + + +if __name__ == '__main__': + # We call the run_speed_test function instead of calling + # t.multiprocessing.spawn. We specify the number of runs + # we want to perform per GPU count along with how many + # GPU counts we want to test. + # + # Here, we test both 1 and 2 GPUs using 3 runs for both. + # The show_plots will report the mean +- standard deviation + # of loss-versus-time/epoch curves across the 3 runs. + # The plot will also show the mean +- standard deviation of + # the GPU-dependent speed ups across the 3 runs. + cdtools.tools.multigpu.run_speed_test(reconstruct, + gpu_counts=(1, 2), + runs=3, + show_plot=True) diff --git a/examples/gold_ball_ptycho_torchrun.py b/examples/gold_ball_ptycho_torchrun.py new file mode 100644 index 00000000..834cb845 --- /dev/null +++ b/examples/gold_ball_ptycho_torchrun.py @@ -0,0 +1,70 @@ +import cdtools +from matplotlib import pyplot as plt +import torch as t + +# At the beginning of the script we need to setup the multi-GPU job +# by initializing the process group and sycnronizing the RNG seed +# across all participating GPUs. +cdtools.tools.multigpu.setup() + +# To avoid redundant print statements, we first grab the GPU "rank" +# (an ID number between 0 and max number of GPUs minus 1). +rank = cdtools.tools.multigpu.get_rank() + +filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi' +dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) + +pad = 10 +dataset.pad(pad) +dataset.inspect() +model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad +) +model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets) +model.weights.requires_grad = False +device = 'cuda' +model.to(device=device) +dataset.get_as(device=device) + +recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + +with model.save_on_exception( + 'example_reconstructions/gold_balls_earlyexit.h5', dataset): + + for loss in recon.optimize(20, lr=0.005, batch_size=50): + # We ensure that only the GPU with rank of 0 runs print statement. + if rank == 0: + print(model.report()) + + # But we don't need to do rank checking for any plotting- or saving- + # related methods; this checking is handled internernally. + if model.epoch % 10 == 0: + model.inspect(dataset) + + for loss in recon.optimize(50, lr=0.002, batch_size=100): + if rank == 0: + print(model.report()) + + if model.epoch % 10 == 0: + model.inspect(dataset) + + for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): + if rank == 0: + print(model.report()) + if model.epoch % 10 == 0: + model.inspect(dataset) + +# After the reconstruction is completed, we need to cleanup things by +# destroying the process group. +cdtools.tools.multigpu.cleanup() + +model.tidy_probes() +model.save_to_h5('example_reconstructions/gold_balls.h5', dataset) +model.inspect(dataset) +model.compare(dataset) +plt.show() diff --git a/src/cdtools/datasets/base.py b/src/cdtools/datasets/base.py index 3f8ec8c3..cbe7386d 100644 --- a/src/cdtools/datasets/base.py +++ b/src/cdtools/datasets/base.py @@ -17,7 +17,7 @@ from copy import copy import h5py import pathlib -from cdtools.tools import data as cdtdata +from cdtools.tools import data as cdtdata, multigpu from torch.utils import data as torchdata __all__ = ['CDataset'] @@ -92,6 +92,11 @@ def __init__( self.get_as(device='cpu') + # This is a flag related to multi-GPU operation which prevents + # saving/plotting functions from being executed on GPUs outside of + # rank 0 + self.rank = multigpu.get_rank() + def to(self, *args, **kwargs): """Sends the relevant data to the given device and dtype diff --git a/src/cdtools/datasets/ptycho_2d_dataset.py b/src/cdtools/datasets/ptycho_2d_dataset.py index adfe866e..887b8e03 100644 --- a/src/cdtools/datasets/ptycho_2d_dataset.py +++ b/src/cdtools/datasets/ptycho_2d_dataset.py @@ -154,6 +154,8 @@ def from_cxi(cls, cxi_file, cut_zeros=True, load_patterns=True): # Generate a base dataset dataset = CDataset.from_cxi(cxi_file) + + # Mutate the class to this subclass (BasicPtychoDataset) dataset.__class__ = cls @@ -198,7 +200,11 @@ def to_cxi(self, cxi_file): cxi_file : str, pathlib.Path, or h5py.File The .cxi file to write to """ - + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return + # If a bare string is passed if isinstance(cxi_file, str) or isinstance(cxi_file, pathlib.Path): with cdtdata.create_cxi(cxi_file) as f: @@ -231,7 +237,10 @@ def inspect( can display a base-10 log plot of the detector readout at each position. """ - + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return def get_images(idx): inputs, output = self[idx] @@ -292,6 +301,11 @@ def plot_mean_pattern(self, log_offset=1): level. """ + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return + mean_pattern, bins, ssnr = analysis.calc_spectral_info(self) cmap_label = f'Log Base 10 of Intensity + {log_offset}' title = 'Scaled mean diffraction pattern' diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index df347b63..e237bbef 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -40,6 +40,7 @@ from scipy import io from contextlib import contextmanager from cdtools.tools.data import nested_dict_to_h5, h5_to_nested_dict, nested_dict_to_numpy, nested_dict_to_torch +from cdtools.tools import multigpu from cdtools.reconstructors import AdamReconstructor, LBFGSReconstructor, SGDReconstructor from cdtools.datasets import CDataset from typing import List, Union, Tuple @@ -65,6 +66,16 @@ def __init__(self): self.training_history = '' self.epoch = 0 + # This is a flag related to multi-GPU operation which prevents + # saving/plotting functions from being executed on GPUs outside of + # rank 0. + self.rank = multigpu.get_rank() + + # Keep track of the time each loss history point was taken relative to + # the initialization of this model. + self.INITIAL_TIME = time.time() + self.loss_times = [] + def from_dataset(self, dataset): raise NotImplementedError() @@ -197,7 +208,9 @@ def save_to_h5(self, filename, *args): *args Accepts any additional args that model.save_results needs, for this model """ - return nested_dict_to_h5(filename, self.save_results(*args)) + # FOR MULTI-GPU: Only run this method if it's called by the rank 0 GPU + if self.rank == 0: + return nested_dict_to_h5(filename, self.save_results(*args)) @contextmanager @@ -219,11 +232,16 @@ def save_on_exit(self, filename, *args, exception_filename=None): """ try: yield - self.save_to_h5(filename, *args) + + # FOR MULTI-GPU: Only run this method if it's called by the rank 0 GPU + if self.rank == 0: + self.save_to_h5(filename, *args) except: - if exception_filename is None: - exception_filename = filename - self.save_to_h5(exception_filename, *args) + # FOR MULTI-GPU: Only run this method if it's called by the rank 0 GPU + if self.rank == 0: + if exception_filename is None: + exception_filename = filename + self.save_to_h5(exception_filename, *args) raise @contextmanager @@ -245,9 +263,11 @@ def save_on_exception(self, filename, *args): try: yield except: - self.save_to_h5(filename, *args) - print('Intermediate results saved under name:') - print(filename, flush=True) + # FOR MULTI-GPU: Only run this method if it's called by the rank 0 GPU + if self.rank == 0: + self.save_to_h5(filename, *args) + print('Intermediate results saved under name:') + print(filename, flush=True) raise @@ -270,6 +290,11 @@ def skip_computation(self): return False def save_checkpoint(self, *args, checkpoint_file=None): + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return + checkpoint = self.save_results(*args) if (hasattr(self, 'current_optimizer') and self.current_optimizer is not None): @@ -332,7 +357,9 @@ def Adam_optimize( subset: Union[int, List[int]] = None, regularization_factor: Union[float, List[float]] = None, thread=True, - calculation_width=10 + calculation_width=10, + rank=None, + world_size=None ): """ Runs a round of reconstruction using the Adam optimizer from @@ -373,14 +400,29 @@ def Adam_optimize( Default 10, how many translations to pass through at once for each round of gradient accumulation. Does not affect the result, only the calculation speed. - + rank : int + Optional, GPU rank assigned during multi-GPU operations. If this + parameter is None, it will be redefined based on the `RANK` + environment variable. If this environment variable doesn't exist, + single-GPU operation will be assumed and a rank of 0 will + automatically be assigned. + world_size : int + Optional, the number of participating GPUs during multi-GPU + operations. If this parameter is None, it will be redefined based on + the `WORLD_SIZE` environment variable. If this environment variable + doesn't exist,single-GPU operation will be assumed and a world_size of + 1 will automatically be assigned. """ + self.rank = rank if rank is not None else multigpu.get_rank() + reconstructor = AdamReconstructor( model=self, dataset=dataset, subset=subset, + rank=self.rank, + world_size=world_size ) - + # Run some reconstructions return reconstructor.optimize( iterations=iterations, @@ -578,6 +620,11 @@ def inspect(self, dataset=None, update=True): Whether to update existing plots or plot new ones """ + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return + # We find or create all the figures first_update = False if update and hasattr(self, 'figs') and self.figs: @@ -660,7 +707,11 @@ def save_figures(self, prefix='', extension='.pdf'): extention : strategy Default is .eps, the file extension to save with. """ - + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return + if hasattr(self, 'figs') and self.figs: figs = self.figs else: @@ -687,6 +738,10 @@ def compare(self, dataset, logarithmic=False): logarithmic : bool, default: False Whether to plot the diffraction on a logarithmic scale """ + # FOR MULTI-GPU: Dont run this block of code if it isn't + # called by the rank 0 GPU + if self.rank != 0: + return fig, axes = plt.subplots(1,3,figsize=(12,5.3)) fig.tight_layout(rect=[0.02, 0.09, 0.98, 0.96]) diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py index e298bdb1..780075b3 100644 --- a/src/cdtools/reconstructors/adam.py +++ b/src/cdtools/reconstructors/adam.py @@ -33,9 +33,18 @@ class AdamReconstructor(Reconstructor): The dataset to reconstruct against. subset : list(int) or int Optional, a pattern index or list of pattern indices to use. - schedule : bool - Optional, create a learning rate scheduler - (torch.optim.lr_scheduler._LRScheduler). + rank : int + Optional, GPU rank assigned during multi-GPU operations. If this + parameter is None, it will be redefined based on the `RANK` + environment variable. If this environment variable doesn't exist, + single-GPU operation will be assumed and a rank of 0 will + automatically be assigned. + world_size : int + Optional, the number of participating GPUs during multi-GPU + operations. If this parameter is None, it will be redefined based on + the `WORLD_SIZE` environment variable. If this environment variable + doesn't exist,single-GPU operation will be assumed and a world_size of + 1 will automatically be assigned. Important attributes: - **model** -- Always points to the core model used. @@ -49,12 +58,15 @@ class AdamReconstructor(Reconstructor): def __init__(self, model: CDIModel, dataset: Ptycho2DDataset, - subset: List[int] = None): + subset: List[int] = None, + rank: int = None, + world_size: int = None): # Define the optimizer for use in this subclass optimizer = t.optim.Adam(model.parameters()) - super().__init__(model, dataset, optimizer, subset=subset) + super().__init__(model, dataset, optimizer, + subset=subset, rank=rank, world_size=world_size) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 16668149..84a482ad 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -17,11 +17,13 @@ import queue import time from typing import List, Union +from cdtools.tools import multigpu +from torch.utils.data.distributed import DistributedSampler if TYPE_CHECKING: from cdtools.models import CDIModel from cdtools.datasets import CDataset - + __all__ = ['Reconstructor'] @@ -44,6 +46,18 @@ class Reconstructor: The optimizer to use for the reconstruction subset : list(int) or int Optional, a pattern index or list of pattern indices to use + rank : int + Optional, GPU rank assigned during multi-GPU operations. If this + parameter is None, it will be redefined based on the `RANK` + environment variable. If this environment variable doesn't exist, + single-GPU operation will be assumed and a rank of 0 will + automatically be assigned. + world_size : int + Optional, the number of participating GPUs during multi-GPU + operations. If this parameter is None, it will be redefined based on + the `WORLD_SIZE` environment variable. If this environment variable + doesn't exist,single-GPU operation will be assumed and a world_size of + 1 will automatically be assigned. Attributes ---------- @@ -60,8 +74,22 @@ def __init__(self, model: CDIModel, dataset: CDataset, optimizer: t.optim.Optimizer, - subset: Union[int, List[int]] = None): - + subset: Union[int, List[int]] = None, + rank: int = None, + world_size: int = None): + + # If we're running multi-GPU jobs, we need to grab some + # information that should be stored as environment variables. + self.rank = rank if rank is not None else multigpu.get_rank() + self.world_size = world_size if world_size is not None \ + else multigpu.get_world_size() + self.multi_gpu_used = int(self.world_size) > 1 + + # Make sure the model and dataset live on the assigned GPUs + if self.multi_gpu_used: + model.to(f'cuda:{self.rank}') + dataset.get_as(f'cuda:{self.rank}') + # Store parameters as attributes of Reconstructor self.model = model self.optimizer = optimizer @@ -79,7 +107,6 @@ def __init__(self, self.scheduler = None self.data_loader = None - def setup_dataloader(self, batch_size: int = None, shuffle: bool = True): @@ -94,14 +121,31 @@ def setup_dataloader(self, Optional, enable/disable shuffling of the dataset. This option is intended for diagnostic purposes and should be left as True. """ - if batch_size is not None: - self.data_loader = td.DataLoader(self.dataset, - batch_size=batch_size, - shuffle=shuffle) + if self.multi_gpu_used: + self.sampler = \ + DistributedSampler(self.dataset, + num_replicas=self.world_size, + rank=self.rank, + shuffle=shuffle, + drop_last=False) + + # Creating extra threads in children processes may cause problems. + # Leave num_workers at 0. + self.data_loader = \ + td.DataLoader(self.dataset, + batch_size=batch_size//self.world_size, + num_workers=0, + drop_last=False, + pin_memory=False, + sampler=self.sampler) else: - self.data_loader = td.Dataloader(self.dataset) + if batch_size is not None: + self.data_loader = td.DataLoader(self.dataset, + batch_size=batch_size, + shuffle=shuffle) + else: + self.data_loader = td.Dataloader(self.dataset) - def adjust_optimizer(self, **kwargs): """ Change hyperparameters for the utilized optimizer. @@ -111,11 +155,10 @@ def adjust_optimizer(self, **kwargs): """ raise NotImplementedError() - def run_epoch(self, - stop_event: threading.Event = None, - regularization_factor: Union[float, List[float]] = None, - calculation_width: int = 10): + stop_event: threading.Event = None, + regularization_factor: Union[float, List[float]] = None, + calculation_width: int = 10): """ Runs one full epoch of the reconstruction. Intended to be called by Reconstructor.optimize. @@ -150,8 +193,18 @@ def run_epoch(self, 'Reconstructor.run_epoch(), or use Reconstructor.optimize(), ' 'which does it automatically.' ) - - + + # If we're using DistributedSampler (i.e., multi-GPU useage), we need + # to tell it which epoch we're on. Otherwise data shuffling will not + # work properly + if self.multi_gpu_used: + self.data_loader.sampler.set_epoch(self.model.epoch) + + # This prevent other GPU rank processes from initializing + # rank 0's GPU (i.e., helps avoid unnecessary memory consumption) + if t.cuda.current_device != self.rank: + t.cuda.set_device(self.rank) + # Initialize some tracking variables normalization = 0 loss = 0 @@ -201,8 +254,14 @@ def closure(): # And accumulate the gradients loss.backward() + # Average and sync gradients + losses for multi-GPU jobs + if self.multi_gpu_used: + multigpu.sync_and_avg_grads(model=self.model, + world_size=self.world_size) + multigpu.sync_loss(loss) + # Normalize the accumulating total loss - total_loss += loss.detach() + total_loss += loss.detach() / self.world_size # If we have a regularizer, we can calculate it separately, # and the gradients will add to the minibatch gradient @@ -212,18 +271,26 @@ def closure(): loss = self.model.regularizer(regularization_factor) loss.backward() + # Avg and sync gradients for multi-GPU jobs + if self.multi_gpu_used: + multigpu.sync_and_avg_grads(model=self.model, + world_size=self.world_size) return total_loss # This takes the step for this minibatch loss += self.optimizer.step(closure).detach().cpu().numpy() loss /= normalization - # We step the scheduler after the full epoch if self.scheduler is not None: self.scheduler.step(loss) + if self.multi_gpu_used: + multigpu.sync_lr(rank=self.rank, + optimizer=self.optimizer) + self.model.loss_history.append(loss) + self.model.loss_times.append(time.time() - self.model.INITIAL_TIME) self.model.epoch = len(self.model.loss_history) self.model.latest_iteration_time = time.time() - t0 self.model.training_history += self.model.report() + '\n' diff --git a/src/cdtools/tools/__init__.py b/src/cdtools/tools/__init__.py index 6adb554b..b533a3ae 100644 --- a/src/cdtools/tools/__init__.py +++ b/src/cdtools/tools/__init__.py @@ -23,4 +23,4 @@ from cdtools.tools import propagators from cdtools.tools import measurements from cdtools.tools import analysis - +from cdtools.tools import multigpu diff --git a/src/cdtools/tools/multigpu/__init__.py b/src/cdtools/tools/multigpu/__init__.py new file mode 100644 index 00000000..8a78f615 --- /dev/null +++ b/src/cdtools/tools/multigpu/__init__.py @@ -0,0 +1,2 @@ +from cdtools.tools.multigpu.multigpu import * +from cdtools.tools.multigpu.multigpu import __all__, __doc__ diff --git a/src/cdtools/tools/multigpu/multigpu.py b/src/cdtools/tools/multigpu/multigpu.py new file mode 100644 index 00000000..4b7addfa --- /dev/null +++ b/src/cdtools/tools/multigpu/multigpu.py @@ -0,0 +1,432 @@ +"""Contains functions to make reconstruction scripts compatible +with multi-GPU distributive approaches in PyTorch. + +Multi-GPU computing here is based on distributed data parallelism, where +each GPU is given identical copies of a model and performs optimization +using different parts of the dataset. After the parameter gradients +are calculated (`loss.backwards()`) on each GPU, the gradients need to be +synchronized and averaged across all participating GPUs. + +The functions in this module assist with gradient synchronization, +setting up conditions necessary to perform distributive computing, and +executing multi-GPU jobs. +""" +from __future__ import annotations +from typing import TYPE_CHECKING, Tuple, Callable + +import torch as t +import torch.distributed as dist +import random +import datetime +import os +from matplotlib import pyplot as plt + +if TYPE_CHECKING: + from cdtools.models import CDIModel + +MIN_INT64 = t.iinfo(t.int64).min +MAX_INT64 = t.iinfo(t.int64).max + + +__all__ = ['get_launch_method', + 'get_rank', + 'get_world_size', + 'sync_and_avg_grads', + 'sync_rng_seed', + 'sync_loss', + 'sync_lr', + 'setup', + 'cleanup', + 'run_speed_test'] + + +def get_launch_method() -> str: + """ + Returns the method used to spawn the multi-GPU job. + + It is assumed that multi-GPU jobs will be created through + one of two means: `torchrun` or `torch.multiprocessing.spawn` + + Returns: + launch_method: str + The method used to launch multi-GPU jobs. This parameter + is either 'torchrun' or 'spawn'. + """ + return 'torchrun' if 'TORCHELASTIC_RUN_ID' in os.environ else 'spawn' + + +def get_rank() -> int: + """ + Returns the rank assigned to the current subprocess via the environment + variable `RANK`. If this environment variable does not exist, a rank of 0 + will be returned. + + This value should range from 0 to `world_size`-1 (`world_size` being the + total number of participating subprocesses/GPUs) + + Returns: + rank: int + Rank of the current subprocess. + """ + rank = os.environ.get('RANK') + return int(rank) if rank is not None else 0 + + +def get_world_size() -> int: + """ + Returns the world_size of the reconstruction job via the environment + variable `WORLD_SIZE`. If this environment variable does not exist, a + world_size of 1 will be returned. + + Returns: + world_size: int + The number of participating GPUs + """ + world_size = os.environ.get('WORLD_SIZE') + return int(world_size) if world_size is not None else 1 + + +def sync_and_avg_grads(model: CDIModel, + world_size: int): + """ + Synchronizes the average of the model parameter gradients across all + participating GPUs using all_reduce. + + Parameters: + model: CDIModel + Model for CDI/ptychography reconstruction + world_size: int + Number of participating GPUs + """ + for param in model.parameters(): + if param.requires_grad: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data /= world_size + + +def sync_rng_seed(seed: int = None, + rank: int = None,): + """ + Synchronizes the random number generator (RNG) seed used by all + participating GPUs. Specifically, all subprocesses will use + either Rank 0's RNG seed or the seed parameter value. + + Parameters: + seed: int + Optional. The random number generator seed. + rank: int + Optional, the rank of the current subprocess. If the multi-GPU + job is created with torch.multiprocessing.spawn, this parameter + should be explicitly defined. + """ + if rank is None: + rank = get_rank() + + if seed is None: + seed_local = t.tensor(random.randint(MIN_INT64, MAX_INT64), + device=f'cuda:{rank}', + dtype=t.int64) + dist.broadcast(seed_local, src=0) + seed = seed_local.item() + + t.manual_seed(seed) + + +def sync_lr(optimizer: t.optim, + rank: int = None): + """ + Synchronizes the learning rate of all participating GPUs to that of + Rank 0's GPU. + + Parameters: + optimizer: t.optim + Optimizer used for reconstructions. + rank: int + Optional, the rank of the current subprocess. If the multi-GPU + job is created with torch.multiprocessing.spawn, this parameter + should be explicitly defined. + """ + if rank is None: + rank = get_rank() + + for param_group in optimizer.param_groups: + lr_tensor = t.tensor(param_group['lr'], + device=f'cuda:{rank}') + dist.broadcast(lr_tensor, src=0) + param_group['lr'] = lr_tensor.item() + + +def sync_loss(loss): + """ + Synchronizes Rank 0 GPU's learning rate to all participating GPUs. + + Parameters: + optimizer: t.optim + Optimizer used for reconstructions. + """ + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + + +def setup(rank: int = None, + world_size: int = None, + init_method: str = 'env://', + master_addr: str = None, + master_port: str = None, + backend: str = 'nccl', + timeout: int = 30, + seed: int = None, + verbose: bool = False): + """ + Sets up the process group and envrionment variables needed for + communications between the participating subprocesses. Also synchronizes + the RNG seed used across all subprocesses. Currently, `torchrun` and + `torch.multiprocessing.spawn` are supported for starting up multi-GPU jobs. + + This function blocks until all processes have joined. + + The following parameters need to be explicitly defined if the multi-GPU + job is started up by `torch.multiprocessing.spawn`: `rank`, `world_size`, + `master_addr`, `master_port`. If the multi-GPU job is started using + `torchrun`, this function will use the environment variables `torchrun` + provides to define the parameters discussed above. + + For additional details on defining the parameters see + https://docs.pytorch.org/docs/stable/distributed.html for + additional information. + + Parameters: + rank: int + Optional, the rank of the current subprocess. + world_size: int + Optional, the number of participating GPUs. + master_addr: str + Optional, address of the rank 0 node. If + master_port: str + Optional, free port of on the machine hosting rank 0. + init_method: str + URL specifying how to initialize the process group. + Default is “env://”. + backend: str + Multi-gpu communication backend to use. Default is the 'nccl' + backend, which is the only supported backend for CDTools. + timeout: int + Timeout for operations executed against the process group in + seconds. Default is 30 seconds. + seed: int + Optional. The random number generator seed. + verbose: bool + Optional. Shows messages indicating status of the setup procedure + """ + # Make sure that the user explicitly defines parameters if spawn is used + if get_launch_method() == 'spawn': + if init_method == 'env://': + if None in (master_addr, master_port): + # We'll check if the master address/port is in the env variable + master_addr = os.environ.get('MASTER_ADDR') + master_port = os.environ.get('MASTER_PORT') + else: + # Set up the environment variables + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + + if None in (rank, world_size, master_addr, master_port): + raise RuntimeError( + 'torch.multiprocessing.spawn was detected as the launching \n' + 'method, but either rank, world_size, master_addr, or \n' + 'master_port has not been explicitly defined. Please ensure \n' + 'that either these parameters have been explicitly defined,\n' + 'MASTER_ADDR/MASTER_PORT have been defined as environment \n' + 'variables, or launch the multi-GPU job with torchrun.\n' + ) + + if rank is None: + rank = get_rank() + if world_size is None: + world_size = get_world_size() + + t.cuda.set_device(rank) + if rank == 0: + print('[INFO]: Initializing process group.') + + dist.init_process_group(rank=rank, + world_size=world_size, + backend=backend, + init_method=init_method, + timeout=datetime.timedelta(seconds=timeout)) + + if rank == 0: + print('[INFO]: Process group initialized.') + + sync_rng_seed(rank=rank, seed=seed) + + if rank == 0: + print('[INFO]: RNG seed synchronized across all subprocesses.') + + +def cleanup(): + """ + Destroys the process group. + """ + rank = get_rank() + dist.destroy_process_group() + if rank == 0: + print('[INFO]: Process group destroyed.') + + +def run_speed_test(fn: Callable, + gpu_counts: Tuple[int], + runs: int = 3, + show_plot: bool = True): + """ + Perform a speed test comparing the performance of multi-GPU reconstructions + against single-GPU reconstructions. Multi-GPU jobs are created using + `torch.multiprocessing.spawn`. + + The speed test is performed by calling a function-wrapped reconstruction + script `runs` times using the number of GPUs specified in `gpu_counts` + tuple. For each GPU count specified in `gpu_counts`, the average + `model.loss_times` and `model.loss_history` values are accumulated + across all the runs to calculate the mean and standard deviation of + the loss-versus-time/epoch and speed-up-versus-GPUs curves. + + The function-wrapped reconstruction should use the following syntax: + + ``` + def reconstruct(rank, world_size, conn): + cdtools.tools.distributed.setup(rank=rank, world_size=world_size) + + # Reconstruction script content goes here + + conn.send((model.loss_times, model.loss_history)) + cdtools.tools.distributed.cleanup() + ``` + The parameters in this example function are defined internally by + `run_speed_test`; the user does not need to define these within the script. + `world_size` is the number of participating GPUs used. `rank` is an ID + number given to a process that will run one of the `world_size` GPUs and + varies in value from [0, `world_size`-1]. `conn.send` is a + `multiprocessing.Connection.Pipe` that allows `run_speed_test` to + retrieve data from the function-wrapped reconstruction. + + Parameters: + fn: Callable + The reconstruction script wrapped in a function. It is recommended + to comment out all plotting and saving-related functions in the + script to properly assess the multi-GPU performance. + gpu_counts: Tuple[int] + Number of GPUs to use for each test run. The first element must be + 1 (performance is compared against that of a single GPU). + runs: int + Number of repeat reconstructions to perform for each `gpu_counts`. + show_plot: bool + Optional, shows a plot of time/epoch-versus-loss and relative + speedups that summarize the test results. + + Returns: + loss_mean_list and loss_std_list: List[t.Tensor] + The mean and standard deviation of the epoch/time-dependent losses. + Each element corresponds with the GPU count used for that test. + time_mean_list and time_std_list: List[t.Tensor] + The mean and standard deviation of the epoch-dependent times. + Each element corresponds with the GPU count used for that test. + speed_up_mean_list and speed_up_std_list: List[float] + The mean and standard deviation of the GPU-count-dependent + speed ups. Each element corresponds with the GPU count used for + that test. + + """ + # Make sure that the first element of gpu_counts is 1 + if gpu_counts[0] != 1: + raise RuntimeError('The first element of gpu_counts needs to be 1.') + + # Set stuff up for plots + if show_plot: + fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + + # Store values of the different speed-up factors, losses, and times + # as a function of GPU count + loss_mean_list = [] + loss_std_list = [] + time_mean_list = [] + time_std_list = [] + speed_up_mean_list = [] + speed_up_std_list = [] + + # Set up a parent/child connection to get the loss-versus-time + # data from each GPU test run created by t.multiprocessing.spawn(). + parent_conn, child_conn = t.multiprocessing.Pipe() + + for gpus in gpu_counts: + # Make a list to store the loss-versus-time values from each run + time_list = [] + loss_hist_list = [] + + for run in range(runs): + # Spawn the multi-GPU or single-GPU job + print('[INFO]: Starting run ', + f'{run+1}/{runs} on {gpus} GPU(s).') + t.multiprocessing.spawn(fn, + args=(gpus, child_conn), + nprocs=gpus) + + # Get the loss-versus-time data + while parent_conn.poll(): + loss_times, loss_history = parent_conn.recv() + + # Update the loss-versus-time data + time_list.append(loss_times) + loss_hist_list.append(loss_history) + + # Calculate the statistics over the runs performed + loss_mean = t.tensor(loss_hist_list).mean(dim=0) + loss_std = t.tensor(loss_hist_list).std(dim=0) + time_mean = t.tensor(time_list).mean(dim=0)/60 + time_std = t.tensor(time_list).std(dim=0)/60 + + if gpus == 1: # Assumes 1 GPU is used first in the test + time_1gpu = time_mean[-1] + std_1gpu = time_std[-1] + + # Calculate the speed-up relative to using a single GPU + speed_up_mean = time_1gpu / time_mean[-1] + speed_up_std = speed_up_mean * \ + t.sqrt((std_1gpu/time_1gpu)**2 + (time_std[-1]/time_mean[-1])**2) + + # Store the final loss-vs-time and speed-ups + loss_mean_list.append(loss_mean) + loss_std_list.append(loss_std) + time_mean_list.append(time_mean) + time_std_list.append(time_std) + speed_up_mean_list.append(speed_up_mean.item()) + speed_up_std_list.append(speed_up_std.item()) + + # Add another loss-versus-epoch/time curve + if show_plot: + ax1.errorbar(time_mean, loss_mean, yerr=loss_std, xerr=time_std, + label=f'{gpus} GPUs') + ax2.errorbar(t.arange(0, loss_mean.shape[0]), loss_mean, + yerr=loss_std, label=f'{gpus} GPUs') + ax3.errorbar(gpus, speed_up_mean, yerr=speed_up_std, fmt='o') + + print('[INFO]: Speed test completed.') + + if show_plot: + fig.suptitle(f'Multi-GPU performance test | {runs} runs performed') + ax1.set_yscale('log') + ax1.set_xscale('linear') + ax2.set_yscale('log') + ax2.set_xscale('linear') + ax3.set_yscale('linear') + ax3.set_xscale('linear') + ax1.legend() + ax2.legend() + ax1.set_xlabel('Time (min)') + ax1.set_ylabel('Loss') + ax2.set_xlabel('Epochs') + ax3.set_xlabel('Number of GPUs') + ax3.set_ylabel('Speed-up relative to single GPU') + plt.show() + + return loss_mean_list, loss_std_list, \ + time_mean_list, time_std_list, \ + speed_up_mean_list, speed_up_std_list diff --git a/tests/conftest.py b/tests/conftest.py index f0faea57..9de16ca3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,21 +32,32 @@ def pytest_addoption(parser): default=False, help="run slow tests, primarily full reconstruction tests." ) + parser.addoption( + "--runmultigpu", + action="store_true", + default=False, + help="Runs tests using 2 NVIDIA CUDA GPUs." + ) def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") + config.addinivalue_line("markers", "multigpu: run the multigpu test using 2 NVIDIA GPUs") def pytest_collection_modifyitems(config, items): - if config.getoption("--runslow"): - # --runslow given in cli: do not skip slow tests - return + # Skip the slow and/or multigpu tests if --runslow and/or --multigpu + # is given in cli. skip_slow = pytest.mark.skip(reason="need --runslow option to run") + skip_multigpu = pytest.mark.skip(reason='need --runmultigpu option to run') + for item in items: - if "slow" in item.keywords: + if "slow" in item.keywords and not config.getoption("--runslow"): item.add_marker(skip_slow) + if "multigpu" in item.keywords and not config.getoption("--runmultigpu"): + item.add_marker(skip_multigpu) + @pytest.fixture def reconstruction_device(request): @@ -415,3 +426,9 @@ def example_nested_dicts(pytestconfig): } return [test_dict_1, test_dict_2, test_dict_3] + + +@pytest.fixture(scope='module') +def multigpu_script(pytestconfig): + return str(pytestconfig.rootpath) + \ + '/tests/multi_gpu/multi_gpu_script_plot_and_save.py' diff --git a/tests/multi_gpu/multi_gpu_script_plot_and_save.py b/tests/multi_gpu/multi_gpu_script_plot_and_save.py new file mode 100644 index 00000000..3f90f61c --- /dev/null +++ b/tests/multi_gpu/multi_gpu_script_plot_and_save.py @@ -0,0 +1,85 @@ +import cdtools +from cdtools.tools import multigpu +import os +from matplotlib import pyplot as plt + +rank = multigpu.get_rank() +world_size = multigpu.get_world_size() +cdtools.tools.multigpu.setup(rank=rank, world_size=world_size) + +filename = os.environ.get('CDTOOLS_TESTING_DATA_PATH') +savedir = os.environ.get('CDTOOLS_TESTING_TMP_PATH') +SHOW_PLOT = bool(int(os.environ.get('CDTOOLS_TESTING_SHOW_PLOT'))) + +print('DONT CLOSE ANY OF THE FIGURES OR THE TEST WILL FAIL!') +dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) + +model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + oversampling=2, + probe_support_radius=120, + propagation_distance=5e-3, + units='mm', + obj_view_crop=-50, +) + +device = 'cuda' +model.to(device=device) +dataset.get_as(device=device) + +# Test Ptycho2DDataset.inspect +if SHOW_PLOT: + dataset.inspect() + +# Test Ptycho2DDataset.to_cxi +filename_to_cxi = os.path.join(savedir, + f'RANK_{model.rank}_test_to_cxi.h5') +dataset.to_cxi(filename_to_cxi) + +# Test CDIModel.save_to_h5 +filename_save_to_h5 = os.path.join(savedir, + f'RANK_{model.rank}_test_save_to.h5') +model.save_to_h5(filename_save_to_h5, dataset) + +# Test CDIModel.save_on_exit(), CDIModel.inspect() +filename_save_on_exit = os.path.join(savedir, + f'RANK_{model.rank}_test_save_on_exit.h5') + +with model.save_on_exit(filename_save_on_exit, dataset): + for loss in model.Adam_optimize(5, dataset, lr=0.02, batch_size=40): + if rank == 0: + print(model.report()) + if SHOW_PLOT: + model.inspect(dataset) + +if SHOW_PLOT: + # Test CDIModel.compare(dataset) + model.compare(dataset) + + # Test CDIModel.save_figures() + filename_save_figures = os.path.join(savedir, + f'RANK_{model.rank}_test_plot_') + model.save_figures(prefix=filename_save_figures, + extension='.png') + + plt.close('all') + +# Test CDIModel.save_checkpoint +filename_save_checkpoint = \ + os.path.join(savedir, f'RANK_{model.rank}_test_save_checkpoint.pt') +model.save_checkpoint(dataset, checkpoint_file=filename_save_checkpoint) + +# Test CDIModel.save_on_exception() +filename_save_on_except = \ + os.path.join(savedir, f'RANK_{model.rank}_test_save_on_except.h5') + +with model.save_on_exception(filename_save_on_except, dataset): + for loss in model.Adam_optimize(10, dataset, lr=0.02, batch_size=40): + if rank == 0 and model.epoch <= 10: + print(model.report()) + elif model.epoch > 10: + raise Exception('This is a deliberate exception raised to ' + + 'test save on exception') + +cdtools.tools.multigpu.cleanup() diff --git a/tests/multi_gpu/test_multi_gpu.py b/tests/multi_gpu/test_multi_gpu.py new file mode 100644 index 00000000..57be53b6 --- /dev/null +++ b/tests/multi_gpu/test_multi_gpu.py @@ -0,0 +1,171 @@ +import cdtools +from cdtools.tools import multigpu +import pytest +import os +import subprocess +import torch as t + +""" +This file contains several tests that are relevant to running multi-GPU +operations in CDTools. +""" + + +def reconstruct(rank, world_size, conn): + """ + An example reconstruction script to test the performance of 1 vs 2 GPU + operation. + """ + filename = os.environ.get('CDTOOLS_TESTING_GOLD_BALL_PATH') + cdtools.tools.multigpu.setup(rank=rank, + world_size=world_size) + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename) + + pad = 10 + dataset.pad(pad) + dataset.inspect() + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + device = 'cuda' + model.to(device=device) + dataset.get_as(device=device) + + recon = cdtools.reconstructors.AdamReconstructor(model, + dataset, + rank=rank, + world_size=world_size) + + for loss in recon.optimize(10, lr=0.005, batch_size=50): + if rank == 0 and model.epoch == 10: + print(model.report()) + conn.send((model.loss_times, model.loss_history)) + cdtools.tools.multigpu.cleanup() + + +@pytest.mark.multigpu +def test_plotting_saving_torchrun(lab_ptycho_cxi, + multigpu_script, + tmp_path, + show_plot): + """ + Run a multi-GPU test via torchrun on a script that executes several + plotting and file-saving methods from CDIModel and ensure they run + without failure. + + Also, make sure that only 1 GPU is generating the plots. + + If this test fails, one of three things happened: + 1) Either something failed while multigpu_script_2 was called + 2) Somehow, something aside from Rank 0 saved results + 3) multigpu_script_2 was not able to save all the data files + we asked it to save. + """ + # Run the test script, which generates several files that either have + # the prefix + cmd = ['torchrun', + '--standalone', + '--nnodes=1', + '--nproc_per_node=2', + multigpu_script] + + child_env = os.environ.copy() + child_env['CDTOOLS_TESTING_DATA_PATH'] = lab_ptycho_cxi + child_env['CDTOOLS_TESTING_TMP_PATH'] = str(tmp_path) + child_env['CDTOOLS_TESTING_SHOW_PLOT'] = str(int(show_plot)) + + try: + subprocess.run(cmd, check=True, env=child_env) + except subprocess.CalledProcessError: + # The called script is designed to throw an exception. + # TODO: Figure out how to distinguish between the engineered error + # in the script versus any other error. + pass + + # Check if all the generated file names only have the prefix 'RANK_0' + filelist = [f for f in os.listdir(tmp_path) + if os.path.isfile(os.path.join(tmp_path, f))] + + assert all([file.startswith('RANK_0') for file in filelist]) + print('All files have the RANK_0 prefix.') + + # Check if plots have been saved + if show_plot: + print('Plots generated: ' + + f"{sum([file.startswith('RANK_0_test_plot') for file in filelist])}") # noqa + assert any([file.startswith('RANK_0_test_plot') for file in filelist]) + else: + print('--plot not enabled. Checks on plotting and figure saving' + + ' will not be conducted.') + + # Check if we have all five data files saved + file_output_suffix = ('test_save_checkpoint.pt', + 'test_save_on_exit.h5', + 'test_save_on_except.h5', + 'test_save_to.h5', + 'test_to_cxi.h5') + + print(f'{sum([file.endswith(file_output_suffix) for file in filelist])}' + + ' out of 5 data files have been generated.') + assert sum([file.endswith(file_output_suffix) for file in filelist]) \ + == len(file_output_suffix) + + +@pytest.mark.multigpu +def test_reconstruction_quality_spawn(gold_ball_cxi, + show_plot): + """ + Run a multi-GPU speed test based on gold_ball_ptycho_speedtest.py + and make sure the final reconstructed loss using 2 GPUs is similar + to 1 GPU. + + This test requires us to have 2 NVIDIA GPUs and makes use of the + multi-GPU speed test. + + If this test fails, it indicates that the reconstruction quality is + getting noticably worse with increased GPU counts. This may be a symptom + of a synchronization/broadcasting issue between the different GPUs. + """ + # Make the gold_ball_cxi file path visible to the reconstruct function + os.environ['CDTOOLS_TESTING_GOLD_BALL_PATH'] = gold_ball_cxi + + loss_mean_list, loss_std_list, \ + _, _, speed_up_mean_list, speed_up_std_list\ + = multigpu.run_speed_test(fn=reconstruct, + gpu_counts=(1, 2), + runs=3, + show_plot=show_plot) + + # Make sure that the final loss values between the 1 and 2 GPU tests + # are comprable to within 1 std of each other. + single_gpu_loss_mean = loss_mean_list[0][-1] + single_gpu_loss_std = loss_std_list[0][-1] + double_gpu_loss_mean = loss_mean_list[1][-1] + double_gpu_loss_std = loss_std_list[1][-1] + + single_gpu_loss_min = single_gpu_loss_mean - single_gpu_loss_std + single_gpu_loss_max = single_gpu_loss_mean + single_gpu_loss_std + multi_gpu_loss_min = double_gpu_loss_mean - double_gpu_loss_std + multi_gpu_loss_max = double_gpu_loss_mean + double_gpu_loss_std + + has_loss_overlap = \ + min(single_gpu_loss_max, multi_gpu_loss_max)\ + > max(single_gpu_loss_min, multi_gpu_loss_min) + + assert has_loss_overlap + + # Make sure the loss mean falls below 3.2e-4. The values of losses I + # recorded at the time of testing were <3.19 e-4. + assert double_gpu_loss_mean < 3.2e-4 + + # Make sure that we have some speed up... + assert speed_up_mean_list[0] < speed_up_mean_list[1]