From 24fe7002eb3f80bbaa52591a312a4e9fb74fa0ec Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 21 Jan 2022 06:02:07 +0500 Subject: [PATCH 01/31] unit test, remove exception, add notes --- deepspeed/runtime/engine.py | 7 ----- deepspeed/runtime/zero/stage_1_and_2.py | 10 +++++++ tests/unit/test_checkpointing.py | 37 +++++++++++++++++++++---- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6271ee88ced0..1b79f4b0f6e4 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2594,13 +2594,6 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if zero_sd_list is None: return False - if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: - raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ - f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ - f"current world size is {self.dp_world_size}. Automatic adjustment " \ - "of ZeRO's optimizer state partitioning with a new world size is not " \ - "currently supported.") - self.optimizer.load_state_dict( state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 9d9f051e728c..5004c3a4693e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2170,6 +2170,16 @@ def load_state_dict(self, ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) + # padding is always at the last rank/partition + # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank + # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus + # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus + # if load_optimizer_states: + # if new_dp_size: + # self.strip_padding() + # self.add_padding_w_new_dp_size() + # self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + if load_optimizer_states: if ckpt_is_rigid: # loading rigid ckpt into either rigid or elastic exec diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 28c2099e60b1..024e955e63fc 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -1187,6 +1187,10 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -1194,6 +1198,12 @@ def _go(): model=models[1], model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + + if load_optim: + saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + curr_sd = model.optimizer.optimizer.state_dict() + assert curr_sd['param_groups'] == saved_sd['param_groups'] + data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, @@ -1249,6 +1259,11 @@ def _go2(models): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) _go2(models) @@ -1257,12 +1272,22 @@ def _go2(models): def _go1(models): ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load model, _, _, _ = deepspeed.initialize(config=ds_config, - model=models[1], - model_parameters=models[1].parameters()) + model=models[1], + model_parameters=models[1].parameters()) + model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + if load_optim: - with pytest.raises(deepspeed.runtime.zero.utils.ZeRORuntimeException): - model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) - else: - model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + curr_sd = model.optimizer.optimizer.state_dict() + assert curr_sd['param_groups'] == saved_sd['param_groups'] + + data_loader = random_dataloader(model=model, + total_samples=8, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() _go1(models) From aafa4e572c488bf56f79d1b54b880b3f60f9ac0a Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 29 Jan 2022 00:26:28 +0000 Subject: [PATCH 02/31] Move param_shapes to model files --- deepspeed/checkpoint/__init__.py | 0 deepspeed/checkpoint/constants.py | 22 +++++++++++ deepspeed/runtime/constants.py | 6 --- deepspeed/runtime/engine.py | 8 ++-- deepspeed/runtime/fp16/fused_optimizer.py | 2 +- deepspeed/runtime/fp16/unfused_optimizer.py | 2 +- deepspeed/runtime/zero/constants.py | 3 -- deepspeed/runtime/zero/stage3.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 4 +- deepspeed/utils/zero_to_fp32.py | 43 +++++++++++++-------- 10 files changed, 58 insertions(+), 34 deletions(-) create mode 100644 deepspeed/checkpoint/__init__.py create mode 100644 deepspeed/checkpoint/constants.py diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py new file mode 100644 index 000000000000..2bd7ac087e1a --- /dev/null +++ b/deepspeed/checkpoint/constants.py @@ -0,0 +1,22 @@ +''' + Various symbolic constants used for model checkpointing +''' + +######################################### +# Optimizer checkpoint keys +######################################### +OPTIMIZER_STATE_DICT = "optimizer_state_dict" +FP32_GROUPS = "fp32_groups" +FP32_FLAT_GROUPS = 'fp32_flat_groups' + +BASE_OPTIMIZER_STATE = 'base_optimizer_state' +SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" + +PARTITION_COUNT = 'partition_count' +ZERO_STAGE = 'zero_stage' + +######################################### +# Module checkpoint keys +######################################### +PARAM_SHAPES = 'param_shapes' +BUFFER_NAMES = 'buffer_names' diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 84a8325bd8c6..2d16f39433c3 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -446,9 +446,3 @@ class ValidationMode: ''' DATALOADER_DROP_LAST = "dataloader_drop_last" DATALOADER_DROP_LAST_DEFAULT = False - -######################################### -# Optimizer checkpoint keys -######################################### -OPTIMIZER_STATE_DICT = "optimizer_state_dict" -FP32_GROUPS = "fp32_groups" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 620adfc94d35..5c1f9b9e16fb 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -38,10 +38,10 @@ from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - PLD_THETA, PLD_GAMMA, OPTIMIZER_STATE_DICT + PLD_THETA, PLD_GAMMA from deepspeed.runtime.zero.constants import \ - ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS, \ - SINGLE_PARTITION_OF_FP32_GROUPS + ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.runtime.sparse_tensor import SparseTensor import deepspeed.runtime.lr_schedules as lr_schedules @@ -2921,6 +2921,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None, + param_shapes=self._get_zero_param_shapes(), lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, sparse_tensor_module_names=self.sparse_tensor_module_names, @@ -3014,7 +3015,6 @@ def _copy_recovery_script(self, save_path): def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), - param_shapes=self._get_zero_param_shapes(), ds_config=self.config, ds_version=version) torch.save(zero_sd, zero_checkpoint_name) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 6158faaee024..86ffc5ab92c0 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,7 +11,7 @@ from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import groups, logger, log_dist -from deepspeed.runtime.constants import OPTIMIZER_STATE_DICT +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT import torch.distributed as dist diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index 1837fb10b034..56b880feb0ad 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -12,7 +12,7 @@ from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger -from deepspeed.runtime.constants import OPTIMIZER_STATE_DICT +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT class FP16_UnfusedOptimizer(object): diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 7f7c10e9d4af..af5c5f195398 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -48,9 +48,6 @@ ZERO_OPTIMIZATION_STAGE_2 = 'stage_2' ZERO_OPTIMIZATION_STAGE_3 = 'stage_3' -BASE_OPTIMIZER_STATE = 'base_optimizer_state' -SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" - ZERO_OPTIMIZATION_STAGE_DEFAULT = ZERO_OPTIMIZATION_DISABLED ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS = 'allgather_partitions' diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ad945b745989..097e64d4db30 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -33,7 +33,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper -from deepspeed.runtime.constants import OPTIMIZER_STATE_DICT +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS # Toggle this to true to enable correctness test # with gradient partitioning and without diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 89f1b34f1cb6..b6c1939ad298 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -11,14 +11,14 @@ from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS -from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER, OFFLOAD_OPTIMIZER_DEVICE +from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version -from .constants import SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE +from deepspeed.checkpoint.constants import SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE # Toggle this to true to enable correctness test # with gradient partitioning and without diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 37787a7962af..9850addfba77 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -18,6 +18,14 @@ # DeepSpeed data structures it has to be available in the current python environment. import deepspeed from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (OPTIMIZER_STATE_DICT, + PARAM_SHAPES, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) debug = 0 @@ -55,9 +63,9 @@ def get_optim_files(checkpoint_dir): def parse_model_state(file): state_dict = torch.load(file, map_location=device) - if "buffer_names" not in state_dict: + if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") - buffer_names = state_dict["buffer_names"] + buffer_names = state_dict[BUFFER_NAMES] if debug: print("Found buffers:", buffer_names) @@ -67,7 +75,9 @@ def parse_model_state(file): for k, v in state_dict["module"].items() if k in buffer_names } - return buffers + param_shapes = state_dict[PARAM_SHAPES] + + return buffers, param_shapes def parse_optim_states(files, ds_checkpoint_dir): @@ -77,11 +87,11 @@ def parse_optim_states(files, ds_checkpoint_dir): for f in files: state_dicts.append(torch.load(f, map_location=device)) - if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: raise ValueError(f"{files[0]} is not a zero checkpoint") - zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] - world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] - param_shapes = state_dicts[0]["param_shapes"] + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert # parameters can be different from data parallelism for non-expert parameters. So we can just # use the max of the partition_count to get the dp world_size. @@ -97,15 +107,15 @@ def parse_optim_states(files, ds_checkpoint_dir): # the groups are named differently in each stage if zero_stage == 2: - fp32_groups_key = "single_partition_of_fp32_groups" + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS elif zero_stage == 3: - fp32_groups_key = "fp32_flat_groups" + fp32_groups_key = FP32_FLAT_GROUPS else: raise ValueError(f"unknown zero stage {zero_stage}") if zero_stage == 2: fp32_flat_groups = [ - state_dicts[i]['optimizer_state_dict'][fp32_groups_key] + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts)) ] elif zero_stage == 3: @@ -116,11 +126,11 @@ def parse_optim_states(files, ds_checkpoint_dir): # will require matching the sub-lists of param_shapes for each param group flattened tensor fp32_flat_groups = [ - torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) ] - return zero_stage, world_size, param_shapes, fp32_flat_groups + return zero_stage, world_size, fp32_flat_groups def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): @@ -134,12 +144,12 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") optim_files = get_optim_files(ds_checkpoint_dir) - zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) print( f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) - buffers = parse_model_state(model_file) + buffers, param_shapes = parse_model_state(model_file) if zero_stage == 2: return _get_fp32_state_dict_from_zero2_checkpoint(world_size, @@ -165,7 +175,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, if debug: for i in range(world_size): for j in range(len(fp32_flat_groups[0])): - print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") # XXX: memory usage doubles here (zero2) num_param_groups = len(fp32_flat_groups[0]) @@ -269,7 +280,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, if debug: for i in range(world_size): - print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}") + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") wanted_params = len(param_shapes) wanted_numel = sum(shape.numel() for shape in param_shapes.values()) From 162c19b35ea9aae0f9466f32aeae60f6df49a5fb Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 29 Jan 2022 00:48:18 +0000 Subject: [PATCH 03/31] Remove hard-coded constants --- deepspeed/runtime/zero/stage3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 097e64d4db30..801b6f612976 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -33,7 +33,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -2918,15 +2918,15 @@ def _clear_fp32_optimizer_param_groups(self): def _rigid_state_dict(self): state_dict = {} - state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS + state_dict[ZERO_STAGE] = ZERO_OPTIMIZATION_WEIGHTS state_dict['loss_scaler'] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow - state_dict['partition_count'] = self.partition_count + state_dict[PARTITION_COUNT] = self.partition_count self._set_fp32_optimizer_param_groups() state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() - state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat + state_dict[FP32_FLAT_GROUPS] = self.fp32_partitioned_groups_flat self._clear_fp32_optimizer_param_groups() return state_dict @@ -3042,7 +3042,7 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): self._clear_fp32_optimizer_param_groups() # restore fp32 partitions - for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict[FP32_FLAT_GROUPS]): curr_param.data.copy_(saved_param.data) # restore fp16 partitions from fp32 From 680e62077b9d554ac9c09e79aa6ec0fdc885a989 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 29 Jan 2022 21:03:27 +0000 Subject: [PATCH 04/31] Conditioned to zero optimizer --- deepspeed/runtime/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1f9b9e16fb..1f655698c8b9 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2921,7 +2921,8 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None, - param_shapes=self._get_zero_param_shapes(), + param_shapes=self._get_zero_param_shapes() + if self.optimizer and self.zero_optimization() else None, lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, sparse_tensor_module_names=self.sparse_tensor_module_names, @@ -2935,7 +2936,6 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): state.update(client_state) log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - #logger.info('Saving model checkpoint: {}'.format(save_path)) torch.save(state, save_path) self._curr_save_path = None From f1b5d16b4a65571ebcb6d915bf055680976591ea Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 29 Jan 2022 21:08:52 +0000 Subject: [PATCH 05/31] Add zero checkpoint merging --- deepspeed/checkpoint/__init__.py | 5 + deepspeed/checkpoint/constants.py | 28 ++ deepspeed/checkpoint/deepspeed_checkpoint.py | 254 +++++++++++++++++++ deepspeed/checkpoint/reshape_meg_2d.py | 226 +++++++++++++++++ deepspeed/checkpoint/reshape_utils.py | 95 +++++++ deepspeed/checkpoint/utils.py | 30 +++ deepspeed/checkpoint/zero_checkpoint.py | 45 ++++ 7 files changed, 683 insertions(+) create mode 100644 deepspeed/checkpoint/deepspeed_checkpoint.py create mode 100644 deepspeed/checkpoint/reshape_meg_2d.py create mode 100644 deepspeed/checkpoint/reshape_utils.py create mode 100644 deepspeed/checkpoint/utils.py create mode 100644 deepspeed/checkpoint/zero_checkpoint.py diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index e69de29bb2d1..96090e75d379 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -0,0 +1,5 @@ +from .reshape_meg_2d import reshape_meg_2d_parallel +from .deepspeed_checkpoint import DeepSpeedCheckpoint +from .utils import (get_layer_ckpt_name_for_rank, + get_model_ckpt_name_for_rank, + get_zero_ckpt_name_for_rank) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 2bd7ac087e1a..9bfc263439bf 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -20,3 +20,31 @@ ######################################### PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' + +######################################### +# Checkpoint naming constants +######################################### +MODEL_FILE_PREFIX = 'mp_rank_' +ZERO_FILE_PREFIX = 'zero_pp_rank_' +OPTIM_FILE_SUFFIX = '_optim_states.pt' +MODEL_FILE_SUFFIX = '_model_states.pt' + +######################################### +# Checkpoint utility keys +######################################### +EMBEDDING_LAYER_INDEX = 0 +FINAL_LAYER_NORM_INDEX = -1 +ARGS_KEY = 'args' +CHECKPOINT_INFO_KEY = 'checkpoint_info' +ITERATION_KEY = 'iteration' +SEQUENTIAL_LAYERS = [ + 'input_layernorm.weight', + 'input_layernorm.bias', + 'self_attention.dense.bias', + 'post_attention_layernorm.weight', + 'post_attention_layernorm.bias', + 'mlp.dense_4h_to_h.bias', + 'position_embeddings.weight' +] + +LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py new file mode 100644 index 000000000000..9ec594c4c4f4 --- /dev/null +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -0,0 +1,254 @@ +import os +from typing import Dict +import torch +from .reshape_utils import (basic_folder_validation, + partition_data, + get_files, + get_files_with_prefix, + ZERO_FILE_PREFIX, + LAYER_FILE_PREFIX, + MP_RANK_FILE_PREFIX) + +from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map +from .zero_checkpoint import ZeROCheckpoint +from .constants import * + + +class DeepSpeedCheckpoint(object): + def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): + self.dir = dir + self._validate_folder(dir) + + self.file_list = get_files(dir) + self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) + self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) + self.mp_rank_files = get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) + + self.layer_keys = self._get_layer_keys() + self.layer_count = len(self.layer_keys) + self.original_tp_degree = len( + get_files_with_prefix(self.layer_files, + f'{LAYER_FILE_PREFIX}01')) + self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree + self.original_dp_degree = max( + 1, + len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)) + + self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree + self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree + self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree + + self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree + self.world_size = self.tp_degree * self.pp_degree * self.dp_degree + + self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree, + self.original_tp_degree) + self.old_2d_map.simple_init() + self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree, + old_tp_degree=self.original_tp_degree, + new_pp_degree=self.pp_degree, + new_tp_degree=self.tp_degree) + self.zero_checkpoint = ZeROCheckpoint(dir) + self.global_state = {} + + self._sanity_check() + self.pp_to_transformer_map = self._build_pp_transformer_map() + self.transformer_file_map = self._build_transformer_file_map() + self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) + self.tp_to_final_norm_map = self._build_tp_other_layer_map( + FINAL_LAYER_NORM_INDEX) + self._build_global_state() + + def show_2d_mapping(self): + print(f'reshaped 2d map ---- begin') + + for i in range(self.pp_degree): + for j in range(self.tp_degree): + file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j) + print(f'[{i}, {j}] = {file_list}') + + print(f'reshaped 2d map ---- end') + + def show_tp_embedding_map(self): + self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') + + def show_tp_final_norm_map(self): + self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') + + def show_pp_tranformer_map(self): + self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') + + def show_transformer_file_map(self): + self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') + + def _build_global_state(self): + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) + + def get_zero_checkpoint_state(self, global_rank) -> dict: + return self.zero_checkpoint.get_state_for_global_rank( + self.world_size, + global_rank, + keys_to_ignore=[PARAM_SHAPES]) + + def get_embedding_layer_id(self): + return self.layer_keys[EMBEDDING_LAYER_INDEX] + + def get_final_norm_layer_id(self): + return self.layer_keys[FINAL_LAYER_NORM_INDEX] + + def get_iteration(self): + if not ITERATION_KEY in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + + return self.global_state[ITERATION_KEY] + + def get_embedding_state(self, tp_index: int) -> Dict: + assert tp_index in self.tp_to_embedding_map.keys() + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) + for fname in self.tp_to_embedding_map[tp_index] + ] + sd = self._merge_state_dicts(sd_list) + return sd + + def _get_checkpoint_value(self, key): + if not key in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[key] = sd.get(key, None) + + return self.global_state[key] + + def get_args(self): + return self._get_checkpoint_value(ARGS_KEY) + + def get_checkpoint_info(self): + return self._get_checkpoint_value(CHECKPOINT_INFO_KEY) + + def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) for fname in fname_list + ] + # HACK HACK HACK, should be merging i.e., sd = self._merge_state_dicts(sd_list) + sd = sd_list[0] + return sd + + def get_transformer_state(self, tp_index: int, pp_index: int) -> list: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + t_list = [] + for fname_list in self.transformer_file_map[(tp_index, pp_index)]: + sd_list = [ + torch.load(fname, + map_location=torch.device('cpu')) for fname in fname_list + ] + sd = self._merge_state_dicts(sd_list) + t_list.append(sd) + return t_list + + def get_pp_transformer_map(self, pp_index: int) -> list: + assert pp_index < self.pp_degree + return self.pp_to_transformer_map[pp_index] + + def get_final_norm_state(self, tp_index: int) -> Dict: + assert tp_index in self.tp_to_final_norm_map.keys() + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], + map_location=torch.device('cpu')) + return sd + + def _build_tp_other_layer_map(self, layer_index: int): + assert layer_index < len(self.layer_files) + layer_files = get_files_with_prefix(self.layer_files, + self.layer_keys[layer_index]) + layer_file_partitions = partition_data(layer_files, self.tp_degree) + data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} + return data_map + + def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index) + return [self.mp_rank_files[i] for i in file_indices] + + def _build_pp_transformer_map(self): + data_map = {} + transformer_layers = self.layer_keys[1:-1] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = { + i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] + for i in range(0, + self.pp_degree) + } + return data_map + + def _dump_mapping(self, data_map, map_tag=None): + if map_tag is not None: + print(f'Dump mapping: {map_tag}') + for k, v in data_map.items(): + print(f'{k} = {v}') + + def _build_transformer_file_map(self): + transformer_layer_keys = self.layer_keys[1:-1] + file_map = {} + layers_per_pp = len(transformer_layer_keys) // self.pp_degree + for key_index, layer_key in enumerate(transformer_layer_keys): + pp_index = key_index // layers_per_pp + layer_files = get_files_with_prefix(self.layer_files, layer_key) + layer_file_partitions = partition_data(layer_files, self.tp_degree) + for tp_index in range(self.tp_degree): + map_key = (tp_index, pp_index) + if not map_key in file_map.keys(): + file_map[map_key] = [] + file_map[map_key].append(layer_file_partitions[tp_index]) + + return file_map + + def _sanity_check(self): + assert len(self.mp_rank_files) % self.tp_degree == 0 + assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 + assert len(self.layer_keys) > 2 + assert (len(self.layer_keys) - 2) % self.pp_degree == 0 + + def validate_files(self): + for file in self.file_list: + if not os.path.isfile(file): + print(f'Error: {file} is not existent') + + def _get_layer_keys(self): + key_set = set() + key_len = len(LAYER_FILE_PREFIX) + 2 + for file_path in self.layer_files: + _, fname = os.path.split(file_path) + key_set.add(fname[:key_len]) + return sorted(list(key_set)) + + def _merge_state_dicts(self, sd_list): + merged_sd = {} + for key in sd_list[0].keys(): + if not key in SEQUENTIAL_LAYERS: + cat_dim = LAYER_CONCAT_DIM.get(key, 0) + merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) + else: + merged_sd[key] = sd_list[0][key] + + return merged_sd + + def _validate_folder(self, dir): + basic_folder_validation(dir) + + file_list = get_files(dir) + + for file_prefix in [ + MP_RANK_FILE_PREFIX, + LAYER_FILE_PREFIX, + f'{LAYER_FILE_PREFIX}01' + ]: + ckpt_files = get_files_with_prefix(file_list, file_prefix) + assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.' diff --git a/deepspeed/checkpoint/reshape_meg_2d.py b/deepspeed/checkpoint/reshape_meg_2d.py new file mode 100644 index 000000000000..0d7cd233c78a --- /dev/null +++ b/deepspeed/checkpoint/reshape_meg_2d.py @@ -0,0 +1,226 @@ +from .reshape_utils import partition_data + + +class meg_2d_parallel_map(object): + def __init__(self, pp_degree, tp_degree): + self.pp_degree = pp_degree + self.tp_degree = tp_degree + self.map = {} + + def simple_init(self): + self.map = { + self._make_key(i // self.tp_degree, + i % self.tp_degree): [i] + for i in range(self.pp_degree * self.tp_degree) + } + + def add_data(self, pp_index, tp_index, data): + self._validate_indices(pp_index, tp_index) + assert type(data) is list + + key = self._make_key(pp_index, tp_index) + if not key in self.map.keys(): + self.map[key] = [] + self.map[key] += data + + def get_data(self, pp_index=None, tp_index=None): + self._validate_indices(pp_index, tp_index) + pp_indices = list(range(self.pp_degree)) if pp_index is None else [pp_index] + tp_indices = list(range(self.tp_degree)) if tp_index is None else [tp_index] + + result = [] + for i in pp_indices: + for j in tp_indices: + result += self.map[self._make_key(i, j)] + + return result + + def print_data(self, tag): + print(f'{tag}') + for key, value in self.map.items(): + print(f'{key} = {value}') + + def _validate_indices(self, pp_index, tp_index): + assert pp_index is None or pp_index < self.pp_degree + assert tp_index is None or tp_index < self.tp_degree + + def _make_key(self, i, j): + return f'{i},{j}' + + +def _reshape_tp_dimension(old_2d_map, new_tp_degree): + old_pp_degree = old_2d_map.pp_degree + new_2d_map = meg_2d_parallel_map(old_pp_degree, new_tp_degree) + for i in range(old_pp_degree): + ranks_for_pp_index = old_2d_map.get_data(pp_index=i, tp_index=None) + split_ranks = partition_data(ranks_for_pp_index, new_tp_degree) + for j in range(new_tp_degree): + new_2d_map.add_data(i, j, split_ranks[j]) + + return new_2d_map + + +def _reshape_pp_dimension(old_2d_map, new_pp_degree): + old_tp_degree = old_2d_map.tp_degree + new_2d_map = meg_2d_parallel_map(new_pp_degree, old_tp_degree) + for i in range(old_tp_degree): + ranks_for_tp_index = old_2d_map.get_data(pp_index=None, tp_index=i) + split_ranks = partition_data(ranks_for_tp_index, new_pp_degree) + for j in range(new_pp_degree): + new_2d_map.add_data(j, i, split_ranks[j]) + + return new_2d_map + + +def reshape_meg_2d_parallel(old_pp_degree, + old_tp_degree, + new_pp_degree, + new_tp_degree, + verbose=False): + assert new_pp_degree <= old_pp_degree + assert new_tp_degree <= old_tp_degree + + old_2d_map = meg_2d_parallel_map(old_pp_degree, old_tp_degree) + old_2d_map.simple_init() + if verbose: + old_2d_map.print_data(f'original_2d_map:') + + if old_tp_degree != new_tp_degree: + new_tp_map = _reshape_tp_dimension(old_2d_map, new_tp_degree) + else: + new_tp_map = old_2d_map + if verbose: + new_tp_map.print_data(f'after_tp_reshape:') + + if old_pp_degree != new_pp_degree: + final_map = _reshape_pp_dimension(new_tp_map, new_pp_degree) + else: + final_map = new_tp_map + + if verbose: + final_map.print_data(f'final_2d_map:') + + return final_map + + +def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): + """ + Initialize model data parallel groups. + + Arguments: + tp_size: number of GPUs used to parallelize model tensor. + pp_size: number of GPUs used to parallelize model pipeline. + dp_size: number of GPUs used to parallelize model data. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + + world_size = tp_size * pp_size * dp_size + + print(f"\n\n*** tp={tp_size}, pp={pp_size}, dp={dp_size}, world={world_size}") + + tensor_model_parallel_size = min(tp_size, world_size) + pipeline_model_parallel_size = min(pp_size, world_size) + data_parallel_size = world_size // (tensor_model_parallel_size * + pipeline_model_parallel_size) + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + num_data_parallel_groups = world_size // data_parallel_size + + # Build the data-parallel groups. + all_dp_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_dp_group_ranks.append(list(ranks)) + + print("DP", all_dp_group_ranks) + + # Build the model-parallel groups. + all_pp_group_ranks = [] + for i in range(data_parallel_size): + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_dp_group_ranks + ] + all_pp_group_ranks.append(list(ranks)) + + print(f"PP", all_pp_group_ranks) + + # Build the tensor model-parallel groups. + all_tp_group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + all_tp_group_ranks.append(list(ranks)) + + print(f"TP", all_tp_group_ranks) + + return all_tp_group_ranks, all_pp_group_ranks, all_dp_group_ranks + + # # Build the pipeline model-parallel groups and embedding groups + # # (first and last rank in each pipeline model-parallel group). + # for i in range(num_pipeline_model_parallel_groups): + # ranks = range(i, world_size, + # num_pipeline_model_parallel_groups) + # print(f"EMB{i}", list(ranks)) + + +def reshape(src, tgt): + """ + reshape([tp_size_src, pp_size_src, dp_size_src], + [tp_size_tgt, pp_size_tgt, dp_size_tgt]) + """ + + print(f"\n\n*** Reshaping: {src} => {tgt}") + + tp_size_src, pp_size_src, dp_size_src = src + tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt + + tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src) + tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src) + tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src) + + # handle tp contraction first + print("\n*** TP contraction:") + + for i, r in enumerate(tp_ranks1): + print(f'{tp_ranks1[i]} => {tp_ranks2[i]}') + + # handle pp contraction next + + print("\n*** PP contraction:") + + for i, r in enumerate(pp_ranks1): + print(f'{pp_ranks2[i]} => {pp_ranks3[i]}') + + +# easy +#reshape([2,2,1],[1,1,1]) + +# probably need more logic to suggest how to pack +#reshape([4,4,1],[2,2,1]) + +#reshape([2,4,2], [8,32,1]) + +# get_mpu_ranks(2,2,2) +# get_mpu_ranks(4,2,1) +# get_mpu_ranks(2,4,1) +# get_mpu_ranks(1,1,8) diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py new file mode 100644 index 000000000000..9a3645fd11ea --- /dev/null +++ b/deepspeed/checkpoint/reshape_utils.py @@ -0,0 +1,95 @@ +import os +import torch +from collections import OrderedDict + +ZERO_FILE_PREFIX = 'zero_pp_rank_' +LAYER_FILE_PREFIX = 'layer_' +MP_RANK_FILE_PREFIX = 'mp_rank_' + + +def basic_folder_validation(dir): + assert os.path.exists(dir), f'{dir} path does not exist' + assert os.path.isdir(dir), f'{dir} is not a folder' + + +def get_files_with_prefix(all_files, prefix): + file_list = [] + for file_path in all_files: + _, fname = os.path.split(file_path) + if fname.startswith(prefix): + file_list.append(file_path) + + return sorted(file_list) + + +def validate_files(file_list): + for file in file_list: + if not os.path.isfile(file): + print(f'Error: {file} is not existent') + + +def get_files(dir): + file_list = [] + for root, _, files in os.walk(dir): + for file in files: + file_list.append(os.path.join(root, file)) + return file_list + + +def partition_data(data_list, num_partitions): + num_elems = len(data_list) + assert num_elems % num_partitions == 0 + partition_size = num_elems // num_partitions + partitions_list = [ + data_list[i:i + partition_size] for i in range(0, + num_elems, + partition_size) + ] + return partitions_list + + +def _key_list_to_string(key_list): + return '.'.join(key_list) + + +def merge_state_dict(dict_a, dict_b, key_list): + if dict_a.keys() != dict_b.keys(): + print(f'key_list = {_key_list_to_string(key_list)}') + raise ValueError(f'''Cannot merge dicts with different keys, + a = {dict_a.keys()} + b = {dict_b.keys()} + ''') + + return type(dict_a)({ + key: merge_state(dict_a[key], + dict_b[key], + key_list + [str(key)]) + for key in dict_a.keys() + }) + + +def merge_state_list(list_a, list_b, key_list): + if len(list_a) != len(list_b): + print(f'{_key_list_to_string(key_list)}') + raise ValueError( + f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}' + ) + + return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)] + + +def merge_state(state_a, state_b, key_list=[]): + if type(state_a) != type(state_b): + key_list_string = _key_list_to_string(key_list) + print(f'key_list = {key_list_string}') + raise ValueError( + f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}') + + if type(state_a) in (dict, OrderedDict): + return merge_state_dict(state_a, state_b, key_list) + elif type(state_a) in (list, tuple): + return type(state_a)(merge_state_list(state_a, state_b, key_list)) + elif torch.is_tensor(state_a): + return torch.cat([state_a, state_b], 0) + else: + return state_a diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py new file mode 100644 index 000000000000..3e653ed80489 --- /dev/null +++ b/deepspeed/checkpoint/utils.py @@ -0,0 +1,30 @@ +import os + +from deepspeed.checkpoint.constants import MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX + + +def get_model_ckpt_name_for_rank(base_folder, mp_rank_str, tag=None): + ckpt_name = os.path.join( + base_folder, + str(tag), + MODEL_FILE_PREFIX + mp_rank_str + MODEL_FILE_SUFFIX, + ) + return ckpt_name + + +def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank, tag=None): + zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}' + mp_rank_string = f'_{MODEL_FILE_PREFIX}_{mp_rank:02d}' + + zero_ckpt_name = os.path.join( + base_folder, + str(tag), + zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX, + ) + return zero_ckpt_name + + +def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank, tag=None): + ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}' + ckpt_path = os.path.join(base_folder, ckpt_file) + return ckpt_path diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py new file mode 100644 index 000000000000..65f67db78155 --- /dev/null +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -0,0 +1,45 @@ +import torch +from .reshape_utils import (basic_folder_validation, + get_files, + get_files_with_prefix, + merge_state, + ZERO_FILE_PREFIX) + + +class ZeROCheckpoint(object): + def __init__(self, dir): + basic_folder_validation(dir) + self.dir = dir + self.file_list = get_files_with_prefix(get_files(dir), ZERO_FILE_PREFIX) + self.num_files = len(self.file_list) + + def get_files_for_global_rank(self, world_size, global_rank): + assert global_rank < world_size, f'Expected global_rank {global_rank} to be less than world size {world_size}' + if world_size == self.num_files: + return [self.file_list[global_rank]] + elif world_size < self.num_files: + assert self.num_files % world_size == 0, \ + f'Expected world size {world_size} that can divide number of zero files {self.num_files}' + num_files_per_rank = self.num_files // world_size + starting_index = global_rank * num_files_per_rank + return self.file_list[starting_index:(starting_index + num_files_per_rank)] + else: + assert world_size % self.num_files == 0, \ + f'Expected world size {world_size} that is multiple of number of zero files {self.num_files}' + num_ranks_per_file = world_size // self.num_files + return [self.file_list[global_rank // num_ranks_per_file]] + + def get_state_for_global_rank(self, world_size, global_rank, keys_to_ignore=[]): + rank_file_list = self.get_files_for_global_rank(world_size, global_rank) + assert len(rank_file_list) > 0, f'Expected global_rank files count {len(rank_file_list)} > 0' + rank_state = None + for ckpt_file in rank_file_list: + sd = torch.load(ckpt_file, map_location=torch.device('cpu')) + for key in keys_to_ignore: + sd.pop(key, None) + if rank_state is None: + rank_state = sd + else: + rank_state = merge_state(rank_state, sd) + + return rank_state From fd8c3e688f57fb7914966fcb5f6c2c9dcede0576 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 31 Jan 2022 19:24:14 +0000 Subject: [PATCH 06/31] Print checkpoint version --- deepspeed/checkpoint/constants.py | 1 + deepspeed/runtime/zero/stage_1_and_2.py | 14 +++++++++----- deepspeed/utils/zero_to_fp32.py | 10 +++++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 2bd7ac087e1a..0162bf6f27d3 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -20,3 +20,4 @@ ######################################### PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' +DS_VERSION = 'ds_version' diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index b6c1939ad298..81ca62af83b5 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -18,7 +18,11 @@ from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version -from deepspeed.checkpoint.constants import SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE +from deepspeed.checkpoint.constants import (DS_VERSION, + PARTITION_COUNT, + SINGLE_PARTITION_OF_FP32_GROUPS, + BASE_OPTIMIZER_STATE, + ZERO_STAGE) # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -2009,10 +2013,10 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding - state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS - state_dict['partition_count'] = self.partition_count + state_dict[ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS + state_dict[PARTITION_COUNT] = self.partition_count - state_dict['ds_version'] = version + state_dict[DS_VERSION] = version return state_dict @@ -2156,7 +2160,7 @@ def load_state_dict(self, self.dynamic_loss_scale = current_rank_sd['dynamic_loss_scale'] self.overflow = current_rank_sd['overflow'] - ckpt_version = current_rank_sd.get("ds_version", False) + ckpt_version = current_rank_sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 9850addfba77..7c229518e0bc 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -18,7 +18,8 @@ # DeepSpeed data structures it has to be available in the current python environment. import deepspeed from deepspeed.utils import logger -from deepspeed.checkpoint.constants import (OPTIMIZER_STATE_DICT, +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, PARAM_SHAPES, SINGLE_PARTITION_OF_FP32_GROUPS, FP32_FLAT_GROUPS, @@ -77,7 +78,9 @@ def parse_model_state(file): } param_shapes = state_dict[PARAM_SHAPES] - return buffers, param_shapes + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version def parse_optim_states(files, ds_checkpoint_dir): @@ -149,7 +152,8 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) - buffers, param_shapes = parse_model_state(model_file) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') if zero_stage == 2: return _get_fp32_state_dict_from_zero2_checkpoint(world_size, From c8689fd235ad4bad58cd3a2adebb67cd3c065dfa Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 7 Feb 2022 19:02:59 +0000 Subject: [PATCH 07/31] Reshape zero_* ckpt files --- deepspeed/checkpoint/__init__.py | 8 ++ deepspeed/checkpoint/constants.py | 16 --- deepspeed/checkpoint/deepspeed_checkpoint.py | 18 ++++ deepspeed/checkpoint/reshape_3d_utils.py | 108 +++++++++++++++++++ deepspeed/checkpoint/zero_checkpoint.py | 78 ++++++++------ 5 files changed, 182 insertions(+), 46 deletions(-) create mode 100644 deepspeed/checkpoint/reshape_3d_utils.py diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index 96090e75d379..edb424e9dfa8 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -1,5 +1,13 @@ from .reshape_meg_2d import reshape_meg_2d_parallel + from .deepspeed_checkpoint import DeepSpeedCheckpoint + from .utils import (get_layer_ckpt_name_for_rank, get_model_ckpt_name_for_rank, get_zero_ckpt_name_for_rank) + +from .reshape_utils import (merge_state) + +from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) + +from .zero_checkpoint import ZeROCheckpoint diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 8fe88b90006a..120a76747148 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -32,20 +32,4 @@ ######################################### # Checkpoint utility keys ######################################### -EMBEDDING_LAYER_INDEX = 0 -FINAL_LAYER_NORM_INDEX = -1 -ARGS_KEY = 'args' -CHECKPOINT_INFO_KEY = 'checkpoint_info' -ITERATION_KEY = 'iteration' -SEQUENTIAL_LAYERS = [ - 'input_layernorm.weight', - 'input_layernorm.bias', - 'self_attention.dense.bias', - 'post_attention_layernorm.weight', - 'post_attention_layernorm.bias', - 'mlp.dense_4h_to_h.bias', - 'position_embeddings.weight' -] - -LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} DS_VERSION = 'ds_version' diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 9ec594c4c4f4..b875a574e7ae 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -13,6 +13,24 @@ from .zero_checkpoint import ZeROCheckpoint from .constants import * +EMBEDDING_LAYER_INDEX = 0 +FINAL_LAYER_NORM_INDEX = -1 +ARGS_KEY = 'args' +CHECKPOINT_INFO_KEY = 'checkpoint_info' +ITERATION_KEY = 'iteration' + +SEQUENTIAL_LAYERS = [ + 'input_layernorm.weight', + 'input_layernorm.bias', + 'self_attention.dense.bias', + 'post_attention_layernorm.weight', + 'post_attention_layernorm.bias', + 'mlp.dense_4h_to_h.bias', + 'position_embeddings.weight' +] + +LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} + class DeepSpeedCheckpoint(object): def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py new file mode 100644 index 000000000000..5722f49d37ed --- /dev/null +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -0,0 +1,108 @@ +from .reshape_utils import (get_files, + get_files_with_prefix, + ZERO_FILE_PREFIX, + MP_RANK_FILE_PREFIX, + LAYER_FILE_PREFIX, + partition_data) + +from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map) + +PP_DIM = 'PP' +TP_DIM = 'TP' +DP_DIM = 'DP' + + +class model_3d_desc(object): + def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1): + self.pp_degree = pp_degree + self.tp_degree = tp_degree + self.dp_degree = dp_degree + + def reshape(self, target_3d_desc, verbose=False): + valid_reshape, reshape_errors = self.can_reshape(target_3d_desc) + assert valid_reshape, ','.join(reshape_errors) + tgt_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.pp_degree, + old_tp_degree=self.tp_degree, + new_pp_degree=target_3d_desc.pp_degree, + new_tp_degree=target_3d_desc.tp_degree, + verbose=verbose) + + flat_3d_map = flatten_dp_dimension(meg_2d_map=tgt_2d_map, + src_2d_size=self.pp_degree * self.tp_degree, + dp_degree=self.dp_degree) + + return unflatten_dp_dimension(meg_2d_map=flat_3d_map, + dp_degree=target_3d_desc.dp_degree) + + def get_desc(self): + return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})' + + def is_valid(self, pp_index, tp_index, dp_index): + err_msg = [] + valid = True + for index, degree, dim_name in [ + (pp_index, self.pp_degree, PP_DIM), + (tp_index, self.tp_degree, TP_DIM), + (dp_index, self.dp_degree, DP_DIM)]: + if index >= degree: + valid = False + err_msg.append( + f'{dim_name} indexing error: index {index} >= degree {degree}') + + return valid, err_msg + + def can_reshape(self, target_3d_desc): + err_msg = [] + if target_3d_desc.pp_degree > self.pp_degree: + err_msg.append( + f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}' + ) + + if target_3d_desc.tp_degree > self.tp_degree: + err_msg.append( + f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}' + ) + + if target_3d_desc.dp_degree > self.dp_degree: + err_msg.append( + f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}' + ) + + return len(err_msg) == 0, err_msg + + +def get_model_3d_descriptor(dir): + file_list = get_files(dir) + tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01')) + pp_degree = len(get_files_with_prefix(file_list, MP_RANK_FILE_PREFIX)) // tp_degree + num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX)) + dp_degree = max(1, num_zero_files // (pp_degree * tp_degree)) + return model_3d_desc(pp_degree, tp_degree, dp_degree) + + +def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree): + new_meg_2d_map = meg_2d_parallel_map(meg_2d_map.pp_degree, meg_2d_map.tp_degree) + for pp_index in range(meg_2d_map.pp_degree): + for tp_index in range(meg_2d_map.tp_degree): + dp0_indices = meg_2d_map.get_data(pp_index, tp_index) + for idx in dp0_indices: + dpX_indices = [idx + (i * src_2d_size) for i in range(dp_degree)] + new_meg_2d_map.add_data(pp_index, tp_index, dpX_indices) + return new_meg_2d_map + + +def unflatten_dp_dimension(meg_2d_map, dp_degree): + pp_degree = meg_2d_map.pp_degree + tp_degree = meg_2d_map.tp_degree + meg_2d_map_list = [ + meg_2d_parallel_map(pp_degree=pp_degree, + tp_degree=tp_degree) for _ in range(dp_degree) + ] + for pp_index in range(pp_degree): + for tp_index in range(tp_degree): + flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index) + partitioned_dp_indices = partition_data(flat_dp_indices, dp_degree) + for dp_indices, _2d_map in zip(partitioned_dp_indices, meg_2d_map_list): + _2d_map.add_data(pp_index, tp_index, dp_indices) + + return meg_2d_map_list diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 65f67db78155..31b313e942ae 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -2,8 +2,10 @@ from .reshape_utils import (basic_folder_validation, get_files, get_files_with_prefix, - merge_state, - ZERO_FILE_PREFIX) + ZERO_FILE_PREFIX, + merge_state) + +from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) class ZeROCheckpoint(object): @@ -13,33 +15,49 @@ def __init__(self, dir): self.file_list = get_files_with_prefix(get_files(dir), ZERO_FILE_PREFIX) self.num_files = len(self.file_list) - def get_files_for_global_rank(self, world_size, global_rank): - assert global_rank < world_size, f'Expected global_rank {global_rank} to be less than world size {world_size}' - if world_size == self.num_files: - return [self.file_list[global_rank]] - elif world_size < self.num_files: - assert self.num_files % world_size == 0, \ - f'Expected world size {world_size} that can divide number of zero files {self.num_files}' - num_files_per_rank = self.num_files // world_size - starting_index = global_rank * num_files_per_rank - return self.file_list[starting_index:(starting_index + num_files_per_rank)] - else: - assert world_size % self.num_files == 0, \ - f'Expected world size {world_size} that is multiple of number of zero files {self.num_files}' - num_ranks_per_file = world_size // self.num_files - return [self.file_list[global_rank // num_ranks_per_file]] - - def get_state_for_global_rank(self, world_size, global_rank, keys_to_ignore=[]): - rank_file_list = self.get_files_for_global_rank(world_size, global_rank) - assert len(rank_file_list) > 0, f'Expected global_rank files count {len(rank_file_list)} > 0' - rank_state = None - for ckpt_file in rank_file_list: - sd = torch.load(ckpt_file, map_location=torch.device('cpu')) - for key in keys_to_ignore: - sd.pop(key, None) - if rank_state is None: - rank_state = sd + self.src_3d = get_model_3d_descriptor(dir) + self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree, + tp_degree=self.src_3d.tp_degree, + dp_degree=self.src_3d.dp_degree) + self._3d_file_map = self.src_3d.reshape(self.target_3d) + + def get_file_indices_for_rank(self, pp_index, tp_index, dp_index): + assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}' + dp_2d_map = self._3d_file_map[dp_index] + return dp_2d_map.get_data(pp_index, tp_index) + + def get_files_for_rank(self, pp_index, tp_index, dp_index): + file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) + return [self.file_list[idx] for idx in file_idx_list] + + def get_state_for_rank(self, pp_index, tp_index, dp_index): + state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) + merged_sd = None + for state_file in state_file_list: + sd = torch.load(state_file, map_location=torch.device('cpu')) + if merged_sd is None: + merged_sd = sd else: - rank_state = merge_state(rank_state, sd) + merged_sd = merge_state(merged_sd, sd) + + return merged_sd + + def print_3d_index_map(self, tag=None): + if tag: + print(f'3D index map: {tag}') + for dp_index, _2d_map in enumerate(self._3d_file_map): + _2d_map.print_data(f'dp = {dp_index}') + + def print_3d_file_map(self, tag=None): + if tag: + print(f'3D file map: {tag}') + for dp_index, _2d_map in enumerate(self._3d_file_map): + for pp_index in _2d_map.pp_degree: + for tp_index in _2d_map.tp_degree: + file_index_list = _2d_map.get_data(pp_index, tp_index) + file_list = [self.file_list[idx] for idx in file_index_list] + print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}') - return rank_state + def reshape(self, target_3d_desc: model_3d_desc): + self.target_3d = target_3d_desc + self._3d_file_map = self.src_3d.reshape(self.target_3d) From 4a86c1a58639b2c327b6cf23cd2065f4dec3663d Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 8 Feb 2022 22:33:27 +0000 Subject: [PATCH 08/31] Merge zero* files contraction --- deepspeed/checkpoint/deepspeed_checkpoint.py | 29 ++++++++++++++++---- deepspeed/checkpoint/utils.py | 16 +++++------ deepspeed/checkpoint/zero_checkpoint.py | 5 +++- deepspeed/runtime/engine.py | 1 + 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index b875a574e7ae..fade0264ce5f 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -1,6 +1,8 @@ import os from typing import Dict import torch + +from deepspeed.checkpoint.reshape_3d_utils import model_3d_desc from .reshape_utils import (basic_folder_validation, partition_data, get_files, @@ -66,7 +68,15 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): old_tp_degree=self.original_tp_degree, new_pp_degree=self.pp_degree, new_tp_degree=self.tp_degree) + self.zero_checkpoint = ZeROCheckpoint(dir) + if self.is_change_pp_degree() or self.is_change_tp_degree( + ) or self.is_change_dp_degree(): + self.zero_checkpoint.reshape( + model_3d_desc(self.pp_degree, + self.tp_degree, + self.dp_degree)) + self.global_state = {} self._sanity_check() @@ -77,6 +87,15 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): FINAL_LAYER_NORM_INDEX) self._build_global_state() + def is_change_tp_degree(self): + return self.tp_degree != self.original_tp_degree + + def is_change_pp_degree(self): + return self.pp_degree != self.original_pp_degree + + def is_change_dp_degree(self): + return self.dp_degree != self.original_dp_degree + def show_2d_mapping(self): print(f'reshaped 2d map ---- begin') @@ -104,11 +123,11 @@ def _build_global_state(self): self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) - def get_zero_checkpoint_state(self, global_rank) -> dict: - return self.zero_checkpoint.get_state_for_global_rank( - self.world_size, - global_rank, - keys_to_ignore=[PARAM_SHAPES]) + def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: + return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index, + keys_to_ignore=[PARAM_SHAPES]) def get_embedding_layer_id(self): return self.layer_keys[EMBEDDING_LAYER_INDEX] diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 3e653ed80489..5221a0531522 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -1,30 +1,30 @@ +from email.mime import base import os +from deepspeed.checkpoint.constants import (MODEL_FILE_PREFIX, + MODEL_FILE_SUFFIX, + OPTIM_FILE_SUFFIX, + ZERO_FILE_PREFIX) -from deepspeed.checkpoint.constants import MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX - -def get_model_ckpt_name_for_rank(base_folder, mp_rank_str, tag=None): +def get_model_ckpt_name_for_rank(base_folder, mp_rank_str): ckpt_name = os.path.join( base_folder, - str(tag), MODEL_FILE_PREFIX + mp_rank_str + MODEL_FILE_SUFFIX, ) return ckpt_name -def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank, tag=None): +def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank): zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}' mp_rank_string = f'_{MODEL_FILE_PREFIX}_{mp_rank:02d}' - zero_ckpt_name = os.path.join( base_folder, - str(tag), zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX, ) return zero_ckpt_name -def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank, tag=None): +def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank): ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}' ckpt_path = os.path.join(base_folder, ckpt_file) return ckpt_path diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 31b313e942ae..54494305212d 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -30,11 +30,14 @@ def get_files_for_rank(self, pp_index, tp_index, dp_index): file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) return [self.file_list[idx] for idx in file_idx_list] - def get_state_for_rank(self, pp_index, tp_index, dp_index): + def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[]): state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) merged_sd = None for state_file in state_file_list: sd = torch.load(state_file, map_location=torch.device('cpu')) + for key in keys_to_ignore: + sd.pop(key, None) + if merged_sd is None: merged_sd = sd else: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e205e2a2f203..be4b997f21f1 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -20,6 +20,7 @@ from torch.distributed.distributed_c10d import _get_global_rank from typing import Callable, Dict, Optional, Union, Iterable +from deepspeed.checkpoint.utils import get_zero_ckpt_name_for_rank from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer From f5db8df8607d64d0b25f99f92edd555ec073bb7a Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 23 Feb 2022 17:39:44 +0000 Subject: [PATCH 09/31] Utils for 3D contraction reshaping --- deepspeed/checkpoint/constants.py | 1 + deepspeed/checkpoint/deepspeed_checkpoint.py | 13 +++- deepspeed/checkpoint/reshape_utils.py | 22 +++--- deepspeed/checkpoint/utils.py | 2 +- deepspeed/checkpoint/zero_checkpoint.py | 73 +++++++++++++++++++- deepspeed/runtime/zero/stage_1_and_2.py | 37 +++++----- tests/unit/test_reshape_checkpoint.py | 58 ++++++++++++++++ 7 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 tests/unit/test_reshape_checkpoint.py diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 120a76747148..c7d38b4b024a 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -11,6 +11,7 @@ BASE_OPTIMIZER_STATE = 'base_optimizer_state' SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" +GROUP_PADDINGS = 'group_paddings' PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index fade0264ce5f..1cc3d4fb11d3 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -4,6 +4,7 @@ from deepspeed.checkpoint.reshape_3d_utils import model_3d_desc from .reshape_utils import (basic_folder_validation, + merge_state, partition_data, get_files, get_files_with_prefix, @@ -173,9 +174,15 @@ def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list ] - # HACK HACK HACK, should be merging i.e., sd = self._merge_state_dicts(sd_list) - sd = sd_list[0] - return sd + + merged_sd = None + for sd in sd_list: + if merged_sd is None: + merged_sd = sd + else: + merged_sd = merge_state(merged_sd, sd) + + return merged_sd def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert tp_index < self.tp_degree diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py index 9a3645fd11ea..c4b02434f77b 100644 --- a/deepspeed/checkpoint/reshape_utils.py +++ b/deepspeed/checkpoint/reshape_utils.py @@ -53,19 +53,15 @@ def _key_list_to_string(key_list): def merge_state_dict(dict_a, dict_b, key_list): - if dict_a.keys() != dict_b.keys(): - print(f'key_list = {_key_list_to_string(key_list)}') - raise ValueError(f'''Cannot merge dicts with different keys, - a = {dict_a.keys()} - b = {dict_b.keys()} - ''') - - return type(dict_a)({ - key: merge_state(dict_a[key], - dict_b[key], - key_list + [str(key)]) - for key in dict_a.keys() - }) + merged_dict = type(dict_a)({}) + + for key, value in dict_b.items(): + if key in dict_a.keys(): + merged_dict[key] = merge_state(dict_a[key], dict_b[key], [str(key)]) + else: + merged_dict[key] = value + + return merged_dict def merge_state_list(list_a, list_b, key_list): diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 5221a0531522..9cc090be5d3e 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -16,7 +16,7 @@ def get_model_ckpt_name_for_rank(base_folder, mp_rank_str): def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank): zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}' - mp_rank_string = f'_{MODEL_FILE_PREFIX}_{mp_rank:02d}' + mp_rank_string = f'_{MODEL_FILE_PREFIX}{mp_rank:02d}' zero_ckpt_name = os.path.join( base_folder, zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX, diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 54494305212d..27035bc1d324 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -1,4 +1,10 @@ import torch + +from deepspeed.checkpoint.constants import (BASE_OPTIMIZER_STATE, + GROUP_PADDINGS, + OPTIMIZER_STATE_DICT, + PARTITION_COUNT) + from .reshape_utils import (basic_folder_validation, get_files, get_files_with_prefix, @@ -7,6 +13,8 @@ from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) +GROUP_STATE_KEY = 'state' + class ZeROCheckpoint(object): def __init__(self, dir): @@ -30,7 +38,12 @@ def get_files_for_rank(self, pp_index, tp_index, dp_index): file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) return [self.file_list[idx] for idx in file_idx_list] - def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[]): + def get_state_for_rank(self, + pp_index, + tp_index, + dp_index, + keys_to_ignore=[], + strip_tensor_paddings=True): state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) merged_sd = None for state_file in state_file_list: @@ -38,11 +51,18 @@ def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[]): for key in keys_to_ignore: sd.pop(key, None) + if strip_tensor_paddings: + self._strip_tensor_paddings(sd) + if merged_sd is None: merged_sd = sd else: merged_sd = merge_state(merged_sd, sd) + self._update_partition_count(merged_sd) + if strip_tensor_paddings: + self._clear_group_paddings(merged_sd) + return merged_sd def print_3d_index_map(self, tag=None): @@ -64,3 +84,54 @@ def print_3d_file_map(self, tag=None): def reshape(self, target_3d_desc: model_3d_desc): self.target_3d = target_3d_desc self._3d_file_map = self.src_3d.reshape(self.target_3d) + + def _strip_tensor_paddings(self, sd): + param_group_states = self._get_param_group_states(sd) + if param_group_states is None: + return + + group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) + if group_paddings is None: + return + + for key, group_state in param_group_states.items(): + if group_paddings[key] == 0: + continue + for state_name, state_value in group_state.items(): + if torch.is_tensor(state_value): + raw_length = state_value.numel() - group_paddings[key] + group_state[state_name] = torch.narrow(state_value, + 0, + 0, + raw_length).clone() + + def _clear_group_paddings(self, sd): + group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) + if group_paddings: + num_groups = len(group_paddings) + sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups + + def _get_optimizer_state(self, sd, state_key): + optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) + if optimizer_state is None: + return None + + return optimizer_state.get(state_key, None) + + def _get_param_group_states(self, sd): + optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None) + if optimizer_state is None: + return None + + base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None) + if base_optimizer_state is None: + return None + + return base_optimizer_state.get(GROUP_STATE_KEY, None) + + def _update_partition_count(self, sd): + partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT) + if partition_counts: + num_groups = len(partition_counts) + sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree + ] * num_groups diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 9e515d2e42d7..eed13719115a 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -11,6 +11,7 @@ from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.runtime.zero.offload_constants import OFFLOAD_CPU_DEVICE, OFFLOAD_OPTIMIZER from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder @@ -19,6 +20,7 @@ from deepspeed.git_version_info import version from deepspeed.checkpoint.constants import (DS_VERSION, + GROUP_PADDINGS, PARTITION_COUNT, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, @@ -268,15 +270,6 @@ def __init__(self, # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group self.bit16_groups.append(param_group['params']) - # Record padding required to align group to world size - if partition_id == dist.get_world_size( - group=self.real_dp_process_group[i]) - 1: - padding = get_alignment_padding(self.bit16_groups[i], - self.partition_count[i]) - else: - padding = 0 - self.groups_padding.append(padding) - # not sure why apex was cloning the weights before flattening # removing cloning here @@ -311,6 +304,15 @@ def __init__(self, see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) + # Record padding required for alignment + if partition_id == dist.get_world_size( + group=self.real_dp_process_group[i]) - 1: + padding = self.bit16_groups_flat[i].numel() - sum( + [t.numel() for t in self.round_robin_bit16_groups[i]]) + else: + padding = 0 + self.groups_padding.append(padding) + if dist.get_rank(group=self.real_dp_process_group[i]) == 0: see_memory_usage( f"After Flattening and after emptying param group {i} cache", @@ -2003,17 +2005,12 @@ def state_dict(self): state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow - if self.elastic_checkpoint: - state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() - else: - state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() - - # Remove paddings for DP alignment to enable loading for other alignment values - fp32_groups_without_padding = self._get_groups_without_padding( - self.single_partition_of_fp32_groups) - state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding - - state_dict[ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS + state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() + state_dict[ + SINGLE_PARTITION_OF_FP32_GROUPS] = self.single_partition_of_fp32_groups + state_dict[ + ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS if self.partition_gradients else ZERO_OPTIMIZATION_OPTIMIZER_STATES + state_dict[GROUP_PADDINGS] = self.groups_padding state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version diff --git a/tests/unit/test_reshape_checkpoint.py b/tests/unit/test_reshape_checkpoint.py new file mode 100644 index 000000000000..317f3bb1661f --- /dev/null +++ b/tests/unit/test_reshape_checkpoint.py @@ -0,0 +1,58 @@ +import pytest +import deepspeed + +from deepspeed.checkpoint import model_3d_desc + + +def _do_reshape(src_3d, tgt_3d): + assert src_3d.can_reshape(tgt_3d) + new_3d_map = src_3d.reshape(tgt_3d) + + assert len(new_3d_map) == tgt_3d.dp_degree + for new_2d_map in new_3d_map: + assert new_2d_map.pp_degree == tgt_3d.pp_degree + assert new_2d_map.tp_degree == tgt_3d.tp_degree + + return new_3d_map + + +# Specify 3d shape as pp/tp/dp +def test_reshape_222_to_111(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=1, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5, 2, 6, 3, 7] + + +def test_reshape_222_to_121(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 2, 6] + assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5, 3, 7] + + +def test_reshape_222_to_122(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=1, tp_degree=2, dp_degree=2) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4] + assert new_3d_map[0].get_data(pp_index=0, tp_index=1) == [1, 5] + assert new_3d_map[1].get_data(pp_index=0, tp_index=0) == [2, 6] + assert new_3d_map[1].get_data(pp_index=0, tp_index=1) == [3, 7] + + +def test_reshape_222_to_211(): + src_3d = model_3d_desc(pp_degree=2, tp_degree=2, dp_degree=2) + tgt_3d = model_3d_desc(pp_degree=2, tp_degree=1, dp_degree=1) + + new_3d_map = _do_reshape(src_3d, tgt_3d) + + assert new_3d_map[0].get_data(pp_index=0, tp_index=0) == [0, 4, 1, 5] + assert new_3d_map[0].get_data(pp_index=1, tp_index=0) == [2, 6, 3, 7] From c12a4e7fb122c0043e190463a399ab1bccbd046b Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 16 May 2022 17:27:03 +0500 Subject: [PATCH 10/31] Remove bogus import --- deepspeed/checkpoint/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 9cc090be5d3e..55e4a29a9c00 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -1,4 +1,3 @@ -from email.mime import base import os from deepspeed.checkpoint.constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, From 1031b32431171f699f88d4487fe00d45101628e6 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 20 May 2022 18:21:52 +0500 Subject: [PATCH 11/31] Support bf16_zero ckpts --- deepspeed/checkpoint/constants.py | 2 ++ deepspeed/checkpoint/deepspeed_checkpoint.py | 13 ++++++------ deepspeed/checkpoint/reshape_3d_utils.py | 11 ++++------ deepspeed/checkpoint/reshape_utils.py | 4 ---- deepspeed/checkpoint/utils.py | 8 ++++---- deepspeed/checkpoint/zero_checkpoint.py | 21 ++++++++++++++------ 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 3a91fc833248..d10d0841c70a 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -29,6 +29,8 @@ ZERO_FILE_PREFIX = 'zero_pp_rank_' OPTIM_FILE_SUFFIX = '_optim_states.pt' MODEL_FILE_SUFFIX = '_model_states.pt' +LAYER_FILE_PREFIX = 'layer_' +BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX ######################################### # Checkpoint utility keys diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 1cc3d4fb11d3..4c32aefa3a99 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -2,15 +2,14 @@ from typing import Dict import torch -from deepspeed.checkpoint.reshape_3d_utils import model_3d_desc +from .reshape_3d_utils import model_3d_desc from .reshape_utils import (basic_folder_validation, merge_state, partition_data, get_files, - get_files_with_prefix, - ZERO_FILE_PREFIX, - LAYER_FILE_PREFIX, - MP_RANK_FILE_PREFIX) + get_files_with_prefix) + +from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map from .zero_checkpoint import ZeROCheckpoint @@ -43,7 +42,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.file_list = get_files(dir) self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) - self.mp_rank_files = get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) + self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX) self.layer_keys = self._get_layer_keys() self.layer_count = len(self.layer_keys) @@ -290,7 +289,7 @@ def _validate_folder(self, dir): file_list = get_files(dir) for file_prefix in [ - MP_RANK_FILE_PREFIX, + MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01' ]: diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py index 5722f49d37ed..b625eb222589 100644 --- a/deepspeed/checkpoint/reshape_3d_utils.py +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -1,9 +1,6 @@ -from .reshape_utils import (get_files, - get_files_with_prefix, - ZERO_FILE_PREFIX, - MP_RANK_FILE_PREFIX, - LAYER_FILE_PREFIX, - partition_data) +from .reshape_utils import (get_files, get_files_with_prefix, partition_data) + +from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map) @@ -74,7 +71,7 @@ def can_reshape(self, target_3d_desc): def get_model_3d_descriptor(dir): file_list = get_files(dir) tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01')) - pp_degree = len(get_files_with_prefix(file_list, MP_RANK_FILE_PREFIX)) // tp_degree + pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX)) dp_degree = max(1, num_zero_files // (pp_degree * tp_degree)) return model_3d_desc(pp_degree, tp_degree, dp_degree) diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py index c4b02434f77b..5c3a687967be 100644 --- a/deepspeed/checkpoint/reshape_utils.py +++ b/deepspeed/checkpoint/reshape_utils.py @@ -2,10 +2,6 @@ import torch from collections import OrderedDict -ZERO_FILE_PREFIX = 'zero_pp_rank_' -LAYER_FILE_PREFIX = 'layer_' -MP_RANK_FILE_PREFIX = 'mp_rank_' - def basic_folder_validation(dir): assert os.path.exists(dir), f'{dir} path does not exist' diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 55e4a29a9c00..cb25e524a201 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -1,8 +1,8 @@ import os -from deepspeed.checkpoint.constants import (MODEL_FILE_PREFIX, - MODEL_FILE_SUFFIX, - OPTIM_FILE_SUFFIX, - ZERO_FILE_PREFIX) +from .constants import (MODEL_FILE_PREFIX, + MODEL_FILE_SUFFIX, + OPTIM_FILE_SUFFIX, + ZERO_FILE_PREFIX) def get_model_ckpt_name_for_rank(base_folder, mp_rank_str): diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 27035bc1d324..01a6ebe9c1d9 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -1,14 +1,15 @@ import torch -from deepspeed.checkpoint.constants import (BASE_OPTIMIZER_STATE, - GROUP_PADDINGS, - OPTIMIZER_STATE_DICT, - PARTITION_COUNT) +from .constants import (BASE_OPTIMIZER_STATE, + GROUP_PADDINGS, + OPTIMIZER_STATE_DICT, + PARTITION_COUNT, + ZERO_FILE_PREFIX, + BF16_ZERO_FILE_PREFIX) from .reshape_utils import (basic_folder_validation, get_files, get_files_with_prefix, - ZERO_FILE_PREFIX, merge_state) from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) @@ -20,8 +21,9 @@ class ZeROCheckpoint(object): def __init__(self, dir): basic_folder_validation(dir) self.dir = dir - self.file_list = get_files_with_prefix(get_files(dir), ZERO_FILE_PREFIX) + self.file_list = self._get_zero_files(dir) self.num_files = len(self.file_list) + assert self.num_files > 0, f'No ZeRO files found in {dir}' self.src_3d = get_model_3d_descriptor(dir) self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree, @@ -135,3 +137,10 @@ def _update_partition_count(self, sd): num_groups = len(partition_counts) sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree ] * num_groups + + def _get_zero_files(self, dir): + file_list = get_files(dir) + zero_files = get_files_with_prefix(file_list, ZERO_FILE_PREFIX) + if len(zero_files) > 0: + return zero_files + return get_files_with_prefix(file_list, BF16_ZERO_FILE_PREFIX) From fd1a377f2f453f168cdd6dfa057d9ce4d660bed1 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 21 May 2022 00:29:08 +0500 Subject: [PATCH 12/31] Add param slice mappings --- deepspeed/checkpoint/constants.py | 1 + deepspeed/checkpoint/deepspeed_checkpoint.py | 13 ++++++++++++ deepspeed/runtime/bf16_optimizer.py | 21 +++++++++++++++++++- deepspeed/runtime/engine.py | 1 + 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index d10d0841c70a..87cb365a3072 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -15,6 +15,7 @@ PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' +PARAM_SLICE_MAPPINGS = 'param_slice_mappings' ######################################### # Module checkpoint keys diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 4c32aefa3a99..85614609614b 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -129,6 +129,11 @@ def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: dp_index=dp_index, keys_to_ignore=[PARAM_SHAPES]) + def get_zero_files(self, pp_index, tp_index, dp_index) -> list: + return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index) + def get_embedding_layer_id(self): return self.layer_keys[EMBEDDING_LAYER_INDEX] @@ -152,6 +157,10 @@ def get_embedding_state(self, tp_index: int) -> Dict: sd = self._merge_state_dicts(sd_list) return sd + def get_embedding_files(self, tp_index: int) -> list: + assert tp_index in self.tp_to_embedding_map.keys() + return self.tp_to_embedding_map[tp_index] + def _get_checkpoint_value(self, key): if not key in self.global_state: sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) @@ -206,6 +215,10 @@ def get_final_norm_state(self, tp_index: int) -> Dict: map_location=torch.device('cpu')) return sd + def get_final_norm_files(self, tp_index: int) -> list: + assert tp_index in self.tp_to_final_norm_map.keys() + return self.tp_to_final_norm_map[tp_index] + def _build_tp_other_layer_map(self, layer_index: int): assert layer_index < len(self.layer_files) layer_files = get_files_with_prefix(self.layer_files, diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 433bb1729b96..d9c647a970fd 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -1,3 +1,4 @@ +from typing import OrderedDict import torch import torch.distributed as dist from deepspeed.runtime.constants import PIPE_REPLICATED @@ -20,7 +21,8 @@ BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, - GROUP_PADDINGS) + GROUP_PADDINGS, + PARAM_SLICE_MAPPINGS) import types @@ -53,6 +55,9 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def get_hp_fragment_address(self): + return self.hp_fragment_address + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() @@ -75,6 +80,7 @@ def get_full_hp_param(self, optim_state_key=None): class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, + param_names, mpu=None, clip_grad=0.0, norm_type=2, @@ -85,6 +91,7 @@ def __init__(self, see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer + self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) self.clip_grad = clip_grad @@ -218,6 +225,17 @@ def _setup_for_real_optimizer(self): # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._param_slice_mappings = self._create_param_mapping() + + def _create_param_mapping(self): + param_mapping = OrderedDict() + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bf16_groups[i]: + if lp._hp_mapping is not None: + lp_name = self.param_names[lp] + param_mapping[lp_name] = lp._hp_mapping.get_hp_fragment_address() + + return param_mapping def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) @@ -455,6 +473,7 @@ def state_dict(self): state_dict[GROUP_PADDINGS] = self.group_paddings state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version + state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings return state_dict diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ec75325edf8f..c0e901c49a60 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1334,6 +1334,7 @@ def _configure_bf16_optimizer(self, optimizer): timers = self.timers if self.wall_clock_breakdown() else None optimizer = BF16_Optimizer( optimizer, + self.param_names, mpu=self.mpu, clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), From 10083db7f40be36a724de310369fdce05cfb1f60 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 24 May 2022 22:40:36 +0500 Subject: [PATCH 13/31] Load universal checkpoints --- deepspeed/checkpoint/constants.py | 1 + deepspeed/runtime/bf16_optimizer.py | 78 ++++++++++++++++++++++++++++- deepspeed/runtime/config.py | 3 ++ deepspeed/runtime/constants.py | 8 ++- deepspeed/runtime/engine.py | 27 +++++++--- 5 files changed, 108 insertions(+), 9 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 87cb365a3072..febd908aa5fa 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -16,6 +16,7 @@ ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' PARAM_SLICE_MAPPINGS = 'param_slice_mappings' +FP32_WEIGHT_KEY = "fp32" ######################################### # Module checkpoint keys diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index d9c647a970fd..09192ed876dd 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -1,5 +1,11 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" + from typing import OrderedDict +from scipy.fft import dst import torch +import os import torch.distributed as dist from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.ops.op_builder import UtilsBuilder @@ -22,7 +28,8 @@ SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, - PARAM_SLICE_MAPPINGS) + PARAM_SLICE_MAPPINGS, + FP32_WEIGHT_KEY) import types @@ -58,6 +65,9 @@ def get_optim_state_fragment(self, key): def get_hp_fragment_address(self): return self.hp_fragment_address + def get_optim_state_keys(self): + return list(self.optim_fragment.keys()) + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() @@ -77,6 +87,35 @@ def get_full_hp_param(self, optim_state_key=None): return reduce_buffer.reshape_as(self) +def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): + hp_mapping = self._hp_mapping + optim_state_keys = hp_mapping.get_optim_state_keys() + hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys + checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} + + for file in checkpoint_files.values(): + assert os.path.isfile(file), f'{file} is not a valid file' + + for key in hp_keys: + ckpt_file = checkpoint_files[key] + full_hp_param = torch.load(ckpt_file) + full_param_numel = full_hp_param.numel() + tp_slice_numel = self.numel() + assert full_param_numel == tp_world_size * tp_slice_numel, \ + f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' + tp_start_offset = tp_rank * tp_slice_numel + dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( + key) + assert dst_tensor.numel() == tp_slice_numel, \ + f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {tp_slice_numel}' + + tp_hp_slice = torch.narrow(full_hp_param.view(dst_tensor.shape), + 0, + tp_start_offset, + tp_slice_numel) + dst_tensor.data.copy_(tp_hp_slice.data) + + class BF16_Optimizer(ZeROOptimizer): def __init__(self, init_optimizer, @@ -262,6 +301,9 @@ def _init_lp_to_hp_mapping(self, lp_param._hp_mapping = None lp_param._dp_group = dp_group lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) + lp_param.load_hp_checkpoint_state = types.MethodType( + load_hp_checkpoint_state, + lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, # 2) current_offset + lp_param.numel() >= partition_start @@ -489,8 +531,23 @@ def refresh_fp32_params(self): def load_state_dict(self, state_dict_list, + checkpoint_folder, load_optimizer_states=True, load_from_fp32_weights=False): + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights) + else: + self._load_legacy_checkpoint(state_dict_list, + load_optimizer_states, + load_from_fp32_weights) + + def _load_legacy_checkpoint(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] @@ -511,11 +568,30 @@ def load_state_dict(self, if load_optimizer_states: self._link_all_hp_params() + def _load_universal_checkpoint(self, + checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights): + self._load_hp_checkpoint_state(checkpoint_folder) + @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups + def _load_hp_checkpoint_state(self, checkpoint_dir): + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + tp_world_size = self.mpu.get_model_parallel_world_size() + + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bf16_groups[i]: + if lp._hp_mapping is not None: + lp.load_hp_checkpoint_state( + os.path.join(checkpoint_dir, + self.param_names[lp]), + tp_rank, + tp_world_size) + def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 1df5912ef172..56777b0aa3ea 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -972,6 +972,9 @@ def _initialize_params(self, param_dict): self.checkpoint_tag_validation_enabled = (validation_mode != ValidationMode.IGNORE) self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL + self.load_universal_checkpoint = checkpoint_params.get( + LOAD_UNIVERSAL_CHECKPOINT, + LOAD_UNIVERSAL_CHECKPOINT_DEFAULT) self.aio_config = get_aio_config(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index ee2e51c6109f..8742a3fd7610 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -387,7 +387,10 @@ class ValidationMode: ######################################### # Checkpoint config params ######################################### -# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]} +# "checkpoint": { +# tag_validation=["Ignore"|"Warn"|"Fail"] +# load_universal=false +# } CHECKPOINT = "checkpoint" CHECKPOINT_TAG_VALIDATION = "tag_validation" CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN @@ -397,6 +400,9 @@ class ValidationMode: ValidationMode.FAIL ] +LOAD_UNIVERSAL_CHECKPOINT = "load_universal" +LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False + ######################################### # Quantization ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c0e901c49a60..f7a17d08795b 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -730,6 +730,9 @@ def loss_scale(self): def gradient_accumulation_steps(self): return self._config.gradient_accumulation_steps + def load_universal_checkpoint(self): + return self._config.load_universal_checkpoint + @property def communication_data_type(self): res = self._config.communication_data_type @@ -2674,18 +2677,28 @@ def get_sparse_tensor_module_names(original_set, return load_path, client_state def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): - zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) - if zero_sd_list is None: - return False + if self.load_universal_checkpoint(): + zero_sd_list = None + checkpoint_folder = f'{os.path.join(load_dir, tag)}_universal' + else: + checkpoint_folder = None + zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) + if zero_sd_list is None: + return False self.optimizer.load_state_dict( state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), - ) - logger.info( - f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}" - ) + checkpoint_folder=checkpoint_folder) + if self.load_universal_checkpoint(): + logger.info( + f'loaded universal zero checpoints from {checkpoint_folder} for rank {self.global_rank}' + ) + else: + logger.info( + f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}" + ) return True def _get_mp_rank_zero_checkpoint_names(self, From 22c755058e89fedc6cb510f47d3cc2fc709d69fd Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 25 May 2022 18:40:56 +0500 Subject: [PATCH 14/31] Per group mappings from Stas --- deepspeed/checkpoint/deepspeed_checkpoint.py | 8 +++++++- deepspeed/runtime/bf16_optimizer.py | 8 +++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 85614609614b..4b8d31e832d7 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -253,7 +253,11 @@ def _dump_mapping(self, data_map, map_tag=None): def _build_transformer_file_map(self): transformer_layer_keys = self.layer_keys[1:-1] file_map = {} + # XXX: this is not guaranteed layers_per_pp = len(transformer_layer_keys) // self.pp_degree + if layers_per_pp == 0: + layers_per_pp = 1 + #print(f"{transformer_layer_keys} {layers_per_pp}") for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp layer_files = get_files_with_prefix(self.layer_files, layer_key) @@ -270,7 +274,9 @@ def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 assert len(self.layer_keys) > 2 - assert (len(self.layer_keys) - 2) % self.pp_degree == 0 + # XXX: fix me - isn't always the case + # only true with --pp-partition-method 'type:transformer|embedding' \ + # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 def validate_files(self): for file in self.file_list: diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 09192ed876dd..27e10a7ab9b7 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -3,7 +3,6 @@ """ from typing import OrderedDict -from scipy.fft import dst import torch import os import torch.distributed as dist @@ -267,12 +266,15 @@ def _setup_for_real_optimizer(self): self._param_slice_mappings = self._create_param_mapping() def _create_param_mapping(self): - param_mapping = OrderedDict() + param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): + param_mapping_per_group = OrderedDict() for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: lp_name = self.param_names[lp] - param_mapping[lp_name] = lp._hp_mapping.get_hp_fragment_address() + param_mapping_per_group[ + lp_name] = lp._hp_mapping.get_hp_fragment_address() + param_mapping.append(param_mapping_per_group) return param_mapping From 5df4135cc4ad493217a902bb1c48b9ca1ef275a3 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 27 May 2022 01:56:24 +0500 Subject: [PATCH 15/31] Hack to load bf16 zero files --- deepspeed/checkpoint/constants.py | 4 ++-- deepspeed/runtime/engine.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index febd908aa5fa..dc79df643af2 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -28,11 +28,11 @@ # Checkpoint naming constants ######################################### MODEL_FILE_PREFIX = 'mp_rank_' -ZERO_FILE_PREFIX = 'zero_pp_rank_' +ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_' OPTIM_FILE_SUFFIX = '_optim_states.pt' MODEL_FILE_SUFFIX = '_model_states.pt' LAYER_FILE_PREFIX = 'layer_' -BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX +BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX ######################################### # Checkpoint utility keys diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f7a17d08795b..0c56a61dab56 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2576,10 +2576,10 @@ def _load_checkpoint(self, model=self.module, mpu=self.mpu, num_experts=self.num_experts) - - self.load_module_state_dict(state_dict=checkpoint['module'], - strict=load_module_strict, - custom_load_fn=custom_load_fn) + if not self.load_universal_checkpoint(): + self.load_module_state_dict(state_dict=checkpoint['module'], + strict=load_module_strict, + custom_load_fn=custom_load_fn) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] From ae2825fd3020e3c61c83a5bfac68a5c440d01923 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 1 Jun 2022 04:10:02 +0500 Subject: [PATCH 16/31] Param attributes --- deepspeed/runtime/bf16_optimizer.py | 100 +++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 27e10a7ab9b7..646c3dde609c 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -92,27 +92,105 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} + # XXX: hack to fix + # need to codifying handling of non-parametric tensors + # we are still trying to load a param which is not trained and is created at run time + if "tied_modules.embed.position_embeddings" in folder: + return + # perhaps just check if the file exists and if not return? but this may mask a potential error + # and random weights will be used instead + for file in checkpoint_files.values(): assert os.path.isfile(file), f'{file} is not a valid file' + # need to deal with slices that were averaged. I thought of 2 ways: + # a. find a way for a client to pass a dict with patterns + # b. see below inside the loop + # XXX: the opposite of averaging here becomes an exact copy of the first slice + # implementation a. + # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + # tp_rank = 0 + # tp_world_size = 1 + # the other approach is to assume that the saved data is correct and if full_hp_param.shape == + # self.shape that means we automatically copy? + for key in hp_keys: ckpt_file = checkpoint_files[key] - full_hp_param = torch.load(ckpt_file) + ckpt_dict = torch.load(ckpt_file) + full_hp_param = ckpt_dict['param'] + + # implementation b. (see notes outside of the loop) + # this version requires no additional data passed from the client + # if the shapes already match it must be slices that were averaged - so we just hack around those + if full_hp_param.shape == self.shape: + tp_rank = 0 + tp_world_size = 1 + + # special case for word_embeddings weights which get padded differently depending on TP degree. + # the converter to universal currently strips the original padding completely so the saved + # weight is padding-free and we just need to add new padding depending on the target TP + # degree + tensor_to_pad = ckpt_dict.get('tensor_to_pad', False) + if tensor_to_pad: + # if "word_embeddings.weight" in folder: + + # print(f"Before {full_hp_param.shape=}") + # XXX: simply reshape to the self.shape*tp_degree? + # or how do we bring the new padded vocab size here? + # pad_to = (50257+pad) * 512 # * tp_world_size # 50432 + # pad_to = 50432 + # target = torch.zeros(pad_to, full_hp_param.shape[1]) + # target[:50257, :] = full_hp_param[:50257, :] + # full_hp_param = target + # print(f"After {full_hp_param.shape=}") + + # In the absense of data passed from the user wrt new padded vocab specific to tp degree + # we can again derive that data by reverse engineering the target shapes like so: + target = torch.zeros(self.shape[0] * tp_world_size, self.shape[1]) + # this relies on making sure the padding was stripped when the universal checkpoint was created + target[:full_hp_param.shape[0], :] = full_hp_param[:full_hp_param.shape[0], :] + full_hp_param = target + # print(f"After {full_hp_param.shape=}") + full_param_numel = full_hp_param.numel() tp_slice_numel = self.numel() + assert full_param_numel == tp_world_size * tp_slice_numel, \ f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' - tp_start_offset = tp_rank * tp_slice_numel dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( key) - assert dst_tensor.numel() == tp_slice_numel, \ - f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {tp_slice_numel}' - tp_hp_slice = torch.narrow(full_hp_param.view(dst_tensor.shape), - 0, - tp_start_offset, - tp_slice_numel) - dst_tensor.data.copy_(tp_hp_slice.data) + print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") + print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") + + # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse + + # of course, we again need to somehow get the names which we don't have + #if "dense_4h_to_h.weight" in folder or "self_attention.dense.weight" in folder: + # chunk_dim = 1 + #else: + # chunk_dim = 0 + + chunk_dim = ckpt_dict.get('cat_dim', 0) + + # this performs the opposite of cat when merging TP slices + tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] + tp_hp_slice = tp_hp_slice.flatten() + + # this deals with zero fragments when DP>1 + # XXX: I'm not sure this is correct but the direction is right + # I'm not sure the shard should always start with 0 + lp_frag_address = hp_mapping.lp_fragment_address + tp_hp_fragment = tp_hp_slice.narrow(0, + lp_frag_address.start, + lp_frag_address.numel) + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + + print(f"{key} SHAPE: {tp_hp_slice.shape=}") + print(f"{key} SHAPE: {dst_tensor.shape=}") + print(f"{key} SHAPE: {tp_hp_fragment.shape=}") + dst_tensor.data.copy_(tp_hp_fragment.data) class BF16_Optimizer(ZeROOptimizer): @@ -583,11 +661,13 @@ def param_groups(self): def _load_hp_checkpoint_state(self, checkpoint_dir): tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_model_parallel_world_size() + tp_world_size = self.mpu.get_slice_parallel_world_size() + # get_model_parallel_world_size() for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: + print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state( os.path.join(checkpoint_dir, self.param_names[lp]), From d11a8dc24701022a25c53d7d6d9185cef8311af2 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 2 Jun 2022 03:42:12 +0500 Subject: [PATCH 17/31] WIP --- deepspeed/runtime/bf16_optimizer.py | 3 ++- deepspeed/runtime/engine.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 646c3dde609c..498ae759c1a5 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -144,7 +144,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # full_hp_param = target # print(f"After {full_hp_param.shape=}") - # In the absense of data passed from the user wrt new padded vocab specific to tp degree + # In the absence of data passed from the user wrt new padded vocab specific to tp degree # we can again derive that data by reverse engineering the target shapes like so: target = torch.zeros(self.shape[0] * tp_world_size, self.shape[1]) # this relies on making sure the padding was stripped when the universal checkpoint was created @@ -660,6 +660,7 @@ def param_groups(self): return self.optimizer.param_groups def _load_hp_checkpoint_state(self, checkpoint_dir): + checkpoint_dir = os.path.join(checkpoint_dir, "zero") tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() # get_model_parallel_world_size() diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0c56a61dab56..225368bd8eca 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2499,7 +2499,9 @@ def load_checkpoint(self, """ if tag is None: - latest_path = os.path.join(load_dir, "latest") + latest_tag = "latest_universal" if self.load_universal_checkpoint( + ) else "latest" + latest_path = os.path.join(load_dir, latest_tag) if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() @@ -2679,7 +2681,7 @@ def get_sparse_tensor_module_names(original_set, def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): if self.load_universal_checkpoint(): zero_sd_list = None - checkpoint_folder = f'{os.path.join(load_dir, tag)}_universal' + checkpoint_folder = f'{os.path.join(load_dir, tag)}' else: checkpoint_folder = None zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) From 691b29d1804da61d65dea0ab85d9a3d8a8f25a1a Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 3 Jun 2022 00:46:10 +0500 Subject: [PATCH 18/31] Fix api bug --- deepspeed/runtime/zero/stage3.py | 3 ++- deepspeed/runtime/zero/stage_1_and_2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e963ef643677..6f1cbb1d0c61 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2812,7 +2812,8 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): def load_state_dict(self, state_dict_list, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + checkpoint_folder=None): r"""Loading a ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 29cff85a69ab..c0492edf990b 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2124,7 +2124,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): def load_state_dict(self, state_dict_list, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + checkpoint_folder=None): r"""Loading ZeRO checkpoint Arguments: From c0a42d360284c6370890dfec6856e4acc5991d9f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 3 Jun 2022 02:50:06 +0500 Subject: [PATCH 19/31] Update lp with local/remote hp --- deepspeed/runtime/bf16_optimizer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 498ae759c1a5..d1f181837338 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -12,6 +12,7 @@ from packaging import version as pkg_version from deepspeed.git_version_info import version +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import print_rank_0 from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, @@ -154,6 +155,9 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): full_param_numel = full_hp_param.numel() tp_slice_numel = self.numel() + if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: + print_rank_0(f'{full_hp_param[:10]=}', force=True) + assert full_param_numel == tp_world_size * tp_slice_numel, \ f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' @@ -503,11 +507,6 @@ def step(self, closure=None): self.update_lp_params() - all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, - dp_process_group=self.real_dp_process_group, - start_alignment_factor=self.nccl_start_alignment_factor, - allgather_bucket_size=self.allgather_bucket_size) - self.clear_hp_grads() self.step_count += 1 @@ -574,6 +573,14 @@ def update_lp_params(self): for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) + # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) + # if i == 0: + # print_rank_0(f'{fp32_partition[:10]=}', force=True) + + all_gather_dp_groups(partitioned_param_groups=self.bf16_partitioned_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) def clear_hp_grads(self): for flat_gradients in self.fp32_groups_gradients_flat: From b4ca4556e1f6eeaa3260e00ae949b56d32afc272 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 4 Jun 2022 05:53:06 +0500 Subject: [PATCH 20/31] Disable vocab padding handling --- deepspeed/runtime/bf16_optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index d1f181837338..5b9880bbbddc 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -132,7 +132,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # weight is padding-free and we just need to add new padding depending on the target TP # degree tensor_to_pad = ckpt_dict.get('tensor_to_pad', False) - if tensor_to_pad: + # TODO: How to handle vocab paddings for TP reshapes + if False and tensor_to_pad: # if "word_embeddings.weight" in folder: # print(f"Before {full_hp_param.shape=}") From b8b54c8335fae60f26e07acd897baa706294af0f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 6 Jun 2022 20:31:03 +0500 Subject: [PATCH 21/31] Update z2 checkpoint --- deepspeed/runtime/zero/stage_1_and_2.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c0492edf990b..700331c8d8ae 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2007,13 +2007,21 @@ def state_dict(self): state_dict['overflow'] = self.overflow state_dict[CLIP_GRAD] = self.clip_grad - state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() - state_dict[ - SINGLE_PARTITION_OF_FP32_GROUPS] = self.single_partition_of_fp32_groups + if self.elastic_checkpoint: + state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state() + else: + state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() + + # Remove paddings for DP alignment to enable loading for other alignment values + fp32_groups_without_padding = self._get_groups_without_padding( + self.single_partition_of_fp32_groups) + state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding + state_dict[ ZERO_STAGE] = ZERO_OPTIMIZATION_GRADIENTS if self.partition_gradients else ZERO_OPTIMIZATION_OPTIMIZER_STATES state_dict[GROUP_PADDINGS] = self.groups_padding state_dict[PARTITION_COUNT] = self.partition_count + state_dict[DS_VERSION] = version return state_dict From be86df9b9cdad7336c020bc03b2a6db66173966d Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 6 Jun 2022 22:12:46 +0500 Subject: [PATCH 22/31] Remove debug prints --- deepspeed/runtime/bf16_optimizer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 5b9880bbbddc..0fe6b0b65c38 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -156,8 +156,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): full_param_numel = full_hp_param.numel() tp_slice_numel = self.numel() - if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: - print_rank_0(f'{full_hp_param[:10]=}', force=True) + # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: + # print_rank_0(f'{full_hp_param[:10]=}', force=True) assert full_param_numel == tp_world_size * tp_slice_numel, \ @@ -165,8 +165,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( key) - print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") - print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") + # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") + # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse @@ -192,9 +192,9 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): assert dst_tensor.numel() == lp_frag_address.numel, \ f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' - print(f"{key} SHAPE: {tp_hp_slice.shape=}") - print(f"{key} SHAPE: {dst_tensor.shape=}") - print(f"{key} SHAPE: {tp_hp_fragment.shape=}") + # print(f"{key} SHAPE: {tp_hp_slice.shape=}") + # print(f"{key} SHAPE: {dst_tensor.shape=}") + # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") dst_tensor.data.copy_(tp_hp_fragment.data) From c87543b75e29b7fd68c546d3dfb9ad5e54df8f3f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 6 Jun 2022 22:25:08 +0500 Subject: [PATCH 23/31] Remove debug prints; Rebase unit test --- deepspeed/runtime/bf16_optimizer.py | 3 +-- tests/unit/test_checkpointing.py | 33 ++++------------------------- 2 files changed, 5 insertions(+), 31 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0fe6b0b65c38..6bf00bec73e4 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -671,12 +671,11 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): checkpoint_dir = os.path.join(checkpoint_dir, "zero") tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() - # get_model_parallel_world_size() for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: - print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") + #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state( os.path.join(checkpoint_dir, self.param_names[lp]), diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 87fbe9df180a..9a2f99e0104b 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -1192,10 +1192,6 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - if load_optim: - torch.save(model.optimizer.optimizer.state_dict(), - os.path.join(tmpdir, - 'opt-state-dict')) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -1203,12 +1199,6 @@ def _go(): model=models[1], model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) - - if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) - curr_sd = model.optimizer.optimizer.state_dict() - assert curr_sd['param_groups'] == saved_sd['param_groups'] - data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, @@ -1264,11 +1254,6 @@ def _go2(models): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - - if load_optim: - torch.save(model.optimizer.optimizer.state_dict(), - os.path.join(tmpdir, - 'opt-state-dict')) model.save_checkpoint(tmpdir) _go2(models) @@ -1279,21 +1264,11 @@ def _go1(models): model, _, _, _ = deepspeed.initialize(config=ds_config, model=models[1], model_parameters=models[1].parameters()) - model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) - if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) - curr_sd = model.optimizer.optimizer.state_dict() - assert curr_sd['param_groups'] == saved_sd['param_groups'] - - data_loader = random_dataloader(model=model, - total_samples=8, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() + with pytest.raises(deepspeed.runtime.zero.utils.ZeRORuntimeException): + model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + else: + model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) _go1(models) From c18ff2d094fc6c17beca039eb1abd86d1f9fc70a Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 6 Jun 2022 22:43:22 +0500 Subject: [PATCH 24/31] Add reshape assert --- deepspeed/runtime/engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 33cd6dbbc630..c5c0e9255b2e 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2684,6 +2684,12 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): zero_sd_list = None checkpoint_folder = f'{os.path.join(load_dir, tag)}' else: + if load_optimizer_states and self.dp_world_size != self.loaded_checkpoint_dp_world_size: + raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ + f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ + f"current world size is {self.dp_world_size}. Automatic adjustment " \ + "of ZeRO's optimizer state partitioning with a new world size is not " \ + "currently supported.") checkpoint_folder = None zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: From 4ea36b74c3dd5e86f1e4586fc906a15bd0f56b13 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 7 Jun 2022 03:19:14 +0500 Subject: [PATCH 25/31] Padding --- deepspeed/runtime/bf16_optimizer.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 6bf00bec73e4..c539ec3c85df 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -131,9 +131,9 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # the converter to universal currently strips the original padding completely so the saved # weight is padding-free and we just need to add new padding depending on the target TP # degree - tensor_to_pad = ckpt_dict.get('tensor_to_pad', False) - # TODO: How to handle vocab paddings for TP reshapes - if False and tensor_to_pad: + vocab_divisibility_padding_tensor = ckpt_dict.get('vocab_divisibility_padding_tensor', None) + # TODO: Currently broken for tp reshaping, e.g. 2 -> 1 + if vocab_divisibility_padding_tensor is not None: # if "word_embeddings.weight" in folder: # print(f"Before {full_hp_param.shape=}") @@ -148,10 +148,22 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # In the absence of data passed from the user wrt new padded vocab specific to tp degree # we can again derive that data by reverse engineering the target shapes like so: - target = torch.zeros(self.shape[0] * tp_world_size, self.shape[1]) - # this relies on making sure the padding was stripped when the universal checkpoint was created - target[:full_hp_param.shape[0], :] = full_hp_param[:full_hp_param.shape[0], :] - full_hp_param = target + padded_target_vocab_size = self.shape[0] * tp_world_size + if padded_target_vocab_size > full_hp_param.shape[0]: + # Need to expand + padding_tensor = vocab_divisibility_padding_tensor.expand(padded_target_vocab_size-full_hp_param.shape[0]) + # Implement the following concat in efficient way using pad + #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) + full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_tensor.shape[0]), "constant", 0) + full_hp_param[:-padding_tensor.shape[0],:] = padding_tensor + else: + # Need to shrink or keep the same + full_hp_param = full_hp_param[:padded_target_vocab_size, :] + + # target = torch.zeros(self.shape[0] * tp_world_size, self.shape[1]) + # # this relies on making sure the padding was stripped when the universal checkpoint was created + # target[:full_hp_param.shape[0], :] = full_hp_param[:full_hp_param.shape[0], :] + # full_hp_param = target # print(f"After {full_hp_param.shape=}") full_param_numel = full_hp_param.numel() From 03715817dddf9f91a6cc0dae5a38d9805eabb7b9 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 7 Jun 2022 03:24:02 +0500 Subject: [PATCH 26/31] Typo --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c5c0e9255b2e..d0f1c9a091b1 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2702,7 +2702,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): checkpoint_folder=checkpoint_folder) if self.load_universal_checkpoint(): logger.info( - f'loaded universal zero checpoints from {checkpoint_folder} for rank {self.global_rank}' + f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}' ) else: logger.info( From a74abc1e63504600272f38731c9db16971058398 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 8 Jun 2022 00:35:03 +0500 Subject: [PATCH 27/31] Catch nonexistent checkpoint path --- deepspeed/runtime/bf16_optimizer.py | 21 +++++++++++++++------ deepspeed/runtime/engine.py | 16 +++++++++++----- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index c539ec3c85df..a060c528d239 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -131,8 +131,10 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # the converter to universal currently strips the original padding completely so the saved # weight is padding-free and we just need to add new padding depending on the target TP # degree - vocab_divisibility_padding_tensor = ckpt_dict.get('vocab_divisibility_padding_tensor', None) - # TODO: Currently broken for tp reshaping, e.g. 2 -> 1 + vocab_divisibility_padding_tensor = ckpt_dict.get( + 'vocab_divisibility_padding_tensor', + None) + # TODO: Currently broken for tp reshaping, e.g. 2 -> 1 if vocab_divisibility_padding_tensor is not None: # if "word_embeddings.weight" in folder: @@ -151,11 +153,18 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): padded_target_vocab_size = self.shape[0] * tp_world_size if padded_target_vocab_size > full_hp_param.shape[0]: # Need to expand - padding_tensor = vocab_divisibility_padding_tensor.expand(padded_target_vocab_size-full_hp_param.shape[0]) + padding_tensor = vocab_divisibility_padding_tensor.expand( + padded_target_vocab_size - full_hp_param.shape[0]) # Implement the following concat in efficient way using pad - #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) - full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_tensor.shape[0]), "constant", 0) - full_hp_param[:-padding_tensor.shape[0],:] = padding_tensor + #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) + full_hp_param = torch.nn.functional.pad(full_hp_param, + (0, + 0, + 0, + padding_tensor.shape[0]), + "constant", + 0) + full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor else: # Need to shrink or keep the same full_hp_param = full_hp_param[:padded_target_vocab_size, :] diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d0f1c9a091b1..c93f3424c8ae 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2507,11 +2507,16 @@ def load_checkpoint(self, with open(latest_path, "r") as fd: tag = fd.read().strip() else: - logger.warning( - f"Unable to find latest file at {latest_path}, if trying to load latest " - "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." - ) - return None, None + if self.load_universal_checkpoint(): + raise ValueError( + f'Invalid for universal checkpoint: {latest_path} does not exist' + ) + else: + logger.warning( + f"Unable to find latest file at {latest_path}, if trying to load latest " + "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." + ) + return None, None if self.zero_optimization_partition_weights(): # Prepare for checkpoint load by ensuring all parameters are partitioned @@ -2700,6 +2705,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder) + if self.load_universal_checkpoint(): logger.info( f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}' From 529dbaebf275e03f397d8828328ba11ac17cfa5b Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 8 Jun 2022 00:59:31 +0500 Subject: [PATCH 28/31] Cleanup --- deepspeed/runtime/bf16_optimizer.py | 60 +++++------------------------ 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index a060c528d239..485c79e0ea22 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -93,34 +93,24 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - # XXX: hack to fix - # need to codifying handling of non-parametric tensors - # we are still trying to load a param which is not trained and is created at run time - if "tied_modules.embed.position_embeddings" in folder: - return - # perhaps just check if the file exists and if not return? but this may mask a potential error - # and random weights will be used instead - for file in checkpoint_files.values(): assert os.path.isfile(file), f'{file} is not a valid file' - # need to deal with slices that were averaged. I thought of 2 ways: - # a. find a way for a client to pass a dict with patterns - # b. see below inside the loop - # XXX: the opposite of averaging here becomes an exact copy of the first slice - # implementation a. - # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): - # tp_rank = 0 - # tp_world_size = 1 - # the other approach is to assume that the saved data is correct and if full_hp_param.shape == - # self.shape that means we automatically copy? - for key in hp_keys: ckpt_file = checkpoint_files[key] ckpt_dict = torch.load(ckpt_file) full_hp_param = ckpt_dict['param'] - # implementation b. (see notes outside of the loop) + # need to deal with slices that were averaged. + # the opposite of averaging here becomes an exact copy of the first slice + # I thought of 2 ways: + # implementation a. find a way for a client to pass a dict with patterns + # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + # tp_rank = 0 + # tp_world_size = 1 + # the other approach is to assume that the saved data is correct and if full_hp_param.shape == + # self.shape that means we automatically copy? + # implementation b. # this version requires no additional data passed from the client # if the shapes already match it must be slices that were averaged - so we just hack around those if full_hp_param.shape == self.shape: @@ -134,20 +124,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): vocab_divisibility_padding_tensor = ckpt_dict.get( 'vocab_divisibility_padding_tensor', None) - # TODO: Currently broken for tp reshaping, e.g. 2 -> 1 if vocab_divisibility_padding_tensor is not None: - # if "word_embeddings.weight" in folder: - - # print(f"Before {full_hp_param.shape=}") - # XXX: simply reshape to the self.shape*tp_degree? - # or how do we bring the new padded vocab size here? - # pad_to = (50257+pad) * 512 # * tp_world_size # 50432 - # pad_to = 50432 - # target = torch.zeros(pad_to, full_hp_param.shape[1]) - # target[:50257, :] = full_hp_param[:50257, :] - # full_hp_param = target - # print(f"After {full_hp_param.shape=}") - # In the absence of data passed from the user wrt new padded vocab specific to tp degree # we can again derive that data by reverse engineering the target shapes like so: padded_target_vocab_size = self.shape[0] * tp_world_size @@ -169,12 +146,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # Need to shrink or keep the same full_hp_param = full_hp_param[:padded_target_vocab_size, :] - # target = torch.zeros(self.shape[0] * tp_world_size, self.shape[1]) - # # this relies on making sure the padding was stripped when the universal checkpoint was created - # target[:full_hp_param.shape[0], :] = full_hp_param[:full_hp_param.shape[0], :] - # full_hp_param = target - # print(f"After {full_hp_param.shape=}") - full_param_numel = full_hp_param.numel() tp_slice_numel = self.numel() # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: @@ -190,22 +161,12 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse - - # of course, we again need to somehow get the names which we don't have - #if "dense_4h_to_h.weight" in folder or "self_attention.dense.weight" in folder: - # chunk_dim = 1 - #else: - # chunk_dim = 0 - chunk_dim = ckpt_dict.get('cat_dim', 0) # this performs the opposite of cat when merging TP slices tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] tp_hp_slice = tp_hp_slice.flatten() - # this deals with zero fragments when DP>1 - # XXX: I'm not sure this is correct but the direction is right - # I'm not sure the shard should always start with 0 lp_frag_address = hp_mapping.lp_fragment_address tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, @@ -386,7 +347,6 @@ def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, param_group in enumerate(self.optimizer.param_groups): # Link bf16 and fp32 params in partition - # TODO: Make this configurable partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_size = self.bf16_groups_flat[i].numel() // dp_world_size self._link_hp_params(self.bf16_groups[i], From 9e2766fa4f560b41c2cde8a77c246ec256432057 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 10 Jun 2022 05:41:56 +0500 Subject: [PATCH 29/31] Restore checkpoint state comparisons --- tests/unit/test_checkpointing.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 9a2f99e0104b..b214106e9af7 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -97,7 +97,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): for s0, s1 in zip(state0.values(), state1.values()): if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' - assert torch.equal(s0, s1) + assert torch.equal(s0.to('cpu'), s1.to('cpu')) else: assert s0 == s1 @@ -1192,6 +1192,10 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -1199,6 +1203,11 @@ def _go(): model=models[1], model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) + + if load_optim: + saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + curr_sd = model.optimizer.optimizer.state_dict() + assert curr_sd['param_groups'] == saved_sd['param_groups'] data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, @@ -1254,6 +1263,11 @@ def _go2(models): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + if load_optim: + torch.save(model.optimizer.optimizer.state_dict(), + os.path.join(tmpdir, + 'opt-state-dict')) model.save_checkpoint(tmpdir) _go2(models) From 14980ad48713a4f0fddc5bde60b167eaa631d5f5 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 18 Jul 2022 20:55:36 +0000 Subject: [PATCH 30/31] Add torch version guards --- tests/unit/test_checkpointing.py | 27 ++++++++++++++++++--------- tests/unit/util.py | 20 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index cb432226721e..e38a3abf54aa 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -16,7 +16,7 @@ from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 -from .util import required_torch_version +from .util import required_minimum_torch_version, required_torch_version import itertools import argparse @@ -88,18 +88,22 @@ def compare_model_states(saved_model, assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}' +def _compare_state_dicts(state0, state1): + for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()): + if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): + assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' + assert torch.equal(s0.to('cpu'), s1.to('cpu')) + else: + assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' + + def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer for state0, state1 in zip(saved_optimizer.state.values(), loaded_optimizer.state.values()): - for s0, s1 in zip(state0.values(), state1.values()): - if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): - assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' - assert torch.equal(s0.to('cpu'), s1.to('cpu')) - else: - assert s0 == s1 + _compare_state_dicts(state0, state1) def compare_lr_scheduler_states(saved_model, loaded_model): @@ -1178,6 +1182,10 @@ def test_checkpoint_zero_elastic(tmpdir, elastic_save, elastic_load, load_optim) @distributed_test(world_size=[2]) def _go(): + # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to + # false positive mismatches in checkpoint state comparisons. + # Newer torch versions store tensor ids as 0, 1, 2, ... + compare_load_optim = load_optim and required_minimum_torch_version(1, 4) models = [SimpleModel(hidden_dim) for _ in range(2)] model, _, _, _ = deepspeed.initialize(config=ds_config, model=models[0], @@ -1190,7 +1198,7 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - if load_optim: + if compare_load_optim: torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, 'opt-state-dict')) @@ -1202,10 +1210,11 @@ def _go(): model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) - if load_optim: + if compare_load_optim: saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) curr_sd = model.optimizer.optimizer.state_dict() assert curr_sd['param_groups'] == saved_sd['param_groups'] + data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, diff --git a/tests/unit/util.py b/tests/unit/util.py index 79a459da3c14..0aa72a2ad032 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -25,3 +25,23 @@ def bf16_required_version_check(): return True else: return False + + +def required_minimum_torch_version(major_version, minor_version): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR < major_version: + return False + + return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version + + +def required_maximum_torch_version(major_version, minor_version): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR > major_version: + return False + + return TORCH_MAJOR < major_version or TORCH_MINOR <= minor_version From 868c463abbce0bd8b3b2ca738ad38681cc7b53cf Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 19 Jul 2022 19:45:48 +0500 Subject: [PATCH 31/31] More precise avoidance of false positives. --- tests/unit/test_checkpointing.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index e38a3abf54aa..dd93e006081f 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -88,8 +88,11 @@ def compare_model_states(saved_model, assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}' -def _compare_state_dicts(state0, state1): +def _compare_state_dicts(state0, state1, expected_mismatch_keys=[]): for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()): + assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' + if k0 in expected_mismatch_keys: + continue if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' assert torch.equal(s0.to('cpu'), s1.to('cpu')) @@ -1185,7 +1188,8 @@ def _go(): # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to # false positive mismatches in checkpoint state comparisons. # Newer torch versions store tensor ids as 0, 1, 2, ... - compare_load_optim = load_optim and required_minimum_torch_version(1, 4) + expected_mismatch_keys = [] if required_minimum_torch_version(1, + 4) else ['params'] models = [SimpleModel(hidden_dim) for _ in range(2)] model, _, _, _ = deepspeed.initialize(config=ds_config, model=models[0], @@ -1198,7 +1202,7 @@ def _go(): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - if compare_load_optim: + if load_optim: torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, 'opt-state-dict')) @@ -1210,10 +1214,13 @@ def _go(): model_parameters=models[1].parameters()) model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) - if compare_load_optim: + if load_optim: saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) curr_sd = model.optimizer.optimizer.state_dict() - assert curr_sd['param_groups'] == saved_sd['param_groups'] + for curr_param_group, saved_param_group in zip(curr_sd['param_groups'], saved_sd['param_groups']): + _compare_state_dicts(curr_param_group, + saved_param_group, + expected_mismatch_keys) data_loader = random_dataloader(model=model, total_samples=8,