Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion maester/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, ConfigDict, Field, ImportString
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Callable, Type, Any
from typing import Callable, Type, Any, Literal
from pathlib import Path
import torch

Expand Down Expand Up @@ -39,6 +39,7 @@ class Config(BaseSettings):
data_parallel_shard_degree: int = 8
data_parallel_replicate_degree: int = 32
tensor_parallel_degree: int = 1
expert_parallel_degree: int = 1
train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch
train_num_steps: int = 22000 # ~92B tokens
compile: bool = True # TODO: only compiles TransformerBlocks until PyTorch supports full fsdp2
Expand Down Expand Up @@ -100,10 +101,25 @@ class Config(BaseSettings):
# fsdp
mixed_precision_param: str = 'bfloat16'
mixed_precision_reduce: str = 'float32'
enable_cpu_offload: bool = False
fsdp_reshard_after_forward: Literal["default", "never", "always"] = "default"

# activation checkpointing
ac_mode: str = "none" # "full" | "selective" | "none"
selective_ac_option: str | int = "op"
per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = Field(
default_factory=lambda: ["moe.router.gate"]
)
"""
When per-op selective ac is used, this list of fully qualified names is used
to determine which mm shapes to force recompute, rather than being considered
by rest of the sac policy, e.g save every other mm. Only nn.Linear modules are
supported today.

Note: this config applies to mms not limited to those matching the specified
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
"""

# experimental
enable_async_tensor_parallel: bool = False
Expand Down
22 changes: 14 additions & 8 deletions maester/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,38 @@

from maester.models.llama import llama2_configs, llama3_configs, mistral_configs, Transformer
from maester.models.gemma import gemma3_configs, GemmaTextModel
from maester.parallelisms import parallelize_gemma, parallelize_llama
from maester.models.deepseek import deepseek_configs, DeepSeekModel, build_deepseek_optimizers
from maester.parallelisms import parallelize_gemma, parallelize_llama, parallelize_deepseek
from maester.optimizers import build_optimizers

models_config = {
"llama2": llama2_configs,
"llama3": llama3_configs,
"mistral": mistral_configs,
"gemma3": gemma3_configs,
"deepseek": deepseek_configs,
}

model_name_to_cls = {
"llama2": Transformer,
"llama3": Transformer,
"mistral": Transformer,
"gemma3": GemmaTextModel,
}

model_name_to_tokenizer = {
"llama2": "sentencepiece",
"llama3": "tiktoken",
"mistral": "sentencepiece",
"gemma3": "sentencepiece",
"deepseek": DeepSeekModel,
}

model_name_to_parallelize = {
"llama2": parallelize_llama,
"llama3": parallelize_llama,
"mistral": parallelize_llama,
"gemma3": parallelize_gemma,
"deepseek": parallelize_deepseek,
}

model_name_to_optimizers_builder = {
"llama2": build_optimizers,
"llama3": build_optimizers,
"mistral": build_optimizers,
"gemma3": build_optimizers,
"deepseek": build_deepseek_optimizers,
}
133 changes: 133 additions & 0 deletions maester/models/deepseek/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from maester.models.deepseek.model import DeepSeekModelArgs, DeepSeekModel
from maester.models.deepseek.optimizer import build_deepseek_optimizers

__all__ = ["DeepSeekModel", "DeepSeekModelArgs", "build_deepseek_optimizers"]

deepseek_configs = {
"debug": DeepSeekModelArgs(
max_batch_size=32,
vocab_size=102400,
dim=512,
inter_dim=2048,
moe_inter_dim=512,
n_layers=8,
n_dense_layers=1,
n_heads=16,
n_routed_experts=8,
n_shared_experts=2,
n_activated_experts=3,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"debug_flex": DeepSeekModelArgs(
vocab_size=2000,
dim=256,
inter_dim=1024,
moe_inter_dim=256,
n_layers=3,
n_dense_layers=1,
n_heads=16,
n_routed_experts=8,
n_shared_experts=2,
n_activated_experts=3,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"16B": DeepSeekModelArgs(
vocab_size=102400,
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
n_heads=16,
n_routed_experts=64,
n_shared_experts=2,
n_activated_experts=6,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="causal",
),
"45B_custom": DeepSeekModelArgs(
vocab_size=102400,
dim=3072,
inter_dim=8192,
moe_inter_dim=1408,
n_layers=20,
n_dense_layers=1,
n_heads=128,
n_routed_experts=160,
n_shared_experts=2,
n_activated_experts=6,
route_scale=16.0,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="causal",
),
# TODO: check correctness of configs below here
"236B": DeepSeekModelArgs(
vocab_size=102400,
dim=5120,
inter_dim=12288,
moe_inter_dim=1536,
n_layers=60,
n_dense_layers=1,
n_heads=128,
n_routed_experts=160,
n_shared_experts=2,
n_activated_experts=6,
route_scale=16.0,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="causal",
),
"685B": DeepSeekModelArgs(
vocab_size=129280,
dim=7168,
inter_dim=18432,
moe_inter_dim=2048,
n_layers=61,
n_dense_layers=3,
n_heads=128,
n_routed_experts=256,
n_shared_experts=1,
n_activated_experts=8,
route_scale=2.5,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=1.0,
use_flex_attn=True,
attn_mask_type="causal",
),
}
102 changes: 102 additions & 0 deletions maester/models/deepseek/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dataclasses import dataclass
from typing import Optional, Literal
from torch import nn
from maester.log_utils import logger

@dataclass
class DeepSeekModelArgs:
max_batch_size: int = 16
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
norm_eps: float = 1e-5 # eps used for RMSNorm
tied_embeddings: bool = False # always False for DeepSeek

# MoE
n_routed_experts: int = 8
n_shared_experts: int = 0
n_activated_experts: int = 4
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.0
use_grouped_mm: bool = True
load_balance_coeff: float = 1e-3

# MLA
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
use_flex_attn: bool = False
attn_mask_type: str = "causal"

# Yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
"""
Adopted from llama4 implementation.
"""
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_experts * self.n_activated_experts // self.n_routed_experts
)

logger.info(
f"Total parameter count: dense {nparams_dense:,}, "
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
)

l, h, q, t = (
self.n_layers,
self.n_heads,
self.dim // self.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
num_flops_per_token = (
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
+ 12 * l * h * q * t
)

return nparams, num_flops_per_token
Loading