diff --git a/scripts/fsdp_hybrid.py b/scripts/fsdp_hybrid.py index 506c9c6..c057a9a 100644 --- a/scripts/fsdp_hybrid.py +++ b/scripts/fsdp_hybrid.py @@ -1,9 +1,12 @@ import contextlib import gc +import itertools +import math import os import time from dataclasses import dataclass from datetime import timedelta +from functools import partial from timeit import default_timer as timer from typing import Any, Type @@ -17,10 +20,12 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel +from torch.distributed.device_mesh import DeviceMesh from transformers import AutoConfig from maester.checkpoint import CheckpointManager -from maester.datasets import build_hf_data_loader, create_tokenizer, MosaicDataset +from maester.datasets import (MosaicDataset, build_hf_data_loader, + create_tokenizer) from maester.datasets.experimental_otf import build_experimental_data_loader from maester.datasets.mosaic_dataset import MosaicDataLoader from maester.log_utils import init_logger, logger @@ -29,10 +34,89 @@ from maester.metrics import build_gpu_memory_monitor, build_metric_logger from maester.models import (model_name_to_cls, model_name_to_tokenizer, models_config) -from maester.parallelize_llama import ParallelDims, parallelize_llama +from maester.parallelisms import ParallelDims, parallelize_llama from maester.profiling import maybe_enable_profiling -from maester.utils import (dist_max, dist_mean, get_num_flop_per_token, get_num_params, get_peak_flops, - init_distributed, set_pg_timeouts) +from maester.utils import (dist_max, dist_mean, get_num_flop_per_token, + get_num_params, get_peak_flops, init_distributed, + set_pg_timeouts) + +# not merged into pytorch so define it here +# from torch.distributed.utils import _sync_module_states_with_mesh +from torch.distributed.utils import _verify_param_shape_across_processes +def _sync_module_states_with_mesh(module: torch.nn.Module, mesh: "DeviceMesh") -> None: + """ + Broadcast from the module states of the first rank of ``mesh`` to other ranks + within the same ``mesh``. + This API is similar to ``_sync_module_states`` but is designed for DeviceMesh. + Instead of extending ``_sync_module_states``, creating a new API makes the + samentic simpler (e.g., the meaning of ``src`` is different for PG and + DeviceMesh). + """ + module_states: list[torch.Tensor] = [] + + # Lazy import to avoid circular dependency + from torch.distributed._tensor import DTensor + + for state in itertools.chain(module.parameters(), module.buffers()): + module_states.append(state.to_local() if isinstance(state, DTensor) else state) + + with torch.no_grad(): + pg = mesh.get_group() + src = dist.get_process_group_ranks(pg)[0] + _verify_param_shape_across_processes(pg, module_states) + + for state in module_states: + # `dist._broadcast_coalesced` will increase the peak memory usage due to + # recordStream. We can implement the broadcast coalescing to speed up. + dist.broadcast(state, src, group=pg) + +def reshape(self, new_shape: tuple[int, ...], new_dim_names: tuple[str, ...] | None = None) -> DeviceMesh: + """ + Reshape the DeviceMesh to a new shape while preserving the total number of devices. + + Args: + new_shape (Tuple[int, ...]): The new shape for the DeviceMesh. + new_dim_names (Optional[Tuple[str, ...]]): New names for the dimensions of the reshaped mesh. + If provided, must have the same length as new_shape. If not provided, dimension names will be kept if compatible, otherwise reset. + + Returns: + The same DeviceMesh object with the reshaped structure. + + Raises: + ValueError: If the new shape is incompatible with the total number of devices, + or if new_dim_names is provided but has a different length than new_shape. + """ + if math.prod(new_shape) != self.mesh.numel(): + raise ValueError("New shape must have the same number of elements as the current mesh.") + + if new_dim_names is not None: + if len(new_dim_names) != len(new_shape): + raise ValueError("new_dim_names must have the same length as new_shape.") + if len(set(new_dim_names)) != len(new_dim_names): + raise ValueError("Each name in new_dim_names must be unique.") + + # Reshape the mesh tensor in-place + self.mesh = self.mesh.reshape(new_shape) + + # Update mesh_dim_names + if new_dim_names is not None: + self.mesh_dim_names = new_dim_names + elif self.mesh_dim_names is not None and len(self.mesh_dim_names) != len(new_shape): + self.mesh_dim_names = None # Reset if incompatible + + # Update the coordinate of the current rank on the reshaped mesh + rank_coords = (self.mesh == self.get_rank()).nonzero() + self._coordinate_on_dim = rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + + # Only reinitialize process groups if the number of dimensions has changed + if len(new_shape) != self.ndim: + self._init_process_groups() + + return self + +# Add the reshape method to the DeviceMesh class +DeviceMesh.reshape = reshape + class DatasetConfig(BaseModel): data_logical_shards: int = 1024 @@ -48,11 +132,14 @@ class Config(BaseModel): model_config = ConfigDict(frozen=True, protected_namespaces=(), arbitrary_types_allowed=True) job_folder: str = "jobs/" - job_name: str = "fineweb-1B-llama2-v2" + job_name: str = "fineweb-1B-llama2-testing" max_grad_norm: float = 1.0 gc_freq: int = 4 + data_parallel_type: str = "fsdp" data_parallel_degree: int = -1 + data_parallel_replicate: int = 1 # for hsdp, not ready to use (set to 1) + context_parallel_degree: int = 1 # not ready for use tensor_parallel_degree: int = 1 pipeline_parallel_degree: int = 1 train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 4096 seqlen = 2.1M tokens per batch @@ -84,6 +171,7 @@ class Config(BaseModel): # model model_name: str = "llama3" flavor: str = "1B-v2" + num_future_tokens: int = 4 seq_len: int = 4096 norm_type: str = "rmsnorm" @@ -92,6 +180,7 @@ class Config(BaseModel): opt_cfg: dict[str, Any] = dict( # TODO: don't use dict, not validateable lr = 4e-4, # max lr, schedule reduces it at points betas = (0.9, 0.95), + weight_decay=0.1, # foreach=True, # foreach might work where fused doesn't fused=True ) @@ -108,6 +197,10 @@ class Config(BaseModel): ac_mode: str = "none" # "full" | "selective" | "none" selective_ac_option: str | int = "op" + # experimental + enable_async_tensor_parallel: bool = False + enable_compiled_autograd: bool = False + # profiling enable_profiling: bool = False traces_folder: str = "traces" @@ -138,7 +231,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: @record def main(): init_logger() - logger.info(f"Starting training.") + logger.info(f"Starting job: ") cfg = Config() # TODO: enable configuring? @@ -150,15 +243,44 @@ def main(): world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( dp=cfg.data_parallel_degree, + cp=cfg.context_parallel_degree, tp=cfg.tensor_parallel_degree, pp=cfg.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=cfg.enable_loss_parallel, + dp_type=cfg.data_parallel_type, + dp_replicate=cfg.data_parallel_replicate, ) torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) init_distributed(cfg) + # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + if parallel_dims.cp_enabled: + dp_mesh = dp_mesh.reshape( + (dp_mesh.size() // parallel_dims.cp, parallel_dims.cp), + ("dp", "cp") + )["dp"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + # if parallel_dims.cp_enabled: + # cp_mesh = world_mesh["cp"] + # context_parallel_ctx = partial( + # context_parallel_buffers, + # cp_rank=cp_mesh.get_local_rank(), + # cp_world_size=cp_mesh.size(), + # ) + # else: + # context_parallel_ctx = partial( + # context_parallel_buffers, + # cp_rank=0, + # cp_world_size=1, + # ) # hf_config = AutoConfig.from_pretrained(cfg.tokenizer_name) # for vocab size below TODO: fails on LUMI? @@ -172,6 +294,7 @@ def main(): model_config.norm_type = cfg.norm_type model_config.vocab_size = 32000 # hf_config.vocab_size model_config.max_seq_len = cfg.seq_len + model_config.n_future_tokens = cfg.num_future_tokens # del hf_config # only needed for vocab size with torch.device("meta"): @@ -211,14 +334,7 @@ def main(): f"({gpu_mem_stats.max_reserved_pct:.2f}%)" ) - # build dataloader - if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree = dp_mesh.size() - dp_rank = dp_mesh.get_local_rank() - else: - dp_degree, dp_rank = 1, 0 # data_loader = build_hf_data_loader( # "c4_mini", # "src/maester/datasets/c4_mini", @@ -273,10 +389,17 @@ def loss_fn(pred, labels): states={"train_state": train_state}, cfg=cfg, ) - checkpoint.load() + checkpoint_loaded = checkpoint.load() # TODO: do we want to checkpoint metrics? + if not checkpoint_loaded and parallel_dims.dp_enabled and parallel_dims.dp_replicate > 1: + # sync params if hsdp + replicate_mesh = dp_mesh.reshape( + (parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate) + ) + _sync_module_states_with_mesh(model, replicate_mesh) + data_iterator = iter(data_loader) logger.info(f"Training starts at step {train_state.step}") @@ -305,23 +428,51 @@ def loss_fn(pred, labels): ntokens_since_last_log += labels.numel() data_loading_times.append(timer() - data_load_start) - input_ids = input_ids.cuda() - labels = labels.cuda() - optimizer.zero_grad() - # non-pp loss parallel, pp is not implemented - with loss_parallel_ctx(): - pred = sharded_model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + # with context_parallel_ctx( + # buffers=[input_ids, labels, model.freqs_cis], + # seq_dims=[1,1,0], + # keep_orig_buffers=[False, False, True] + # ): + with contextlib.nullcontext(): + input_ids = input_ids.cuda() + labels = labels.cuda() + + # non-pp loss parallel, pp is not implemented + with loss_parallel_ctx(): + z = sharded_model.trunk(input_ids, sharded_model.freqs_cis) # (bsz, seq_len, dim) + d = z.detach() + d.requires_grad_(True) + # print(f"d shape: {d.shape}, requires_grad: {d.requires_grad}") + + for i, head in enumerate(sharded_model.heads.values()): + h = head(d, sharded_model.freqs_cis) + # print(f"h shape after head {i}: {h.shape}, requires_grad: {h.requires_grad}") + h = sharded_model.norm(h) + # print(f"h shape after norm {i}: {h.shape}, requires_grad: {h.requires_grad}") + pred = sharded_model.output(h) + # print(f"pred shape for head {i}: {pred.shape}, requires_grad: {pred.requires_grad}") + loss = loss_fn(pred, labels) # TODO: labels for num_future_tokens > 1 + # print(f"loss for head {i}: {loss.item()}, requires_grad: {loss.requires_grad}") + + # pred.shape=(bs, seq_len, vocab_size) + # need to free before bwd to avoid peaking memory + del h, pred + loss.backward() + # print(f"d.grad after head {i}: {d.grad is not None}, shape: {d.grad.shape if d.grad is not None else None}") + + # print(f"Final d.grad shape: {d.grad.shape if d.grad is not None else None}") + # print(f"z grad_fn: {z.grad_fn}") + z.backward(gradient=d.grad) + total_grad_norm = torch.nn.utils.clip_grad_norm_( sharded_model.parameters(), cfg.max_grad_norm, foreach=True ) + + # optimizer step + checkpoint.wait_for_staging() optimizer.step() scheduler.step() diff --git a/src/maester/checkpoint.py b/src/maester/checkpoint.py index 611f5b1..c942dc1 100644 --- a/src/maester/checkpoint.py +++ b/src/maester/checkpoint.py @@ -228,3 +228,9 @@ def load(self, step: int = -1) -> bool: f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds" ) return True + + def wait_for_staging(self) -> None: + """ + To be implemented as part of async checkpointing. + """ + pass diff --git a/src/maester/models/llama/__init__.py b/src/maester/models/llama/__init__.py index 3b2cb13..477a7af 100644 --- a/src/maester/models/llama/__init__.py +++ b/src/maester/models/llama/__init__.py @@ -57,6 +57,15 @@ multiple_of=1024, rope_theta=500000, ), + "1B-v2": ModelArgs( + dim=1536, + n_layers=24, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=512, + rope_theta=500000, + ), "8B": ModelArgs( dim=4096, n_layers=32, diff --git a/src/maester/models/llama/model.py b/src/maester/models/llama/model.py index 4615ab9..f803ffa 100644 --- a/src/maester/models/llama/model.py +++ b/src/maester/models/llama/model.py @@ -9,6 +9,7 @@ from dataclasses import dataclass +import itertools from typing import Optional, Tuple import torch @@ -23,6 +24,7 @@ class ModelArgs: dim: int = 4096 n_layers: int = 32 + n_future_tokens: int = 1 n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer @@ -349,6 +351,33 @@ def init_weights(self): self.attention.init_weights(self.weight_init_std) self.feed_forward.init_weights(self.weight_init_std) +class TransformerTrunk(nn.Module): + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.n_future_tokens = model_args.n_future_tokens + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers - model_args.n_future_tokens): # do not build last layers + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + def init_weights(self): + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + layer.init_weights() + + def forward(self, tokens: torch.Tensor, freqs_cis: torch.Tensor): + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + # Model trunk + for layer in self.layers.values(): + h = layer(h, freqs_cis) + return h class Transformer(nn.Module): """ @@ -361,8 +390,11 @@ class Transformer(nn.Module): model_args (ModelArgs): Model configuration arguments. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. + n_future_tokens (int): Number of prediction heads in the model (= 1 + `len(extra_heads)`). tok_embeddings (ParallelEmbedding): Token embeddings. layers (torch.nn.ModuleList): List of Transformer blocks. + extra_heads (torch.nn.ModuleList): List of Transformer blocks + (additional prediction heads for multi-token prediction). norm (RMSNorm): Layer normalization for the model output. output (ColumnParallelLinear): Linear layer for final output. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. @@ -374,8 +406,7 @@ def __init__(self, model_args: ModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers - - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.n_future_tokens = model_args.n_future_tokens # TODO persistent should be set to false, since this buffer can be recomputed. # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, @@ -386,9 +417,13 @@ def __init__(self, model_args: ModelArgs): # just the non-persistent buffers that is called after loading checkpoints. self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=False) - self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.trunk = TransformerTrunk(model_args) + + # Prediction heads (both the last layer and all extras) + # `layer_id` counts contiguously from the first Transformer block. + self.heads = torch.nn.ModuleDict() + for layer_id in range(self.n_layers - self.n_future_tokens, self.n_layers): + self.heads[str(layer_id)] = TransformerBlock(layer_id, model_args) self.norm = create_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps @@ -411,8 +446,9 @@ def init_weights(self): """ with torch.device(self.freqs_cis.device): self.freqs_cis = self._precompute_freqs_cis() - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): + + self.trunk.init_weights() + for layer in self.heads.values(): layer.init_weights() self.norm.reset_parameters() final_out_std = self.model_args.dim**-0.5 @@ -434,22 +470,37 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_theta, ) - def forward(self, tokens: torch.Tensor): + def forward(self, tokens: torch.Tensor, return_all_heads: bool = False): """ Perform a forward pass through the Transformer model. + It is not efficient to call this for training with return_all_heads=True. + Use `trunk` and `head` instead. Args: tokens (torch.Tensor): Input token indices. + return_all_heads (bool, optional): Whether to return logits + for all prediction heads. Defaults to False. Returns: - torch.Tensor: Output logits after applying the Transformer model. + torch.Tensor: Output logits after applying the Transformer model + of shape (batch_size, seq_len, n_future_tokens, vocab_size). + Note: + If return_all_heads is False, the output logits broadcast to + (batch_size, seq_len, vocab_size) and are compatible with standard + decoding. """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): + h = self.trunk(tokens) + + # Prediction heads (both the last layer and all extras) + latents = [] + prediction_heads = [list(self.layers.values())[-1]] + list(self.extra_heads.values()) + n_heads_to_use = self.n_future_tokens if return_all_heads else 1 + for layer in prediction_heads[:n_heads_to_use]: h = layer(h, self.freqs_cis) + latents.append(h) + + h = torch.stack(latents, dim=-2) # (_bsz, seqlen, n_heads_to_use, dim) h = self.norm(h) if self.norm else h output = self.output(h).float() if self.output else h return output diff --git a/src/maester/parallelisms/__init__.py b/src/maester/parallelisms/__init__.py new file mode 100644 index 0000000..3469d91 --- /dev/null +++ b/src/maester/parallelisms/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from functools import cached_property + +from torch.distributed.device_mesh import init_device_mesh +from maester.log_utils import logger +from maester.parallelisms.parallelize_llama import parallelize_llama + +@dataclass +class ParallelDims: + dp: int + cp: int + tp: int + pp: int + world_size: int + enable_loss_parallel: bool + dp_type: str + dp_replicate: int + + def __post_init__(self): + self.dp_type = self.dp_type.lower() + self._validate() + + def _validate(self): + dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp + if dp == -1: + self.dp = dp = self.world_size // (cp * tp * pp) + assert dp >= 1, dp + assert dp % self.dp_replicate == 0, (self.dp_replicate, dp) + assert cp >= 1, cp + assert tp >= 1, tp + assert pp >= 1, pp + assert ( + dp * cp * tp * pp == self.world_size + ), f"Invalid parallel dims: dp({dp}) * cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp", "hsdp") + + def build_mesh(self, device_type): + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp * self.cp, self.tp], ["pp", "dp", "tp"], strict=True + ): + if d > 1: + dims.append(d) + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + return init_device_mesh(device_type, dims, mesh_dim_names=names) + + @property + def dp_enabled(self): + return self.dp > 1 + + @property + def cp_enabled(self): + return self.cp > 1 + + @property + def tp_enabled(self): + return self.tp > 1 + + @property + def pp_enabled(self): + return self.pp > 1 + + @property + def loss_parallel_enabled(self): + return self.tp > 1 and self.enable_loss_parallel + + @cached_property + def model_parallel_size(self): + return self.tp * self.pp \ No newline at end of file diff --git a/src/maester/parallelisms/parallelize_llama.py b/src/maester/parallelisms/parallelize_llama.py new file mode 100644 index 0000000..f90c6cf --- /dev/null +++ b/src/maester/parallelisms/parallelize_llama.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# this file applies the PTD parallelisms and various training techniques to the +# llama model, i.e. activation checkpointing, etc. + +from collections import defaultdict +import itertools + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import (MixedPrecisionPolicy, + fully_shard) +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard + +try: + from torch.distributed._tensor.experimental.attention import \ + enable_context_parallel +except ImportError: + print("The PyTorch version does not include the experimental CP APIs.") +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed.tensor.parallel import (ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module) + +from maester.log_utils import logger + +# for selective AC +no_recompute_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, +} + + +def checkpoint_wrapper(module: torch.nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.ac_mode not in valid_ac_modes: + raise ValueError(f"Invalid AC mode: {ac_config.ac_mode}. Valid modes: {valid_ac_modes}") + + if ac_config.ac_mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.ac_mode == "selective", ac_config.ac_mode + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError(f"Invalid selective_ac_option: {ac_config.selective_ac_option}. Valid options: 'op' or a positive integer.") + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, create_selective_checkpoint_contexts) + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in no_recompute_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + if ac_freq <= 0: + raise ValueError( + f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}" + ) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + +def get_tp_parallel_strategy(config): + """Get the parallel strategy for the transformer model. + + This function handles the special case of using float8 with tensor parallelism (not implemented) + """ + # if config.training.fp8_linear == "dynamic": + # from float8_experimental.float8_tensor_parallel import ( + # Float8ColwiseParallel, + # Float8RowwiseParallel, + # PrepareFloat8ModuleInput, + # ) + + # return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput + return RowwiseParallel, ColwiseParallel, PrepareModuleInput + +def apply_tp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + config, +): + """Apply tensor parallelism.""" + + tp_mesh = world_mesh["tp"] + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + ( + rowwise_parallel_weight, + colwise_parallel_weight, + prepare_module_input, + ) = get_tp_parallel_strategy(config) + loss_parallel = parallel_dims.loss_parallel_enabled + + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + model = parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": colwise_parallel_weight( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": colwise_parallel_weight(), + "attention.wk": colwise_parallel_weight(), + "attention.wv": colwise_parallel_weight(), + "attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel_weight(), + "feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel_weight(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if config.enable_async_tensor_parallel: + from torch.distributed._symmetric_memory import \ + enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info("Applied Tensor Parallelism to the model") + return model + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = checkpoint_wrapper(transformer_block, ac_config) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + return model + + +def apply_compile(model: nn.Module, config): + """Apply torch.compile to each transformer block.""" + + if config.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." + ) + + # compile trunk layers + for layer_id, transformer_block in model.trunk.layers.items(): + # TODO: dynamic shape have some issues so we turn it off for now. + # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate + # compile time. + # torch._dynamo.config.inline_inbuilt_nn_modules = True + transformer_block = torch.compile(transformer_block, dynamic=False) + model.trunk.layers.register_module(layer_id, transformer_block) + + # compile head layers + for layer_id, transformer_block in model.heads.items(): + # TODO: dynamic shape have some issues so we turn it off for now. + # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate + # compile time. + # torch._dynamo.config.inline_inbuilt_nn_modules = True + transformer_block = torch.compile(transformer_block, dynamic=False) + model.heads.register_module(layer_id, transformer_block) + + logger.info("Compiled each TransformerBlock with torch.compile") + return model + + +def apply_cp(model, world_mesh, parallel_dims, config): + """ + Apply context parallelism to the model. This is an experimental feature. + """ + if parallel_dims.tp_enabled or parallel_dims.pp_enabled: + raise NotImplementedError("CP + TP or CP + PP are not supported yet.") + dp_mesh = world_mesh["dp"] + cp_mesh = dp_mesh.reshape( + (dp_mesh.size() // parallel_dims.cp, parallel_dims.cp), ("dp", "cp") + )["cp"] + callers = [] + for layer_id, transformer_block in model.layers.items(): + callers.append(transformer_block.attention) + enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh) + logger.info("Applied CP to the model") + + return model + + +def apply_fsdp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + config, +): + """ + Apply data parallelism to the model. FSDP2 is used here. + """ + + # This mesh also includes cp degree if it is larger than 1. + if parallel_dims.dp_type == "fsdp": + dp_mesh = world_mesh["dp"] + else: + assert parallel_dims.dp_type == "hsdp", parallel_dims.dp_type + dp_mesh = world_mesh["dp"] + dp_mesh = dp_mesh.reshape( + (parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate), + ("dp_replicate", "dp_shard"), + ) + # assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + + mp_policy = config.mixed_precision_policy + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + + for layer_id, transformer_block in itertools.chain(model.trunk.layers.items(), model.heads.items()): + if parallel_dims.pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for + # the heads since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.trunk.layers) + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + # these must be wrapped individually, in order to call them for multi-token prediction + fully_shard(model.trunk, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled) + fully_shard(model.norm, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled) + fully_shard(model.output, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled) + + fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) + + logger.info("Applied FSDP to the model") + return model + + +def apply_ddp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + config, +): + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + + if config.compile: + if config.enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + return model + + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + config, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + model = apply_tp(model, world_mesh, parallel_dims, config) + + if config.ac_mode != "none": + model = apply_ac(model, config) + + if config.compile: + model = apply_compile(model, config) + + if parallel_dims.cp_enabled: + model = apply_cp(model, world_mesh, parallel_dims, config) + + if parallel_dims.dp_enabled: + if parallel_dims.dp_type == "fsdp" or parallel_dims.dp_type == "hsdp": + model = apply_fsdp(model, world_mesh, parallel_dims, config) + else: + model = apply_ddp(model, world_mesh, parallel_dims, config) + + return model \ No newline at end of file diff --git a/src/maester/parallelize_llama.py b/src/maester/parallelize_llama.py deleted file mode 100644 index 56b0d7e..0000000 --- a/src/maester/parallelize_llama.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# this file applies the PTD parallelisms and various training techniques to the -# llama model, i.e. activation checkpointing, etc. - -from collections import defaultdict -from dataclasses import dataclass -from functools import cached_property -from typing import Tuple - -import torch - -from torch.distributed.device_mesh import init_device_mesh, DeviceMesh -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, - CheckpointImpl, -) -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, -) - -from torch.utils.checkpoint import checkpoint -import torch._dynamo.config - -from maester.log_utils import logger - - -@dataclass -class ParallelDims: - dp: int - tp: int - pp: int - world_size: int - enable_loss_parallel: bool - - def __post_init__(self): - self._validate() - - def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp - if dp == -1: - self.dp = dp = self.world_size // (tp * pp) - assert dp >= 1, dp - assert tp >= 1, tp - assert pp >= 1, pp - assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - - def build_mesh(self, device_type): - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True - ): - if d > 1: - dims.append(d) - names.append(name) - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) - - @property - def dp_enabled(self): - return self.dp > 1 - - @property - def tp_enabled(self): - return self.tp > 1 - - @property - def pp_enabled(self): - return self.pp > 1 - - @property - def loss_parallel_enabled(self): - return self.tp > 1 and self.enable_loss_parallel - - @cached_property - def model_parallel_size(self): - return self.tp * self.pp - - -# for selective AC -no_recompute_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, -} - - -# Uses PTD FSDP AC wrapper -# currently selective per op and per layer checkpointing are supported -def checkpoint_wrapper(module, config): - if config.ac_mode == "selective" and config.selective_ac_option == "op": - from torch.utils.checkpoint import create_selective_checkpoint_contexts, CheckpointPolicy - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in no_recompute_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - return ptd_checkpoint_wrapper( - module, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - checkpoint_fn=checkpoint, - context_fn=selective_checkpointing_context_fn, - use_reentrant=False, - preserve_rng_state=False, - ) - elif config.ac_mode == "full": - return ptd_checkpoint_wrapper( - module, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - checkpoint_fn=checkpoint, - use_reentrant=False, - preserve_rng_state=False, - ) - - elif config.ac_mode == "selective" and config.selective_ac_option.isdigit(): - """enables selective checkpointing of candidate layers. - Usage: - 'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint. - 1 == checkpointing every one (all). - 2 == checkpoint every 2nd one - """ - ac_freq = int(config.selective_ac_option) - assert ( - ac_freq >= 0 - ), f"selective layer AC policy (ac_freq) expects a positive integer, received {ac_freq}" - - checkpoint_wrapper.__dict__.setdefault("_count", 0) - - checkpoint_wrapper._count += 1 - if not ac_freq or checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper( - module, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - checkpoint_fn=checkpoint, - use_reentrant=False, - preserve_rng_state=False, - ) - # skip activation checkpointing and store activations for this layer - else: - return module - - else: - raise NotImplementedError( - "Unknown AC type or AC config. Only selective op and selective layer ac implemented currently." - ) - - -def parallelize_llama(model, world_mesh: DeviceMesh, parallel_dims, cfg) -> torch.nn.Module: - """ - Apply parallelisms and activation checkpointing to the model. - - NOTE: The passed-in model preferably should be on meta device. Otherwise, - the model must fit on GPU or CPU memory. - """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") - - if parallel_dims.tp_enabled: - if cfg.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." - ) - - tp_mesh = world_mesh["tp"] - row_parallel_strategy, col_parallel_strategy = [RowwiseParallel, ColwiseParallel] # no FP8 support - loss_parallel = parallel_dims.loss_parallel_enabled - - # 1. Parallelize the first embedding and the last linear proj layer - # 2. Parallelize the root norm layer over the sequence dim - # 3. Shard the first transformer block's inputs - model = parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "output": col_parallel_strategy( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, - ), - "norm": SequenceParallel(), - }, - ) - - # Apply tensor + sequence parallelism to every transformer block - for layer_id, transformer_block in model.layers.items(): - layer_plan = { - "attention": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), - ), - "attention.wq": col_parallel_strategy(), - "attention.wk": col_parallel_strategy(), - "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), - "feed_forward": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), - "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(), - } - - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=layer_plan, - ) - - logger.info("Applied Tensor Parallelism to the model") - - # apply AC + torch.compile - for layer_id, transformer_block in model.layers.items(): - if cfg.ac_mode in ("full", "selective"): - transformer_block = checkpoint_wrapper(transformer_block, cfg) - if cfg.compile: - # turn on per-transformer block compile after AC wrapping and before FSDP - # TODO: dynamic shape have some issues so we turn it off for now. - # TODO: inline inbuilt nn modules does not work yet, enable it to accelerate - # compile time. - # torch._dynamo.config.inline_inbuilt_nn_modules = True - transformer_block = torch.compile(transformer_block, dynamic=False) - model.layers[layer_id] = transformer_block - - if cfg.ac_mode in ("full", "selective"): - logger.info(f"Applied {cfg.ac_mode} activation checkpointing to the model") - if ( - cfg.compile - and cfg.ac_mode == "selective" - and cfg.selective_ac_option == "op" - ): - # TODO: still needed? some temp flags for torch.compile enablement + SAC - pass - # torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - # True - # ) - if cfg.compile: - if cfg.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." - ) - logger.info("Compiled each TransformerBlock with torch.compile") - - # apply DP (FSDP2) - if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - mp_policy = cfg.mixed_precision_policy - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in model.layers.items(): - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = ( - int(layer_id) < len(model.layers) - 1 - ) - fully_shard( - transformer_block, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) - model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) - logger.info("Applied FSDP to the model") - - return model diff --git a/src/maester/utils.py b/src/maester/utils.py index 64d04d7..d7a4781 100644 --- a/src/maester/utils.py +++ b/src/maester/utils.py @@ -62,7 +62,7 @@ def set_pg_timeouts(timeout, world_mesh): def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: num_params = sum(p.numel() for p in model.parameters()) if exclude_embedding: - num_params -= model.tok_embeddings.weight.numel() + num_params -= model.trunk.tok_embeddings.weight.numel() return num_params