From 3a123bec156c90d46b599e38f9ecccd8aee28961 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 20 Feb 2026 00:42:28 +0000 Subject: [PATCH 1/4] Enable user to specify MOE expert calibration ratio Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 10 +++++ examples/llm_ptq/hf_ptq.py | 10 +++++ modelopt/torch/export/moe_utils.py | 2 +- modelopt/torch/quantization/config.py | 10 +++++ modelopt/torch/quantization/mode.py | 6 +++ .../torch/quantization/plugins/huggingface.py | 37 +++++++++++++------ 6 files changed, 63 insertions(+), 12 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index d8bff7ba2..ccc594612 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -201,6 +201,7 @@ def build_quant_cfg( model_type, quant_cfg_choices, kv_quant_cfg_choices, + moe_calib_experts_ratio, ) -> dict[str, Any]: quant_cfg = {} assert qformat in quant_cfg_choices, ( @@ -232,6 +233,15 @@ def build_quant_cfg( getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], ) + if moe_calib_experts_ratio: + if isinstance(quant_cfg["algorithm"], str): + quant_cfg["algorithm"] = { + "method": quant_cfg["algorithm"], + "moe_calib_experts_ratio": moe_calib_experts_ratio, + } + else: + quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d7aadf994..7d3db93f3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -906,6 +906,7 @@ def quantize_main( model_type, QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES, + args.moe_calib_experts_ratio, ) # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) @@ -1126,6 +1127,15 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--moe_calib_experts_ratio", + type=float, + default=1.0 / 4, + help=( + "Percentage of experts to calibrate during forward pass. Only used for MOE models. " + "This is used to reduce the number of experts to calibrate during forward pass. " + ), + ) return parser.parse_args() diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index a5ba465b1..dc3574868 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -48,7 +48,7 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non "th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }", "th { background: #f0f0f0; }", "", - "

Expert Token Counts (per MoE layer)

", + "

Expert Calib Token Counts (per MoE layer)

", "", ] html_parts.extend(f"" for i in range(num_experts)) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 291acba03..26c1e865f 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1070,6 +1070,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.", ) + moe_calib_experts_ratio: float | None = ModeloptField( + default=None, + title="% of experts to calibrate during forward pass.", + description=( + "If specified, we force forward tokens to % of experts during the calibration" + " pass. This forward is for calibration purpose only and will not affect the" + " actual inference." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index bfcdb64da..37e41ae31 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -216,6 +216,12 @@ def wrapped_calib_func( # For backward compatibility kwargs["algorithm"] = method + moe_calib_experts_ratio = kwargs.pop("moe_calib_experts_ratio", None) + if moe_calib_experts_ratio is not None: + for module in model.modules(): + if hasattr(module, "_moe_calib_experts_ratio"): + module._moe_calib_experts_ratio = moe_calib_experts_ratio + if func is not None: # Call the function with forward_loop as a separate argument func(model, forward_loop=forward_loop, **kwargs) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index aa274ea7e..cc8301927 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -458,8 +458,9 @@ def _setup(self): elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): num_experts = self.experts.num_experts - self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") + self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cuda") self._count_expert_tokens = False + self._moe_calib_experts_ratio = None if num_experts == 0: warnings.warn( @@ -483,36 +484,50 @@ def _gate_forward_hook(self, module, input, output): logits = output if not isinstance(output, tuple) else output[0] top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k _, indices = torch.topk(logits.float(), top_k, dim=-1) - counts = torch.bincount( - indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) - ) - self.expert_token_count += counts + counts = torch.bincount(indices.reshape(-1), minlength=len(self.expert_token_count)) + self.expert_token_count += counts.to(self.expert_token_count.device) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) - if is_calib: + self._count_expert_tokens = is_calib + if is_calib and self._moe_calib_experts_ratio: + self._count_expert_tokens = True + assert 0 < self._moe_calib_experts_ratio <= 1, ( + "moe_calib_experts_ratio must be between 0 and 1" + ) # If any of the experts are in calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k if TRANSFORMERS_VERSION_GE_5_0: assert hasattr(self, "gate") and hasattr(self.gate, "top_k") original_top_k = self.gate.top_k - self.gate.top_k = self.gate.num_experts + self.gate.top_k = round(self.gate.num_experts * self._moe_calib_experts_ratio) + assert self.gate.top_k >= original_top_k, ( + f"moe_calib_experts_ratio {self._moe_calib_experts_ratio}," + f" calib top_k {self.gate.top_k} smaller than original" + f" top_k {original_top_k}" + ) super().forward(hidden_states) self.gate.top_k = original_top_k else: # Path for transformers < 5.0 original_top_k = self.top_k if hasattr(self, "num_experts"): - self.top_k = self.num_experts + self.top_k = round(self.num_experts * self._moe_calib_experts_ratio) elif hasattr(self, "experts"): - self.top_k = self.experts.num_experts + self.top_k = round(self.experts.num_experts * self._moe_calib_experts_ratio) else: raise ValueError(f"Could not find num_experts in module {self}") + assert self.top_k >= original_top_k, ( + f"moe_calib_experts_ratio {self._moe_calib_experts_ratio}," + f" calib top_k {self.top_k} smaller than original" + f" top_k {original_top_k}" + ) super().forward(hidden_states) self.top_k = original_top_k - # Enable counting only for the real-routing forward during calibration - self._count_expert_tokens = is_calib + self._count_expert_tokens = False + else: + self._count_expert_tokens = True output = super().forward(hidden_states) self._count_expert_tokens = False return output From e16f04493519439b38c514f104f183cae8c5d669 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 20 Feb 2026 03:18:03 +0000 Subject: [PATCH 2/4] Add changelog Signed-off-by: Chenjie Luo --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 744238656..a9933a1fe 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ NVIDIA Model Optimizer Changelog (Linux) - User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. - ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. +- Add ``--moe_calib_experts_percentage`` flag in ``hf_ptq.py`` to specify the percentage of experts to calibrate during forward pass to improve expert coverage during calibration. Default to 1/4 of all the experts. - Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. 0.42 (2026-02-xx) From 440798133f06533c3721efbf32f5e51fde31e33f Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Thu, 19 Feb 2026 19:29:26 -0800 Subject: [PATCH 3/4] Update CHANGELOG.rst Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a9933a1fe..b7047213b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,7 +8,7 @@ NVIDIA Model Optimizer Changelog (Linux) - User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. - ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. -- Add ``--moe_calib_experts_percentage`` flag in ``hf_ptq.py`` to specify the percentage of experts to calibrate during forward pass to improve expert coverage during calibration. Default to 1/4 of all the experts. +- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to 1/4 of all the experts. - Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. 0.42 (2026-02-xx) From f89ef8ac05ea9f18143c23c3a3b3f8c3c4c75825 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 20 Feb 2026 19:12:09 +0000 Subject: [PATCH 4/4] Fix Signed-off-by: Chenjie Luo --- modelopt/torch/quantization/plugins/huggingface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index cc8301927..cb54faf1b 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -458,7 +458,9 @@ def _setup(self): elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): num_experts = self.experts.num_experts - self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cuda") + self.expert_token_count = torch.zeros( + num_experts, dtype=torch.long, device=next(self.parameters()).device + ) self._count_expert_tokens = False self._moe_calib_experts_ratio = None
Layer/Expert{i}