@@ -60,14 +60,21 @@ def parallelize_llama(
6060 "fused_rmsnorm is not compatible with torch.compile yet. "
6161 "Please use rmsnorm or layernorm."
6262 )
63- apply_compile (model )
63+ apply_compile (model , fullgraph = not parallel_dims . cp_enabled )
6464
65- if parallel_dims .dp_shard_enabled :
65+ if parallel_dims .dp_shard_enabled or parallel_dims . cp_enabled :
6666 if parallel_dims .dp_replicate_enabled :
67- dp_mesh = world_mesh ["dp_replicate" , "dp_shard" ]
67+ if parallel_dims .cp_enabled :
68+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard_cp" )
69+ else :
70+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard" )
6871 else :
69- dp_mesh = world_mesh ["dp" ]
72+ if parallel_dims .cp_enabled :
73+ dp_mesh_dim_names = ("dp_shard_cp" ,)
74+ else :
75+ dp_mesh_dim_names = ("dp" ,)
7076
77+ dp_mesh = world_mesh [tuple (dp_mesh_dim_names )]
7178 apply_fsdp (model , dp_mesh , param_dtype = TORCH_DTYPE_MAP [config .mixed_precision_param ],
7279 reduce_dtype = TORCH_DTYPE_MAP [config .mixed_precision_reduce ])
7380 if parallel_dims .dp_replicate_enabled :
@@ -239,16 +246,16 @@ def apply_ac(model: nn.Module, ac_config: Config):
239246 logger .info (f"Applied { ac_config .ac_mode } activation checkpointing to the model" )
240247
241248
242- def apply_compile (model : nn .Module ):
249+ def apply_compile (model : nn .Module , fullgraph : bool = True ):
243250 """
244251 Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
245252 repeated structure. Alternatively one can compile the whole model (after applying DP).
246253 """
247254 for layer_id , transformer_block in model .layers .named_children ():
248- transformer_block = torch .compile (transformer_block , fullgraph = True )
255+ transformer_block = torch .compile (transformer_block , fullgraph = fullgraph )
249256 model .layers .register_module (layer_id , transformer_block )
250257
251- logger .info ("Compiling each TransformerBlock with torch.compile" )
258+ logger .info (f "Compiling each TransformerBlock with torch.compile (fullgraph= { fullgraph } ) " )
252259
253260
254261def apply_fsdp (
0 commit comments