diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6455447ac..dcd9007a2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,18 @@ def _check_moe_calibration_complete(quantizer, parallel_state): @torch.no_grad() -def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): +def max_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync=True, +): """Calibrate the model using max. Args: model: Model to be calibrated. forward_loop: A callable which takes the model as argument and forwards calibration data through the model. + distributed_sync: Whether to sync input_quantizer amax across distributed processes. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -113,7 +118,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) - # Sync amax across local experts within each rank (for SequentialMLP) + # Sync input_quantizer amax across local experts within each rank (for SequentialMLP) for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax() diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6e92fce90..e84735ae9 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -575,10 +575,10 @@ def _setup(self): expert.linear_fc2.parallel_state = self.parallel_state def layer_sync_moe_local_experts_amax(self): - """Sync amax across local experts in a SequentialMLP. + """Sync input quantizer amax across local experts in a SequentialMLP. - Synchronize the amax values across local experts in a lyaer such that all local experts will - share the same amax. This function operates on a single rank and does not require distributed sync. + Ensures all experts have the same input quantizer amax.This function operates + on a single rank and does not require distributed sync. Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). This function should be called before the distributed sync to ensure the amax values @@ -586,15 +586,19 @@ def layer_sync_moe_local_experts_amax(self): Note: Because there are logic which calls collective communication based on whether amax is not None, - We need to garuantee that all experts must have amax. Otherwise, there will be deadlock - when synchroizing over EP since some ranks may have amax None and not calling the collective + We need to guarantee that all experts must have amax. Otherwise, there will be deadlock + when synchronizing over EP since some ranks may have amax None and not calling the collective communication. """ # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: for name, module in expert.named_modules(): - if isinstance(module, TensorQuantizer) and module.amax is not None: + if ( + isinstance(module, TensorQuantizer) + and module.amax is not None + and "input_quantizer" in name + ): stored_amax = amax_dict.get(name) amax_tensor = module.amax.detach().clone() amax_dict[name] = ( diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index b107eca71..e1866cba3 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -473,10 +473,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device, @pytest.mark.parametrize( "config", - [ - NVFP4_GEMM_KV_CFG, - FP8_GEMM_KV_CFG, - ], + [NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG], ) def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config): """Test sharded state dict for hybrid Mamba MOE models.""" @@ -735,6 +732,81 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus): ) +@pytest.mark.parametrize("ep_size", [1, 2]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + size = torch.cuda.device_count() + if size < ep_size: + pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test") + + spawn_multiprocess_job( + size=size, + job=partial( + _test_layer_sync_moe_local_experts_amax, + ep_size, + moe_grouped_gemm, + ), + backend="nccl", + ) + + +def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size): + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=1, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=1, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=8, + transformer_impl="modelopt", + ) + quant_cfg = mtq.FP8_DEFAULT_CFG + model = mtq.quantize(model, quant_cfg, get_forward(model)) + + for layer in model.decoder.layers: + layer.mlp.experts.layer_sync_moe_local_experts_amax() + + for layer in model.decoder.layers: + # Check input quantizer amax is synced across local experts + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.input_quantizer.amax is not None + assert expert.linear_fc2.input_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.input_quantizer.amax + else: + assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.input_quantizer.amax + else: + assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax) + + # Check weight quantizer amax is different across local experts + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.weight_quantizer.amax is not None + assert expert.linear_fc2.weight_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.weight_quantizer.amax + else: + assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.weight_quantizer.amax + else: + assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + + def _test_expert_model_parallel_amax_sync( tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size ): @@ -815,9 +887,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): if size < ep_size * etp_size: pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") - if moe_grouped_gemm: - pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - spawn_multiprocess_job( size=size, job=partial(