From b0eb1c5b831dfe00ffee9a5d9432bc3699453015 Mon Sep 17 00:00:00 2001 From: Samuel Weinbach Date: Wed, 7 Sep 2022 12:44:12 +0000 Subject: [PATCH 01/13] Allreduce bfloat in fp32 in zero --- deepspeed/runtime/engine.py | 1417 +++++++++++++---------- deepspeed/runtime/pipe/engine.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 3 files changed, 838 insertions(+), 583 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1fd7d7e964e8..56b9a91524a5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -23,34 +23,56 @@ from .zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException +from deepspeed.runtime.zero.utils import ( + is_zero_supported_optimizer, + ZeRORuntimeException, +) from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer -from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ - ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ - TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER +from deepspeed.runtime.config import ( + DeepSpeedConfig, + DEEPSPEED_OPTIMIZERS, + ADAGRAD_OPTIMIZER, + ADAM_OPTIMIZER, + ADAMW_OPTIMIZER, + LAMB_OPTIMIZER, + ONEBIT_ADAM_OPTIMIZER, + ONEBIT_LAMB_OPTIMIZER, + TORCH_ADAM_PARAM, + ADAM_W_MODE, + ADAM_W_MODE_DEFAULT, + ZERO_ONE_ADAM_OPTIMIZER, +) from deepspeed.runtime.dataloader import DeepSpeedDataLoader -from deepspeed.runtime.constants import \ - ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - PLD_THETA, PLD_GAMMA, BFLOAT16, FP16 +from deepspeed.runtime.constants import ( + ROUTE_TRAIN, + ROUTE_PREDICT, + ROUTE_EVAL, + PLD_THETA, + PLD_GAMMA, + BFLOAT16, + FP16, +) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.compression import compression_scheduler -from deepspeed.compression.constants import \ - WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \ - WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \ - WEIGHT_QUANTIZE_ENABLED, \ - WEIGHT_QUANTIZE_GROUPS, \ - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \ - WEIGHT_QUANTIZE_CHANGE_RATIO, \ - WEIGHT_QUANTIZE_TYPE, \ - WEIGHT_QUANTIZE_ROUNDING, \ - WEIGHT_QUANTIZE_VERBOSE, \ - WEIGHT_QUANTIZE_KERNEL +from deepspeed.compression.constants import ( + WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, + WEIGHT_QUANTIZATION, + SHARED_PARAMETERS, + WEIGHT_QUANTIZE_ENABLED, + WEIGHT_QUANTIZE_GROUPS, + WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, + WEIGHT_QUANTIZE_CHANGE_RATIO, + WEIGHT_QUANTIZE_TYPE, + WEIGHT_QUANTIZE_ROUNDING, + WEIGHT_QUANTIZE_VERBOSE, + WEIGHT_QUANTIZE_KERNEL, +) from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.runtime.sparse_tensor import SparseTensor @@ -64,7 +86,9 @@ from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler -from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import ( + TorchCheckpointEngine, +) from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status @@ -83,13 +107,15 @@ MEMORY_OPT_ALLREDUCE_SIZE = 500000000 -DeepSpeedOptimizerCallable = \ - Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer] +DeepSpeedOptimizerCallable = Callable[ + [Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer +] DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler] try: import apex from apex import amp + APEX_INSTALLED = True except ImportError: # Fail silently so we don't spam logs unnecessarily if user isn't using amp @@ -103,11 +129,13 @@ def split_half_float_double_sparse(tensors): "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor", - SparseTensor.type() + SparseTensor.type(), ] for t in tensors: - assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}" + assert ( + t.type() in supported_types + ), f"attempting to reduce an unsupported grad type: {t.type()}" buckets = [] for i, dtype in enumerate(supported_types): @@ -124,20 +152,21 @@ def print_configuration(args, name): logger.info(" {} {} {}".format(arg, dots, getattr(args, arg))) -FORWARD_MICRO_TIMER = 'forward_microstep' -FORWARD_GLOBAL_TIMER = 'forward' -BACKWARD_MICRO_TIMER = 'backward_microstep' -BACKWARD_GLOBAL_TIMER = 'backward' -BACKWARD_INNER_MICRO_TIMER = 'backward_inner_microstep' -BACKWARD_INNER_GLOBAL_TIMER = 'backward_inner' -BACKWARD_REDUCE_MICRO_TIMER = 'backward_allreduce_microstep' -BACKWARD_REDUCE_GLOBAL_TIMER = 'backward_allreduce' -STEP_MICRO_TIMER = 'step_microstep' -STEP_GLOBAL_TIMER = 'step' +FORWARD_MICRO_TIMER = "forward_microstep" +FORWARD_GLOBAL_TIMER = "forward" +BACKWARD_MICRO_TIMER = "backward_microstep" +BACKWARD_GLOBAL_TIMER = "backward" +BACKWARD_INNER_MICRO_TIMER = "backward_inner_microstep" +BACKWARD_INNER_GLOBAL_TIMER = "backward_inner" +BACKWARD_REDUCE_MICRO_TIMER = "backward_allreduce_microstep" +BACKWARD_REDUCE_GLOBAL_TIMER = "backward_allreduce" +STEP_MICRO_TIMER = "step_microstep" +STEP_GLOBAL_TIMER = "step" class EngineTimers(object): r"""Wallclock timers for DeepSpeedEngine""" + def __init__(self, enable_micro_timers, enable_global_timers): self.forward_timers = [] self.backward_timers = [] @@ -158,7 +187,7 @@ def __init__(self, enable_micro_timers, enable_global_timers): BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, - STEP_MICRO_TIMER + STEP_MICRO_TIMER, ] if enable_global_timers: @@ -172,12 +201,13 @@ def __init__(self, enable_micro_timers, enable_global_timers): BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, - STEP_GLOBAL_TIMER + STEP_GLOBAL_TIMER, ] class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" + def __init__( self, args, @@ -223,12 +253,15 @@ def __init__( self.moe_layers = [] self._step_applied = False self._global_grad_norm = None - self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. + self.use_ds_comm = ( + False # False --> Use torch.dist, True --> Use ds.comm backend. + ) self.checkpoint_engine = None global dist from deepspeed import comm as dist + self._is_gradient_accumulation_boundary = None self.scale_wrt_gas = None @@ -243,10 +276,12 @@ def __init__( self.config = config_params from deepspeed.comm import supported_torch_version + # This supported_torch_version check is for torch1.2 compatibility only if supported_torch_version: - dist.init_distributed(dist_backend=self.dist_backend, - dist_init_required=dist_init_required) + dist.init_distributed( + dist_backend=self.dist_backend, dist_init_required=dist_init_required + ) else: if dist_init_required is None: dist_init_required = not dist.is_initialized() @@ -262,13 +297,15 @@ def __init__( self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() - see_memory_usage(f"DeepSpeed Engine: After args sanity test", - force=self.memory_breakdown()) + see_memory_usage( + f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown() + ) if mpu is not None: if self.elasticity_enabled(): if not self.is_elastic_model_parallel_supported(): assert not self.elasticity_enabled(), ( - "Elasticity is not currently supported" " with model parallelism." + "Elasticity is not currently supported" + " with model parallelism." ) self._set_distributed_vars(args) @@ -301,8 +338,10 @@ def __init__( monitor_memory=False, ) - log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", - ranks=[0]) + log_dist( + f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", + ranks=[0], + ) if self.flops_profiler_enabled(): self.flops_profiler = FlopsProfiler(self.module, self) @@ -332,12 +371,14 @@ def __init__( self.sparse_tensor_module_names = set() # if self.sparse_gradients_enabled(): for name, module in self.module.named_modules(): - if isinstance(module, - (torch.nn.Embedding, - torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled(): + if ( + isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) + and self.sparse_gradients_enabled() + ): self.sparse_tensor_module_names.add(name + ".weight") logger.info( - "Will convert {} to sparse tensor during training".format(name)) + "Will convert {} to sparse tensor during training".format(name) + ) self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False @@ -358,7 +399,8 @@ def __init__( self.engine_timers = EngineTimers( enable_micro_timers=self.wall_clock_breakdown(), enable_global_timers=self.wall_clock_breakdown() - or self.flops_profiler_enabled()) + or self.flops_profiler_enabled(), + ) if self.global_rank == 0: self._config.print("DeepSpeedEngine configuration") @@ -371,7 +413,7 @@ def __init__( self.unflatten = util_ops.unflatten def destroy(self): - if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): + if self.optimizer is not None and hasattr(self.optimizer, "destroy"): self.optimizer.destroy() def _get_model_parameters(self): @@ -391,10 +433,12 @@ def _get_model_parameters(self): if p.requires_grad: trainable_num_params += n if self.global_rank == 0: - self.autotuning_model_info[ - "num_params"] = num_params * self.mp_world_size - self.autotuning_model_info[ - "trainable_num_params"] = trainable_num_params * self.mp_world_size + self.autotuning_model_info["num_params"] = ( + num_params * self.mp_world_size + ) + self.autotuning_model_info["trainable_num_params"] = ( + trainable_num_params * self.mp_world_size + ) logger.info(f"model parameter = {num_params}") @@ -424,13 +468,18 @@ def set_train_batch_size(self, train_batch_size): ValueError: if ``train_batch_size`` is not divisible by the configured micro-batch size and data parallelism. """ - if train_batch_size % (self.train_micro_batch_size_per_gpu() * - self.dp_world_size) != 0: - #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') + if ( + train_batch_size + % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) + != 0 + ): + # print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') raise ValueError( - f'Train batch size must be divisible by micro-batch data parallelism') - new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * - self.dp_world_size) + f"Train batch size must be divisible by micro-batch data parallelism" + ) + new_gas = train_batch_size // ( + self.train_micro_batch_size_per_gpu() * self.dp_world_size + ) # overwrite config self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas @@ -454,14 +503,15 @@ def __getattr__(self, name): _module = {} if "module" in self.__dict__: - _module = self.__dict__['module'] + _module = self.__dict__["module"] if name in dir(self): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) else: raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'") + f"'{type(self).__name__}' object has no attribute '{name}'" + ) def checkpoint_tag_validation_enabled(self): return self._config.checkpoint_tag_validation_enabled @@ -475,7 +525,11 @@ def elasticity_enabled(self): def is_elastic_model_parallel_supported(self): if self.elasticity_enabled(): # Add code for finding number of GPUs per node automatically - if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0: + if ( + self._config.num_gpus_per_node + % self._config.elastic_model_parallel_size + == 0 + ): return True else: return False @@ -576,10 +630,11 @@ def autotuning_metric(self): return self._config.autotuning_config.metric def autotuning_profile_model_info(self): - return self.autotuning_enabled( - ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get( - "profile", - False) + return ( + self.autotuning_enabled() + and self._config.autotuning_config.model_info + and self._config.autotuning_config.model_info.get("profile", False) + ) def sparse_gradients_enabled(self): return self._config.sparse_gradients_enabled @@ -591,8 +646,11 @@ def train_micro_batch_size_per_gpu(self): return self._config.train_micro_batch_size_per_gpu def optimizer_name(self): - return (self.client_optimizer.__class__.__name__ - if self.client_optimizer else self._config.optimizer_name) + return ( + self.client_optimizer.__class__.__name__ + if self.client_optimizer + else self._config.optimizer_name + ) def optimizer_params(self): return self._config.optimizer_params @@ -608,24 +666,33 @@ def scheduler_params(self): def quantize_training(self): return ( - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_GROUPS], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_CHANGE_RATIO], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_TYPE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_ROUNDING], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_VERBOSE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_KERNEL], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_IN_FORWARD_ENABLED + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_ENABLED + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_GROUPS + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_CHANGE_RATIO + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_TYPE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_ROUNDING + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_VERBOSE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_KERNEL + ], ) def zero_optimization(self): @@ -650,13 +717,16 @@ def zero_use_cpu_optimizer(self): if self._config.zero_config.offload_optimizer is not None: return self._config.zero_config.offload_optimizer.device in [ OffloadDeviceEnum.cpu, - OffloadDeviceEnum.nvme + OffloadDeviceEnum.nvme, ] return False def zero_cpu_offload(self): if self._config.zero_config.offload_optimizer is not None: - return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu + return ( + self._config.zero_config.offload_optimizer.device + == OffloadDeviceEnum.cpu + ) return False def zero_sub_group_size(self): @@ -745,10 +815,10 @@ def communication_data_type(self): res = self._config.communication_data_type if res is not None: return res - elif self.fp16_enabled() or self.zero_optimization_stage(): + elif self.fp16_enabled(): return torch.float16 elif self.bfloat16_enabled(): - return torch.bfloat16 + return torch.float32 return torch.float32 @@ -794,28 +864,33 @@ def _configure_lr_scheduler(self, client_lr_scheduler): if lr_scheduler: log_dist( f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", - ranks=[0]) + ranks=[0], + ) self.lr_scheduler = lr_scheduler else: if isinstance(client_lr_scheduler, Callable): - log_dist('DeepSpeed using client callable to create LR scheduler', - ranks=[0]) + log_dist( + "DeepSpeed using client callable to create LR scheduler", ranks=[0] + ) self.lr_scheduler = client_lr_scheduler(self.basic_optimizer) else: - log_dist('DeepSpeed using client LR scheduler', ranks=[0]) + log_dist("DeepSpeed using client LR scheduler", ranks=[0]) self.lr_scheduler = client_lr_scheduler - log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) + log_dist(f"DeepSpeed LR Scheduler = {self.lr_scheduler}", ranks=[0]) def _configure_checkpointing(self, dist_init_required): self.checkpoint_engine = TorchCheckpointEngine() if self._config is not None and self._config.nebula_config.enabled: try: - from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ - NebulaCheckpointEngine + from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import ( + NebulaCheckpointEngine, + ) + self.checkpoint_engine = NebulaCheckpointEngine( - config_params=self._config.nebula_config) + config_params=self._config.nebula_config + ) except ImportError as err: logger.error( f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" @@ -828,7 +903,8 @@ def _configure_checkpointing(self, dist_init_required): # only the first data parallel process needs to store the model checkpoint self.save_non_zero_checkpoint = ( - dp_rank == 0) or self.zero_optimization_partition_weights() + dp_rank == 0 + ) or self.zero_optimization_partition_weights() if self.zero_optimization() or self.bfloat16_enabled(): param_rank = dist.get_rank(group=self.optimizer.dp_process_group) @@ -856,9 +932,11 @@ def _scheduler_from_config(self, optimizer): return None def _set_distributed_vars(self, args): - device_rank = args.device_rank if args is not None and hasattr( - args, - 'device_rank') else self.local_rank + device_rank = ( + args.device_rank + if args is not None and hasattr(args, "device_rank") + else self.local_rank + ) if device_rank >= 0: torch.cuda.set_device(device_rank) self.device = torch.device("cuda", device_rank) @@ -878,19 +956,21 @@ def _configure_with_arguments(self, args, mpu): if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") - local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) - assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ + local_rank = os.environ.get("LOCAL_RANK", ompi_local_rank) + assert ompi_local_rank == local_rank, ( + f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " "not sure how to proceed as we're seeing conflicting local rank info." - os.environ['LOCAL_RANK'] = local_rank + ) + os.environ["LOCAL_RANK"] = local_rank - self.local_rank = int(os.environ['LOCAL_RANK']) - if hasattr(args, 'local_rank'): + self.local_rank = int(os.environ["LOCAL_RANK"]) + if hasattr(args, "local_rank"): args.local_rank = self.local_rank if self.config is None: - self.config = (args.deepspeed_config - if hasattr(args, - "deepspeed_config") else None) + self.config = ( + args.deepspeed_config if hasattr(args, "deepspeed_config") else None + ) self._config = DeepSpeedConfig(self.config, mpu) # Validate command line arguments @@ -905,13 +985,18 @@ def _do_args_sanity_check(self, args): ), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config - assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ - "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ + assert ( + "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ + ), ( + "DeepSpeed requires the LOCAL_RANK environment " + "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." + ) - if hasattr(args, 'local_rank') and args.local_rank != None: + if hasattr(args, "local_rank") and args.local_rank != None: assert isinstance( - args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" + args.local_rank, int + ), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" if args.local_rank >= 0: env_local_rank = int(os.environ.get("LOCAL_RANK")) assert ( @@ -920,8 +1005,7 @@ def _do_args_sanity_check(self, args): if self.config is None: assert ( - hasattr( - args, "deepspeed_config") and args.deepspeed_config is not None + hasattr(args, "deepspeed_config") and args.deepspeed_config is not None ), "DeepSpeed requires --deepspeed_config to specify configuration file" assert os.path.isfile( @@ -931,10 +1015,10 @@ def _do_args_sanity_check(self, args): ) def _is_supported_optimizer(self, optimizer_name): - return (optimizer_name in DEEPSPEED_OPTIMIZERS - or getattr(torch.optim, - optimizer_name, - None) is not None) + return ( + optimizer_name in DEEPSPEED_OPTIMIZERS + or getattr(torch.optim, optimizer_name, None) is not None + ) def _supported_optims(self): FairseqOptimizer = None @@ -953,8 +1037,9 @@ def _supported_optims(self): def _do_sanity_check(self): expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] - assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ - f'Client Optimizer is of unexpected type {type(self.client_optimizer)}' + assert isinstance( + self.client_optimizer, tuple(expected_optim_types) + ), f"Client Optimizer is of unexpected type {type(self.client_optimizer)}" if not self.client_optimizer: if self.optimizer_name() is not None: @@ -964,8 +1049,10 @@ def _do_sanity_check(self): self.optimizer_name() ) - if (self.optimizer_name() == LAMB_OPTIMIZER - or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER): + if ( + self.optimizer_name() == LAMB_OPTIMIZER + or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER + ): assert ( self.dynamic_loss_scale() ), "DeepSpeed {} optimizer requires dynamic loss scaling".format( @@ -974,8 +1061,9 @@ def _do_sanity_check(self): # Detect invalid combinations of client optimizer and client scheduler if isinstance(self.client_lr_scheduler, _LRScheduler): - assert isinstance(self.client_optimizer, Optimizer), \ - f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' + assert isinstance( + self.client_optimizer, Optimizer + ), f"Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated" def _broadcast_model(self): def is_replicated(p): @@ -987,20 +1075,26 @@ def is_replicated(p): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, - groups._get_expert_broadcast_src_rank(p.group_name), - group=self.expert_data_parallel_group[p.group_name]) + dist.broadcast( + p, + groups._get_expert_broadcast_src_rank(p.group_name), + group=self.expert_data_parallel_group[p.group_name], + ) else: if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, - groups._get_broadcast_src_rank(), - group=self.data_parallel_group) + dist.broadcast( + p, + groups._get_broadcast_src_rank(), + group=self.data_parallel_group, + ) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: return - if not all(param.dtype == dtype - for param in model.parameters()) and dist.get_rank() == 0: + if ( + not all(param.dtype == dtype for param in model.parameters()) + and dist.get_rank() == 0 + ): raise ValueError( f"{dtype} is enabled but the following parameters have dtype that is " f"not {dtype}: " @@ -1009,23 +1103,25 @@ def __check_params(model: Module, dtype: torch.dtype) -> None: def _set_client_model(self, model): # register client model in _modules so that nn.module methods work correctly - modules = self.__dict__.get('_modules') - modules['module'] = model + modules = self.__dict__.get("_modules") + modules["module"] = model # register module attribute in engine but avoid getattr - self.__dict__['module'] = model + self.__dict__["module"] = model def _configure_distributed_model(self, model): self._set_client_model(model) if self.fp16_enabled(): if self.zero_optimization_partition_weights() and any( - [hasattr(param, - "ds_id") for param in self.module.parameters()]): + [hasattr(param, "ds_id") for param in self.module.parameters()] + ): if not all( - [param.dtype == torch.half for param in self.module.parameters()]): + [param.dtype == torch.half for param in self.module.parameters()] + ): names = [ - n for n, - p in self.module.named_parameters() if p.dtype != torch.half + n + for n, p in self.module.named_parameters() + if p.dtype != torch.half ] raise ValueError( f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}" @@ -1033,8 +1129,8 @@ def _configure_distributed_model(self, model): self.module.half() elif self.bfloat16_enabled(): if self.zero_optimization_partition_weights() and any( - hasattr(param, - 'ds_id') for param in self.module.parameters()): + hasattr(param, "ds_id") for param in self.module.parameters() + ): self.__check_params(self.module, torch.bfloat16) if self.zero_optimization_stage() == 0 and not self.pipeline_parallelism: raise NotImplementedError( @@ -1070,7 +1166,7 @@ def _configure_distributed_model(self, model): # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): - if hasattr(module, 'set_deepspeed_parallelism'): + if hasattr(module, "set_deepspeed_parallelism"): module.set_deepspeed_parallelism() # Query the groups module to get information about various parallel groups @@ -1091,12 +1187,17 @@ def _check_for_duplicates(self, optimizer): def ids_list(group): return [id(param) for param in group] - occurrence = sum([ - ids_list(group['params']).count(param_id) - if param_id in ids_list(group['params']) else 0 - for group in optimizer.param_groups - ]) - assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour." + occurrence = sum( + [ + ids_list(group["params"]).count(param_id) + if param_id in ids_list(group["params"]) + else 0 + for group in optimizer.param_groups + ] + ) + assert ( + occurrence <= 1 + ), f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour." # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): @@ -1107,24 +1208,28 @@ def _configure_optimizer(self, client_optimizer, model_parameters): ] log_dist( "Removing param_group that has no 'params' in the client Optimizer", - ranks=[0]) + ranks=[0], + ) basic_optimizer = client_optimizer - log_dist('Using client Optimizer as basic optimizer', ranks=[0]) + log_dist("Using client Optimizer as basic optimizer", ranks=[0]) else: basic_optimizer = client_optimizer(model_parameters) - log_dist('Using client callable to create basic optimizer', ranks=[0]) + log_dist("Using client callable to create basic optimizer", ranks=[0]) else: basic_optimizer = self._configure_basic_optimizer(model_parameters) log_dist( f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", - ranks=[0]) + ranks=[0], + ) self._check_for_duplicates(basic_optimizer) self.basic_optimizer = basic_optimizer - log_dist("DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", - ranks=[0]) + log_dist( + "DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", + ranks=[0], + ) if self.zero_optimization(): assert ( @@ -1143,7 +1248,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters): # while ZeRO optimizer itself wraps the original optimizer. self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): - assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" + assert not ( + self.fp16_enabled() or self.bfloat16_enabled() + ), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" amp_params = self.amp_params() log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0]) try: @@ -1151,7 +1258,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters): except NameError: # If apex/amp is available it will be imported above raise RuntimeError( - "Unable to import apex/amp, please make sure it is installed") + "Unable to import apex/amp, please make sure it is installed" + ) model, self.optimizer = amp.initialize( self.module, basic_optimizer, **amp_params ) @@ -1164,8 +1272,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters): self.optimizer = self._configure_bf16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer - log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), - ranks=[0]) + log_dist( + "DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0] + ) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() @@ -1180,32 +1289,44 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) - if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: + if self.optimizer_name() in [ + ADAGRAD_OPTIMIZER, + ADAM_OPTIMIZER, + ADAMW_OPTIMIZER, + ]: torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set - effective_adam_w_mode = self.optimizer_name( - ) == ADAMW_OPTIMIZER or adam_w_mode + effective_adam_w_mode = ( + self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode + ) if torch_adam: if not effective_adam_w_mode: - optimizer = torch.optim.Adam(model_parameters, - **optimizer_parameters) + optimizer = torch.optim.Adam( + model_parameters, **optimizer_parameters + ) else: - optimizer = torch.optim.AdamW(model_parameters, - **optimizer_parameters) + optimizer = torch.optim.AdamW( + model_parameters, **optimizer_parameters + ) else: if self.zero_use_cpu_optimizer(): if self.optimizer_name() == ADAGRAD_OPTIMIZER: from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad - optimizer = DeepSpeedCPUAdagrad(model_parameters, - **optimizer_parameters) + + optimizer = DeepSpeedCPUAdagrad( + model_parameters, **optimizer_parameters + ) else: from deepspeed.ops.adam import DeepSpeedCPUAdam - optimizer = DeepSpeedCPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=effective_adam_w_mode) + + optimizer = DeepSpeedCPUAdam( + model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode, + ) else: from deepspeed.ops.adam import FusedAdam @@ -1235,7 +1356,8 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): logger.warning( - f'Currently the convergence of 0/1 Adam is only verified under FP16') + f"Currently the convergence of 0/1 Adam is only verified under FP16" + ) elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb @@ -1266,7 +1388,9 @@ def _configure_quantization(self): use_quantizer_kernel, ) = self.quantize_training() if quantize_enabled and not quantize_weight_in_forward: - assert self.fp16_enabled(), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" + assert ( + self.fp16_enabled() + ), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" quantizer = None if quantize_enabled and not quantize_weight_in_forward: from deepspeed.runtime.quantize import Quantizer @@ -1292,8 +1416,10 @@ def _configure_fp16_optimizer(self, optimizer): fused_opts = (apex.optimizers.FusedAdam, FusedAdam) else: fused_opts = FusedAdam - if isinstance(optimizer, fused_opts) \ - or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]: + if isinstance(optimizer, fused_opts) or self.optimizer_name() in [ + ONEBIT_ADAM_OPTIMIZER, + ZERO_ONE_ADAM_OPTIMIZER, + ]: if self.dynamic_loss_scale(): log_dist("Creating fp16 optimizer with dynamic loss scale", ranks=[0]) timers = self.timers if self.wall_clock_breakdown() else None @@ -1312,7 +1438,8 @@ def _configure_fp16_optimizer(self, optimizer): else: log_dist( "Creating fp16 optimizer with static loss scale: {}".format( - self.loss_scale()), + self.loss_scale() + ), ranks=[0], ) optimizer = FP16_Optimizer( @@ -1325,8 +1452,9 @@ def _configure_fp16_optimizer(self, optimizer): has_moe_layers=self.has_moe_layers, ) else: - log_dist("Creating fp16 unfused optimizer with dynamic loss scale", - ranks=[0]) + log_dist( + "Creating fp16 unfused optimizer with dynamic loss scale", ranks=[0] + ) optimizer = FP16_UnfusedOptimizer( optimizer, deepspeed=self, @@ -1346,7 +1474,7 @@ def _configure_bf16_optimizer(self, optimizer): if optimizer is None: optimizer = DummyOptim(list(self.module.parameters())) - log_dist('Creating BF16 optimizer', ranks=[0]) + log_dist("Creating BF16 optimizer", ranks=[0]) timers = self.timers if self.wall_clock_breakdown() else None optimizer = BF16_Optimizer( @@ -1356,13 +1484,14 @@ def _configure_bf16_optimizer(self, optimizer): clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group, - timers=timers) + timers=timers, + ) return optimizer def _configure_zero_optimizer(self, optimizer): zero_stage = self.zero_optimization_stage() - assert self.communication_data_type in (torch.float16, torch.bfloat16), "ZeRO supports only 'communication_data_type': ['fp16', 'bfp16']" + # assert self.communication_data_type in (torch.float16, torch.bfloat16), "ZeRO supports only 'communication_data_type': ['fp16', 'bfp16']" timers = self.timers if self.wall_clock_breakdown() else None if optimizer is None: @@ -1377,10 +1506,13 @@ def _configure_zero_optimizer(self, optimizer): overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() round_robin_gradients = self.zero_round_robin_gradients() - assert not isinstance(optimizer, DummyOptim), "zero stage 2 requires an optimizer" + assert not isinstance( + optimizer, DummyOptim + ), "zero stage 2 requires an optimizer" - log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), - ranks=[0]) + log_dist( + "Creating fp16 ZeRO stage {} optimizer".format(zero_stage), ranks=[0] + ) # Overlap and contiguous grads are meaningless in stage 1 and are ignored if zero_stage == ZeroStageEnum.optimizer_states: overlap_comm = False @@ -1404,9 +1536,11 @@ def _configure_zero_optimizer(self, optimizer): allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group, expert_parallel_group=self.expert_parallel_group - if self.has_moe_layers else None, + if self.has_moe_layers + else None, expert_data_parallel_group=self.expert_data_parallel_group - if self.has_moe_layers else None, + if self.has_moe_layers + else None, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=overlap_comm, cpu_offload=self.zero_cpu_offload(), @@ -1418,10 +1552,10 @@ def _configure_zero_optimizer(self, optimizer): partition_grads=zero_stage == ZeroStageEnum.gradients, round_robin_gradients=round_robin_gradients, has_moe_layers=self.has_moe_layers, - fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients( - ), + fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), communication_data_type=self.communication_data_type, - elastic_checkpoint=self.zero_elastic_checkpoint()) + elastic_checkpoint=self.zero_elastic_checkpoint(), + ) elif zero_stage == ZeroStageEnum.weights: assert not self.has_moe_layers, "MoE not supported with Stage 3" @@ -1438,11 +1572,15 @@ def _configure_zero_optimizer(self, optimizer): param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(), offload_param_config=self.zero_offload_param(), - mpu=self.mpu) + mpu=self.mpu, + ) else: - log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), - ranks=[0]) + log_dist( + "Creating fp16 ZeRO stage {} optimizer".format(zero_stage), + ranks=[0], + ) from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 + optimizer = DeepSpeedZeroOptimizer_Stage3( self.module, optimizer, @@ -1470,10 +1608,13 @@ def _configure_zero_optimizer(self, optimizer): gradient_predivide_factor=self.gradient_predivide_factor(), gradient_accumulation_steps=self.gradient_accumulation_steps(), aio_config=self.aio_config(), - communication_data_type=self.communication_data_type) + communication_data_type=self.communication_data_type, + ) else: - raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) + raise NotImplementedError( + "ZeRO stage {} not implemented".format(zero_stage) + ) return optimizer @@ -1505,9 +1646,9 @@ def is_map_style_dataset(obj): @staticmethod def is_iterable_style_dataset(obj): - return isinstance(obj, - torch.utils.data.IterableDataset - ) # hasattr(obj, "__iter__") should work as well + return isinstance( + obj, torch.utils.data.IterableDataset + ) # hasattr(obj, "__iter__") should work as well def dataloader_drop_last(self): return self._config.dataloader_drop_last @@ -1522,16 +1663,20 @@ def was_step_applied(self) -> bool: """ return self._step_applied - def deepspeed_io(self, - dataset, - batch_size=None, - route=ROUTE_TRAIN, - pin_memory=True, - data_sampler=None, - collate_fn=None, - num_local_io_workers=None): - if not (self.is_map_style_dataset(dataset) - or self.is_iterable_style_dataset(dataset)): + def deepspeed_io( + self, + dataset, + batch_size=None, + route=ROUTE_TRAIN, + pin_memory=True, + data_sampler=None, + collate_fn=None, + num_local_io_workers=None, + ): + if not ( + self.is_map_style_dataset(dataset) + or self.is_iterable_style_dataset(dataset) + ): raise ValueError("Training data must be a torch Dataset") if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL): @@ -1555,17 +1700,19 @@ def deepspeed_io(self, data_parallel_world_size = self.mpu.get_data_parallel_world_size() data_parallel_rank = self.mpu.get_data_parallel_rank() - return DeepSpeedDataLoader(dataset=dataset, - batch_size=batch_size, - pin_memory=pin_memory, - collate_fn=collate_fn, - local_rank=self.local_rank, - tput_timer=deepspeed_io_timer, - num_local_io_workers=num_local_io_workers, - data_sampler=data_sampler, - data_parallel_world_size=data_parallel_world_size, - data_parallel_rank=data_parallel_rank, - dataloader_drop_last=self.dataloader_drop_last()) + return DeepSpeedDataLoader( + dataset=dataset, + batch_size=batch_size, + pin_memory=pin_memory, + collate_fn=collate_fn, + local_rank=self.local_rank, + tput_timer=deepspeed_io_timer, + num_local_io_workers=num_local_io_workers, + data_sampler=data_sampler, + data_parallel_world_size=data_parallel_world_size, + data_parallel_rank=data_parallel_rank, + dataloader_drop_last=self.dataloader_drop_last(), + ) def train(self, mode=True): r"""""" @@ -1612,16 +1759,21 @@ def forward(self, *inputs, **kwargs): else: see_memory_usage("Engine before forward", force=self.memory_breakdown()) - flops_profiler_active = (self.flops_profiler_enabled() and self.global_steps - == self.flops_profiler_profile_step() - and self.global_rank == 0) + flops_profiler_active = ( + self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() + and self.global_rank == 0 + ) # used to check quantization happens at step 0! if self.global_steps == 0 and hasattr(self, "compression_scheduler"): self.compression_scheduler.step(step_zero_check=True) if self.quantizer: - tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + tensor_to_quantize = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, @@ -1643,10 +1795,11 @@ def forward(self, *inputs, **kwargs): if self.module.training and self.curriculum_enabled(): self.curriculum_scheduler.update_difficulty(self.global_steps + 1) if self.curriculum_params()["curriculum_type"] == "seqlen": - kwargs.update({ - "curriculum_seqlen": - self.curriculum_scheduler.get_current_difficulty() - }) + kwargs.update( + { + "curriculum_seqlen": self.curriculum_scheduler.get_current_difficulty() + } + ) if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that @@ -1678,9 +1831,9 @@ def forward(self, *inputs, **kwargs): if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem - print_json_dist(self.autotuning_model_info, - [0], - path=self.autotuning_model_info_path()) + print_json_dist( + self.autotuning_model_info, [0], path=self.autotuning_model_info_path() + ) exit() else: see_memory_usage("Engine after forward", force=self.memory_breakdown()) @@ -1697,7 +1850,7 @@ def _cast_inputs_half(self, inputs): for k, v in inputs: new_inputs[k] = self._cast_inputs_half(v) return new_inputs - elif hasattr(inputs, 'half'): + elif hasattr(inputs, "half"): return inputs.half() else: return inputs @@ -1709,11 +1862,11 @@ def print_forward_breakdown(self, fwd_time): salltoall = 0.0 for gate in self.gate_modules: - #logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms") + # logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms") gate_time += gate.gate_time for l in self.moe_layers: - #logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}") + # logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}") moe_time += l.time_moe falltoall += l.time_falltoall salltoall += l.time_salltoall @@ -1723,15 +1876,15 @@ def print_forward_breakdown(self, fwd_time): # if deepspeed.comm.get_rank() == 0: log_dist( f"rank={dist.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})", - ranks=[0]) + ranks=[0], + ) @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ - f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' # Pass (PP) gas boundary flag to optimizer (required for zero) - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + self.optimizer.is_gradient_accumulation_boundary = ( + self.is_gradient_accumulation_boundary() ) # ZeRO stage 2 communicates during non gradient accumulation boundaries as well @@ -1742,17 +1895,20 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): elif self.is_gradient_accumulation_boundary(): if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: self.optimizer.reduce_gradients( - pipeline_parallel=self.pipeline_parallelism) + pipeline_parallel=self.pipeline_parallelism + ) else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) @instrument_w_nvtx - def backward(self, - loss, - allreduce_gradients=True, - release_loss=False, - retain_graph=False, - scale_wrt_gas=True): + def backward( + self, + loss, + allreduce_gradients=True, + release_loss=False, + retain_graph=False, + scale_wrt_gas=True, + ): r"""Execute backward pass on the loss Arguments: loss: Torch tensor on which to execute backward propagation @@ -1779,31 +1935,35 @@ def backward(self, if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: - self.summary_events = [( - f"Train/Samples/train_loss", - loss.mean().item() * self.gradient_accumulation_steps(), - self.global_samples, - )] + self.summary_events = [ + ( + f"Train/Samples/train_loss", + loss.mean().item() * self.gradient_accumulation_steps(), + self.global_samples, + ) + ] self.monitor.write_events(self.summary_events) self._start_timers(self.engine_timers.backward_timers) - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use backward" + assert self.optimizer is not None and not isinstance( + self.optimizer, DummyOptim + ), "must provide optimizer during init in order to use backward" self._start_timers(self.engine_timers.backward_inner_timers) if self.zero_optimization(): - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + self.optimizer.is_gradient_accumulation_boundary = ( + self.is_gradient_accumulation_boundary() ) self.optimizer.backward(loss, retain_graph=retain_graph) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, - self.optimizer, - delay_unscale=delay_unscale) as scaled_loss: + with amp.scale_loss( + loss, self.optimizer, delay_unscale=delay_unscale + ) as scaled_loss: scaled_loss.backward(retain_graph=retain_graph) elif self.fp16_enabled(): if self.eigenvalue_enabled(): @@ -1846,8 +2006,7 @@ def is_gradient_accumulation_boundary(self): bool: if the current step is a gradient accumulation boundary. """ if self._is_gradient_accumulation_boundary is None: - return (self.micro_steps + 1) % \ - self.gradient_accumulation_steps() == 0 + return (self.micro_steps + 1) % self.gradient_accumulation_steps() == 0 else: return self._is_gradient_accumulation_boundary @@ -1882,31 +2041,42 @@ def zero_grad(self): param.grad = None def clip_fp32_gradients(self): - clip_grad_norm_(parameters=self.module.parameters(), - max_norm=self.gradient_clipping(), - mpu=self.mpu) + clip_grad_norm_( + parameters=self.module.parameters(), + max_norm=self.gradient_clipping(), + mpu=self.mpu, + ) def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: - if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() - or self.zero_optimization()): + if not ( + self.fp16_enabled() + or self.bfloat16_enabled() + or self.amp_enabled() + or self.zero_optimization() + ): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) - clip_grad_norm_(parameters=master_params, - max_norm=self.gradient_clipping(), - mpu=self.mpu) + clip_grad_norm_( + parameters=master_params, + max_norm=self.gradient_clipping(), + mpu=self.mpu, + ) self.optimizer.step() - if hasattr(self.optimizer, '_global_grad_norm'): + if hasattr(self.optimizer, "_global_grad_norm"): self._global_grad_norm = self.optimizer._global_grad_norm # Quantize the updated parameter if there is no overflow if self.quantizer: - tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + tensor_to_quantize = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, @@ -1962,14 +2132,17 @@ def step(self, lr_kwargs=None): # Check early because self.global_steps is incremented at some point here. # TODO: Delay self.global_steps increment until very end of this function. - flops_profiler_active = self.flops_profiler_enabled( - ) and self.global_steps == self.flops_profiler_profile_step( - ) and self.global_rank == 0 + flops_profiler_active = ( + self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() + and self.global_rank == 0 + ) self._start_timers(self.engine_timers.step_timers) - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use step" + assert self.optimizer is not None and not isinstance( + self.optimizer, DummyOptim + ), "must provide optimizer during init in order to use step" report_progress = self.global_rank == 0 if self.global_rank else True @@ -1979,21 +2152,28 @@ def step(self, lr_kwargs=None): if self.is_gradient_accumulation_boundary(): self.gas_boundary_ctr += 1 - if (self.eigenvalue_enabled() and - (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) - and self.quantizer.any_precision_switch()): + if ( + self.eigenvalue_enabled() + and ( + self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() + == 0 + ) + and self.quantizer.any_precision_switch() + ): log_dist(f"computing eigenvalue...", ranks=[0]) self.block_eigenvalue = self.eigenvalue.compute_eigenvalue( - self.module, - self.device, - self.optimizer.cur_scale) + self.module, self.device, self.optimizer.cur_scale + ) if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % - self.eigenvalue_gas_boundary_resolution() - and self.quantizer.any_precision_switch()): + if ( + self.eigenvalue_enabled() + and not self.gas_boundary_ctr + % self.eigenvalue_gas_boundary_resolution() + and self.quantizer.any_precision_switch() + ): self._take_model_step(lr_kwargs, self.block_eigenvalue) else: self._take_model_step(lr_kwargs) @@ -2006,26 +2186,33 @@ def step(self, lr_kwargs=None): if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: - self.summary_events = [(f"Train/Samples/lr", - self.get_lr()[0], - self.global_samples)] + self.summary_events = [ + (f"Train/Samples/lr", self.get_lr()[0], self.global_samples) + ] if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): - self.summary_events.append(( - f"Train/Samples/loss_scale", - self.optimizer.cur_scale, - self.global_samples, - )) - - if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % - self.eigenvalue_gas_boundary_resolution()): + self.summary_events.append( + ( + f"Train/Samples/loss_scale", + self.optimizer.cur_scale, + self.global_samples, + ) + ) + + if ( + self.eigenvalue_enabled() + and not self.gas_boundary_ctr + % self.eigenvalue_gas_boundary_resolution() + ): ev_values = self.block_eigenvalue.values() for i in range(len(ev_values)): - self.summary_events.append(( - f"Train/Eigenvalues/ModelBlockParam_{i}", - self.ev_values[i][0], - self.global_samples, - )) + self.summary_events.append( + ( + f"Train/Eigenvalues/ModelBlockParam_{i}", + self.ev_values[i][0], + self.global_samples, + ) + ) self.monitor.write_events(self.summary_events) # Check flops profiling @@ -2043,13 +2230,16 @@ def step(self, lr_kwargs=None): self.flops_profiler.end_profile() if self.autotuning_enabled() and self.global_steps == ( - self.autotuning_end_profile_step() + 1): + self.autotuning_end_profile_step() + 1 + ): self._autotuning_exit() if self.wall_clock_breakdown(): # Log micro timing and reset - self.timers.log(names=self.engine_timers.micro_timers, - memory_breakdown=self.memory_breakdown()) + self.timers.log( + names=self.engine_timers.micro_timers, + memory_breakdown=self.memory_breakdown(), + ) if self.wall_clock_breakdown() or self.flops_profiler_enabled(): # Log global timing and reset @@ -2071,29 +2261,37 @@ def _start_timers(self, timer_names): self.timers(name).start() def _stop_timers(self, timer_names): - record = self.is_gradient_accumulation_boundary() and \ - self.flops_profiler_enabled() and \ - (self.global_steps >= self.flops_profiler_profile_step()) + record = ( + self.is_gradient_accumulation_boundary() + and self.flops_profiler_enabled() + and (self.global_steps >= self.flops_profiler_profile_step()) + ) for name in timer_names: self.timers(name).stop(record=record) def _autotuning_exit(self): if self.global_rank == 0: - msg = self.timers.get_mean([ - FORWARD_GLOBAL_TIMER, - BACKWARD_GLOBAL_TIMER, - STEP_GLOBAL_TIMER, - ], - reset=False) - titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[ - STEP_GLOBAL_TIMER] + msg = self.timers.get_mean( + [ + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, + ], + reset=False, + ) + titer = ( + msg[FORWARD_GLOBAL_TIMER] + + msg[BACKWARD_GLOBAL_TIMER] + + msg[STEP_GLOBAL_TIMER] + ) msg["latency"] = titer - msg["FLOPS_per_gpu"] = self.flops * self.gradient_accumulation_steps( - ) / titer - msg["throughput"] = self.train_batch_size() * 1000 / \ - msg["latency"] + msg["FLOPS_per_gpu"] = ( + self.flops * self.gradient_accumulation_steps() / titer + ) + msg["throughput"] = self.train_batch_size() * 1000 / msg["latency"] print_json_dist(msg, [0], path=self.autotuning_metric_path()) import atexit + atexit.register(print, "Autotuning: done with running current ds config.") exit() @@ -2160,8 +2358,9 @@ def get_pld_theta(self): def _report_progress(self, step): lr = self.get_lr() mom = self.get_mom() - log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", - ranks=[0]) + log_dist( + f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0] + ) def allreduce_bucket(self, bucket, dp_group): tensor = self.flatten(bucket) @@ -2178,14 +2377,20 @@ def allreduce_bucket(self, bucket, dp_group): dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: if self.gradient_predivide_factor() != dist.get_world_size( - group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / - dist.get_world_size(group=dp_group)) + group=dp_group + ): + tensor_to_allreduce.mul_( + self.gradient_predivide_factor() + / dist.get_world_size(group=dp_group) + ) else: - tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(1.0 / dist.get_world_size(group=dp_group)) dist.all_reduce(tensor_to_allreduce, group=dp_group) - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if ( + self.communication_data_type != tensor.dtype + and tensor is not tensor_to_allreduce + ): tensor.copy_(tensor_to_allreduce) return tensor @@ -2222,9 +2427,9 @@ def _get_gradients_for_reduction(self): # rank is reducing the same size. In some cases it may make # sense in the future to support the ability to average not # w.r.t. world size but with a different value. - param.grad = torch.zeros(param.size(), - dtype=param.dtype, - device=param.device) + param.grad = torch.zeros( + param.size(), dtype=param.dtype, device=param.device + ) grad_data = param.grad.data if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: @@ -2251,9 +2456,9 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) else: - self.allreduce_no_retain(bucket, - dp_group=dp_group, - numel_per_bucket=elements_per_buffer) + self.allreduce_no_retain( + bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer + ) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): for ep_name, expert_grads_group in expert_grads.items(): @@ -2262,20 +2467,23 @@ def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): bucket_type, bucket = bucket_tuple if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain( - bucket, - groups._get_expert_data_parallel_group(ep_name)) + bucket, groups._get_expert_data_parallel_group(ep_name) + ) else: # Separate between diff groups self.allreduce_no_retain( bucket, dp_group=groups._get_expert_data_parallel_group(ep_name), - numel_per_bucket=elements_per_buffer) + numel_per_bucket=elements_per_buffer, + ) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: - assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" + assert ( + not self.has_moe_layers + ), "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer) @@ -2312,10 +2520,12 @@ def sparse_allreduce(self, sparse, dp_group): if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / - dist.get_world_size(group=dp_group)) + values.mul_( + self.gradient_predivide_factor() + / dist.get_world_size(group=dp_group) + ) else: - values.mul_(1. / dist.get_world_size(group=dp_group)) + values.mul_(1.0 / dist.get_world_size(group=dp_group)) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) @@ -2342,8 +2552,7 @@ def sparse_all_gather(self, value, dp_group): if fill_size > 0: value = torch.cat([value, value.new_empty(fill_size, value.size()[1])]) tensor_list = [ - value.new_empty(max_size, - value.size()[1]) + value.new_empty(max_size, value.size()[1]) for _ in range(dist.get_world_size(group=dp_group)) ] @@ -2352,10 +2561,10 @@ def sparse_all_gather(self, value, dp_group): for dev_idx, t in enumerate(tensor_list): size = all_sizes[dev_idx][0] tensors.append( - t.index_select(0, - torch.arange(size, - dtype=torch.long, - device=self.device))) + t.index_select( + 0, torch.arange(size, dtype=torch.long, device=self.device) + ) + ) return tensors @@ -2372,36 +2581,46 @@ def module_state_dict(self, destination=None, prefix="", keep_vars=False): return sd @staticmethod - def load_moe_state_dict(checkpoint_path, - tag, - state_dict, - old_moe_load, - model=None, - mpu=None, - num_experts=1, - checkpoint_engine=TorchCheckpointEngine()): + def load_moe_state_dict( + checkpoint_path, + tag, + state_dict, + old_moe_load, + model=None, + mpu=None, + num_experts=1, + checkpoint_engine=TorchCheckpointEngine(), + ): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank( - groups._get_max_expert_size_name()) + groups._get_max_expert_size_name() + ) num_local_experts = max( - num_experts) // groups._get_expert_parallel_world_size( - groups._get_max_expert_size_name()) + num_experts + ) // groups._get_expert_parallel_world_size( + groups._get_max_expert_size_name() + ) for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( - checkpoint_path, - -1, # -1 means ignore layer_id - global_expert_id, - tag, - mpu), - map_location=torch.device('cpu')) + expert_state_dict = checkpoint_engine.load( + DeepSpeedEngine._get_expert_ckpt_name( + checkpoint_path, + -1, # -1 means ignore layer_id + global_expert_id, + tag, + mpu, + ), + map_location=torch.device("cpu"), + ) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." for key in list(expert_state_dict.keys()): - local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', - f'{moe_str_prefix}{local_expert_id}') + local_key = key.replace( + f"{moe_str_prefix}{global_expert_id}", + f"{moe_str_prefix}{local_expert_id}", + ) expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) @@ -2414,22 +2633,27 @@ def load_moe_state_dict(checkpoint_path, expp_rank = groups._get_expert_parallel_rank(group_name) # loop all local_experts for local_expert_id in range(num_local_experts): - global_expert_id = expp_rank * num_local_experts + local_expert_id + global_expert_id = ( + expp_rank * num_local_experts + local_expert_id + ) expert_state_dict = checkpoint_engine.load( DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, moe_layer_id, global_expert_id, tag, - mpu), - map_location=torch.device('cpu')) + mpu, + ), + map_location=torch.device("cpu"), + ) # print(expert_state_dict.keys()) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." for key in list(expert_state_dict.keys()): local_key = key.replace( - f'{moe_str_prefix}{global_expert_id}', - f'{moe_str_prefix}{local_expert_id}') + f"{moe_str_prefix}{global_expert_id}", + f"{moe_str_prefix}{local_expert_id}", + ) expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) moe_layer_id += 1 @@ -2438,18 +2662,14 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): if custom_load_fn: custom_load_fn(src=state_dict, dst=self.module) else: - self.module.load_state_dict(state_dict, # TODO - strict=strict) + self.module.load_state_dict(state_dict, strict=strict) # TODO def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' - def _get_rank_zero_ckpt_name(self, - checkpoints_path, - tag, - mp_rank, - dp_rank, - bf16_mode): + def _get_rank_zero_ckpt_name( + self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode + ): file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) zero_ckpt_name = os.path.join( checkpoints_path, @@ -2462,11 +2682,9 @@ def _get_zero_ckpt_name(self, checkpoints_path, tag): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() pp_rank = dist.get_rank(group=self.optimizer.dp_process_group) bf16_mode = self.bfloat16_enabled() - return self._get_rank_zero_ckpt_name(checkpoints_path, - tag, - mp_rank, - pp_rank, - bf16_mode) + return self._get_rank_zero_ckpt_name( + checkpoints_path, tag, mp_rank, pp_rank, bf16_mode + ) def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): if mp_placeholder is not None: @@ -2477,7 +2695,8 @@ def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): if self.zero_optimization_partition_weights(): filename = "zero_pp_rank_{}".format( - dist.get_rank(group=self.optimizer.dp_process_group)) + dist.get_rank(group=self.optimizer.dp_process_group) + ) ckpt_name = os.path.join( checkpoints_path, str(tag), @@ -2496,7 +2715,8 @@ def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank): ckpt_name = os.path.join( checkpoints_path, str(tag), - f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt') + f"expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt", + ) return ckpt_name @staticmethod @@ -2506,36 +2726,39 @@ def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None): # Used to support old checkpoint loading ckpt_name = os.path.join( checkpoints_path, - '' if tag is None else str(tag), - f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') + "" if tag is None else str(tag), + f"expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt", + ) else: # Used to support new checkpoint loading ckpt_name = os.path.join( checkpoints_path, - '' if tag is None else str(tag), - f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' + "" if tag is None else str(tag), + f"layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt", ) return ckpt_name def _get_all_ckpt_names(self, checkpoints_path, tag): # It is required that (checkpoints_path, tag) are consistent among all ranks. - ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, - tag, - mp_placeholder="*") + ckpt_file_pattern = self._get_ckpt_name( + checkpoints_path, tag, mp_placeholder="*" + ) import glob ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() return ckpt_files - def load_checkpoint(self, - load_dir, - tag=None, - load_module_strict=True, - load_optimizer_states=True, - load_lr_scheduler_states=True, - load_module_only=False, - custom_load_fn=None): + def load_checkpoint( + self, + load_dir, + tag=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + custom_load_fn=None, + ): """Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from @@ -2556,8 +2779,9 @@ def load_checkpoint(self, """ if tag is None: - latest_tag = "latest_universal" if self.load_universal_checkpoint( - ) else "latest" + latest_tag = ( + "latest_universal" if self.load_universal_checkpoint() else "latest" + ) latest_path = os.path.join(load_dir, latest_tag) if os.path.isfile(latest_path): with open(latest_path, "r") as fd: @@ -2565,7 +2789,7 @@ def load_checkpoint(self, else: if self.load_universal_checkpoint(): raise ValueError( - f'Invalid for universal checkpoint: {latest_path} does not exist' + f"Invalid for universal checkpoint: {latest_path} does not exist" ) else: logger.warning( @@ -2578,20 +2802,21 @@ def load_checkpoint(self, # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() - load_path, client_states = self._load_checkpoint(load_dir, - tag, - load_module_strict=load_module_strict, - load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, - load_module_only=load_module_only, - custom_load_fn=custom_load_fn) + load_path, client_states = self._load_checkpoint( + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn, + ) load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled() if load_zero_checkpoint and load_path is not None: success = self._load_zero_checkpoint( - load_dir, - tag, - load_optimizer_states=load_optimizer_states) + load_dir, tag, load_optimizer_states=load_optimizer_states + ) if not success: self.optimizer._restore_from_bit16_weights() @@ -2600,21 +2825,23 @@ def load_checkpoint(self, return load_path, client_states - def _load_checkpoint(self, - load_dir, - tag, - load_module_strict=True, - load_optimizer_states=True, - load_lr_scheduler_states=True, - load_module_only=False, - custom_load_fn=None): + def _load_checkpoint( + self, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + custom_load_fn=None, + ): from deepspeed.runtime.state_dict_factory import SDLoaderFactory ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader( - ckpt_list, - checkpoint_engine=self.checkpoint_engine) + ckpt_list, checkpoint_engine=self.checkpoint_engine + ) is_pipe_parallel = isinstance(self.module, PipelineModule) @@ -2633,55 +2860,66 @@ def _load_checkpoint(self, if self.has_moe_layers: # print(checkpoint.keys()) old_moe_load = False - if not isinstance(checkpoint['num_experts'], list): + if not isinstance(checkpoint["num_experts"], list): old_moe_load = True - DeepSpeedEngine.load_moe_state_dict(load_dir, - tag, - state_dict=checkpoint['module'], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu, - num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_engine) + DeepSpeedEngine.load_moe_state_dict( + load_dir, + tag, + state_dict=checkpoint["module"], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu, + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_engine, + ) if not self.load_universal_checkpoint(): - self.load_module_state_dict(state_dict=checkpoint['module'], - strict=load_module_strict, - custom_load_fn=custom_load_fn) + self.load_module_state_dict( + state_dict=checkpoint["module"], + strict=load_module_strict, + custom_load_fn=custom_load_fn, + ) - self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] + self.loaded_checkpoint_dp_world_size = checkpoint["dp_world_size"] if load_module_only: - deepspeed_states = ['module'] + deepspeed_states = ["module"] if self.optimizer is not None and self.fp16_enabled(): self.optimizer.refresh_fp32_params() else: if self.has_moe_layers: largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) - optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) + optim_load_path = self._get_optimizer_ckpt_name( + load_dir, tag, expp_rank + ) optim_checkpoint = self.checkpoint_engine.load( - optim_load_path, - map_location=torch.device('cpu')) + optim_load_path, map_location=torch.device("cpu") + ) else: optim_checkpoint = checkpoint - has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled( + has_zero_optimizer_state = ( + self.zero_optimization() or self.bfloat16_enabled() ) - if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: + if ( + load_optimizer_states + and self.optimizer is not None + and not has_zero_optimizer_state + ): if self.fp16_enabled(): self.optimizer.load_state_dict( - optim_checkpoint['optimizer'], - load_optimizer_states=load_optimizer_states) + optim_checkpoint["optimizer"], + load_optimizer_states=load_optimizer_states, + ) else: - self.optimizer.load_state_dict(optim_checkpoint['optimizer']) + self.optimizer.load_state_dict(optim_checkpoint["optimizer"]) if load_lr_scheduler_states and self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - def get_sparse_tensor_module_names(original_set, - loaded_set, - original_parameters, - loaded_parameters): + def get_sparse_tensor_module_names( + original_set, loaded_set, original_parameters, loaded_parameters + ): result = set() for name in original_set: @@ -2692,14 +2930,15 @@ def get_sparse_tensor_module_names(original_set, for name in loaded_set: if name in original_parameters: result.add( - name) # parameter exists in both configs and it was sparse + name + ) # parameter exists in both configs and it was sparse return result - if 'sparse_tensor_module_names' in checkpoint: - sparse_tensor_module_names = checkpoint['sparse_tensor_module_names'] - elif 'csr_tensor_module_names' in checkpoint: - sparse_tensor_module_names = checkpoint['csr_tensor_module_names'] + if "sparse_tensor_module_names" in checkpoint: + sparse_tensor_module_names = checkpoint["sparse_tensor_module_names"] + elif "csr_tensor_module_names" in checkpoint: + sparse_tensor_module_names = checkpoint["csr_tensor_module_names"] else: sparse_tensor_module_names = None if sparse_tensor_module_names is not None: @@ -2710,51 +2949,57 @@ def get_sparse_tensor_module_names(original_set, self.sparse_tensor_module_names, sparse_tensor_module_names, dict(self.module.named_parameters()), - checkpoint["module"]) + checkpoint["module"], + ) - self.global_steps = checkpoint['global_steps'] + self.global_steps = checkpoint["global_steps"] self.global_samples = checkpoint.get( - 'global_samples', - self.global_steps * self.train_batch_size()) - self.skipped_steps = checkpoint['skipped_steps'] - self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] + "global_samples", self.global_steps * self.train_batch_size() + ) + self.skipped_steps = checkpoint["skipped_steps"] + self.loaded_checkpoint_mp_world_size = checkpoint["mp_world_size"] deepspeed_states = [ - 'module', - 'sparse_tensor_module_names', - 'skipped_steps', - 'global_steps', - 'dp_world_size', - 'mp_world_size' + "module", + "sparse_tensor_module_names", + "skipped_steps", + "global_steps", + "dp_world_size", + "mp_world_size", ] client_state = {} if load_lr_scheduler_states: - deepspeed_states.append('lr_scheduler') + deepspeed_states.append("lr_scheduler") if load_optimizer_states: - deepspeed_states.append('optimizer') + deepspeed_states.append("optimizer") client_state = { key: value - for key, - value in checkpoint.items() if not key in deepspeed_states + for key, value in checkpoint.items() + if not key in deepspeed_states } if not load_optimizer_states and not load_module_only: - client_state['optimizer'] = optim_checkpoint['optimizer'] + client_state["optimizer"] = optim_checkpoint["optimizer"] return load_path, client_state def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if self.load_universal_checkpoint(): zero_sd_list = None - checkpoint_folder = f'{os.path.join(load_dir, tag)}' + checkpoint_folder = f"{os.path.join(load_dir, tag)}" else: - if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: - raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ - f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ - f"current world size is {self.dp_world_size}. Automatic adjustment " \ - "of ZeRO's optimizer state partitioning with a new world size is not " \ - "currently supported.") + if ( + load_optimizer_states + and self.dp_world_size != self.loaded_checkpoint_dp_world_size + ): + raise ZeRORuntimeException( + "The checkpoint being loaded used a DP " + f"world size of {self.loaded_checkpoint_dp_world_size} but the " + f"current world size is {self.dp_world_size}. Automatic adjustment " + "of ZeRO's optimizer state partitioning with a new world size is not " + "currently supported." + ) checkpoint_folder = None zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: @@ -2764,11 +3009,12 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder) + checkpoint_folder=checkpoint_folder, + ) if self.load_universal_checkpoint(): logger.info( - f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}' + f"loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}" ) else: logger.info( @@ -2776,19 +3022,18 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): ) return True - def _get_mp_rank_zero_checkpoint_names(self, - load_dir, - tag, - mp_rank, - dp_world_size, - bf16_mode): + def _get_mp_rank_zero_checkpoint_names( + self, load_dir, tag, mp_rank, dp_world_size, bf16_mode + ): zero_ckpt_names = [] for dp_rank in range(dp_world_size): - ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, - tag=tag, - mp_rank=mp_rank, - dp_rank=dp_rank, - bf16_mode=bf16_mode) + ckpt_name = self._get_rank_zero_ckpt_name( + checkpoints_path=load_dir, + tag=tag, + mp_rank=mp_rank, + dp_rank=dp_rank, + bf16_mode=bf16_mode, + ) zero_ckpt_names.append(ckpt_name) return zero_ckpt_names @@ -2800,14 +3045,16 @@ def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): tag=tag, mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size, - bf16_mode=bf16_mode) + bf16_mode=bf16_mode, + ) invalid_zero_ckpt_paths = [] for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): # transparently handle the old file pattern for optim_states if "optim_states.pt" in ckpt_name: - ckpt_name_try = ckpt_name.replace("_optim_states.pt", - "optim_states.pt") + ckpt_name_try = ckpt_name.replace( + "_optim_states.pt", "optim_states.pt" + ) if os.path.exists(ckpt_name_try): zero_ckpt_names[i] = ckpt_name_try continue @@ -2826,11 +3073,13 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): for i, ckpt_name in enumerate(zero_ckpt_names): _state = None # Fully load state for current rank - if self.zero_elastic_checkpoint() or dist.get_rank( - group=self.optimizer.dp_process_group) == i: + if ( + self.zero_elastic_checkpoint() + or dist.get_rank(group=self.optimizer.dp_process_group) == i + ): _state = self.checkpoint_engine.load( ckpt_name, - map_location='cpu', + map_location="cpu", ) else: _state = {OPTIMIZER_STATE_DICT: None} @@ -2845,16 +3094,15 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): def _get_all_zero_checkpoints(self, load_dir, tag): for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: zero_ckpt_names = self._get_all_zero_checkpoint_names( - load_dir, - tag, - bf16_mode) + load_dir, tag, bf16_mode + ) if zero_ckpt_names is not None: # Warn if loading checkpoint of different bit16 type if bf16_mode is not self.bfloat16_enabled(): checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 logger.warn( - f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine' + f"Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine" ) return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) @@ -2872,7 +3120,8 @@ def _checkpoint_tag_validation(self, tag): msg = ( f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " "all ranks. Including rank unique information in checkpoint tag could cause issues when " - "restoring with different world sizes.") + "restoring with different world sizes." + ) if self.checkpoint_tag_validation_fail(): assert valid, msg elif not valid: @@ -2931,7 +3180,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Save latest checkpoint tag self.checkpoint_engine.commit(tag) if save_latest and self.global_rank == 0: - with open(os.path.join(save_dir, 'latest'), 'w') as fd: + with open(os.path.join(save_dir, "latest"), "w") as fd: fd.write(tag) dist.barrier() @@ -2940,10 +3189,10 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) def _get_non_moe_state_dict(self, full_state_dict): """ - Get the state dict of the non-moe layers + Get the state dict of the non-moe layers """ for key in list(full_state_dict.keys()): - if 'expert' in key and 'moe.gate.wg.weight' not in key: + if "expert" in key and "moe.gate.wg.weight" not in key: full_state_dict.pop(key) return full_state_dict @@ -2970,9 +3219,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): # get all moe parameters moe_state_dict = {} for n, p in module.state_dict().items(): - if 'expert' in n and 'moe.gate.wg.weight' not in n: - moe_state_dict[n_module + '.' + n] = p - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + if "expert" in n and "moe.gate.wg.weight" not in n: + moe_state_dict[n_module + "." + n] = p + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines # Reorder the moe name rank, so that each checkpoint only has one expert experts_state_dict = defaultdict(dict) @@ -2981,26 +3230,27 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): local_expert_id = None if not m: - logger.warn(f'No expert found in key {key}.') + logger.warn(f"No expert found in key {key}.") else: local_expert_id = m.group(1) - global_expert_id = expp_rank * \ - num_local_experts + int(local_expert_id) - expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}', - f'{moe_str_prefix}{global_expert_id}') - experts_state_dict[str( - global_expert_id)][expert_key] = moe_state_dict.pop(key) + global_expert_id = expp_rank * num_local_experts + int( + local_expert_id + ) + expert_key = key.replace( + f"{moe_str_prefix}{local_expert_id}", + f"{moe_str_prefix}{global_expert_id}", + ) + experts_state_dict[str(global_expert_id)][ + expert_key + ] = moe_state_dict.pop(key) # let save the moe parameters for global_expert_id, expert_state_dict in experts_state_dict.items(): # save the moe parameters moe_save_path = self._get_expert_ckpt_name( - save_dir, - moe_layer_id, - global_expert_id, - tag, - self.mpu) + save_dir, moe_layer_id, global_expert_id, tag, self.mpu + ) self.checkpoint_engine.save(expert_state_dict, moe_save_path) moe_layer_id += 1 @@ -3018,9 +3268,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): # Save optimizer states. They are different across each exp parallel rank. optimizer_state = { - 'optimizer': - self.optimizer.state_dict() - if self.optimizer and not self.zero_optimization() else None + "optimizer": self.optimizer.state_dict() + if self.optimizer and not self.zero_optimization() + else None } # TODO: why use BufferedWriter not the path file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) @@ -3032,34 +3282,27 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): if expp_rank == 0: # TODO: update num experts info,.. in checkpoint state = { - 'module': - model_state_dict, - 'lr_scheduler': - self.lr_scheduler.state_dict() - if self.lr_scheduler is not None else None, - 'sparse_tensor_module_names': - self.sparse_tensor_module_names, - 'skipped_steps': - self.skipped_steps, - 'global_steps': - self.global_steps, - 'global_samples': - self.global_samples, - 'dp_world_size': - self.dp_world_size, - 'mp_world_size': - self.mp_world_size, - 'num_experts': - self.num_experts + "module": model_state_dict, + "lr_scheduler": self.lr_scheduler.state_dict() + if self.lr_scheduler is not None + else None, + "sparse_tensor_module_names": self.sparse_tensor_module_names, + "skipped_steps": self.skipped_steps, + "global_steps": self.global_steps, + "global_samples": self.global_samples, + "dp_world_size": self.dp_world_size, + "mp_world_size": self.mp_world_size, + "num_experts": self.num_experts, } state.update(client_state) - logger.info(f'Saving model checkpoint: {save_path}') + logger.info(f"Saving model checkpoint: {save_path}") self.checkpoint_engine.save(state, save_path) self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): - name_function = (self._get_zero_ckpt_name - if zero_checkpoint else self._get_ckpt_name) + name_function = ( + self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name + ) try: checkpoint_name = name_function(save_dir, tag) ensure_directory_exists(checkpoint_name) @@ -3088,25 +3331,30 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): # then instead just returns None. self._curr_ckpt_path = os.path.join(save_dir, tag) zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() - state = dict(module=self.module_state_dict(), - buffer_names=self._get_buffer_names(), - optimizer=self.optimizer.state_dict() - if self.optimizer and not zero_optimizer_state else None, - param_shapes=self._get_zero_param_shapes() - if self.optimizer and zero_optimizer_state else None, - lr_scheduler=self.lr_scheduler.state_dict() - if self.lr_scheduler is not None else None, - sparse_tensor_module_names=self.sparse_tensor_module_names, - skipped_steps=self.skipped_steps, - global_steps=self.global_steps, - global_samples=self.global_samples, - dp_world_size=self.dp_world_size, - mp_world_size=self.mp_world_size, - ds_config=self.config, - ds_version=version) + state = dict( + module=self.module_state_dict(), + buffer_names=self._get_buffer_names(), + optimizer=self.optimizer.state_dict() + if self.optimizer and not zero_optimizer_state + else None, + param_shapes=self._get_zero_param_shapes() + if self.optimizer and zero_optimizer_state + else None, + lr_scheduler=self.lr_scheduler.state_dict() + if self.lr_scheduler is not None + else None, + sparse_tensor_module_names=self.sparse_tensor_module_names, + skipped_steps=self.skipped_steps, + global_steps=self.global_steps, + global_samples=self.global_samples, + dp_world_size=self.dp_world_size, + mp_world_size=self.mp_world_size, + ds_config=self.config, + ds_version=version, + ) state.update(client_state) - log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) + log_dist(message=f"Saving model checkpoint: {save_path}", ranks=[0, 1]) self.checkpoint_engine.save(state, save_path) self._curr_save_path = None @@ -3152,8 +3400,11 @@ def _get_zero_param_shapes(self): elif self.bfloat16_enabled() and not self.zero_optimization(): bit16_groups = self.optimizer.bf16_groups else: - bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + bit16_groups = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) for bit16_group in bit16_groups: param_shapes = OrderedDict() @@ -3178,22 +3429,24 @@ def _copy_recovery_script(self, save_path): script = "zero_to_fp32.py" src = os.path.join(base_dir, "utils", script) dst = os.path.join(save_path, script) - #logger.info(f"creating recovery script {dst}") + # logger.info(f"creating recovery script {dst}") copyfile(src, dst) # make executable os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) - zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), - ds_config=self.config, - ds_version=version) + zero_sd = dict( + optimizer_state_dict=self.optimizer.state_dict(), + ds_config=self.config, + ds_version=version, + ) self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) - ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' - logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') + ckpt_type = "zero" if self.zero_optimization() else "bf16_zero" + logger.info(f"{ckpt_type} checkpoint saved {zero_checkpoint_name}") def _zero3_consolidated_16bit_state_dict(self): """ @@ -3216,10 +3469,10 @@ def _zero3_consolidated_16bit_state_dict(self): def get_layer_state_dict(module, prefix=""): # gather one layer at a time to be memory-efficient # must use modifier_rank=0 to release GPU memory after each layer gathered - #see_memory_usage("before GatheredParameters", force=True) - with deepspeed.zero.GatheredParameters(list( - module.parameters(recurse=False)), - modifier_rank=0): + # see_memory_usage("before GatheredParameters", force=True) + with deepspeed.zero.GatheredParameters( + list(module.parameters(recurse=False)), modifier_rank=0 + ): if dist.get_rank() == 0: # handle params for name, param in module.named_parameters(recurse=False): @@ -3231,19 +3484,21 @@ def get_layer_state_dict(module, prefix=""): # (and shared params will have the same param.ds_id) if param.ds_id in shared_params: # shared weights - #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") + # print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") state_dict[key] = state_dict[shared_params[param.ds_id]] else: state_dict[key] = param.detach().cpu() shared_params[param.ds_id] = key - #print(f"param {param.ds_id} {param.shape} {key} ") + # print(f"param {param.ds_id} {param.shape} {key} ") # now buffers - not sure if need to take care of potentially shared weights here for name, buf in module.named_buffers(recurse=False): - if (buf is not None - and name not in module._non_persistent_buffers_set): + if ( + buf is not None + and name not in module._non_persistent_buffers_set + ): state_dict[prefix + name] = buf.detach().cpu() - #see_memory_usage("after GatheredParameters", force=True) + # see_memory_usage("after GatheredParameters", force=True) for name, child in module.named_children(): if child is not None: diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 33edc2db1a6a..b20cad1f71e8 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -254,7 +254,7 @@ def _exec_reduce_grads(self): self._bf16_reduce_grads() else: assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported" - raise NotImplementedError() + self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) else: self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) self._force_grad_boundary = False diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c3511e2819f9..396fca35dc18 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -222,7 +222,7 @@ def __init__(self, assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32. Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}. Either disable fp16_master_weights_and_gradients or enable ZeRO-2 Offload with DeepSpeedCPUAdam" if self.reduce_scatter: - assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-2 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" + #assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-2 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" From ecb85a5e700e812c1be89cb4bef52996a3fb1494 Mon Sep 17 00:00:00 2001 From: Samuel Weinbach Date: Wed, 7 Sep 2022 13:04:50 +0000 Subject: [PATCH 02/13] Tied grad reduce for bfloat --- deepspeed/runtime/pipe/engine.py | 616 ++++++++++++++++++------------- 1 file changed, 353 insertions(+), 263 deletions(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index b20cad1f71e8..a98c2fb1ea70 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -34,11 +34,12 @@ def _tensor_bytes(tensor): class PipelineEngine(DeepSpeedEngine): - """ A training engine hybrid pipeline, data, and model parallel training. + """A training engine hybrid pipeline, data, and model parallel training. This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule` is provided. """ + ID_TO_DTYPE = [ torch.float32, torch.float64, @@ -51,7 +52,7 @@ class PipelineEngine(DeepSpeedEngine): torch.int16, torch.int32, torch.int64, - torch.bool + torch.bool, ] DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)} @@ -59,7 +60,9 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) assert isinstance(self.module, PipelineModule), "model must base PipelineModule" - assert self.zero_optimization_stage() < 2, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" + assert ( + self.zero_optimization_stage() < 2 + ), "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False @@ -72,8 +75,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): if self.elasticity_enabled(): if not self.is_elastic_model_parallel_supported(): - assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ - " with pipeline parallelism." + assert not self.elasticity_enabled(), ( + "Elasticity is not currently supported" + " with pipeline parallelism." + ) # pipeline step for logging self.log_batch_step_id = -1 @@ -84,14 +89,18 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # Set Grid and Communication Groups self.grid = self.module._grid if self.grid.get_global_rank() == 0: - logger.info(f'CONFIG: micro_batches={self.micro_batches} ' - f'micro_batch_size={self.micro_batch_size}') + logger.info( + f"CONFIG: micro_batches={self.micro_batches} " + f"micro_batch_size={self.micro_batch_size}" + ) self.global_rank = self.grid.get_global_rank() assert self.dp_world_size == self.grid.data_parallel_size - assert self.train_batch_size() == \ - self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size + assert ( + self.train_batch_size() + == self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size + ) # Set Stage Inf self.num_stages = self.grid.pipe_parallel_size @@ -104,12 +113,13 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self._force_grad_boundary = False - self.batch_timer = ThroughputTimer(batch_size=self.micro_batch_size * - self.micro_batches, - num_workers=self.dp_world_size, - logging_fn=self.tput_log, - monitor_memory=False, - steps_per_output=self.steps_per_print()) + self.batch_timer = ThroughputTimer( + batch_size=self.micro_batch_size * self.micro_batches, + num_workers=self.dp_world_size, + logging_fn=self.tput_log, + monitor_memory=False, + steps_per_output=self.steps_per_print(), + ) # PipelineEngine needs to handle data loading specially due to only the first # and last stages loading inputs/labels. We construct a sampler that uses @@ -132,35 +142,38 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): if self.module.tied_comms: tied_params = 0 for key, d in self.module.tied_comms.items(): - if self.global_rank != min(d['ranks']): - tied_params += sum(p.numel() for p in d['module'].parameters()) + if self.global_rank != min(d["ranks"]): + tied_params += sum(p.numel() for p in d["module"].parameters()) unique_params -= tied_params - params_tensor = torch.LongTensor(data=[num_params, - unique_params]).to(self.device) + params_tensor = torch.LongTensor(data=[num_params, unique_params]).to( + self.device + ) dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group()) params_tensor = params_tensor.tolist() total_params = params_tensor[0] unique_params = params_tensor[1] if self.grid.data_parallel_id == 0: - logger.info(f'RANK={self.global_rank} ' - f'STAGE={self.stage_id} ' - f'LAYERS={self.module._local_stop - self.module._local_start} ' - f'[{self.module._local_start}, {self.module._local_stop}) ' - f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) ' - f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) ' - f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)') - - #initialize peer-2-peer communication and allreduce groups + logger.info( + f"RANK={self.global_rank} " + f"STAGE={self.stage_id} " + f"LAYERS={self.module._local_stop - self.module._local_start} " + f"[{self.module._local_start}, {self.module._local_stop}) " + f"STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) " + f"TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) " + f"UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)" + ) + + # initialize peer-2-peer communication and allreduce groups if self.is_pipe_parallel: p2p.init_process_groups(self.grid) # Pipeline buffers self.num_pipe_buffers = 0 self.pipe_buffers = { - 'inputs' : [], # batch input and received activations - 'labels' : [], # labels from batch input - 'outputs' : [], # activations - 'output_tensors' : [], # tensor object to preserve backward graph + "inputs": [], # batch input and received activations + "labels": [], # labels from batch input + "outputs": [], # activations + "output_tensors": [], # tensor object to preserve backward graph } self.pipe_recv_buf = None self.grad_layer = None @@ -170,22 +183,23 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.first_output_send = True self.first_gradient_send = True - #stores the loss for the current micro batch being processed + # stores the loss for the current micro batch being processed self.loss = torch.tensor(0.0).to(self.device) - #stores the loss for the entire batch + # stores the loss for the entire batch self.total_loss = None self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device) self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device) - if self._config.pipeline['activation_checkpoint_interval'] > 0: + if self._config.pipeline["activation_checkpoint_interval"] > 0: self.module.activation_checkpoint_interval = self._config.pipeline[ - 'activation_checkpoint_interval'] + "activation_checkpoint_interval" + ] if self.is_last_stage(): self.loss_model = self.module.loss_fn - self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe' + self.has_attention_mask = self.module.__class__.__name__ == "GPT2ModelPipe" # Initialize pipeline communicators. Just send a 0. if is_even(self.stage_id): if not self.is_last_stage(): @@ -201,18 +215,18 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # XXX look into timer reporting timing # Initialize some timers because of early weirdness. if self.wall_clock_breakdown(): - self.timers('forward_microstep').start() - self.timers('forward_microstep').stop() - self.timers('backward_microstep').start() - self.timers('backward_microstep').stop() - self.timers('backward_inner_microstep').start() - self.timers('backward_inner_microstep').stop() - self.timers('backward_allreduce_microstep').start() - self.timers('backward_allreduce_microstep').stop() - self.timers('backward_allreduce').start() - self.timers('backward_allreduce').stop() - self.timers('step_microstep').start() - self.timers('step_microstep').stop() + self.timers("forward_microstep").start() + self.timers("forward_microstep").stop() + self.timers("backward_microstep").start() + self.timers("backward_microstep").stop() + self.timers("backward_inner_microstep").start() + self.timers("backward_inner_microstep").stop() + self.timers("backward_allreduce_microstep").start() + self.timers("backward_allreduce_microstep").stop() + self.timers("backward_allreduce").start() + self.timers("backward_allreduce").stop() + self.timers("step_microstep").start() + self.timers("step_microstep").stop() def set_has_attention_mask(self, value): assert isinstance(value, bool) @@ -223,7 +237,8 @@ def _build_data_iter(self, dataset): dataset, num_replicas=self.dp_world_size, rank=self.mpu.get_data_parallel_rank(), - shuffle=False) + shuffle=False, + ) # Build a loader and make it repeating. pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler) pipe_dataloader = RepeatingLoader(pipe_dataloader) @@ -243,9 +258,20 @@ def _exec_reduce_tied_grads(self): weight_group_list = self.module.get_tied_weights_and_groups() for weight, group in weight_group_list: - grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad + if self.bfloat16_enabled() and self.zero_optimization(): + grad = weight.grad + else: + grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad + + is_bfloat = False + if grad.dtype == torch.bfloat16: + is_bfloat = True + grad = grad.to(torch.float32) dist.all_reduce(grad, group=group) + if is_bfloat: + grad = grad.to(torch.bfloat16) + def _exec_reduce_grads(self): self._force_grad_boundary = True if self.pipeline_enable_backward_allreduce: @@ -253,7 +279,9 @@ def _exec_reduce_grads(self): if self.zero_optimization_stage() == 0: self._bf16_reduce_grads() else: - assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported" + assert ( + self.zero_optimization_stage() == 1 + ), "only bf16 + z1 are supported" self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) else: self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE) @@ -262,8 +290,10 @@ def _exec_reduce_grads(self): def _bf16_reduce_grads(self): # Make our own list of gradients from the optimizer's FP32 grads grads = [] - self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(), - elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + self.buffered_allreduce_fallback( + grads=self.optimizer.get_grads_for_reduction(), + elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE, + ) def _reserve_pipe_buffers(self, num_buffers): """Ensure that each pipeline buffer has at least ``num_buffers`` slots. @@ -317,17 +347,20 @@ def train_batch(self, data_iter=None): """ if not torch._C.is_grad_enabled(): raise RuntimeError( - f'train_batch() requires gradients enabled. Use eval_batch() instead.') + f"train_batch() requires gradients enabled. Use eval_batch() instead." + ) # Curriculum learning could change activation shape if self.curriculum_enabled(): - new_difficulty = self.curriculum_scheduler.update_difficulty( \ - self.global_steps + 1) + new_difficulty = self.curriculum_scheduler.update_difficulty( + self.global_steps + 1 + ) if self.global_steps == 0 or self.curriculum_scheduler.first_step: self.reset_activation_shape() self.curriculum_scheduler.first_step = False - elif new_difficulty != self.curriculum_scheduler.get_difficulty( \ - self.global_steps): + elif new_difficulty != self.curriculum_scheduler.get_difficulty( + self.global_steps + ): self.reset_activation_shape() if data_iter: @@ -338,49 +371,59 @@ def train_batch(self, data_iter=None): self._compute_loss = True # Do the work - self.timers('train_batch').start() - sched = schedule.TrainSchedule(micro_batches=self.micro_batches, - stages=self.num_stages, - stage_id=self.stage_id) + self.timers("train_batch").start() + sched = schedule.TrainSchedule( + micro_batches=self.micro_batches, + stages=self.num_stages, + stage_id=self.stage_id, + ) self._exec_schedule(sched) self.agg_train_loss = self._aggregate_total_loss() - self.timers('train_batch').stop() + self.timers("train_batch").stop() if self.global_steps % self.steps_per_print() == 0: if self.global_rank == 0: - elapsed = self.timers('train_batch').elapsed(reset=True) / 1000.0 + elapsed = self.timers("train_batch").elapsed(reset=True) / 1000.0 iter_time = elapsed / self.steps_per_print() tput = self.train_batch_size() / iter_time - print(f'steps: {self.global_steps} ' - f'loss: {self.agg_train_loss:0.4f} ' - f'iter time (s): {iter_time:0.3f} ' - f'samples/sec: {tput:0.3f}') + print( + f"steps: {self.global_steps} " + f"loss: {self.agg_train_loss:0.4f} " + f"iter time (s): {iter_time:0.3f} " + f"samples/sec: {tput:0.3f}" + ) # Monitoring if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/train_loss', - self.agg_train_loss.mean().item(), - self.global_samples)] + self.summary_events = [ + ( + f"Train/Samples/train_loss", + self.agg_train_loss.mean().item(), + self.global_samples, + ) + ] self.monitor.write_events(self.summary_events) - if self.wall_clock_breakdown( - ) and self.global_steps % self.steps_per_print() == 0: - self.timers.log([ - 'pipe_send_output', - 'pipe_send_grad', - 'pipe_recv_input', - 'pipe_recv_grad' - ]) + if ( + self.wall_clock_breakdown() + and self.global_steps % self.steps_per_print() == 0 + ): + self.timers.log( + [ + "pipe_send_output", + "pipe_send_grad", + "pipe_recv_input", + "pipe_recv_grad", + ] + ) # TODO: should return precisely what loss returned and allow others to be queried? return self.agg_train_loss - def eval_batch(self, - data_iter, - return_logits=False, - compute_loss=True, - reduce_output='avg'): + def eval_batch( + self, data_iter, return_logits=False, compute_loss=True, reduce_output="avg" + ): """Evaluate the pipeline on a batch of data from ``data_iter``. The engine will evaluate ``self.train_batch_size()`` total samples collectively across all workers. @@ -412,13 +455,15 @@ def eval_batch(self, # Curriculum learning could change activation shape if self.curriculum_enabled(): - new_difficulty = self.curriculum_scheduler.update_difficulty( \ - self.global_steps + 1) + new_difficulty = self.curriculum_scheduler.update_difficulty( + self.global_steps + 1 + ) if self.global_steps == 0 or self.curriculum_scheduler.first_step: self.reset_activation_shape() self.curriculum_scheduler.first_step = False - elif new_difficulty != self.curriculum_scheduler.get_difficulty( \ - self.global_steps): + elif new_difficulty != self.curriculum_scheduler.get_difficulty( + self.global_steps + ): self.reset_activation_shape() eval_output = None @@ -430,9 +475,11 @@ def eval_batch(self, self.set_dataiterator(data_iter) # Do the work - sched = schedule.InferenceSchedule(micro_batches=self.micro_batches, - stages=self.num_stages, - stage_id=self.stage_id) + sched = schedule.InferenceSchedule( + micro_batches=self.micro_batches, + stages=self.num_stages, + stage_id=self.stage_id, + ) # prevent dead-lock with multiple evals sequence dist.barrier() @@ -447,16 +494,20 @@ def eval_batch(self, eval_output = self._bcast_pipe_scalar(eval_output) if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/eval_loss', - eval_output.mean().item(), - self.global_samples)] + self.summary_events = [ + ( + f"Train/Samples/eval_loss", + eval_output.mean().item(), + self.global_samples, + ) + ] self.monitor.write_events(self.summary_events) # Restore the training iterator self.set_dataiterator(train_iterator) # Reset any buffers that may have been populated during the forward passes. - #ds_checkpointing.reset() + # ds_checkpointing.reset() self.eval_return_logits = False if return_logits: outputs = self.outputs @@ -485,11 +536,11 @@ def is_last_stage(self): """True if this process is in the last stage in the pipeline.""" return self.stage_id == self.num_stages - 1 - def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): + def _reduce_outputs(self, outputs, reduce="avg", reduce_dp=True): if reduce is None: return outputs - if reduce.lower() == 'avg': + if reduce.lower() == "avg": # first sum over all microbatches if torch.is_tensor(outputs[0]): reduced = sum(outputs) @@ -509,13 +560,14 @@ def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): reduced /= self.dp_world_size else: for idx in range(len(reduced)): - dist.all_reduce(reduced[idx], - group=self.mpu.get_data_parallel_group()) + dist.all_reduce( + reduced[idx], group=self.mpu.get_data_parallel_group() + ) reduced[idx] /= self.dp_world_size return reduced else: - raise NotImplementedError(f'reduction type {reduce} not supported.') + raise NotImplementedError(f"reduction type {reduce} not supported.") def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): # Default to last stage (e.g., for broadcasting loss) @@ -526,11 +578,11 @@ def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): if self.global_rank == src_rank: result = data.clone().detach() else: - result = torch.Tensor([0.]).type(dtype).to(self.device) + result = torch.Tensor([0.0]).type(dtype).to(self.device) - dist.broadcast(tensor=result, - src=src_rank, - group=self.mpu.get_pipe_parallel_group()) + dist.broadcast( + tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group() + ) return result @@ -542,25 +594,27 @@ def _aggregate_total_loss(self): ## Average loss across all data-parallel groups agg_loss = self.dp_group_loss.clone().detach() - #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True) + # print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True) if self.is_data_parallel: dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) agg_loss /= self.dp_world_size assert self.global_rank in self.grid.pp_group losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) - dist.broadcast(tensor=losses, - src=self.global_rank, - group=self.mpu.get_pipe_parallel_group()) + dist.broadcast( + tensor=losses, + src=self.global_rank, + group=self.mpu.get_pipe_parallel_group(), + ) else: # Get loss from last stage src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group - losses = torch.Tensor([0., 0.]).to(self.device) - dist.broadcast(tensor=losses, - src=src_rank, - group=self.grid.get_pipe_parallel_group()) + losses = torch.Tensor([0.0, 0.0]).to(self.device) + dist.broadcast( + tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group() + ) self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() @@ -573,7 +627,7 @@ def set_dataloader(self, loader): self.data_iterator = iter(self.training_dataloader) def set_dataiterator(self, iterator): - """ Store an iterator to sample for training data. """ + """Store an iterator to sample for training data.""" if self.is_first_stage() or self.is_last_stage(): self.training_dataloader = None self.data_iterator = iterator @@ -601,14 +655,15 @@ def log_for_device(self, *msg): if LOG_STAGE == self.stage_id or LOG_STAGE == -1: if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1: print( - f'RANK={dist.get_rank()} ' - f'PIPE-ID={self.stage_id} ' - f'DATA-ID={self.grid.data_parallel_id} ' - f'MBATCH-ID={self.microbatch_id} ' - f'STEP-ID={self.log_batch_step_id} ' - '::', + f"RANK={dist.get_rank()} " + f"PIPE-ID={self.stage_id} " + f"DATA-ID={self.grid.data_parallel_id} " + f"MBATCH-ID={self.microbatch_id} " + f"STEP-ID={self.log_batch_step_id} " + "::", *msg, - flush=True) + flush=True, + ) def tput_log(self, *msg): if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0: @@ -628,27 +683,28 @@ def _next_batch(self): def _exec_forward_pass(self, buffer_id): self.tput_timer.start() - self.mem_status('BEFORE FWD', reset_max=True) + self.mem_status("BEFORE FWD", reset_max=True) - if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): - inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) + if isinstance(self.pipe_buffers["inputs"][buffer_id], tuple): + inputs = tuple(t.clone() for t in self.pipe_buffers["inputs"][buffer_id]) else: - inputs = self.pipe_buffers['inputs'][buffer_id].clone() + inputs = self.pipe_buffers["inputs"][buffer_id].clone() # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): part_input = PartitionedTensor.from_meta( meta=inputs[0], local_part=inputs[1], - group=self.grid.get_slice_parallel_group()) + group=self.grid.get_slice_parallel_group(), + ) inputs = (part_input.full(), *inputs[2:]) inputs[0].requires_grad = True # skip mask - #inputs[1].requires_grad = True + # inputs[1].requires_grad = True part_input = None inputs = inputs[0] if len(inputs) == 1 else inputs - self.pipe_buffers['inputs'][buffer_id] = inputs + self.pipe_buffers["inputs"][buffer_id] = inputs # Zero out the gradients each time we use the tensor because only the data in # tensor changes across batches @@ -661,31 +717,34 @@ def _exec_forward_pass(self, buffer_id): if isinstance(outputs, tuple): first_output = outputs[0] # TODO: Improve pipe partitioning to pass multiple tensors that require grads - assert all([ - torch.is_tensor(elt) and elt.requires_grad is False - for elt in outputs[1:] - ]) + assert all( + [ + torch.is_tensor(elt) and elt.requires_grad is False + for elt in outputs[1:] + ] + ) outputs_tail = outputs[1:] elif torch.is_tensor(outputs): first_output = outputs outputs_tail = [] else: raise ValueError("expecting a tensor or a tuple of tensors") - part = PartitionedTensor(tensor=first_output, - group=self.grid.get_slice_parallel_group()) + part = PartitionedTensor( + tensor=first_output, group=self.grid.get_slice_parallel_group() + ) # Clear the large output data, but save the computation graph first_output.data = torch.zeros(1) - self.pipe_buffers['output_tensors'][buffer_id] = first_output + self.pipe_buffers["output_tensors"][buffer_id] = first_output # Inject the partitioned tensor into the output before sending outputs = (part.to_meta(), part.data(), *outputs_tail) part = None - self.pipe_buffers['outputs'][buffer_id] = outputs + self.pipe_buffers["outputs"][buffer_id] = outputs # Optionally compute loss on the last device if self.is_last_stage(): if self._compute_loss and self.module.loss_fn is not None: - labels = self.pipe_buffers['labels'][buffer_id] + labels = self.pipe_buffers["labels"][buffer_id] self.loss = self.module.loss_fn(outputs, labels) else: # Some models just return loss from forward() @@ -707,25 +766,26 @@ def _exec_forward_pass(self, buffer_id): self.total_loss[idx] += l.detach() def _exec_backward_pass(self, buffer_id): - assert self.optimizer is not None, "must provide optimizer during " \ - "init in order to use backward" + assert self.optimizer is not None, ( + "must provide optimizer during " "init in order to use backward" + ) - self.mem_status('BEFORE BWD', reset_max=True) + self.mem_status("BEFORE BWD", reset_max=True) # The last stage just runs backward on the loss using DeepSpeed's typical # mechanisms. if self.is_last_stage(): super().backward(self.loss) - self.mem_status('AFTER BWD') + self.mem_status("AFTER BWD") return - outputs = self.pipe_buffers['outputs'][buffer_id] + outputs = self.pipe_buffers["outputs"][buffer_id] if self.wall_clock_breakdown(): - self.timers('backward_microstep').start() - self.timers('backward').start() - self.timers('backward_inner_microstep').start() - self.timers('backward_inner').start() + self.timers("backward_microstep").start() + self.timers("backward").start() + self.timers("backward_inner_microstep").start() + self.timers("backward_inner").start() # Reconstruct if we previously partitioned the output. We must be # careful to also restore the computational graph of the tensors we partitioned. @@ -734,26 +794,28 @@ def _exec_backward_pass(self, buffer_id): part_output = PartitionedTensor.from_meta( meta=outputs[0], local_part=outputs[1], - group=self.grid.get_slice_parallel_group()) - self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() - outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:]) + group=self.grid.get_slice_parallel_group(), + ) + self.pipe_buffers["output_tensors"][buffer_id].data = part_output.full() + outputs = (self.pipe_buffers["output_tensors"][buffer_id], *outputs[2:]) else: # Already restored from partition - self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0] - outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:]) + self.pipe_buffers["output_tensors"][buffer_id].data = outputs[0] + outputs = (self.pipe_buffers["output_tensors"][buffer_id], *outputs[1:]) grad_tensors = self.grad_layer if self.is_grad_partitioned: - #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + # print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') part_grad = PartitionedTensor.from_meta( meta=self.grad_layer[0], local_part=self.grad_layer[1], - group=self.grid.get_slice_parallel_group()) + group=self.grid.get_slice_parallel_group(), + ) grad_tensors = (part_grad.full(), *grad_tensors[2:]) part_grad = None - #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + # print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - if self.bfloat16_enabled() and not self.is_last_stage(): + if not self.zero_optimization() and self.bfloat16_enabled() and not self.is_last_stage(): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() @@ -763,28 +825,32 @@ def _exec_backward_pass(self, buffer_id): assert len(out_tensors) == len(grad_tensors) torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) else: - torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + torch.autograd.backward(tensors=(outputs,), grad_tensors=(grad_tensors,)) - if self.bfloat16_enabled() and not self.is_last_stage(): + if ( + not self.zero_optimization() + and self.bfloat16_enabled() + and not self.is_last_stage() + ): # manually call because we don't call optimizer.backward() self.optimizer.update_hp_grads(clear_lp_grads=False) # Free up the memory from the output of forward() - self.pipe_buffers['output_tensors'][buffer_id] = None - self.pipe_buffers['outputs'][buffer_id] = None + self.pipe_buffers["output_tensors"][buffer_id] = None + self.pipe_buffers["outputs"][buffer_id] = None grad_tensors = None if self.wall_clock_breakdown(): - self.timers('backward_inner').stop() - self.timers('backward_inner_microstep').stop() - self.timers('backward').stop() - self.timers('backward_microstep').stop() + self.timers("backward_inner").stop() + self.timers("backward_inner_microstep").stop() + self.timers("backward").stop() + self.timers("backward_microstep").stop() - self.mem_status('AFTER BWD') + self.mem_status("AFTER BWD") def _exec_load_micro_batch(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('batch_input').start() + self.timers("batch_input").start() batch = self._next_batch() @@ -804,7 +870,7 @@ def _exec_load_micro_batch(self, buffer_id): loaded.append(mine) loaded = tuple(loaded) - self.pipe_buffers['inputs'][buffer_id] = loaded + self.pipe_buffers["inputs"][buffer_id] = loaded if self.is_last_stage(): loaded = batch[1] @@ -818,13 +884,13 @@ def _exec_load_micro_batch(self, buffer_id): loaded.append(x) loaded = tuple(loaded) - self.pipe_buffers['labels'][buffer_id] = loaded + self.pipe_buffers["labels"][buffer_id] = loaded if self.wall_clock_breakdown(): - self.timers('batch_input').stop() + self.timers("batch_input").stop() def _send_tensor_meta(self, buffer, recv_stage): - """ Communicate metadata about upcoming p2p transfers. + """Communicate metadata about upcoming p2p transfers. Metadata is communicated in this order: * type (0: tensor, 1: list) @@ -843,7 +909,7 @@ def _send_tensor_meta(self, buffer, recv_stage): p2p.send(send_shape, recv_stage) send_bytes += _tensor_bytes(buffer) elif isinstance(buffer, list): - assert (False) + assert False type_tensor = torch.LongTensor(data=[1]).to(self.device) p2p.send(type_tensor, recv_stage) count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) @@ -865,12 +931,13 @@ def _send_tensor_meta(self, buffer, recv_stage): send_shape = torch.LongTensor(data=tensor.size()).to(self.device) send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to( - self.device) + self.device + ) p2p.send(send_dtype, recv_stage) p2p.send(send_ndims, recv_stage) p2p.send(send_shape, recv_stage) # Useful for performance debugging. - ''' + """ new_bytes = _tensor_bytes(tensor) send_bytes += _tensor_bytes(tensor) # Useful for performance debugging. @@ -878,15 +945,15 @@ def _send_tensor_meta(self, buffer, recv_stage): print( f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB' ) - ''' + """ else: - raise NotImplementedError(f'Could not send meta type {type(buffer)}') + raise NotImplementedError(f"Could not send meta type {type(buffer)}") # Useful for performance debugging. - ''' + """ if self.grid.data_parallel_id == 0: print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB') - ''' + """ def _recv_tensor_meta(self, send_stage): """Receive metadata about upcoming p2p transfers and return allocated buffers. @@ -940,13 +1007,13 @@ def _recv_tensor_meta(self, send_stage): return buffers else: - raise NotImplementedError(f'Could not receive type {type(recv_type)}') + raise NotImplementedError(f"Could not receive type {type(recv_type)}") def _exec_send_activations(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_send_output').start() + self.timers("pipe_send_output").start() - outputs = self.pipe_buffers['outputs'][buffer_id] + outputs = self.pipe_buffers["outputs"][buffer_id] # NCCL does not like to send torch.BoolTensor types, so cast the mask to half(). # We could do char, but with half() we can eventually flatten with other fp16 @@ -966,8 +1033,9 @@ def _exec_send_activations(self, buffer_id): for idx, buffer in enumerate(outputs): p2p.send(buffer, self.next_stage) else: - raise NotImplementedError('Could not send output of type ' - f'{type(outputs)}') + raise NotImplementedError( + "Could not send output of type " f"{type(outputs)}" + ) # Restore the boolean tensor if self.has_attention_mask or self.has_bool_tensors: @@ -976,13 +1044,13 @@ def _exec_send_activations(self, buffer_id): outputs = tuple(outputs) if self.wall_clock_breakdown(): - self.timers('pipe_send_output').stop() + self.timers("pipe_send_output").stop() def _exec_send_grads(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_send_grad').start() + self.timers("pipe_send_grad").start() - inputs = self.pipe_buffers['inputs'][buffer_id] + inputs = self.pipe_buffers["inputs"][buffer_id] # Partition the gradient if self.is_grad_partitioned: @@ -998,8 +1066,9 @@ def _exec_send_grads(self, buffer_id): else: raise ValueError("expecting a tensor or a tuple of tensors") assert torch.is_tensor(first_input) - part = PartitionedTensor(tensor=first_input.grad, - group=self.grid.get_slice_parallel_group()) + part = PartitionedTensor( + tensor=first_input.grad, group=self.grid.get_slice_parallel_group() + ) inputs = (part.to_meta(), part.data(), *inputs_grad_tail) @@ -1032,14 +1101,14 @@ def _exec_send_grads(self, buffer_id): p2p.send(buffer.grad, self.prev_stage) # We can free up the input buffer now - self.pipe_buffers['inputs'][buffer_id] = None + self.pipe_buffers["inputs"][buffer_id] = None if self.wall_clock_breakdown(): - self.timers('pipe_send_grad').stop() + self.timers("pipe_send_grad").stop() def _exec_recv_activations(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_recv_input').start() + self.timers("pipe_recv_input").start() recvd = None @@ -1059,9 +1128,9 @@ def _exec_recv_activations(self, buffer_id): # XXX hardcode meta type if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long: if self.meta_buffer is None: - self.meta_buffer = torch.zeros(buffer.size(), - dtype=torch.long, - device=self.device) + self.meta_buffer = torch.zeros( + buffer.size(), dtype=torch.long, device=self.device + ) buffer = self.meta_buffer p2p.recv(buffer, self.prev_stage) @@ -1077,35 +1146,36 @@ def _exec_recv_activations(self, buffer_id): for buffer in recvd: buffer.requires_grad = buffer.is_floating_point() - self.pipe_buffers['inputs'][buffer_id] = recvd + self.pipe_buffers["inputs"][buffer_id] = recvd if self.wall_clock_breakdown(): - self.timers('pipe_recv_input').stop() + self.timers("pipe_recv_input").stop() def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): - self.timers('pipe_recv_grad').start() + self.timers("pipe_recv_grad").start() - outputs = self.pipe_buffers['outputs'][buffer_id] + outputs = self.pipe_buffers["outputs"][buffer_id] # XXX these shapes are hardcoded for Megatron # Restore partitioned output if it was partitioned and we are sending full gradients if self.is_pipe_partitioned and not self.is_grad_partitioned: part_output = PartitionedTensor.from_meta( meta=outputs[0], local_part=outputs[1], - group=self.grid.get_slice_parallel_group()) + group=self.grid.get_slice_parallel_group(), + ) outputs[0].data = part_output.full() outputs = (outputs[0], *outputs[2:]) # save for backward - self.pipe_buffers['outputs'][buffer_id] = outputs + self.pipe_buffers["outputs"][buffer_id] = outputs # Allocate gradient if necessary if self.grad_layer is None: if isinstance(outputs, torch.Tensor): s = list(outputs.size()) - self.grad_layer = self._allocate_buffer(s, - dtype=outputs.dtype, - num_buffers=1)[0] + self.grad_layer = self._allocate_buffer( + s, dtype=outputs.dtype, num_buffers=1 + )[0] else: # XXX This is a HACK # When we exchange activations/gradients, the two pipe stages @@ -1123,16 +1193,21 @@ def _exec_recv_grads(self, buffer_id): # metadata tensor. if self.is_grad_partitioned: sizes_and_dtypes = [ - (list(t.size()), - t.dtype) for t in outputs[:2] - ] + [(list(t.size()), - t.dtype) for t in outputs[2:] if t.is_floating_point()] + (list(t.size()), t.dtype) for t in outputs[:2] + ] + [ + (list(t.size()), t.dtype) + for t in outputs[2:] + if t.is_floating_point() + ] else: - sizes_and_dtypes = [(list(t.size()), - t.dtype) for t in outputs - if t.is_floating_point()] - self.grad_layer = self._allocate_buffers(sizes_and_dtypes, - num_buffers=1)[0] + sizes_and_dtypes = [ + (list(t.size()), t.dtype) + for t in outputs + if t.is_floating_point() + ] + self.grad_layer = self._allocate_buffers( + sizes_and_dtypes, num_buffers=1 + )[0] if isinstance(self.grad_layer, torch.Tensor): p2p.recv(self.grad_layer, self.next_stage) @@ -1141,57 +1216,65 @@ def _exec_recv_grads(self, buffer_id): for idx, buffer in enumerate(self.grad_layer): # XXX GPT-2 hack if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long: - buffer.data = torch.zeros(buffer.size(), - dtype=torch.long, - device=self.device) + buffer.data = torch.zeros( + buffer.size(), dtype=torch.long, device=self.device + ) p2p.recv(buffer, self.next_stage) if self.wall_clock_breakdown(): - self.timers('pipe_recv_grad').stop() + self.timers("pipe_recv_grad").stop() def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): - self.timers('step_microstep').start() - self.timers('step').start() - self.mem_status('BEFORE STEP', reset_max=True) + self.timers("step_microstep").start() + self.timers("step").start() + self.mem_status("BEFORE STEP", reset_max=True) self._force_grad_boundary = True self._take_model_step(lr_kwargs) self._force_grad_boundary = False - self.mem_status('AFTER STEP') + self.mem_status("AFTER STEP") if self.global_rank == 0 and self.monitor.enabled: - self.summary_events = [(f'Train/Samples/lr', - self.get_lr()[0], - self.global_samples)] - if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'): - self.summary_events.append((f'Train/Samples/loss_scale', - self.optimizer.cur_scale, - self.global_samples)) + self.summary_events = [ + (f"Train/Samples/lr", self.get_lr()[0], self.global_samples) + ] + if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): + self.summary_events.append( + ( + f"Train/Samples/loss_scale", + self.optimizer.cur_scale, + self.global_samples, + ) + ) self.monitor.write_events(self.summary_events) if self.wall_clock_breakdown(): - self.timers('step_microstep').stop() - self.timers('step').stop() + self.timers("step_microstep").stop() + self.timers("step").stop() if self.global_steps % self.steps_per_print() == 0: - self.timers.log([ - 'batch_input', - 'forward_microstep', - 'backward_microstep', - 'backward_inner_microstep', - 'backward_allreduce_microstep', - 'backward_tied_allreduce_microstep', - 'step_microstep' - ]) + self.timers.log( + [ + "batch_input", + "forward_microstep", + "backward_microstep", + "backward_inner_microstep", + "backward_allreduce_microstep", + "backward_tied_allreduce_microstep", + "step_microstep", + ] + ) if self.global_steps % self.steps_per_print() == 0: - self.timers.log([ - 'forward', - 'backward', - 'backward_inner', - 'backward_allreduce', - 'step' - ]) + self.timers.log( + [ + "forward", + "backward", + "backward_inner", + "backward_allreduce", + "step", + ] + ) def _zero_grads(self, inputs): if isinstance(inputs, torch.Tensor): @@ -1203,7 +1286,7 @@ def _zero_grads(self, inputs): t.grad.data.zero_() def _allocate_zeros(self, shape, **kwargs): - """ Allocate a tensor of zeros on the engine's device. + """Allocate a tensor of zeros on the engine's device. Arguments: shape: the shape of the tensor to allocate @@ -1236,29 +1319,30 @@ def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers= buffer = [] for shape, dtype in shapes_and_dtypes: buffer.append( - self._allocate_zeros(shape, - dtype=dtype, - requires_grad=requires_grad)) + self._allocate_zeros( + shape, dtype=dtype, requires_grad=requires_grad + ) + ) buffers.append(buffer) return buffers def forward(self, *args, **kwargs): - """Disabled for pipeline parallel training. See ``train_batch()``. """ + """Disabled for pipeline parallel training. See ``train_batch()``.""" raise PipelineError("Only train_batch() is accessible in pipeline mode.") def backward(self, *args, **kwargs): - """Disabled for pipeline parallel training. See ``train_batch()``. """ + """Disabled for pipeline parallel training. See ``train_batch()``.""" raise PipelineError("Only train_batch() is accessible in pipeline mode.") def step(self, *args, **kwargs): - """Disabled for pipeline parallel training. See ``train_batch()``. """ + """Disabled for pipeline parallel training. See ``train_batch()``.""" raise PipelineError("Only train_batch() is accessible in pipeline mode.") def mem_status(self, msg, print_rank=-1, reset_max=False): return global mem_alloced, mem_cached if not self.global_steps == 0 or not self.global_steps == 9: - #return + # return pass if self.mpu.get_data_parallel_rank() != 0: return @@ -1297,10 +1381,10 @@ def mem_status(self, msg, print_rank=-1, reset_max=False): max_cached /= 1024**3 print( - f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', + f"RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS", msg, - f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) ' - f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)' + f"current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) " + f"current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)", ) def module_state_dict(self): @@ -1314,11 +1398,13 @@ def module_state_dict(self): None """ assert isinstance(self.module, PipelineModule) - assert self._curr_ckpt_path is not None, \ - "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" + assert ( + self._curr_ckpt_path is not None + ), "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" - self.module.save_state_dict(self._curr_ckpt_path, - checkpoint_engine=self.checkpoint_engine) + self.module.save_state_dict( + self._curr_ckpt_path, checkpoint_engine=self.checkpoint_engine + ) return None def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): @@ -1332,14 +1418,18 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): state_dict (str, None): unused strict (bool, optional): Strict state loading. Defaults to True. """ - assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism" + assert ( + custom_load_fn is None + ), "custom_load_fn not supported w. pipeline parallelism" if (state_dict is not None) and (not isinstance(state_dict, str)): super().load_module_state_dict(state_dict, strict) return - self.module.load_state_dir(load_dir=self._curr_ckpt_path, - strict=strict, - checkpoint_engine=self.checkpoint_engine) + self.module.load_state_dir( + load_dir=self._curr_ckpt_path, + strict=strict, + checkpoint_engine=self.checkpoint_engine, + ) # A map of PipeInstruction types to methods. Each method will be executed with the # kwargs provided to the PipeInstruction from the scheduler. @@ -1367,7 +1457,7 @@ def _exec_schedule(self, pipe_schedule): for cmd in step_cmds: if type(cmd) not in self._INSTRUCTION_MAP: raise RuntimeError( - f'{self.__class__.__name__} does not understand instruction {repr(cmd)}' + f"{self.__class__.__name__} does not understand instruction {repr(cmd)}" ) # Equivalent to: self._exec_forward_pass(buffer_id=0) From 55bf9f846058f8bf0d135e872782e5f4b2ed09c6 Mon Sep 17 00:00:00 2001 From: rrutmann <97447451+rrutmann@users.noreply.github.com> Date: Tue, 20 Sep 2022 11:32:53 +0200 Subject: [PATCH 03/13] Correct op_builder import --- deepspeed/env_report.py | 2 +- deepspeed/ops/adagrad/cpu_adagrad.py | 2 +- deepspeed/ops/adam/cpu_adam.py | 2 +- deepspeed/ops/adam/fused_adam.py | 2 +- deepspeed/ops/aio/__init__.py | 2 +- deepspeed/ops/lamb/fused_lamb.py | 2 +- deepspeed/ops/quantizer/quantizer.py | 2 +- deepspeed/ops/transformer/inference/moe_inference.py | 2 +- deepspeed/ops/transformer/inference/transformer_inference.py | 2 +- deepspeed/ops/transformer/transformer.py | 2 +- deepspeed/runtime/bf16_optimizer.py | 2 +- deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/zero/stage3.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py index c5949a8c4d31..69c795e07256 100644 --- a/deepspeed/env_report.py +++ b/deepspeed/env_report.py @@ -2,7 +2,7 @@ import deepspeed import subprocess import argparse -from .ops.op_builder import ALL_OPS +from op_builder import ALL_OPS from .git_version_info import installed_ops, torch_info GREEN = '\033[92m' diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index 2527259b1382..6ed3b256c8c3 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -3,7 +3,7 @@ ''' import torch -from ..op_builder import CPUAdagradBuilder +from op_builder import CPUAdagradBuilder from deepspeed.utils.logging import should_log_le diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 911e4924dfbc..ad95a6e995c0 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -4,7 +4,7 @@ import torch from cpuinfo import get_cpu_info -from ..op_builder import CPUAdamBuilder +from op_builder import CPUAdamBuilder from deepspeed.utils import logger from deepspeed.utils.logging import should_log_le diff --git a/deepspeed/ops/adam/fused_adam.py b/deepspeed/ops/adam/fused_adam.py index 5a1a1ddcaed3..315b4aeb4dca 100644 --- a/deepspeed/ops/adam/fused_adam.py +++ b/deepspeed/ops/adam/fused_adam.py @@ -9,7 +9,7 @@ from .multi_tensor_apply import MultiTensorApply multi_tensor_applier = MultiTensorApply(2048 * 32) -from ..op_builder import FusedAdamBuilder +from op_builder import FusedAdamBuilder class FusedAdam(torch.optim.Optimizer): diff --git a/deepspeed/ops/aio/__init__.py b/deepspeed/ops/aio/__init__.py index d25f815739aa..755930443679 100755 --- a/deepspeed/ops/aio/__init__.py +++ b/deepspeed/ops/aio/__init__.py @@ -3,4 +3,4 @@ Licensed under the MIT license. ''' -from ..op_builder import AsyncIOBuilder +from op_builder import AsyncIOBuilder diff --git a/deepspeed/ops/lamb/fused_lamb.py b/deepspeed/ops/lamb/fused_lamb.py index e9210cdda9bc..ba175de4eb4b 100644 --- a/deepspeed/ops/lamb/fused_lamb.py +++ b/deepspeed/ops/lamb/fused_lamb.py @@ -6,7 +6,7 @@ ''' import types import torch -from ..op_builder import FusedLambBuilder +from op_builder import FusedLambBuilder class FusedLamb(torch.optim.Optimizer): diff --git a/deepspeed/ops/quantizer/quantizer.py b/deepspeed/ops/quantizer/quantizer.py index 6b25d02d87e7..c5c21179c76b 100755 --- a/deepspeed/ops/quantizer/quantizer.py +++ b/deepspeed/ops/quantizer/quantizer.py @@ -3,7 +3,7 @@ ''' import torch -from ..op_builder import QuantizerBuilder +from op_builder import QuantizerBuilder # Cuda modules will be imported if needed quantizer_cuda_module = None diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index ca4b5b7a9702..373f95bd08ec 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -5,7 +5,7 @@ import math import torch from torch.autograd import Function -from ... import op_builder +import op_builder #from ...inference.engine import inference_cuda_module, specialized_mode # Cuda modules will be imported if needed inference_cuda_module = None diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index fa28a34f04a2..071e98a8df2f 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -5,7 +5,7 @@ import math import torch from torch.autograd import Function -from ... import op_builder +import op_builder import torch.nn as nn from deepspeed import comm as dist from deepspeed.utils.logging import log_dist diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index 7963d11774e3..fc6cabb2c997 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -7,7 +7,7 @@ from torch import nn from torch.autograd import Function -from ..op_builder import TransformerBuilder, StochasticTransformerBuilder +from op_builder import TransformerBuilder, StochasticTransformerBuilder # Cuda modules will be imported if needed transformer_cuda_module = None diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 40b5b769bad1..f984afc8efb0 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -7,7 +7,7 @@ import os from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED -from deepspeed.ops.op_builder import UtilsBuilder +from op_builder import UtilsBuilder from deepspeed.runtime import ZeROOptimizer from packaging import version as pkg_version diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 56b9a91524a5..449e909cfa39 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -92,7 +92,7 @@ from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status -from ..ops.op_builder import UtilsBuilder +from op_builder import UtilsBuilder from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 580b72ee119f..4c01c6c9b960 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -20,7 +20,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder +from op_builder import UtilsBuilder from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 396fca35dc18..f38024c9e298 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -19,7 +19,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder +from op_builder import UtilsBuilder from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version From f58fc1bf2d6778413396ce4b0221b58c2ebc22ef Mon Sep 17 00:00:00 2001 From: rrutmann <97447451+rrutmann@users.noreply.github.com> Date: Tue, 20 Sep 2022 17:36:54 +0200 Subject: [PATCH 04/13] Correct import from git_version_info --- op_builder/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 4fc1d40eff53..228471380c64 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -458,7 +458,7 @@ def builder(self): extra_link_args=self.strip_empty_entries(self.extra_ldflags())) def load(self, verbose=True): - from ...git_version_info import installed_ops, torch_info + from deepspeed.git_version_info import installed_ops, torch_info if installed_ops[self.name]: # Ensure the op we're about to load was compiled with the same # torch/cuda versions we are currently using at runtime. From e794448b0854c14e55d097d5961e99a8ea082a62 Mon Sep 17 00:00:00 2001 From: mali Date: Mon, 24 Oct 2022 11:05:31 +0200 Subject: [PATCH 05/13] Test communication type --- tests/unit/runtime/test_bf16.py | 48 ++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index d75b92441117..fb4285ca8465 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -4,7 +4,7 @@ from deepspeed.ops.adam import FusedAdam from tests.unit.common import DistributedTest from deepspeed.ops.op_builder import CPUAdamBuilder -from tests.unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader +from tests.unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, LinearStackPipe from tests.unit.util import bf16_required_version_check from deepspeed import comm as dist @@ -29,6 +29,7 @@ def test(self, zero_stage=2, use_cpu_offload=False): "lr": 0.00015 } }, + "train_micro_batch_size_per_gpu": 1, "scheduler": { "type": "OneCycle", "params": { @@ -58,8 +59,8 @@ def test(self, zero_stage=2, use_cpu_offload=False): hidden_dim = 10 model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, - model=model, - model_parameters=model.parameters()) + model=model, + model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, @@ -353,3 +354,44 @@ def custom_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False) model.backward(loss) model.step() dist.reduce = orig_torch_reduce + + +class TestBF16Training(DistributedTest): + def set_up(self): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True + }, + "zero_optimization": { + "stage": 0 + }, + "communication_data_type": "fp32" + } + + input_dim = 1 + hidden_dim = 10 + output_dim = 10 + num_layers = 4 + num_stages = 2 + + model = LinearStackPipe( + input_dim=input_dim, + hidden_dim=hidden_dim, + output_dim=output_dim, + num_layers=num_layers, + num_stages=num_stages, + ) + optimizer = torch.optim.Adam(model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer) + data_loader = random_dataloader(model=model, + total_samples=2, + hidden_dim=hidden_dim, + device=model.device, + dtype="fp32") + + def test_communication_data_type(self): + self.set_up() From 1466e72aa91fb1700fcb8cd888fef4817ccdaaf2 Mon Sep 17 00:00:00 2001 From: mali Date: Tue, 25 Oct 2022 10:19:35 +0200 Subject: [PATCH 06/13] Test parameters' type & communication type --- tests/unit/runtime/test_bf16.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index fb4285ca8465..0a1a81bd7bd5 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -357,6 +357,8 @@ def custom_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False) class TestBF16Training(DistributedTest): + + #TODO: Use @pytest.fixture def set_up(self): config_dict = { "train_batch_size": 2, @@ -376,22 +378,29 @@ def set_up(self): num_layers = 4 num_stages = 2 - model = LinearStackPipe( + pipe_model = LinearStackPipe( input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers, num_stages=num_stages, ) - optimizer = torch.optim.Adam(model.parameters()) - model, _, _, _ = deepspeed.initialize(config=config_dict, - model=model, - optimizer=optimizer) - data_loader = random_dataloader(model=model, - total_samples=2, - hidden_dim=hidden_dim, - device=model.device, - dtype="fp32") + optimizer = torch.optim.Adam(pipe_model.parameters()) + deepspeed_model, _, _, _ = deepspeed.initialize( + config=config_dict, + model=pipe_model, + optimizer=optimizer, + ) + + self.model = deepspeed_model + + def test_parameter_type(self): + self.set_up() + params = list(self.model.parameters()) + + for p in params: + assert (p.dtype == torch.bfloat16) def test_communication_data_type(self): self.set_up() + assert (self.model.communication_data_type == torch.float32) From 5ee97ef1c73337c3d22adf93f9529eed0516a14b Mon Sep 17 00:00:00 2001 From: mali Date: Tue, 25 Oct 2022 10:53:40 +0200 Subject: [PATCH 07/13] Test Zero stage 0 and 1 --- tests/unit/runtime/test_bf16.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index 0a1a81bd7bd5..9734f1639e96 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -358,8 +358,8 @@ def custom_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False) class TestBF16Training(DistributedTest): - #TODO: Use @pytest.fixture - def set_up(self): + # TODO: Use @pytest.fixture + def set_up(self, zero_stage: int): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -367,7 +367,7 @@ def set_up(self): "enabled": True }, "zero_optimization": { - "stage": 0 + "stage": zero_stage, }, "communication_data_type": "fp32" } @@ -394,13 +394,21 @@ def set_up(self): self.model = deepspeed_model - def test_parameter_type(self): - self.set_up() + def _check_params(self): params = list(self.model.parameters()) for p in params: assert (p.dtype == torch.bfloat16) + def test_parameter_type(self): + self.set_up(zero_stage=0) + self._check_params() + self.set_up(zero_stage=1) + self._check_params() + def test_communication_data_type(self): - self.set_up() + self.set_up(zero_stage=0) + assert (self.model.communication_data_type == torch.float32) + + self.set_up(zero_stage=1) assert (self.model.communication_data_type == torch.float32) From 86d5b4267034096ff91026bd4c2015981f6e192a Mon Sep 17 00:00:00 2001 From: rrutmann <97447451+rrutmann@users.noreply.github.com> Date: Wed, 26 Oct 2022 11:20:37 +0200 Subject: [PATCH 08/13] Test _exec_reduce_tied_grads --- tests/unit/runtime/test_bf16.py | 90 ++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index 0a1a81bd7bd5..47b3d4cb2ed0 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -1,12 +1,14 @@ +from types import MethodType import torch import deepspeed import pytest from deepspeed.ops.adam import FusedAdam from tests.unit.common import DistributedTest -from deepspeed.ops.op_builder import CPUAdamBuilder -from tests.unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, LinearStackPipe +from op_builder import CPUAdamBuilder +from tests.unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, LinearStackPipe, TiedLinearStackPipe from tests.unit.util import bf16_required_version_check from deepspeed import comm as dist +from deepspeed.runtime.pipe import schedule class TestAdamBF16ZeroOneCycleCompatibility(DistributedTest): @@ -394,6 +396,48 @@ def set_up(self): self.model = deepspeed_model + def set_up_tied_model(self): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True + }, + "zero_optimization": { + "stage": 1 + }, + "communication_data_type": "fp32" + } + + input_dim = 10 + hidden_dim = 10 + output_dim = 10 + num_layers = 4 + num_stages = 2 + + pipe_model = TiedLinearStackPipe( + input_dim=input_dim, + hidden_dim=hidden_dim, + output_dim=output_dim, + num_layers=num_layers, + num_stages=num_stages, + ) + + optimizer = torch.optim.Adam(pipe_model.parameters()) + deepspeed_model, _, _, _ = deepspeed.initialize( + config=config_dict, + model=pipe_model, + optimizer=optimizer, + ) + self.data_loader = random_dataloader(model=deepspeed_model, + total_samples=2, + hidden_dim=hidden_dim, + device=deepspeed_model.device, + dtype=torch.bfloat16) + + self.tied_model = deepspeed_model + self.tied_model.set_dataloader(self.data_loader) + def test_parameter_type(self): self.set_up() params = list(self.model.parameters()) @@ -404,3 +448,45 @@ def test_parameter_type(self): def test_communication_data_type(self): self.set_up() assert (self.model.communication_data_type == torch.float32) + + def test__exec_reduce_tied_grads(self): + self.set_up_tied_model() + for n, batch in enumerate(self.data_loader): + self.tied_model.module.train() + self.tied_model.total_loss = None + self.tied_model._compute_loss = True + + # Do the work + self.tied_model.timers("train_batch").start() + + sched = schedule.TrainSchedule( + micro_batches=self.tied_model.micro_batches, + stages=self.tied_model.num_stages, + stage_id=self.tied_model.stage_id, + ) + # Reserve and reset buffers. + self.tied_model._reserve_pipe_buffers(sched.num_pipe_buffers()) + self.tied_model.fwd_outputs = [] + for step_cmds in sched: + # For each instruction in the step + for cmd in step_cmds: + if type(cmd) not in self.tied_model._INSTRUCTION_MAP: + raise RuntimeError( + f"{self.__class__.__name__} does not understand instruction {repr(cmd)}" + ) + + # Equivalent to: self._exec_forward_pass(buffer_id=0) + self.tied_model._exec_instr = MethodType(self.tied_model._INSTRUCTION_MAP[type(cmd)], self.tied_model) + if type(cmd) == schedule.ReduceTiedGrads: + # check the gradient data types before and after executing ReduceTiedGrads + # during the execution it is not possible to access the gradients + weight_group_list = self.tied_model.module.get_tied_weights_and_groups() + for weight, group in weight_group_list: + assert weight.grad.dtype == torch.bfloat16 + self.tied_model._exec_instr(**cmd.kwargs) + weight_group_list = self.tied_model.module.get_tied_weights_and_groups() + for weight, group in weight_group_list: + assert weight.grad.dtype == torch.bfloat16 + else: + self.tied_model._exec_instr(**cmd.kwargs) + break From c6ba9ae5ac5184626a778ee82f0992eeea7d1c78 Mon Sep 17 00:00:00 2001 From: rrutmann <97447451+rrutmann@users.noreply.github.com> Date: Wed, 26 Oct 2022 11:21:36 +0200 Subject: [PATCH 09/13] Add TiedLinearStackPipe --- tests/unit/simple_model.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index f94ee288d4b2..391ecb525f3b 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -3,7 +3,7 @@ import argparse import torch -from deepspeed.pipe import PipelineModule, LayerSpec +from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec from deepspeed.moe.layer import MoE import deepspeed.comm as dist @@ -151,6 +151,36 @@ def __init__(self, super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs) +class TiedLinearStackPipe(PipelineModule): + def __init__(self, + input_dim=128, + hidden_dim=128, + output_dim=128, + num_layers=4, + **kwargs): + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + layers = [] + + layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim)) + layers.append( + TiedLayerSpec("tied", torch.nn.Linear, self.hidden_dim, self.hidden_dim, tied_weight_attr='weight')) + for x in range(self.num_layers): + layers.append( + LayerSpec(torch.nn.Linear, + self.hidden_dim, + self.hidden_dim, + bias=False)) + layers.append(lambda x: x) + layers.append(TiedLayerSpec("tied", torch.nn.Linear, self.hidden_dim, self.hidden_dim, tied_weight_attr='weight')) + layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim)) + + super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs) + + class SimpleOptimizer(torch.optim.Optimizer): def __init__(self, params, lr=0.11072018): defaults = dict(lr=lr) From 12c4315d529114fb5ec4bc253930b9cc246cfd71 Mon Sep 17 00:00:00 2001 From: rrutmann <97447451+rrutmann@users.noreply.github.com> Date: Tue, 8 Nov 2022 14:42:52 +0100 Subject: [PATCH 10/13] Initial draft of test__exec_backward_pass --- tests/unit/runtime/test_bf16.py | 50 ++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index f3303bf15778..85b03d15244d 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -396,7 +396,7 @@ def set_up(self, zero_stage: int): self.model = deepspeed_model - def set_up_tied_model(self): + def set_up_tied_model(self, zero_stage: int): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -404,7 +404,7 @@ def set_up_tied_model(self): "enabled": True }, "zero_optimization": { - "stage": 1 + "stage": zero_stage }, "communication_data_type": "fp32" } @@ -438,7 +438,6 @@ def set_up_tied_model(self): self.tied_model = deepspeed_model self.tied_model.set_dataloader(self.data_loader) - def _check_params(self): params = list(self.model.parameters()) @@ -459,7 +458,8 @@ def test_communication_data_type(self): assert (self.model.communication_data_type == torch.float32) def test__exec_reduce_tied_grads(self): - self.set_up_tied_model() + # self.set_up_tied_model(1) + self.set_up_tied_model(1) for n, batch in enumerate(self.data_loader): self.tied_model.module.train() self.tied_model.total_loss = None @@ -499,3 +499,45 @@ def test__exec_reduce_tied_grads(self): else: self.tied_model._exec_instr(**cmd.kwargs) break + + def test__exec_backward_pass(self): + self.set_up_tied_model(0) + for n, batch in enumerate(self.data_loader): + self.tied_model.module.train() + self.tied_model.total_loss = None + self.tied_model._compute_loss = True + + # Do the work + self.tied_model.timers("train_batch").start() + + sched = schedule.TrainSchedule( + micro_batches=self.tied_model.micro_batches, + stages=self.tied_model.num_stages, + stage_id=self.tied_model.stage_id, + ) + # Reserve and reset buffers. + self.tied_model._reserve_pipe_buffers(sched.num_pipe_buffers()) + self.tied_model.fwd_outputs = [] + for step_cmds in sched: + # For each instruction in the step + for cmd in step_cmds: + if type(cmd) not in self.tied_model._INSTRUCTION_MAP: + raise RuntimeError( + f"{self.__class__.__name__} does not understand instruction {repr(cmd)}" + ) + + # Equivalent to: self._exec_forward_pass(buffer_id=0) + self.tied_model._exec_instr = MethodType(self.tied_model._INSTRUCTION_MAP[type(cmd)], + self.tied_model) + if type(cmd) == schedule.BackwardPass: + # check the gradient data types before and after executing ReduceTiedGrads + # during the execution it is not possible to access the gradients + self.tied_model._exec_instr(**cmd.kwargs) + if not self.tied_model.is_last_stage(): + for group in self.tied_model.optimizer.bf16_groups: + for param in group: + assert param.grad is None + print() + else: + self.tied_model._exec_instr(**cmd.kwargs) + break From d5d8ea16ca20ed6d87cc8251bc29eb5aea495927 Mon Sep 17 00:00:00 2001 From: mali Date: Tue, 8 Nov 2022 15:15:05 +0100 Subject: [PATCH 11/13] Adapt assertion --- tests/unit/runtime/test_bf16.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/unit/runtime/test_bf16.py b/tests/unit/runtime/test_bf16.py index 85b03d15244d..1f92e1005d06 100644 --- a/tests/unit/runtime/test_bf16.py +++ b/tests/unit/runtime/test_bf16.py @@ -396,16 +396,17 @@ def set_up(self, zero_stage: int): self.model = deepspeed_model - def set_up_tied_model(self, zero_stage: int): + def set_up_tied_model(self, zero_stage: int, enable_bf16: bool): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "bf16": { - "enabled": True + "enabled": enable_bf16, }, "zero_optimization": { "stage": zero_stage }, + # TODO: check where defined "communication_data_type": "fp32" } @@ -458,8 +459,7 @@ def test_communication_data_type(self): assert (self.model.communication_data_type == torch.float32) def test__exec_reduce_tied_grads(self): - # self.set_up_tied_model(1) - self.set_up_tied_model(1) + self.set_up_tied_model(zero_stage=1, enable_bf16=1) for n, batch in enumerate(self.data_loader): self.tied_model.module.train() self.tied_model.total_loss = None @@ -500,8 +500,7 @@ def test__exec_reduce_tied_grads(self): self.tied_model._exec_instr(**cmd.kwargs) break - def test__exec_backward_pass(self): - self.set_up_tied_model(0) + def _execute_instructions(self): for n, batch in enumerate(self.data_loader): self.tied_model.module.train() self.tied_model.total_loss = None @@ -532,12 +531,18 @@ def test__exec_backward_pass(self): if type(cmd) == schedule.BackwardPass: # check the gradient data types before and after executing ReduceTiedGrads # during the execution it is not possible to access the gradients - self.tied_model._exec_instr(**cmd.kwargs) - if not self.tied_model.is_last_stage(): - for group in self.tied_model.optimizer.bf16_groups: - for param in group: - assert param.grad is None - print() + try: + self.tied_model._exec_instr(**cmd.kwargs) + except Exception as e: + assert False, f"'exec_backward_pass' raised an exception {e}" + return else: self.tied_model._exec_instr(**cmd.kwargs) - break + + def test__exec_backward_pass(self): + self.set_up_tied_model(zero_stage=1, enable_bf16=1) + self._execute_instructions() + + self.set_up_tied_model(zero_stage=0, enable_bf16=1) + self._execute_instructions() + From 2d343ff2ff00526076952772343d509bfa3acf85 Mon Sep 17 00:00:00 2001 From: mali Date: Tue, 29 Nov 2022 10:37:41 +0100 Subject: [PATCH 12/13] Add comments --- deepspeed/runtime/pipe/engine.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index a98c2fb1ea70..d58eedb4d987 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -264,6 +264,7 @@ def _exec_reduce_tied_grads(self): grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad is_bfloat = False + # Make sure that we perform the computations in float32 when using bf16 if grad.dtype == torch.bfloat16: is_bfloat = True grad = grad.to(torch.float32) @@ -815,6 +816,7 @@ def _exec_backward_pass(self, buffer_id): part_grad = None # print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') + # When using ZeRO stage 1, we employ the DeepSpeedZeroOptimizer optimizer which doesn't provide clear_lp_grads() if not self.zero_optimization() and self.bfloat16_enabled() and not self.is_last_stage(): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() @@ -828,6 +830,8 @@ def _exec_backward_pass(self, buffer_id): torch.autograd.backward(tensors=(outputs,), grad_tensors=(grad_tensors,)) if ( + # When using ZeRO stage 1 is used, we employ the DeepSpeedZeroOptimizer which has not the functions + # update_hp_grads() not self.zero_optimization() and self.bfloat16_enabled() and not self.is_last_stage() @@ -851,12 +855,14 @@ def _exec_backward_pass(self, buffer_id): def _exec_load_micro_batch(self, buffer_id): if self.wall_clock_breakdown(): self.timers("batch_input").start() - + # OPENGPT-X: (input, target), where input is a tuple + # (tokens, position_ids, attention_mask), (labels, loss_mask) batch = self._next_batch() if self.is_first_stage(): loaded = None if torch.is_tensor(batch[0]): + # OPENGPT-X: why cloning? loaded = batch[0].clone().to(self.device).detach() loaded.requires_grad = loaded.is_floating_point() else: From 256ca0bd099184aaa252fd155e29e9a05841d9a1 Mon Sep 17 00:00:00 2001 From: mali Date: Tue, 29 Nov 2022 10:38:32 +0100 Subject: [PATCH 13/13] Add comment --- deepspeed/runtime/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 449e909cfa39..0e38949f9c55 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -818,6 +818,7 @@ def communication_data_type(self): elif self.fp16_enabled(): return torch.float16 elif self.bfloat16_enabled(): + # Communicate in torch.float32 since bf16 communication is not available for NCCL-version < 2.10.3 return torch.float32 return torch.float32