From 597f655783575c6b358cd68cc0d829e83656a36a Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Mon, 28 Jul 2025 15:34:15 +0200 Subject: [PATCH 1/5] initial deepseek commit --- maester/config.py | 18 +- maester/models/__init__.py | 22 +- maester/models/deepseek/__init__.py | 133 +++++ maester/models/deepseek/args.py | 102 ++++ maester/models/deepseek/attention.py | 245 ++++++++++ maester/models/deepseek/expert_parallel.py | 295 +++++++++++ maester/models/deepseek/model.py | 391 +++++++++++++++ maester/models/deepseek/moe.py | 377 ++++++++++++++ maester/models/deepseek/moe_indices.py | 334 +++++++++++++ maester/models/deepseek/optimizer.py | 89 ++++ maester/parallelisms/__init__.py | 1 + maester/parallelisms/parallel_dims.py | 135 +++++- maester/parallelisms/parallelize_deepseek.py | 324 +++++++++++++ maester/profiling.py | 2 +- maester/utils.py | 16 +- scripts/convert/deepseek/from_dcp.py | 458 ++++++++++++++++++ scripts/convert/deepseek/to_dcp.py | 260 ++++++++++ .../gemma/from_dcp.py} | 0 .../gemma/to_dcp.py} | 0 .../llama/from_dcp.py} | 0 .../mup_llama/from_dcp.py} | 0 train.py | 127 +++-- 22 files changed, 3238 insertions(+), 91 deletions(-) create mode 100644 maester/models/deepseek/__init__.py create mode 100644 maester/models/deepseek/args.py create mode 100644 maester/models/deepseek/attention.py create mode 100644 maester/models/deepseek/expert_parallel.py create mode 100644 maester/models/deepseek/model.py create mode 100644 maester/models/deepseek/moe.py create mode 100644 maester/models/deepseek/moe_indices.py create mode 100644 maester/models/deepseek/optimizer.py create mode 100644 maester/parallelisms/parallelize_deepseek.py create mode 100644 scripts/convert/deepseek/from_dcp.py create mode 100644 scripts/convert/deepseek/to_dcp.py rename scripts/{convert_gemma_from_dcp.py => convert/gemma/from_dcp.py} (100%) rename scripts/{convert_gemma_to_dcp.py => convert/gemma/to_dcp.py} (100%) rename scripts/{convert_dcp_to_hf.py => convert/llama/from_dcp.py} (100%) rename scripts/{convert_mup_to_hf.py => convert/mup_llama/from_dcp.py} (100%) diff --git a/maester/config.py b/maester/config.py index 740c6f5..850e38b 100644 --- a/maester/config.py +++ b/maester/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, ConfigDict, Field, ImportString from pydantic_settings import BaseSettings, SettingsConfigDict -from typing import Callable, Type, Any +from typing import Callable, Type, Any, Literal from pathlib import Path import torch @@ -39,6 +39,7 @@ class Config(BaseSettings): data_parallel_shard_degree: int = 8 data_parallel_replicate_degree: int = 32 tensor_parallel_degree: int = 1 + expert_parallel_degree: int = 1 train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch train_num_steps: int = 22000 # ~92B tokens compile: bool = True # TODO: only compiles TransformerBlocks until PyTorch supports full fsdp2 @@ -100,10 +101,25 @@ class Config(BaseSettings): # fsdp mixed_precision_param: str = 'bfloat16' mixed_precision_reduce: str = 'float32' + enable_cpu_offload: bool = False + fsdp_reshard_after_forward: Literal["default", "never", "always"] = "default" # activation checkpointing ac_mode: str = "none" # "full" | "selective" | "none" selective_ac_option: str | int = "op" + per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = Field( + default_factory=lambda: ["moe.router.gate"] + ) + """ + When per-op selective ac is used, this list of fully qualified names is used + to determine which mm shapes to force recompute, rather than being considered + by rest of the sac policy, e.g save every other mm. Only nn.Linear modules are + supported today. + + Note: this config applies to mms not limited to those matching the specified + fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified, + ANY mm with shape matching (*, in) x (in, out) will be force recomputed. + """ # experimental enable_async_tensor_parallel: bool = False diff --git a/maester/models/__init__.py b/maester/models/__init__.py index 85f6766..53d4108 100644 --- a/maester/models/__init__.py +++ b/maester/models/__init__.py @@ -6,13 +6,16 @@ from maester.models.llama import llama2_configs, llama3_configs, mistral_configs, Transformer from maester.models.gemma import gemma3_configs, GemmaTextModel -from maester.parallelisms import parallelize_gemma, parallelize_llama +from maester.models.deepseek import deepseek_configs, DeepSeekModel, build_deepseek_optimizers +from maester.parallelisms import parallelize_gemma, parallelize_llama, parallelize_deepseek +from maester.optimizers import build_optimizers models_config = { "llama2": llama2_configs, "llama3": llama3_configs, "mistral": mistral_configs, "gemma3": gemma3_configs, + "deepseek": deepseek_configs, } model_name_to_cls = { @@ -20,13 +23,7 @@ "llama3": Transformer, "mistral": Transformer, "gemma3": GemmaTextModel, -} - -model_name_to_tokenizer = { - "llama2": "sentencepiece", - "llama3": "tiktoken", - "mistral": "sentencepiece", - "gemma3": "sentencepiece", + "deepseek": DeepSeekModel, } model_name_to_parallelize = { @@ -34,4 +31,13 @@ "llama3": parallelize_llama, "mistral": parallelize_llama, "gemma3": parallelize_gemma, + "deepseek": parallelize_deepseek, +} + +model_name_to_optimizers_builder = { + "llama2": build_optimizers, + "llama3": build_optimizers, + "mistral": build_optimizers, + "gemma3": build_optimizers, + "deepseek": build_deepseek_optimizers, } diff --git a/maester/models/deepseek/__init__.py b/maester/models/deepseek/__init__.py new file mode 100644 index 0000000..93a1345 --- /dev/null +++ b/maester/models/deepseek/__init__.py @@ -0,0 +1,133 @@ +from maester.models.deepseek.model import DeepSeekModelArgs, DeepSeekModel +from maester.models.deepseek.optimizer import build_deepseek_optimizers + +__all__ = ["DeepSeekModel", "DeepSeekModelArgs", "build_deepseek_optimizers"] + +deepseek_configs = { + "debug": DeepSeekModelArgs( + max_batch_size=32, + vocab_size=102400, + dim=512, + inter_dim=2048, + moe_inter_dim=512, + n_layers=8, + n_dense_layers=1, + n_heads=16, + n_routed_experts=8, + n_shared_experts=2, + n_activated_experts=3, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "debug_flex": DeepSeekModelArgs( + vocab_size=2000, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=3, + n_dense_layers=1, + n_heads=16, + n_routed_experts=8, + n_shared_experts=2, + n_activated_experts=3, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "16B": DeepSeekModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + n_routed_experts=64, + n_shared_experts=2, + n_activated_experts=6, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="causal", + ), + "45B_custom": DeepSeekModelArgs( + vocab_size=102400, + dim=3072, + inter_dim=8192, + moe_inter_dim=1408, + n_layers=20, + n_dense_layers=1, + n_heads=128, + n_routed_experts=160, + n_shared_experts=2, + n_activated_experts=6, + route_scale=16.0, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="causal", + ), + # TODO: check correctness of configs below here + "236B": DeepSeekModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + n_routed_experts=160, + n_shared_experts=2, + n_activated_experts=6, + route_scale=16.0, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="causal", + ), + "685B": DeepSeekModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + n_routed_experts=256, + n_shared_experts=1, + n_activated_experts=8, + route_scale=2.5, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=1.0, + use_flex_attn=True, + attn_mask_type="causal", + ), +} \ No newline at end of file diff --git a/maester/models/deepseek/args.py b/maester/models/deepseek/args.py new file mode 100644 index 0000000..4af3d9f --- /dev/null +++ b/maester/models/deepseek/args.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Optional, Literal +from torch import nn +from maester.log_utils import logger + +@dataclass +class DeepSeekModelArgs: + max_batch_size: int = 16 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + norm_eps: float = 1e-5 # eps used for RMSNorm + tied_embeddings: bool = False # always False for DeepSeek + + # MoE + n_routed_experts: int = 8 + n_shared_experts: int = 0 + n_activated_experts: int = 4 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1.0 + use_grouped_mm: bool = True + load_balance_coeff: float = 1e-3 + + # MLA + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + # Yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + Adopted from llama4 implementation. + """ + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.n_activated_experts // self.n_routed_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token \ No newline at end of file diff --git a/maester/models/deepseek/attention.py b/maester/models/deepseek/attention.py new file mode 100644 index 0000000..99b48e1 --- /dev/null +++ b/maester/models/deepseek/attention.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from typing import Callable, ClassVar + +import torch +import torch.nn.functional as F +from torch.nn.attention import sdpa_kernel, SDPBackend +from torch.nn.attention.flex_attention import ( + _mask_mod_signature, + BlockMask, + create_block_mask, + flex_attention, +) + +from maester.utils import has_cuda_capability + +# FlexAttention mask type. For each mask type, we initialize it at most once per +# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to +# track the initialized mask. +FLEX_ATTN_MASK_T = tuple[str, int | None] + + +class FlexAttention(torch.nn.Module): + """FlexAttention module that uses torch.nn.attention.flex_attention. + + This module is a wrapper around torch.nn.attention.flex_attention. This module + implements certain common attention types, such as causal and block_causal. + + Args: + attn_mask_type (str): The type of attention mask. Currently, we support + "causal" and "block_causal". "causal" means the lower triangle of the + attention matrix is masked. "block_causal" means the attention matrix + is divided into blocks, where block boundary is defined by EOS token, + and the lower triangle of each block is masked. + fixed_block_size (int | None): The block size to be used to perform attention. + If specified, each sequence will be further divided to blocks, where each + block has the maximum size of ``fixed_block_size``. A query will only attend + to the keys within the same block. + """ + + # We registered flex_attention related attributes as class variables as we + # need to amortize the cost of compilation. + flex_attn: ClassVar[Callable] = torch.compile( + flex_attention, #mode="max-autotune-no-cudagraphs" + ) + compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) + used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() + # Attention mask type to the created BlockMask. + # This allows us to keep track the created block masks for each + # new batch. We will use this to update the block mask when a + # new batch is created. This also allows user to create different + # block masks for different layers. + block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} + + # Instance variables. + attn_mask_type: str + + def __init__( + self, attn_mask_type: str, fixed_block_size: int | None = None + ) -> None: + super().__init__() + if attn_mask_type not in ["causal", "block_causal"]: + raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") + self.attn_mask_type = attn_mask_type + self.fixed_block_size = fixed_block_size + + FlexAttention.used_attn_mask_types.add(self.mask_key) + + @property + def mask_key(self) -> FLEX_ATTN_MASK_T: + return (self.attn_mask_type, self.fixed_block_size) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + ) -> torch.Tensor: + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + + @staticmethod + def _get_causal_mask_mod() -> _mask_mod_signature: + def causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): + return q_idx >= kv_idx + + return causal_mask + + @staticmethod + def _get_block_causal_mask_mod( + batch: torch.Tensor, eos_id: int + ) -> _mask_mod_signature: + # batch is [b, s, h, d] shape + mask = batch == eos_id + mask[:, -1] = True + acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) + seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) + seq_idx[:, 1:] = acc_mask[:, :-1] + + def block_causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): + return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) + + return block_causal_mask + + @staticmethod + def _fixed_block_mask_mod( + mask_mod: _mask_mod_signature, fixed_block_size: int + ) -> _mask_mod_signature: + """ + Given an arbirary mask_mod, divide the input sequence to blocks + and only allow attention within the same block. + + Args: + mask_mod: The mask mod to apply to the documents + fixed_block_size: The number of tokens in each block. + """ + + # Credit to @drisspg. + def blocked_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): + # Get the block index of the query and key + q_block = q_idx // fixed_block_size + kv_block = kv_idx // fixed_block_size + # Only allow attention within the same block + same_block = q_block == kv_block + # Apply the original mask mod + inner_mask = mask_mod( + b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size + ) + + return same_block & inner_mask + + blocked_mask_mod.__name__ = ( + f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" + ) + + return blocked_mask_mod + + @staticmethod + @torch.no_grad() + def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: + # batch is [b, s, h, d] shape + for mask_key in FlexAttention.used_attn_mask_types: + attn_mask_type, fixed_block_size = mask_key + match attn_mask_type: + case "causal": + if FlexAttention.block_masks.get(mask_key, None) is not None: + continue + # We don't care about batch dimension -- + # all samples have the same lower triangle mask. + batch_dimension = 1 + mask_mod = FlexAttention._get_causal_mask_mod() + case "block_causal": + if eos_id is None: + raise RuntimeError( + "eos_id must be provided for block_causal mask." + ) + batch_dimension = batch.shape[0] + mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) + case _: + raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") + + if fixed_block_size is not None and fixed_block_size > 0: + mask_mod = FlexAttention._fixed_block_mask_mod( + mask_mod, fixed_block_size + ) + + seq_len = batch.shape[1] + block_mask = FlexAttention.compiled_create_block_mask( + mask_mod, batch_dimension, None, seq_len, seq_len + ) + FlexAttention.block_masks[mask_key] = block_mask + + +class ScaledDotProductAttention(torch.nn.Module): + backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + ScaledDotProductAttention._init_backend() + + @classmethod + def _init_backend(cls) -> None: + if cls.backends: + return + + # Add CuDNN on B200 w/ highest priority + cls.backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + if has_cuda_capability(10, 0): + cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + ) -> torch.Tensor: + assert self.backends, "SDPA Backends should not be empty." + #with sdpa_kernel(self.backends, set_priority=True): + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): # force CuDNN for perf, or crash + print(q.shape, k.shape, v.shape, scale) + return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) + + +def build_attention( + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None +): + if use_flex_attn: + return FlexAttention(attn_mask_type, fixed_block_size) + else: + if fixed_block_size is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support fixed_block_size." + ) + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + return ScaledDotProductAttention(attn_mask_type) + + +def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: + FlexAttention.init_attention_mask(batch, eos_id) diff --git a/maester/models/deepseek/expert_parallel.py b/maester/models/deepseek/expert_parallel.py new file mode 100644 index 0000000..ff88b08 --- /dev/null +++ b/maester/models/deepseek/expert_parallel.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +from typing import Callable + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel for the GroupedExperts in MoE +class TensorParallel(ParallelStyle): + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "w2", + nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "w3", + nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), + ) # Column-wise sharding + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Placement | None = None, + output_layout: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = num_tokens_per_expert.new_empty( + num_tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + num_tokens_per_expert_group, + num_tokens_per_expert, + group=device_mesh.get_group(), + ) + # NOTE: this would incur a device-to-host sync + self.input_splits = ( + num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + ) + self.output_splits = ( + num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + .sum(dim=1) + .tolist() + ) + + # perform all-to-all + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. + + return routed_input, num_tokens_per_expert_group + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "w1", + nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "w2", + nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + mod.register_parameter( + "w3", + nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(w1, DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + + if num_tokens_per_expert is not None: + from .moe_indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(self, w1, w2, w3, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/maester/models/deepseek/model.py b/maester/models/deepseek/model.py new file mode 100644 index 0000000..0451d29 --- /dev/null +++ b/maester/models/deepseek/model.py @@ -0,0 +1,391 @@ +import math +import torch +import torch.distributed as dist + +import torch.distributed._symmetric_memory as symm_mem +import torch.nn.functional as f +import torch.utils.checkpoint + +from typing import Optional, Tuple +from torch import nn + +from cut_cross_entropy import linear_cross_entropy, LinearCrossEntropyImpl +from .args import DeepSeekModelArgs +from .moe import FeedForward, MoE +from .attention import build_attention, init_attention_mask + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: DeepSeekModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (DeepSeekModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> Tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > args.original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] × [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: DeepSeekModelArgs): + super().__init__() + self.dim = model_args.dim + self.n_heads = model_args.n_heads + self.q_lora_rank = model_args.q_lora_rank + self.kv_lora_rank = model_args.kv_lora_rank + self.qk_nope_head_dim = model_args.qk_nope_head_dim + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim + self.v_head_dim = model_args.v_head_dim + + if self.q_lora_rank == 0: + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False) + else: + self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) + self.wq_b = nn.Linear( + self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False + ) + self.wkv_a = nn.Linear( + self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) + self.softmax_scale = self.qk_head_dim**-0.5 + + if model_args.max_seq_len > model_args.original_seq_len: + mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + else: + q = self.wq_a(x) + q = self.wq_b(self.q_norm(q)) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of q and kv as TP may have sharded them after + # the above linear ops. + q = q.view(bsz, seqlen, -1, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) + + # Key-value projection + kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pe = apply_rotary_emb( + k_pe.unsqueeze(2), freqs_cis + ) # (bsz, seqlen, 1, qk_rope_head_dim) + + kv = self.wkv_b( + self.kv_norm(kv) + ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + ) # (bsz, seqlen, n_heads, qk_head_dim) + + q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + + output = self.sdpa(q, k, v, scale=self.softmax_scale) + + # Reshape and project output + output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim) + output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) + return self.wo(output) # (bsz, seqlen, dim) + + def init_weights(self, init_std: float): + linear_list = [ + self.wkv_a, + self.wkv_b, + ] + if self.q_lora_rank > 0: + linear_list.extend([self.wq_a, self.wq_b]) + else: + linear_list.append(self.wq) + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + self.kv_norm.reset_parameters() + if self.q_lora_rank > 0: + self.q_norm.reset_parameters() + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: DeepSeekModelArgs): + + super().__init__() + self.attention = Attention(model_args) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.moe_enabled = layer_id >= model_args.n_dense_layers + + if self.moe_enabled: + self.moe = MoE(model_args) + else: + self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis) + if self.moe_enabled: + x = x + self.moe(self.ffn_norm(x)) + else: + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) + +class DeepSeekModel(nn.Module): + """ + DeepSeek-V3 Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: DeepSeekModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = nn.RMSNorm(model_args.dim) + self.output = nn.Linear( + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + ): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, eos_id=eos_id + ) + + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h + return output + + @classmethod + def from_model_args(cls, model_args: DeepSeekModelArgs) -> "DeepSeekModel": + """Initialize from model args (compatible with training loop).""" + return cls(model_args) \ No newline at end of file diff --git a/maester/models/deepseek/moe.py b/maester/models/deepseek/moe.py new file mode 100644 index 0000000..2730bb5 --- /dev/null +++ b/maester/models/deepseek/moe.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn +from .expert_parallel import expert_parallel + +from .args import DeepSeekModelArgs + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.use_grouped_mm = use_grouped_mm + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return self._run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + else: + return self._run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + def _run_experts_for_loop( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + @expert_parallel + def _run_experts_grouped_mm( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + route_scaling_factor: float = 1.0, + ): + super().__init__() + + self.dim = dim + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + self.route_scaling_factor = route_scaling_factor + self.gate = nn.Linear(self.dim, self.num_experts, bias=False) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = self.gate(x) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)) + else: + scores = F.softmax(scores.to(torch.float32), dim=1) + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + if self.use_sigmoid: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + # reorder the scores to match the order of the token indices + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + top_scores = ( + top_scores * self.route_scaling_factor + ) # must multiply the scaling factor + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + def __init__(self, model_args: DeepSeekModelArgs): + + super().__init__() + dim = model_args.dim + + num_experts = model_args.n_routed_experts + hidden_dim = model_args.moe_inter_dim + top_k = model_args.n_activated_experts + route_scaling_factor = model_args.route_scale + + self.experts = GroupedExperts( + dim=dim, + hidden_dim=hidden_dim, + num_experts=num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + dim=dim, + num_experts=num_experts, + top_k=top_k, + use_sigmoid=model_args.score_func == "sigmoid", + route_scaling_factor=route_scaling_factor, + ) + self.shared_expert = ( + # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 + GroupedExperts( + dim=dim, + hidden_dim=hidden_dim * model_args.n_shared_experts, + num_experts=1, # Here needs to be 1 to make it equivalent to the MLP + use_grouped_mm=model_args.use_grouped_mm, + ) + if model_args.n_shared_experts > 0 + else None + ) + + # auxiliary-loss-free load balancing + self.load_balance_coeff = model_args.load_balance_coeff + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=False + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=False + ) + else: + self.expert_bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + + # tokens_per_expert will be used to update the expert bias for load balancing. + # Prevent extra local tokens accumulation on evaluation or activation recomputation. + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + # Accumulate multiple expert results becase each token can be routed to multiple experts + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) + + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/maester/models/deepseek/moe_indices.py b/maester/models/deepseek/moe_indices.py new file mode 100644 index 0000000..30f7d98 --- /dev/null +++ b/maester/models/deepseek/moe_indices.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +__all__ = ["generate_permute_indices"] + + +# parallelized kernel +@triton.jit +def _fill_indices_kernel( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + + +# ============== +# wrapper +# ============== + + +def fill_indices_wrapper( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +): + # preallocate output + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + + # write offsets is per local expert... + num_blocks = min(experts_per_rank, max_blocks) + # grid = one block per expert unless capped and then we loop... + grid = (num_blocks,) + + # launch kernel + _fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + experts_per_rank, + num_ranks, + BLOCK_SIZE=block_size, + ) + return permuted_indices + + +# reference +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output - we ignore device and force it on cpu + # device = tokens_per_expert_group.device + permuted_indices = torch.full( + (max_len,), + -1, + dtype=torch.int32, + ) # device=device) + # Fill the permuted indices + # For each local expert + for e in range(experts_per_rank): + write_start = write_offsets[e].item() + # For each remote rank + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = start_index_values[i].item() + length = tokens_per_expert_group[i].item() + # Fill in the indices + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + # device=device, + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + """ + Prepare permutation indices and the number of tokens for each expert. + + Args: + tokens_per_expert_group: number of tokens for each expert from all ranks. + experts_per_rank: number of experts per rank. + num_ranks: number of ranks. + max_len: maximum length of the output index vector. + alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts. + use_cpu: whether to use CPU implementation. + + + Returns: + permuted_indices: Tensor of indices that map original token order to the expert-grouped order. + m_sizes: aligned number of tokens for each expert (padded to alignment boundary). + m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens. + + Explanatory details: + `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: + From: | rank 0 | rank 1 | + To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | + | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | + """ + + # prefix sum to get start index of each expert (parallel scan kernel in future?) + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + + # total tokens for each expert (sum over ranks) + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + + # pad out empty experts to alignment requirement + total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + + # align the chunk sizes (cdiv) + m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to( + torch.int32 + ) + + # additional prefix sum to get write offset of each expert in permuted_indices + # write offsets is per local expert, not global + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + # Select the implementation to use + if use_cpu: + permuted_indices = fill_indices_cpu( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + else: + permuted_indices = fill_indices_wrapper( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + + return permuted_indices, m_sizes, m_offsets.to(torch.int32) + + +# Below is for testing only + + +def simple_test(): + device = torch.device("cuda", 0) + experts_per_rank = 4 + num_ranks = 4 + tokens_per_expert_group = torch.full( + (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device + ) + max_len = 128 + alignment = 32 + # Use the GPU kernel + permuted_indices_gpu, m_sizes, _ = generate_permute_indices( + tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment + ) + # Use the CPU method + permuted_indices_cpu, m_sizes, _ = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + use_cpu=True, + ) + # Check that the results are the same + + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) + assert torch.equal( + torch.remainder(m_sizes, alignment), + torch.zeros(experts_per_rank, device=device), + ) + # Print the results + print(f"{permuted_indices_gpu=}, \n{permuted_indices_cpu=}") + print(f"{m_sizes=}") + print("Success") + return True # assert would have failed meaning getting here is success. + + +def test_with_zero_tokens(): + device = torch.device("cuda", 0) + experts_per_rank = 4 + num_ranks = 2 + + # Create a test case where some experts have zero tokens + tokens_per_expert_group = torch.tensor( + [4, 0, 2, 3, 1, 0, 0, 5], # Some experts have zero tokens + dtype=torch.int32, + device=device, + ) + + max_len = 128 + alignment = 8 + + # Use the GPU kernel + permuted_indices_gpu, m_sizes, m_offsets = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + ) + + # Use the CPU method + permuted_indices_cpu, m_sizes_cpu, m_offsets_cpu = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + use_cpu=True, + ) + + # Check that the results are the same + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) + assert torch.equal(m_sizes, m_sizes_cpu) + + # Verify that experts with zero tokens have at least min_slots_per_expert + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + zero_token_experts = total_tokens_per_expert == 0 + if zero_token_experts.any(): + assert (m_sizes[zero_token_experts] >= alignment).all() + + # Check alignment + assert torch.equal( + torch.remainder(m_sizes, alignment), + torch.zeros(experts_per_rank, device=device), + ) + + # Print the results + print(f"tokens_per_expert_group = {tokens_per_expert_group}") + print(f"total_tokens_per_expert = {total_tokens_per_expert}") + print(f"m_sizes = {m_sizes}") + print(f"m_offsets = {m_offsets}") + print(f"permuted_indices = {permuted_indices_gpu[:sum(m_sizes).item()]}") + + # Check that experts with zero tokens have -1 in their slots + for e in range(experts_per_rank): + start = (m_offsets[e] - m_sizes[e]).item() + end = m_offsets[e].item() + expert_indices = permuted_indices_gpu[start:end] + if total_tokens_per_expert[e] == 0: + assert ( + expert_indices == -1 + ).all(), f"Expert {e} with zero tokens should have all -1 indices" + assert ( + expert_indices.size(0) >= alignment + ), f"Expert {e} with zero tokens should have at least {alignment} slots" + print( + f"Expert {e} has zero tokens and {expert_indices.size(0)} slots with all -1" + ) + + print("All tests passed successfully!") + return True + + +if __name__ == "__main__": + simple_test() + test_with_zero_tokens() diff --git a/maester/models/deepseek/optimizer.py b/maester/models/deepseek/optimizer.py new file mode 100644 index 0000000..79f4be1 --- /dev/null +++ b/maester/models/deepseek/optimizer.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from maester.config import Config +from maester.parallelisms import ParallelDims +from maester.optimizers import OptimizersContainer, build_optimizers + + +def build_deepseek_optimizers(model: nn.Module, cfg: Config, parallel_dims: ParallelDims) -> OptimizersContainer: + """Build optimizers for DeepSeek model with MoE load balancing hooks. + + This adds auxiliary-loss-free load balancing for MoE layers as described + in the DeepSeek-V3 paper. + + Args: + model: DeepSeek model instance + cfg: Training configuration + parallel_dims: Parallelism dimensions + + Returns: + OptimizersContainer with MoE hooks registered + """ + # Start with base optimizer configuration + optimizers = build_optimizers(model, cfg, parallel_dims) + + # TODO: This is a temporary solution for storing optimizer hook statistics. + # Consider refactoring OptimizersContainer to have official metrics support. + optimizers._hook_stats = {} + + # Define expert bias update function + def _update_expert_bias(model, parallel_dims, hook_stats): + # Note: We don't support context parallelism (cp) yet, so just use dp_enabled + dp_cp_mesh = parallel_dims.world_mesh["dp"] if parallel_dims.dp_enabled else None + + # Clear previous stats + hook_stats.clear() + + # Iterate through model layers + for layer_name, layer in model.layers.items(): + if hasattr(layer, 'moe_enabled') and layer.moe_enabled: + moe = layer.moe + if hasattr(moe, 'load_balance_coeff') and moe.load_balance_coeff is not None: + # Sync tokens_per_expert across data parallel ranks if needed + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + # Collect expert balancing statistics before update + with torch.no_grad(): + tokens_mean = moe.tokens_per_expert.mean().item() + + if tokens_mean > 0: + # Max load ratio: how overloaded is the busiest expert + max_load_ratio = moe.tokens_per_expert.max().item() / tokens_mean + # Min load ratio: how underutilized is the least used expert + min_load_ratio = moe.tokens_per_expert.min().item() / tokens_mean + else: + max_load_ratio = 1.0 + min_load_ratio = 1.0 + + # Store stats for logging - extract layer number from name + layer_idx = layer_name.split('.')[-1] if '.' in layer_name else layer_name + hook_stats[f"expert_balance/layer{layer_idx}/max_load_ratio"] = max_load_ratio + hook_stats[f"expert_balance/layer{layer_idx}/min_load_ratio"] = min_load_ratio + + # Update expert bias for load balancing + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + # Register the hook to run before optimizer steps using lambda to capture context + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model, parallel_dims, optimizers._hook_stats + ) + ) + + return optimizers \ No newline at end of file diff --git a/maester/parallelisms/__init__.py b/maester/parallelisms/__init__.py index 8dd5d05..0ffaf88 100644 --- a/maester/parallelisms/__init__.py +++ b/maester/parallelisms/__init__.py @@ -10,4 +10,5 @@ from maester.parallelisms.parallel_dims import ParallelDims from maester.parallelisms.parallelize_llama import parallelize_llama from maester.parallelisms.parallelize_gemma import parallelize_gemma +from maester.parallelisms.parallelize_deepseek import parallelize_deepseek diff --git a/maester/parallelisms/parallel_dims.py b/maester/parallelisms/parallel_dims.py index 058108a..f971a2d 100644 --- a/maester/parallelisms/parallel_dims.py +++ b/maester/parallelisms/parallel_dims.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from functools import cached_property -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import init_device_mesh, DeviceMesh from maester.log_utils import logger +from maester.utils import device_type @dataclass @@ -11,15 +12,18 @@ class ParallelDims: dp_replicate: int dp_shard: int tp: int + # cp: int # TODO: implement context parallelism + ep: int world_size: int enable_loss_parallel: bool def __post_init__(self): self._validate() + self._world_mesh = None # Lazy initialization def _validate(self): - dp_replicate, dp_shard, tp = self.dp_replicate, self.dp_shard, self.tp - for d in (dp_replicate, tp): + dp_replicate, dp_shard, tp, ep = self.dp_replicate, self.dp_shard, self.tp, self.ep + for d in (dp_replicate, tp, ep): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." @@ -36,7 +40,18 @@ def _validate(self): f"tp({tp}) != WORLD_SIZE({self.world_size})" ) - def build_mesh(self, device_type): + if ep > 1: + #assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + assert dp_shard % ep == 0 + + def build_mesh(self): + if self.ep > 1: + return self._build_mesh_with_ep() + else: + return self._build_mesh_without_ep() + + def _build_mesh_without_ep(self) -> DeviceMesh: + # TODO: this might be wrong, investigate dims = [] names = [] for d, name in zip( @@ -45,22 +60,91 @@ def build_mesh(self, device_type): ): if d > 1: dims.append(d) - if (name == "dp_replicate" and self.dp_shard == 1) or ( - name == "dp_shard" and self.dp_replicate == 1 - ): - names.append("dp") - else: - names.append(name) + names.append(name) if dims == []: # edge case for non-distributed mesh w/ 1 GPU dims = [1] names = ("dp",) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - names = tuple(names) mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are - # initialized - if self.dp_replicate > 1 and self.dp_shard > 1: - mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + + return mesh + + def _build_mesh_with_ep(self) -> DeviceMesh: + # With ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_mod_ep * dp_shard_in_ep + # ep = dp_shard_in_ep * cp + # NOTE: cp not implemented + dp_shard_mod_ep = self.dp_shard // self.ep + dp_shard_in_ep = self.ep + + dims = [] + names = [] + for d, name in zip( + [ + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.tp, + ], + ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "tp"], + ): + # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_mod_ep": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + # Mesh for ep + ep_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + # dp_shard_mod_ep is always needed, even if it's 1 + dp_mesh_dim_names.append("dp_shard_mod_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") + dp_cp_mesh_dim_names.append("dp_shard_mod_ep") + if "dp_shard_in_ep" in names: + dp_mesh_dim_names.append("dp_shard_in_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") + dp_cp_mesh_dim_names.append("dp_shard_in_ep") + ep_mesh_dim_names.append("dp_shard_in_ep") + # if self.cp_enabled: + # dp_shard_cp_mesh_dim_names.append("cp") + # dp_cp_mesh_dim_names.append("cp") + # ep_mesh_dim_names.append("cp") + + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard") + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + logger.info(f"Built EP device mesh: {mesh}") + return mesh @property @@ -79,10 +163,31 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 + @property + def ep_enabled(self): + return self.ep > 1 + + @property + def fsdp_enabled(self): + return self.dp_shard_enabled + @property def loss_parallel_enabled(self): return self.tp > 1 and self.enable_loss_parallel + + @property + def world_mesh(self) -> DeviceMesh: + # doing late init so ParallelDims can still be used as a lightweight + # dataclass without having to initialize the world mesh + if self._world_mesh is None: + self._world_mesh = self.build_mesh() + return self._world_mesh @cached_property def model_parallel_size(self): - return self.tp \ No newline at end of file + return self.tp + + @cached_property + def dense_params_mesh_ndim(self): + # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh + return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled \ No newline at end of file diff --git a/maester/parallelisms/parallelize_deepseek.py b/maester/parallelisms/parallelize_deepseek.py new file mode 100644 index 0000000..f17c549 --- /dev/null +++ b/maester/parallelisms/parallelize_deepseek.py @@ -0,0 +1,324 @@ +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import ( + MixedPrecisionPolicy, + fully_shard +) +from torch.distributed._composable.replicate import replicate +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper +) +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, + PrepareModuleInputOutput +) +from typing import Optional +from collections import defaultdict + +from maester.config import Config, TORCH_DTYPE_MAP +from maester.parallelisms.parallel_dims import ParallelDims +from maester.log_utils import logger +from maester.models.deepseek.expert_parallel import ExpertParallel, ExpertTensorParallel, TensorParallel, NoParallel + +def parallelize_deepseek( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + config: Config, +): + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) + + if config.ac_mode != "none": + apply_ac(model, config) + + if config.compile: + if parallel_dims.ep_enabled: + logger.warning("Compiling MoE layers is broken") + apply_compile(model) + else: + apply_compile(model, fullgraph=False) + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + else: + dp_mesh_dim_names = ("dp_shard",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce], + cpu_offload=config.enable_cpu_offload, + reshard_after_forward_policy=config.fsdp_reshard_after_forward, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if dp_mod_ep_mesh_dim_names + else None + ), + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if config.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: Config, *, base_fqn: Optional[str] = None +): + valid_ac_modes = ("full", "selective") + if ac_config.ac_mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.ac_mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.ac_mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.ac_mode == "selective", f"{ac_config.ac_mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + mm_recompute_shapes = set() + if False:#len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config: Config): + """Apply activation checkpointing to the model.""" + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.ac_mode} activation checkpointing to the model") + +def apply_compile(model: nn.Module, fullgraph: bool = False): + """Compile each transformer layer individually.""" + for layer_id, layer in model.layers.items(): + compiled_layer = torch.compile(layer, fullgraph=fullgraph) + model.layers[layer_id] = compiled_layer + logger.info("Compiled each transformer layer with torch.compile") + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + dp_mod_ep_mesh: DeviceMesh | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + # NOTE: in an MoE layer, the router and the shared experts + # are sharded together with the TransformerBlock + if transformer_block.moe_enabled and dp_mod_ep_mesh: + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=True) + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue + + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) \ No newline at end of file diff --git a/maester/profiling.py b/maester/profiling.py index f3d8b34..c9f9eeb 100644 --- a/maester/profiling.py +++ b/maester/profiling.py @@ -71,7 +71,7 @@ def trace_handler(prof): torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], - schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=3), on_trace_ready=trace_handler, ) as torch_profiler: yield torch_profiler diff --git a/maester/utils.py b/maester/utils.py index 74a66d2..830f33e 100644 --- a/maester/utils.py +++ b/maester/utils.py @@ -21,10 +21,18 @@ _has_foreach_support, ) from torch.nn.utils.clip_grad import _clip_grads_with_norm_ +from torch._utils import _get_available_device_type, _get_device_module from maester.log_utils import logger from maester.models.gemma.model import Embedding as GemmaEmbedding +def get_device_info() -> tuple[str, torch.device]: + device_type = _get_available_device_type() or "cuda" + device_module = _get_device_module(device_type) # default device_module:torch.cuda + return device_type, device_module + +device_type, device_module = get_device_info() + def dist_max(x: int | float, mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) @@ -301,4 +309,10 @@ def clean_param_name(name: str) -> str: elif 'output' in name: return f"output/{'/'.join(parts[-2:])}" else: - return f"other/{'/'.join(parts[-2:])}" \ No newline at end of file + return f"other/{'/'.join(parts[-2:])}" + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) \ No newline at end of file diff --git a/scripts/convert/deepseek/from_dcp.py b/scripts/convert/deepseek/from_dcp.py new file mode 100644 index 0000000..a0dbcaf --- /dev/null +++ b/scripts/convert/deepseek/from_dcp.py @@ -0,0 +1,458 @@ +""" +Convert trained DeepSeek models from DCP format back to HuggingFace format. + +This script handles the reverse conversion after training, converting DCP +checkpoint weights back to HuggingFace DeepSeek V2 format. + +Usage: + python scripts/convert_deepseek_from_dcp.py checkpoints/step_1000 output_hf + +The script automatically: +- Maps DCP weight names back to HuggingFace format +- Handles MLA components correctly +- Converts MoE layers with proper expert indexing +- Creates properly sharded safetensors files if the model is large +""" + +import argparse +import json +from pathlib import Path +from typing import Dict, Any, Optional + +import torch +import torch.distributed.checkpoint as DCP +from safetensors.torch import save_file +from torch.distributed.checkpoint import FileSystemReader +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, TensorStorageMetadata +from torch.distributed.checkpoint._traverse import set_element +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + +# Import our model configs +from maester.models.deepseek import deepseek_configs + + +def ungroup_expert_weights(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Ungroup expert weights from grouped tensors back to individual tensors.""" + new_state_dict = {} + + for key, value in state_dict.items(): + # Handle grouped expert weights + if "moe.experts.w" in key and value.dim() == 3: + # This is a grouped tensor [num_experts, in_features, out_features] + # Split it into individual experts + parts = key.split(".") + layer_idx = parts[1] + weight_type = parts[4] # w1, w2, or w3 + + num_experts = value.shape[0] + for expert_idx in range(num_experts): + # Transpose back from [in_features, out_features] to [out_features, in_features] + expert_weight = value[expert_idx].t().contiguous() + new_key = f"layers.{layer_idx}.moe.experts.{expert_idx}.{weight_type}.weight" + new_state_dict[new_key] = expert_weight + print(f"Ungrouped {num_experts} experts from {key}") + + # Handle shared expert weights + elif "moe.shared_expert.w" in key and value.dim() == 3: + # This is a shared expert tensor [1, in_features, out_features] + # Don't ungroup - just keep as is for mapping + new_state_dict[key] = value + print(f"Keeping shared expert {key} with shape {value.shape} for mapping") + + else: + # Keep other weights as-is + new_state_dict[key] = value + + return new_state_dict + + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata, + is_coordinator: bool, + ) -> None: + assert not state_dict + + # rebuild the state dict from the metadata + for k, v in metadata.state_dict_metadata.items(): + if isinstance(v, TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + + +@torch.inference_mode() +def convert_deepseek_from_dcp( + checkpoint_dir: Path, + output_dir: Path, + original_model_dir: Optional[Path] = None +): + """ + Convert DeepSeek weights from DCP format back to HuggingFace format. + + Args: + checkpoint_dir: Directory containing DCP checkpoint + output_dir: Directory to save HuggingFace format weights + original_model_dir: Optional directory with original model (for config files) + """ + + print(f"Loading checkpoint from {checkpoint_dir}") + + # Load the DCP checkpoint + state_dict: STATE_DICT_TYPE = {} + storage_reader = FileSystemReader(str(checkpoint_dir)) + + _load_state_dict( + state_dict, + storage_reader=storage_reader, + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + + # Check if this is a full checkpoint or model-only + if 'model' in state_dict: + print(f"Full checkpoint detected, extracting model weights only. All keys: {list(state_dict.keys())}") + state_dict = state_dict['model'] + + # Remove '_orig_mod' suffix if present (from torch.compile) + state_dict = {k.replace('._orig_mod', ''): v for k, v in state_dict.items()} + + # Convert to bfloat16 to match the expected output format + print("Converting weights to bfloat16...") + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(torch.bfloat16) + + # Load config + config = None + job_config = None + + # 1. Look for job config.json in the job root directory + if "checkpoints" in str(checkpoint_dir): + current_path = checkpoint_dir + while current_path.name != "checkpoints" and current_path.parent != current_path: + current_path = current_path.parent + + if current_path.name == "checkpoints": + job_root = current_path.parent + config_path = job_root / "config.json" + + if config_path.exists(): + with open(config_path, "r") as f: + job_config = json.load(f) + print(f"Found job config at {config_path}") + + # 2. If we found a job config, load the architecture config from our definitions + if job_config and "model_name" in job_config and "flavor" in job_config: + model_name = job_config["model_name"] + flavor = job_config["flavor"] + + if model_name == "deepseek" and flavor in deepseek_configs: + model_args = deepseek_configs[flavor] + config = { + "vocab_size": model_args.vocab_size, + "hidden_size": model_args.dim, + "intermediate_size": model_args.inter_dim, + "moe_intermediate_size": model_args.moe_inter_dim, + "num_hidden_layers": model_args.n_layers, + "num_attention_heads": model_args.n_heads, + "num_key_value_heads": model_args.n_heads, # DeepSeek uses MLA, not GQA + "hidden_act": "silu", + "max_position_embeddings": model_args.max_seq_len, + "rms_norm_eps": model_args.norm_eps, + "tie_word_embeddings": model_args.tied_embeddings, + "rope_theta": model_args.rope_theta, + "attention_bias": False, + "attention_dropout": 0.0, + "n_routed_experts": model_args.n_routed_experts, + "n_shared_experts": model_args.n_shared_experts, + "num_experts_per_tok": model_args.n_activated_experts, + "first_k_dense_replace": model_args.n_dense_layers, + "norm_topk_prob": model_args.score_func == "softmax", + "scoring_func": model_args.score_func, + "aux_loss_alpha": model_args.load_balance_coeff, + "seq_aux": True, + "model_type": "deepseek_v2", + "q_lora_rank": model_args.q_lora_rank, + "kv_lora_rank": model_args.kv_lora_rank, + "qk_rope_head_dim": model_args.qk_rope_head_dim, + "qk_nope_head_dim": model_args.qk_nope_head_dim, + "v_head_dim": model_args.v_head_dim, + } + print(f"Loaded architecture config for {model_name} {flavor}") + + # 3. Finally, try original model directory (HF format) + if not config and original_model_dir and (original_model_dir / "config.json").exists(): + with open(original_model_dir / "config.json", "r") as f: + config = json.load(f) + print("Loaded config from original model (HF format)") + + # Convert to HF format + hf_state_dict = convert_to_hf_deepseek(state_dict, config) + + # Save in HuggingFace format + save_hf_checkpoint(hf_state_dict, output_dir, original_model_dir, config) + + print(f"Successfully converted to {output_dir}") + + +def convert_to_hf_deepseek(state_dict: Dict[str, torch.Tensor], config: Optional[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: + """Convert DCP DeepSeek model to HuggingFace format.""" + # First, ungroup expert weights + ungrouped_state_dict = ungroup_expert_weights(state_dict) + + # Keys to skip (runtime buffers, not model parameters) + skip_keys = ["freqs_cis", "expert_bias", "tokens_per_expert"] + + hf_state_dict = {} + for key, value in ungrouped_state_dict.items(): + # Skip known runtime buffers + if any(skip_key in key for skip_key in skip_keys): + continue + + # Map our keys to HF keys + new_key = map_to_hf_deepseek_key(key, value) + if new_key: + if new_key.startswith("__shared__"): + # Handle shared expert - squeeze and transpose + actual_key = new_key[10:] # Remove "__shared__" prefix + squeezed_transposed = value.squeeze(0).t().contiguous() + hf_state_dict[actual_key] = squeezed_transposed + print(f"Converted shared expert {key} from {value.shape} to {squeezed_transposed.shape}") + else: + hf_state_dict[new_key] = value + else: + print(f"Warning: Unmapped key {key}") + + # Add lm_head if using tied embeddings + if config and config.get("tie_word_embeddings", False) and "model.embed_tokens.weight" in hf_state_dict: + hf_state_dict["lm_head.weight"] = hf_state_dict["model.embed_tokens.weight"] + + # Fix embedding size if needed (for checkpoints trained before vocab_size fix) + if "model.embed_tokens.weight" in hf_state_dict: + embed_weight = hf_state_dict["model.embed_tokens.weight"] + embed_size = embed_weight.shape[0] + print(f"Embedding vocab size: {embed_size}") + + if config and "vocab_size" in config: + config_vocab_size = config["vocab_size"] + print(f"Config vocab size: {config_vocab_size}") + + if embed_size < config_vocab_size: + print(f"Padding embeddings from {embed_size} to {config_vocab_size}") + # Pad with zeros for the missing tokens + padding_size = config_vocab_size - embed_size + padding = torch.zeros(padding_size, embed_weight.shape[1], dtype=embed_weight.dtype) + hf_state_dict["model.embed_tokens.weight"] = torch.cat([embed_weight, padding], dim=0) + + # Also pad lm_head if it exists and has the same issue + if "lm_head.weight" in hf_state_dict: + lm_head_weight = hf_state_dict["lm_head.weight"] + if lm_head_weight.shape[0] == embed_size: + lm_head_padding = torch.zeros(padding_size, lm_head_weight.shape[1], dtype=lm_head_weight.dtype) + hf_state_dict["lm_head.weight"] = torch.cat([lm_head_weight, lm_head_padding], dim=0) + print(f"Also padded lm_head from {embed_size} to {config_vocab_size}") + + return hf_state_dict + + +def map_to_hf_deepseek_key(key: str, tensor: torch.Tensor) -> Optional[str]: + """Map our DeepSeek model keys to HuggingFace format.""" + + # Embeddings + if key == "tok_embeddings.weight": + return "model.embed_tokens.weight" + + # Final norm + if key == "norm.weight": + return "model.norm.weight" + + # LM head + if key == "output.weight": + return "lm_head.weight" + + # Layer components + if key.startswith("layers."): + parts = key.split(".") + layer_idx = parts[1] + + # Check normalization layers first (they're at the layer level) + if key == f"layers.{layer_idx}.attention_norm.weight": + return f"model.layers.{layer_idx}.input_layernorm.weight" + elif key == f"layers.{layer_idx}.ffn_norm.weight": + return f"model.layers.{layer_idx}.post_attention_layernorm.weight" + + # MLA (Multi-head Latent Attention) components + elif "attention." in key: + if key.endswith("wq.weight"): + return f"model.layers.{layer_idx}.self_attn.q_proj.weight" + elif key.endswith("wkv_a.weight"): + return f"model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight" + elif key.endswith("attention.kv_norm.weight"): + return f"model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight" + elif key.endswith("wkv_b.weight"): + return f"model.layers.{layer_idx}.self_attn.kv_b_proj.weight" + elif key.endswith("wo.weight"): + return f"model.layers.{layer_idx}.self_attn.o_proj.weight" + + # Dense FFN (for first layer) + elif "feed_forward" in key: + if key.endswith("w1.weight"): + return f"model.layers.{layer_idx}.mlp.gate_proj.weight" + elif key.endswith("w3.weight"): + return f"model.layers.{layer_idx}.mlp.up_proj.weight" + elif key.endswith("w2.weight"): + return f"model.layers.{layer_idx}.mlp.down_proj.weight" + + # MoE components + elif "moe" in key: + if key.endswith("router.gate.weight"): + return f"model.layers.{layer_idx}.mlp.gate.weight" + + # Shared expert (grouped tensor that needs special handling) + elif "shared_expert.w" in key and not key.endswith(".weight"): + # Extract weight type and handle the grouped tensor + if key.endswith("w1"): + # Need to squeeze and transpose the tensor + # This will be done in the main conversion function + return f"__shared__model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight" + elif key.endswith("w3"): + return f"__shared__model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight" + elif key.endswith("w2"): + return f"__shared__model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight" + + # Routed experts + elif "experts" in key and not "shared" in key: + # Extract expert index + expert_match = key.split(".") + expert_idx = expert_match[4] # layers.X.moe.experts.Y + + if key.endswith("w1.weight"): + return f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" + elif key.endswith("w3.weight"): + return f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" + elif key.endswith("w2.weight"): + return f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" + + return None + + +def save_hf_checkpoint( + state_dict: Dict[str, torch.Tensor], + output_dir: Path, + original_model_dir: Optional[Path], + config: Optional[Dict[str, Any]] = None +): + """Save checkpoint in HuggingFace format.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # Determine if we need to shard + total_size = sum(v.numel() * v.element_size() for v in state_dict.values()) + shard_size = 5 * 1024 * 1024 * 1024 # 5GB per shard + + if total_size > shard_size: + # Shard the checkpoint + shards = [] + current_shard = {} + current_size = 0 + + for key, tensor in state_dict.items(): + tensor_size = tensor.numel() * tensor.element_size() + if current_size + tensor_size > shard_size and current_shard: + shards.append(current_shard) + current_shard = {} + current_size = 0 + + current_shard[key] = tensor + current_size += tensor_size + + if current_shard: + shards.append(current_shard) + + # Save shards and create index + weight_map = {} + for i, shard in enumerate(shards): + shard_name = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors" + save_file(shard, output_dir / shard_name) + for key in shard: + weight_map[key] = shard_name + + # Create index file + index = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map + } + with open(output_dir / "model.safetensors.index.json", "w") as f: + json.dump(index, f, indent=2) + else: + # Save as single file + save_file(state_dict, output_dir / "model.safetensors") + + # Save config if we have one + if config: + with open(output_dir / "config.json", "w") as f: + json.dump(config, f, indent=2) + + # Copy other files from original model if available + if original_model_dir: + import shutil + + # Config files + config_files = ["config.json", "generation_config.json", "tokenizer_config.json"] + + # Tokenizer files + tokenizer_files = [ + "tokenizer.json", + "tokenizer.model", + "special_tokens_map.json", + ] + + for file_name in config_files + tokenizer_files: + src = original_model_dir / file_name + if src.exists(): + shutil.copy2(src, output_dir / file_name) + print(f"Copied {file_name}") + + print(f"Saved checkpoint to {output_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert DeepSeek weights from DCP to HuggingFace format.") + parser.add_argument( + "checkpoint_dir", + type=Path, + help="Input directory with DCP checkpoint" + ) + parser.add_argument( + "output_dir", + type=Path, + help="Output directory for HuggingFace format" + ) + parser.add_argument( + "--original-model-dir", + type=Path, + help="Original model directory (for config and tokenizer files)" + ) + + args = parser.parse_args() + + convert_deepseek_from_dcp( + args.checkpoint_dir, + args.output_dir, + args.original_model_dir + ) \ No newline at end of file diff --git a/scripts/convert/deepseek/to_dcp.py b/scripts/convert/deepseek/to_dcp.py new file mode 100644 index 0000000..9d7e90b --- /dev/null +++ b/scripts/convert/deepseek/to_dcp.py @@ -0,0 +1,260 @@ +""" +Convert DeepSeek models from HuggingFace format to DCP format for training. + +This script handles DeepSeek V2 Lite models with MoE (Mixture of Experts) architecture. + +Usage: + python scripts/convert_deepseek_to_dcp.py /path/to/deepseek-v2-lite output_dir + +The script: +- Converts HuggingFace DeepSeek weights to DCP format +- Handles MLA (Multi-head Latent Attention) components +- Converts MoE layers with shared and routed experts +- Maps layer normalization weights correctly +""" + +import argparse +import json +from pathlib import Path +from typing import Dict, Any + +import torch +import torch.distributed.checkpoint as DCP +from safetensors import safe_open + + +def group_expert_weights(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Group individual expert weights into grouped tensors.""" + new_state_dict = {} + expert_groups = {} + + # First pass: collect expert weights to group + for key, value in state_dict.items(): + if key.startswith("__group__"): + # Extract grouping info for routed experts + actual_key = key[9:] # Remove "__group__" prefix + parts = actual_key.split(".") + # Format: layers.X.moe.experts.Y.wN + layer_idx = parts[1] + expert_idx = int(parts[4]) + weight_type = parts[5] # w1, w2, or w3 + + group_key = f"layers.{layer_idx}.moe.experts.{weight_type}" + if group_key not in expert_groups: + expert_groups[group_key] = {} + expert_groups[group_key][expert_idx] = value + elif key.startswith("__shared__"): + # Handle shared expert - needs to be transposed and unsqueezed + actual_key = key[10:] # Remove "__shared__" prefix + # Transpose from [out_features, in_features] to [in_features, out_features] + # then add unsqueezed dimension to make it [1, in_features, out_features] + transposed = value.t() + new_state_dict[actual_key] = transposed.unsqueeze(0) + print(f"Converted shared expert {actual_key} from {value.shape} to {transposed.unsqueeze(0).shape}") + else: + # Keep non-expert weights as-is + new_state_dict[key] = value + + # Second pass: create grouped tensors + for group_key, experts in expert_groups.items(): + # Sort by expert index + sorted_experts = sorted(experts.items()) + max_expert_idx = max(experts.keys()) + + # Stack expert weights into a single tensor + # Create list with None placeholders for missing experts + expert_list = [] + for i in range(max_expert_idx + 1): + if i in experts: + # Transpose from [out_features, in_features] to [in_features, out_features] + expert_list.append(experts[i].t()) + else: + # This shouldn't happen with DeepSeek, but handle gracefully + print(f"Warning: Missing expert {i} in {group_key}") + # Use zeros with same shape as other experts + shape = next(iter(experts.values())).shape + expert_list.append(torch.zeros(shape[1], shape[0])) # Transposed shape + + # Stack along dim 0 to create [num_experts, in_features, out_features] + grouped_tensor = torch.stack(expert_list, dim=0) + new_state_dict[group_key] = grouped_tensor + print(f"Grouped {len(expert_list)} experts into {group_key} with shape {grouped_tensor.shape}") + + return new_state_dict + + +@torch.inference_mode() +def convert_deepseek_weights(input_dir: Path, output_dir: Path): + """ + Convert DeepSeek weights to DCP format. + + Args: + input_dir: Directory containing DeepSeek weights (safetensors format) + output_dir: Directory to save DCP checkpoint + """ + + # Find safetensors files + safetensors_files = sorted(list(input_dir.glob("*.safetensors"))) + if not safetensors_files: + # Check for index file + index_file = input_dir / "model.safetensors.index.json" + if index_file.exists(): + with open(index_file, "r") as f: + index = json.load(f) + safetensors_files = sorted(list(set( + input_dir / fname for fname in index["weight_map"].values() + ))) + + if not safetensors_files: + raise ValueError(f"No safetensors files found in {input_dir}") + + print(f"Found {len(safetensors_files)} safetensors files") + + # Load config if available + config_path = input_dir / "config.json" + if config_path.exists(): + with open(config_path, "r") as f: + config = json.load(f) + print(f"Loaded config: {config.get('model_type', 'unknown')} model") + else: + config = {} + + # Convert weights + state_dict = convert_deepseek_model(safetensors_files) + + # Group expert weights + state_dict = group_expert_weights(state_dict) + + # Save to DCP format + print("Writing to DCP format...") + output_dir.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(str(output_dir)) + DCP.save({"model": state_dict}, storage_writer=storage_writer) + + # Copy config from input directory if available + if config_path.exists(): + import shutil + shutil.copy2(config_path, output_dir / "config.json") + print("Copied config.json from input directory") + + print(f"Successfully converted to {output_dir}") + + +def convert_deepseek_model(safetensors_files: list) -> Dict[str, torch.Tensor]: + """Convert DeepSeek model weights.""" + state_dict = {} + + for file_path in safetensors_files: + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + + # Map HF names to our names + new_key = map_deepseek_key(key) + if new_key: + state_dict[new_key] = tensor.clone() + else: + print(f"Warning: Unmapped key {key}") + + return state_dict + + +def map_deepseek_key(key: str) -> str: + """Map HuggingFace DeepSeek keys to our format.""" + + # Embeddings + if key == "model.embed_tokens.weight": + return "tok_embeddings.weight" + + # Final norm + if key == "model.norm.weight": + return "norm.weight" + + # LM head + if key == "lm_head.weight": + # DeepSeek doesn't use tied embeddings by default + return "output.weight" + + # Layer components + if key.startswith("model.layers."): + parts = key.split(".") + layer_idx = parts[2] + + # MLA (Multi-head Latent Attention) components + if "self_attn" in key: + if key.endswith("q_proj.weight"): + return f"layers.{layer_idx}.attention.wq.weight" + elif key.endswith("kv_a_proj_with_mqa.weight"): + # This combines kv_a projection with rope dimensions + return f"layers.{layer_idx}.attention.wkv_a.weight" + elif key.endswith("kv_a_layernorm.weight"): + return f"layers.{layer_idx}.attention.kv_norm.weight" + elif key.endswith("kv_b_proj.weight"): + return f"layers.{layer_idx}.attention.wkv_b.weight" + elif key.endswith("o_proj.weight"): + return f"layers.{layer_idx}.attention.wo.weight" + + # MLP/MoE components + elif "mlp" in key: + # Dense layers (layer 0 in V2-Lite) + if "experts" not in key and "shared_experts" not in key and "mlp.gate." not in key: + if key.endswith("gate_proj.weight"): + return f"layers.{layer_idx}.feed_forward.w1.weight" + elif key.endswith("up_proj.weight"): + return f"layers.{layer_idx}.feed_forward.w3.weight" + elif key.endswith("down_proj.weight"): + return f"layers.{layer_idx}.feed_forward.w2.weight" + + # MoE layers + elif key.endswith("mlp.gate.weight"): + return f"layers.{layer_idx}.moe.router.gate.weight" + + # Shared experts - need to add dimension for GroupedExperts format + elif "shared_experts" in key: + if key.endswith("gate_proj.weight"): + return f"__shared__layers.{layer_idx}.moe.shared_expert.w1" + elif key.endswith("up_proj.weight"): + return f"__shared__layers.{layer_idx}.moe.shared_expert.w3" + elif key.endswith("down_proj.weight"): + return f"__shared__layers.{layer_idx}.moe.shared_expert.w2" + + # Routed experts - need to be collected into grouped tensors + elif "experts" in key: + # We'll handle this separately - mark for grouping + expert_idx = parts[5] # model.layers.X.mlp.experts.Y + if key.endswith("gate_proj.weight"): + return f"__group__layers.{layer_idx}.moe.experts.{expert_idx}.w1" + elif key.endswith("up_proj.weight"): + return f"__group__layers.{layer_idx}.moe.experts.{expert_idx}.w3" + elif key.endswith("down_proj.weight"): + return f"__group__layers.{layer_idx}.moe.experts.{expert_idx}.w2" + + # Normalization layers + elif key.endswith("input_layernorm.weight"): + return f"layers.{layer_idx}.attention_norm.weight" + elif key.endswith("post_attention_layernorm.weight"): + return f"layers.{layer_idx}.ffn_norm.weight" + + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert DeepSeek weights to DCP format.") + parser.add_argument( + "input_dir", + type=Path, + help="Input directory with DeepSeek weights (safetensors format)" + ) + parser.add_argument( + "output_dir", + type=Path, + help="Output directory for DCP checkpoint" + ) + + args = parser.parse_args() + + # Convert weights + convert_deepseek_weights( + args.input_dir, + args.output_dir + ) \ No newline at end of file diff --git a/scripts/convert_gemma_from_dcp.py b/scripts/convert/gemma/from_dcp.py similarity index 100% rename from scripts/convert_gemma_from_dcp.py rename to scripts/convert/gemma/from_dcp.py diff --git a/scripts/convert_gemma_to_dcp.py b/scripts/convert/gemma/to_dcp.py similarity index 100% rename from scripts/convert_gemma_to_dcp.py rename to scripts/convert/gemma/to_dcp.py diff --git a/scripts/convert_dcp_to_hf.py b/scripts/convert/llama/from_dcp.py similarity index 100% rename from scripts/convert_dcp_to_hf.py rename to scripts/convert/llama/from_dcp.py diff --git a/scripts/convert_mup_to_hf.py b/scripts/convert/mup_llama/from_dcp.py similarity index 100% rename from scripts/convert_mup_to_hf.py rename to scripts/convert/mup_llama/from_dcp.py diff --git a/train.py b/train.py index 2275ead..000a1a0 100644 --- a/train.py +++ b/train.py @@ -34,6 +34,7 @@ model_name_to_cls, models_config, model_name_to_parallelize, + model_name_to_optimizers_builder, ) from maester.parallelisms import ParallelDims from maester.profiling import (maybe_enable_memory_snapshot, @@ -79,6 +80,7 @@ def main(): else: logger.info("Using configuration from config.py") cfg = Config() + logger.info(f"Configuration: {cfg}") # take control of garbage collection to avoid stragglers gc.disable() @@ -90,6 +92,7 @@ def main(): dp_shard=cfg.data_parallel_shard_degree, dp_replicate=cfg.data_parallel_replicate_degree, tp=cfg.tensor_parallel_degree, + ep=cfg.expert_parallel_degree, world_size=world_size, enable_loss_parallel=cfg.enable_loss_parallel, ) @@ -102,7 +105,7 @@ def main(): ) as memory_profiler: # build meshes - world_mesh = parallel_dims.build_mesh(device_type="cuda") + world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree = dp_mesh.size() @@ -110,7 +113,8 @@ def main(): else: dp_degree, dp_rank = 1, 0 logger.info(f"world mesh: {world_mesh}") - logger.info(f"dp mesh: {dp_mesh}") + if parallel_dims.dp_enabled: + logger.info(f"dp mesh: {dp_mesh}") # Get tokenizer to determine vocab size if os.path.isfile(cfg.tokenizer_name): @@ -127,9 +131,9 @@ def main(): # 3. max_seq_len base on inputs model_config.norm_type = cfg.norm_type # Get vocab size from tokenizer (vocab_size is base vocabulary without added tokens) - if hasattr(model_config, 'vocab_size') and model_config.vocab_size > 0: - model_config.vocab_size = tokenizer.vocab_size - else: # rely on tokenizer to provide vocab size + # If model config already has a vocab_size set, respect it (e.g. for padded vocabularies) + if not hasattr(model_config, 'vocab_size') or model_config.vocab_size <= 0: + # Only set vocab size from tokenizer if not already configured model_config.vocab_size = len(tokenizer) model_config.max_seq_len = cfg.seq_len if cfg.enable_mup: @@ -150,13 +154,15 @@ def main(): model = model_cls.from_model_args(model_config) # log model size - model_param_count = get_num_params(model) - model_param_count_without_embedding = get_num_params(model, exclude_embedding=True) - num_flop_per_token = get_num_flop_per_token( - model_param_count if model.model_args.tied_embeddings else model_param_count_without_embedding, # count lm head matmul only - model_config, - cfg.seq_len, - ) + # model_param_count = get_num_params(model) + # model_param_count_without_embedding = get_num_params(model, exclude_embedding=True) + # num_flop_per_token = get_num_flop_per_token( + # model_param_count if model.model_args.tied_embeddings else model_param_count_without_embedding, # count lm head matmul only + # model_config, + # cfg.seq_len, + # ) + model_param_count, num_flop_per_token = model_config.get_nparams_and_flops(model, cfg.seq_len) + model_param_count_without_embedding = 0 logger.info( f"Model {cfg.model_name} {cfg.flavor} " f"size: {model_param_count:,} total parameters ({model_param_count_without_embedding:,} without embeddings)" @@ -225,29 +231,11 @@ def fw_hook(mod: torch.nn.Module, inp, out, key: str): # data_monitor = DataMonitor(train_state, log_freq=cfg.log_freq) - if cfg.enable_mup: - mup_decay_params = [] - decay_params = [] - nodecay_params = [] - for name, param in model.named_parameters(): - if param.dim() >= 2: - if 'attention' in name or 'feed_forward' in name: - # logger.info(f"Mup weight: {name}") - mup_decay_params.append(param) - else: - # logger.info(f"Decay weight: {name}") - decay_params.append(param) - else: - # logger.info(f"Nodecay weight: {name}") - nodecay_params.append(param) - optimizer: torch.optim.Optimizer = cfg.opt_class([ - {'params': mup_decay_params, 'weight_decay': cfg.opt_cfg['weight_decay'], 'lr': cfg.opt_cfg['lr'] / model_config.mup_width_mul}, - {'params': decay_params, 'weight_decay': cfg.opt_cfg['weight_decay'], 'lr': cfg.opt_cfg['lr']}, - {'params': nodecay_params, 'weight_decay': 0.0, 'lr': cfg.opt_cfg['lr']}, - ], **cfg.opt_cfg) - else: - optimizer: torch.optim.Optimizer = cfg.opt_class(model.parameters(), **cfg.opt_cfg) - scheduler = get_lr_scheduler(optimizer, cfg) + # Build optimizers using model-specific builder + optimizers_builder = model_name_to_optimizers_builder[cfg.model_name] + optimizers = optimizers_builder(model, cfg, parallel_dims) + + scheduler = get_lr_scheduler(optimizers, cfg) metric_logger = build_metric_logger(cfg) @@ -265,15 +253,15 @@ def loss_fn(pred, labels): # training loop cleanup_before_training() model.train() - if hasattr(optimizer, 'train'): # some optimizers need to be put in train mode (e.g. schedule free) - optimizer.train() # type: ignore (.train obviously exists) + if hasattr(optimizers, 'train'): # some optimizers need to be put in train mode (e.g. schedule free) + optimizers.train() # type: ignore (.train obviously exists) weight_scale_monitor = WeightScaleMonitor(model, log_freq=cfg.log_freq) # checkpointing checkpoint = CheckpointManager( model=model, - optimizer=optimizer, + optimizer=optimizers, lr_scheduler=scheduler, dataloader=data_loader, states={"train_state": train_state}, @@ -316,7 +304,7 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() + optimizers.zero_grad() # data_monitor.log_batch_samples(input_ids, labels, data_loader.dataset) # data_monitor.log_dataset_stats(data_loader.dataset) @@ -336,10 +324,12 @@ def loss_fn(pred, labels): del pred loss.backward() - grad_norms = clip_grad_norm( # note: maester.utils.clip_grad_norm, not torch.nn.utils.clip_grad_norm_ - model.parameters(), cfg.max_grad_norm, foreach=True - ) - optimizer.step() + + # TODO: re-enable grad clipping (broken w/ MoE) and/or monitoring? + # grad_norms = clip_grad_norm( # note: maester.utils.clip_grad_norm, not torch.nn.utils.clip_grad_norm_ + # model.parameters(), cfg.max_grad_norm, foreach=True + # ) + optimizers.step() scheduler.step() weight_scale_stats = weight_scale_monitor.step_monitor() @@ -361,19 +351,20 @@ def loss_fn(pred, labels): else: global_avg_loss, global_max_loss = avg_loss, max_loss - param_to_name = {param: name for name, param in model.named_parameters()} - exp_avgs, exp_avg_sqs, param_names = [], [], [] - for group in optimizer.param_groups: - for p in group['params']: - if p.grad is None: - continue - state = optimizer.state[p] - if 'exp_avg' in state: # Check if states initialized - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - param_names.append(param_to_name[p]) - exp_avg_norms = torch._foreach_norm(exp_avgs, 2) - exp_avg_sq_norms = torch._foreach_norm(exp_avg_sqs, 2) + # TODO: re-enable grad norm logging? + # param_to_name = {param: name for name, param in model.named_parameters()} + # exp_avgs, exp_avg_sqs, param_names = [], [], [] + # for group in optimizer.param_groups: + # for p in group['params']: + # if p.grad is None: + # continue + # state = optimizer.state[p] + # if 'exp_avg' in state: # Check if states initialized + # exp_avgs.append(state['exp_avg']) + # exp_avg_sqs.append(state['exp_avg_sq']) + # param_names.append(param_to_name[p]) + # exp_avg_norms = torch._foreach_norm(exp_avgs, 2) + # exp_avg_sq_norms = torch._foreach_norm(exp_avg_sqs, 2) time_delta = timer() - time_last_log @@ -405,23 +396,29 @@ def loss_fn(pred, labels): "memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries, "memory/num_ooms": gpu_mem_stats.num_ooms, } - for i in range(len(optimizer.param_groups)): + for i in range(len(optimizers.param_groups)): metrics[f"lr/group{i}"] = scheduler.get_last_lr()[i] - for gn, (name, _) in zip(grad_norms, model.named_parameters()): - cn = clean_param_name(name) - metrics[f"{cn}/grad_norm"] = gn - for exp_avg_norm, exp_avg_sq_norm, name in zip(exp_avg_norms, exp_avg_sq_norms, param_names): - cn = clean_param_name(name) - metrics[f"{cn}/exp_avg_norm"] = exp_avg_norm - metrics[f"{cn}/exp_avg_sq_norm"] = exp_avg_sq_norm + # for gn, (name, _) in zip(grad_norms, model.named_parameters()): + # cn = clean_param_name(name) + # metrics[f"{cn}/grad_norm"] = gn + # for exp_avg_norm, exp_avg_sq_norm, name in zip(exp_avg_norms, exp_avg_sq_norms, param_names): + # cn = clean_param_name(name) + # metrics[f"{cn}/exp_avg_norm"] = exp_avg_norm + # metrics[f"{cn}/exp_avg_sq_norm"] = exp_avg_sq_norm if cfg.enable_mup and cfg.mup_log_coord_check: for key in activation_stats: # type: ignore if activation_stats[key]: # type: ignore metrics[f'act/{key}_abs_mean'] = np.mean(activation_stats[key]) # type: ignore activation_stats = defaultdict(list) # reset # metrics.update(get_logits_metrics()) - if weight_scale_stats: + if weight_scale_stats is not None: metrics.update(weight_scale_stats) + + # Collect optimizer hook statistics if available + # TODO: This is a temporary solution - consider refactoring for cleaner API + if hasattr(optimizers, '_hook_stats') and optimizers._hook_stats: + metrics.update(optimizers._hook_stats) + if metric_logger is not None: metric_logger.log(metrics, step=train_state.step) From 0c81588ae8393f17ec54df9c0236b6b9cc74a2b0 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Mon, 28 Jul 2025 15:37:51 +0200 Subject: [PATCH 2/5] add missing new file optimizers.py --- maester/optimizers.py | 171 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 maester/optimizers.py diff --git a/maester/optimizers.py b/maester/optimizers.py new file mode 100644 index 0000000..594354f --- /dev/null +++ b/maester/optimizers.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.state_dict import ( + get_optimizer_state_dict, + set_optimizer_state_dict, +) + +from maester.config import Config +from maester.parallelisms import ParallelDims + + +class OptimizersContainer(Optimizer, Stateful): + """Container for potentially multiple optimizers of different types. + + This container allows using multiple optimizers (potentially of different types) + in a single training loop. By inheriting from Optimizer, we get hook functionality + for free. + + Note: Unlike TorchTitan, we support mixed optimizer types but don't support + checkpoint resharding. Checkpoints must be loaded with the same configuration. + """ + + def __init__(self, model: nn.Module, optimizers: List[torch.optim.Optimizer]): + self.model = model + self.optimizers = optimizers + + # Collect all parameters from all optimizers for parent init + all_params = [] + for opt in self.optimizers: + for group in opt.param_groups: + all_params.extend(group['params']) + + # Call parent Optimizer.__init__ to enable hook functionality + # We pass empty defaults since each optimizer has its own config + super().__init__(all_params, defaults={}) + + # HACK: Override param_groups to aggregate from all optimizers + # This is needed for LR scheduler compatibility + # TODO: Consider redesigning how LR schedulers interact with multi-optimizer setups + self.param_groups = [] + for opt in self.optimizers: + self.param_groups.extend(opt.param_groups) + + def zero_grad(self, *args, **kwargs) -> None: + """Zero gradients for all optimizers.""" + for optimizer in self.optimizers: + optimizer.zero_grad(*args, **kwargs) + + def step(self, *args, **kwargs) -> None: + """Step all optimizers.""" + # Sync learning rates from our param_groups to underlying optimizers + # This ensures LR scheduler updates are propagated + group_idx = 0 + for opt in self.optimizers: + for opt_group in opt.param_groups: + opt_group['lr'] = self.param_groups[group_idx]['lr'] + group_idx += 1 + + for optimizer in self.optimizers: + optimizer.step(*args, **kwargs) + + def state_dict(self) -> Dict[str, Any]: + """Return state dict for all optimizers. + + For mixed optimizer types, we use a simple numbered scheme. + This doesn't support resharding but allows different optimizer types. + """ + state_dict = {} + for i, opt in enumerate(self.optimizers): + # For single optimizer case, maintain backward compatibility + if len(self.optimizers) == 1: + return get_optimizer_state_dict(self.model, opt) + else: + opt_state = get_optimizer_state_dict(self.model, opt) + # Prefix keys to avoid collisions + for k, v in opt_state.items(): + state_dict[f"opt{i}_{k}"] = v + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load state dict for all optimizers.""" + if len(self.optimizers) == 1: + # Single optimizer - direct load for backward compatibility + set_optimizer_state_dict(self.model, self.optimizers[0], state_dict) + else: + # Multiple optimizers - split by prefix + for i, opt in enumerate(self.optimizers): + opt_state = {} + prefix = f"opt{i}_" + for k, v in state_dict.items(): + if k.startswith(prefix): + opt_state[k[len(prefix):]] = v + if opt_state: + set_optimizer_state_dict(self.model, opt, opt_state) + + # @property + # def param_groups(self) -> List[Dict[str, Any]]: + # """Return all param groups from all optimizers (for LR scheduler compatibility).""" + # all_param_groups = [] + # for opt in self.optimizers: + # all_param_groups.extend(opt.param_groups) + # return all_param_groups + + def __len__(self) -> int: + return len(self.optimizers) + + def __iter__(self): + return iter(self.optimizers) + + +def build_optimizers(model: nn.Module, cfg: Config, parallel_dims: ParallelDims) -> OptimizersContainer: + """Build default optimizers with optional MuP support. + + Args: + model: The model to optimize + cfg: Training configuration + parallel_dims: Parallelism dimensions + + Returns: + OptimizersContainer with configured optimizer(s) + """ + if cfg.enable_mup: + # MuP parameter grouping + mup_decay_params = [] + decay_params = [] + nodecay_params = [] + + for name, param in model.named_parameters(): + if param.dim() >= 2: + if 'attention' in name or 'feed_forward' in name: + mup_decay_params.append(param) + else: + decay_params.append(param) + else: + nodecay_params.append(param) + + # Get MuP width multiplier from model config + mup_width_mul = getattr(model.model_args, 'mup_width_mul', 1.0) + + optimizer = cfg.opt_class([ + { + 'params': mup_decay_params, + 'weight_decay': cfg.opt_cfg['weight_decay'], + 'lr': cfg.opt_cfg['lr'] / mup_width_mul + }, + { + 'params': decay_params, + 'weight_decay': cfg.opt_cfg['weight_decay'], + 'lr': cfg.opt_cfg['lr'] + }, + { + 'params': nodecay_params, + 'weight_decay': 0.0, + 'lr': cfg.opt_cfg['lr'] + }, + ], **cfg.opt_cfg) + else: + optimizer = cfg.opt_class(model.parameters(), **cfg.opt_cfg) + + return OptimizersContainer(model, [optimizer]) + + From 4bd507755ffd2d94644a6899b5cb4c26aa770747 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Mon, 28 Jul 2025 15:41:45 +0200 Subject: [PATCH 3/5] llama conversion scripts --- scripts/convert/llama/from_dcp.py | 7 ++-- scripts/convert/llama/to_dcp.py | 63 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 scripts/convert/llama/to_dcp.py diff --git a/scripts/convert/llama/from_dcp.py b/scripts/convert/llama/from_dcp.py index ec7731c..ee572b7 100644 --- a/scripts/convert/llama/from_dcp.py +++ b/scripts/convert/llama/from_dcp.py @@ -60,7 +60,7 @@ def set_up_planner( parser = argparse.ArgumentParser() parser.add_argument("src", type=str, help="Path to the source DCP model") parser.add_argument("dst", type=str, help="Path to the destination model") - parser.add_argument("--base", type=str, default=None, help="Path to HF model this is based on, also uses tokenizer unless --tokenizer is specified") # TODO: can we not do this?? + parser.add_argument("--base", type=str, required=True, help="Path to HF model this is based on, also uses tokenizer unless --tokenizer is specified") # TODO: can we not do this?? parser.add_argument("--tokenizer", type=str, default=None, help="Path to HF tokenizer this is based on") # TODO: can we not do this?? parser.add_argument("--name", type=str, required=True, help="Name (variant) of the model checkpoint to load, e.g. step-1000") parser.add_argument("--type", type=str, default="hf", choices=["hf", "pt"], help="Type of the destination model") @@ -96,7 +96,7 @@ def set_up_planner( hf_config.torch_dtype = dtype hf_config.num_hidden_layers = max([int(re.search(r'layers.(\d+)', k).group(1)) for k in sd.keys() if 'layers' in k]) + 1 hf_config.hidden_size = sd['layers.0.attention.wq.weight'].shape[0] - hf_config.num_attention_heads = 32 # TODO: read all these from a config + hf_config.num_attention_heads = 24 # TODO: read all these from a config hf_config.num_key_value_heads = 8 hf_config.intermediate_size = sd['layers.0.feed_forward.w1.weight'].shape[0] hf_config.vocab_size = sd['tok_embeddings.weight'].shape[0] @@ -164,12 +164,13 @@ def permute(w, n_heads, dim1=hf_config.hidden_size, dim2=hf_config.hidden_size): final_result[new_key] = value + print(f"embedding weights: {final_result['model.embed_tokens.weight']}") # Save weights # torch.save(weights_state_dict, os.path.join(args.dst, 'pytorch_model.bin')) safetensors.torch.save_file(final_result, os.path.join(dst_dir, 'model.safetensors'), metadata={"format": "pt"}) print('#' * 30) - print(f'HF checkpoint folder successfully created at {args.dst}.') + print(f'HF checkpoint folder successfully created at {dst_dir}.') if args.upload: from huggingface_hub import HfApi diff --git a/scripts/convert/llama/to_dcp.py b/scripts/convert/llama/to_dcp.py new file mode 100644 index 0000000..279de7f --- /dev/null +++ b/scripts/convert/llama/to_dcp.py @@ -0,0 +1,63 @@ +import argparse +import json +from pathlib import Path + +import torch +import torch.distributed.checkpoint as DCP + +@torch.inference_mode() +def convert_llama_weights(input_dir, output_dir): + with open(args.input_dir / "params.json", "r") as f: + params = json.load(f) + n_layers = params["n_layers"] + n_heads = params["n_heads"] + dim = params["dim"] + dims_per_head = dim // n_heads + + checkpoint_list = sorted([file for file in input_dir.rglob("*.pth")]) + print(f"Loading from these files: {checkpoint_list}.") + shards = [torch.load(ckpt, map_location="cpu", weights_only=True) for ckpt in checkpoint_list] + + n_heads_per_shard = n_heads // len(shards) + num_key_value_heads = params["n_kv_heads"] + n_kv_heads_per_shard = num_key_value_heads // len(shards) + key_value_dim = dims_per_head * num_key_value_heads + + if len(shards) == 1: + state_dict = shards[0] + else: # sharded + state_dict = {} + for layer in range(n_layers): + state_dict[f"layers.{layer}.attention_norm.weight"] = shards[0][f"layers.{layer}.attention_norm.weight"].clone() # replicated + state_dict[f"layers.{layer}.ffn_norm.weight"] = shards[0][f"layers.{layer}.ffn_norm.weight"].clone() # replicated + + for wn, nh in [("wq", n_heads_per_shard), + ("wk", n_kv_heads_per_shard), + ("wv", n_kv_heads_per_shard)]: + state_dict[f"layers.{layer}.attention.{wn}.weight"] = torch.cat( + [shards[i][f"layers.{layer}.attention.{wn}.weight"].view(nh, dims_per_head, dim) for i in range(len(shards))], dim=0 + ).reshape(nh * len(shards) * dims_per_head, dim) + + state_dict[f"layers.{layer}.attention.wo.weight"] = torch.cat([shards[i][f"layers.{layer}.attention.wo.weight"] for i in range(len(shards))], dim=1) + state_dict[f"layers.{layer}.feed_forward.w1.weight"] = torch.cat([shards[i][f"layers.{layer}.feed_forward.w1.weight"] for i in range(len(shards))], dim=0) + state_dict[f"layers.{layer}.feed_forward.w2.weight"] = torch.cat([shards[i][f"layers.{layer}.feed_forward.w2.weight"] for i in range(len(shards))], dim=1) + state_dict[f"layers.{layer}.feed_forward.w3.weight"] = torch.cat([shards[i][f"layers.{layer}.feed_forward.w3.weight"] for i in range(len(shards))], dim=0) + + state_dict["norm.weight"] = shards[0]["norm.weight"] + state_dict["tok_embeddings.weight"] = torch.cat([shards[i]["tok_embeddings.weight"] for i in range(len(shards))], dim=0) + state_dict["output.weight"] = torch.cat([shards[i]["output.weight"] for i in range(len(shards))], dim=0) + + print("Writing to DCP...") + args.output_dir.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_dir) + DCP.save({"model": state_dict}, storage_writer=storage_writer) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Llama weights to DCP format.") + parser.add_argument( + "input_dir", type=Path, help="Input directory with original Llama weights." + ) + parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") + args = parser.parse_args() + + convert_llama_weights(args.input_dir, args.output_dir) From 3cfa0fb4d50a023ab6989f7e8a93d3f6d8bdfe85 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Fri, 1 Aug 2025 11:22:43 +0200 Subject: [PATCH 4/5] register triton op for compatibility with sac --- maester/models/deepseek/moe_indices.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maester/models/deepseek/moe_indices.py b/maester/models/deepseek/moe_indices.py index 30f7d98..d2b852d 100644 --- a/maester/models/deepseek/moe_indices.py +++ b/maester/models/deepseek/moe_indices.py @@ -65,7 +65,7 @@ def _fill_indices_kernel( # wrapper # ============== - +@torch.library.triton_op("moe_indices::fill_indices", mutates_args={}) def fill_indices_wrapper( tokens_per_expert_group: torch.Tensor, start_index_values: torch.Tensor, @@ -87,7 +87,7 @@ def fill_indices_wrapper( grid = (num_blocks,) # launch kernel - _fill_indices_kernel[grid]( + torch.library.wrap_triton(_fill_indices_kernel)[grid]( tokens_per_expert_group, start_index_values, write_offsets, From 62b0e75aa3d20f0bbc64fa4c12e84dc4e8866531 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Fri, 1 Aug 2025 12:31:23 +0200 Subject: [PATCH 5/5] triton type annotation --- maester/models/deepseek/moe_indices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maester/models/deepseek/moe_indices.py b/maester/models/deepseek/moe_indices.py index d2b852d..58af6d1 100644 --- a/maester/models/deepseek/moe_indices.py +++ b/maester/models/deepseek/moe_indices.py @@ -75,7 +75,7 @@ def fill_indices_wrapper( max_len: int, block_size: int = 128, max_blocks: int = 1024, # cap on total number of blocks to launch -): +) -> torch.Tensor: # preallocate output permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device