Skip to content

Commit 9ceae8a

Browse files
committed
llama context parallelism
1 parent 547476b commit 9ceae8a

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

maester/models/llama/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from maester.models.norms import create_norm
2121
from maester.models.llama.tied_linear import TiedLinear
2222

23+
from torch.distributed.device_mesh import DeviceMesh
24+
2325

2426
@dataclass
2527
class ModelArgs:
@@ -457,12 +459,13 @@ def forward(
457459
return output
458460

459461
@classmethod
460-
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
462+
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: Optional[DeviceMesh] = None) -> "Transformer":
461463
"""
462464
Initialize a Transformer model from a ModelArgs object.
463465
464466
Args:
465467
model_args (ModelArgs): Model configuration arguments.
468+
cp_device_mesh (Optional[DeviceMesh]): Device mesh for context parallelism.
466469
467470
Returns:
468471
Transformer: Transformer model.

maester/parallelisms/parallelize_llama.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,21 @@ def parallelize_llama(
6060
"fused_rmsnorm is not compatible with torch.compile yet. "
6161
"Please use rmsnorm or layernorm."
6262
)
63-
apply_compile(model)
63+
apply_compile(model, fullgraph=not parallel_dims.cp_enabled)
6464

65-
if parallel_dims.dp_shard_enabled:
65+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
6666
if parallel_dims.dp_replicate_enabled:
67-
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
67+
if parallel_dims.cp_enabled:
68+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
69+
else:
70+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
6871
else:
69-
dp_mesh = world_mesh["dp"]
72+
if parallel_dims.cp_enabled:
73+
dp_mesh_dim_names = ("dp_shard_cp",)
74+
else:
75+
dp_mesh_dim_names = ("dp",)
7076

77+
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
7178
apply_fsdp(model, dp_mesh, param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param],
7279
reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce])
7380
if parallel_dims.dp_replicate_enabled:
@@ -239,16 +246,16 @@ def apply_ac(model: nn.Module, ac_config: Config):
239246
logger.info(f"Applied {ac_config.ac_mode} activation checkpointing to the model")
240247

241248

242-
def apply_compile(model: nn.Module):
249+
def apply_compile(model: nn.Module, fullgraph: bool = True):
243250
"""
244251
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
245252
repeated structure. Alternatively one can compile the whole model (after applying DP).
246253
"""
247254
for layer_id, transformer_block in model.layers.named_children():
248-
transformer_block = torch.compile(transformer_block, fullgraph=True)
255+
transformer_block = torch.compile(transformer_block, fullgraph=fullgraph)
249256
model.layers.register_module(layer_id, transformer_block)
250257

251-
logger.info("Compiling each TransformerBlock with torch.compile")
258+
logger.info(f"Compiling each TransformerBlock with torch.compile (fullgraph={fullgraph})")
252259

253260

254261
def apply_fsdp(

0 commit comments

Comments
 (0)