From 44e002eb5137415a215eb7290f8b7531e3114223 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 16 Jan 2026 00:16:33 +0000 Subject: [PATCH 01/10] add local hessian calibration Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 133 +++++++++++ modelopt/torch/quantization/mode.py | 27 ++- modelopt/torch/quantization/model_calib.py | 259 ++++++++++++++++++++- 3 files changed, 417 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..b3fcbb0ab 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -388,6 +388,69 @@ "algorithm": "max", } +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, + }, +} + + +NVFP4_LOCAL_HESSIAN_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "local_hessian", + "fp8_scale_sweep": True, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -1060,6 +1123,76 @@ class MseCalibConfig(QuantizeAlgorithmConfig): ) +class LocalHessianCalibConfig(QuantizeAlgorithmConfig): + """Configuration for local Hessian-weighted MSE calibration. + + This algorithm uses activation information to optimize per-block scales for weight + quantization. It minimizes the output reconstruction error by weighting the loss + with the local Hessian matrix computed from input activations. + + The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where: + - ``dw = weight - quantized_weight`` (weight reconstruction error per block) + - ``H = X @ X.T`` is the local Hessian computed from input activations X + + This method is particularly effective for NVFP4 weight-only quantization where + activation information helps select better per-block scales. + + """ + + method: Literal["local_hessian"] = ModeloptField("local_hessian") + + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", + ) + + start_multiplier: float | None = ModeloptField( + default=0.25, + gt=0.0, + title="Starting multiplier for amax search.", + description="Starting multiplier for amax search range (multiplies initial amax).", + ) + + stop_multiplier: float | None = ModeloptField( + default=4.0, + gt=0.0, + title="Ending multiplier for amax search.", + description="Ending multiplier for amax search range (multiplies initial amax).", + ) + + fp8_scale_sweep: bool | None = ModeloptField( + default=True, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep over all 128 possible FP8 E4M3 scale values " + "for NVFP4 per-block quantization instead of using multipliers. " + "This is the recommended setting for NVFP4 quantization.", + ) + + block_size: int | None = ModeloptField( + default=16, + gt=0, + title="Block size for local Hessian computation.", + description="The block size used for computing the local Hessian matrix. " + "This should match the block size used in the quantization config. " + "Default is 16 for NVFP4.", + ) + + distributed_sync: bool | None = ModeloptField( + default=True, + title="Whether to sync the amax across the distributed processes.", + description="If True, the amax will be synced across the distributed processes.", + ) + + debug: bool | None = ModeloptField( + default=False, + title="Debug mode.", + description="If True, module's local Hessian metadata will be kept as a module attribute.", + ) + + class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): """The config for ``smoothquant`` algorithm (SmoothQuant). diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index bfcdb64da..1f3346ea9 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -38,6 +38,7 @@ AWQLiteCalibConfig, CompressConfig, GPTQLiteConfig, + LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, QuantizeAlgoCfgType, @@ -56,7 +57,15 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, gptq_lite, max_calibrate, mse_calibrate, smoothquant, svdquant +from .model_calib import ( + awq, + gptq_lite, + local_hessian_calibrate, + max_calibrate, + mse_calibrate, + smoothquant, + svdquant, +) __all__ = ["BaseCalibrateModeDescriptor"] @@ -377,6 +386,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: _calib_func = mse_calibrate +@CalibrateModeRegistry.register_mode +class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for local Hessian-weighted MSE calibration algorithm. + + This algorithm uses activation information to optimize per-block scales for weight + quantization by minimizing output reconstruction error instead of weight reconstruction error. + """ + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return LocalHessianCalibConfig + + _calib_func = local_hessian_calibrate + + @CalibrateModeRegistry.register_mode class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor): """Mode for smoothquant calibration algorithm.""" diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d7b3a32b9..be85d2a1f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -18,6 +18,7 @@ import math import os import warnings +from collections.abc import Callable from functools import partial import torch @@ -48,7 +49,7 @@ weight_attr_names, ) -__all__ = ["awq", "max_calibrate", "smoothquant", "svdquant"] +__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] def weight_only_quantize(model: nn.Module): @@ -372,6 +373,262 @@ def mse_calibrate( # TODO: Sync amax across distributed processes +@torch.no_grad() +def local_hessian_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync: bool = True, + step_size: float = 0.1, + start_multiplier: float = 0.25, + stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = True, + block_size: int = 16, + debug: bool = False, +): + """Calibrate the model using local Hessian-weighted MSE search. + + This calibration method collects input activations during forward pass, computes + per-block local Hessian matrices (H = X @ X.T), and uses them to weight the + MSE loss for scale selection. This minimizes output reconstruction error rather + than weight reconstruction error. + + Args: + model: Model to be calibrated. + forward_loop: A callable which takes the model as argument and + forwards calibration data through the model. Required for this algorithm. + distributed_sync: Whether to sync amax across distributed processes. + step_size: Step size for amax search (default: 0.1). + start_multiplier: Starting multiplier for amax search (default: 0.25). + stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization (default: True). + block_size: Block size for local Hessian computation (default: 16). + debug: If True, keep the local Hessian metadata on modules. + + See :class:`LocalHessianCalibConfig ` + for details on the configuration options. + """ + if forward_loop is None: + warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") + return + + class LocalHessianHelper: + """Helper class to collect activations and compute local Hessian per module.""" + + cache_mode: bool = False + + def __init__(self, module, name): + self.name = name + self.module = module + self.weight_shape = module.weight.shape # (cout, cin) + self.cout, self.cin = self.weight_shape + self.block_size = block_size + self.num_blocks_per_cin = self.cin // block_size + self.is_enabled = True + + # Accumulated Hessian per block: (cin // block_size, block_size, block_size) + self.hessian_per_block = torch.zeros( + self.num_blocks_per_cin, + block_size, + block_size, + dtype=torch.float32, + device=module.weight.device, + ) + self.num_samples = 0 + + def setup(self): + """Set up the forward hook to collect activations.""" + module = self.module + bind_forward_method(module, forward, "_forward_no_local_hessian") + + # Check if cin is divisible by block_size + if self.cin % self.block_size != 0: + warnings.warn( + f"Module {self.name}: input features ({self.cin}) not divisible by " + f"block_size ({self.block_size}). Skipping local Hessian for this module." + ) + self.is_enabled = False + + def cleanup(self): + """Clean up the forward hook.""" + unpatch_forward_method(self.module, "_forward_no_local_hessian") + if not debug: + if hasattr(self.module, "local_hessian"): + delattr(self.module, "local_hessian") + + def accumulate_hessian(self, input_tensor: torch.Tensor): + """Accumulate local Hessian from input activations. + + Args: + input_tensor: Input tensor of shape (..., cin) + """ + if not self.is_enabled: + return + + # Flatten to (num_tokens, cin) + x = input_tensor.reshape(-1, self.cin).T # (cin, num_tokens) + x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) # (num_blocks, bs, n) + + # Compute H = X @ X.T for each block and accumulate + hessian_batch = (x @ x.transpose(-1, -2)).to(torch.float32) + self.hessian_per_block += hessian_batch + self.num_samples += input_tensor.numel() // self.cin + + def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + """Get the local Hessian error function for MSE calibration.""" + cout = self.cout + bs = self.block_size + # Normalize hessian by number of samples + hessian = self.hessian_per_block / max(self.num_samples, 1) + + def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + """Compute local Hessian-weighted error.""" + original_shape = x.shape + dw = (x - xq).view(-1, 1, bs) # (num_blocks, 1, block_size) + # Repeat hessian for each output channel + hessian_expanded = hessian.repeat( + cout, 1, 1 + ) # (num_blocks, block_size, block_size) + # Per-block loss: (num_blocks,) + block_loss = (dw @ hessian_expanded @ dw.transpose(-1, -2)).squeeze(-1).squeeze(-1) + error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) + return error + + return local_hessian_error + + def forward(self, input, *args, **kwargs): + """Custom forward that collects activations in cache mode.""" + if LocalHessianHelper.cache_mode and self.local_hessian.is_enabled: + # Get local tensor from DTensor if applicable + input_local = input.to_local() if hasattr(input, "to_local") else input + self.local_hessian.accumulate_hessian(input_local) + + # Forward without quantization during caching + if LocalHessianHelper.cache_mode: + self.weight_quantizer.disable() + out = self._forward_no_local_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return self._forward_no_local_hessian(input, *args, **kwargs) + + # Setup helpers for all quantized linear modules + name_to_module = dict(model.named_modules()) + weight_quantizers_info = [] + + for name, module in name_to_module.items(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + with enable_weight_access_and_writeback(module, model, name_to_module): + module.local_hessian = LocalHessianHelper(module, name) + module.local_hessian.setup() + if module.local_hessian.is_enabled: + weight_quantizers_info.append((name, module)) + + # Cache activations by running forward loop + LocalHessianHelper.cache_mode = True + print_rank_0("local_hessian: Caching activations and computing local Hessian...") + forward_loop(model) + + # TODO(fridah-nv): Sync Hessian across distributed processes if needed + + # Get initial amax using max calibration on weights + print_rank_0("local_hessian: Computing initial amax with max calibration...") + for name, module in weight_quantizers_info: + with enable_weight_access_and_writeback(module, model, name_to_module): + max_calibrate(module, lambda m: m.weight_quantizer(m.weight), distributed_sync) + + # Replace calibrators with MseCalibrator using local Hessian error function + print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + helper = module.local_hessian + + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + continue + + initial_amax = weight_quantizer._amax.clone().detach() + + def quant_func(x, amax, quantizer=weight_quantizer): + original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + quantizer._amax = amax + + with ( + enable_quant(quantizer), + disable_calib(quantizer), + enable_fake_quant(quantizer), + ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) + xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) + + if original_amax is not None: + quantizer._amax = original_amax + else: + delattr(quantizer, "_amax") + + return xq + + is_nvfp4_per_block = ( + fp8_scale_sweep + and weight_quantizer.is_static_block_quant + and weight_quantizer._num_bits == (2, 1) + and weight_quantizer._block_sizes is not None + and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) + ) + + error_func = helper.get_error_func() + + weight_quantizer._calibrator = MseCalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + quant_func=quant_func, + error_func=error_func, + fp8_scale_sweep=is_nvfp4_per_block, + ) + + # Calibrate weights with local Hessian MSE + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + if weight_quantizer._calibrator is None: + continue + + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + + with enable_weight_access_and_writeback(module, model, name_to_module): + weight = module.weight + weight_quantizer(weight) + + # Compute optimal amax and load it + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + if weight_quantizer._calibrator is None: + continue + + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Cleanup and free memory + LocalHessianHelper.cache_mode = False + for name, module in weight_quantizers_info: + module.local_hessian.cleanup() + if hasattr(module.weight_quantizer, "_calibrator"): + cal = module.weight_quantizer._calibrator + if hasattr(cal, "clear"): + cal.clear() + + print_rank_0("local_hessian: Calibration complete.") + + def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" for name, module in model.named_modules(): From a36a7a962a2aaf5baabff51dda52ad879f534644 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:51:46 +0000 Subject: [PATCH 02/10] se NVFP4StaticQuantizer and NVFP4MSECalibrator Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 37 ++++++++++++++-------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index be85d2a1f..701ac7d6f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -570,26 +570,37 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_per_block = ( - fp8_scale_sweep - and weight_quantizer.is_static_block_quant + is_nvfp4_static = ( + weight_quantizer.is_static_block_quant and weight_quantizer._num_bits == (2, 1) and weight_quantizer._block_sizes is not None and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) ) + if is_nvfp4_static: + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) + error_func = helper.get_error_func() - weight_quantizer._calibrator = MseCalibrator( - amax=initial_amax, - axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, - step_size=step_size, - start_multiplier=start_multiplier, - stop_multiplier=stop_multiplier, - quant_func=quant_func, - error_func=error_func, - fp8_scale_sweep=is_nvfp4_per_block, - ) + if fp8_scale_sweep and is_nvfp4_static: + weight_quantizer._calibrator = NVFP4MSECalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + global_amax=weight_quantizer.global_amax, + quant_func=quant_func, + error_func=error_func, + ) + else: + weight_quantizer._calibrator = MseCalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + quant_func=quant_func, + error_func=error_func, + ) # Calibrate weights with local Hessian MSE for name, module in weight_quantizers_info: From 7f07f7f003fd216c007a76fb2b471a7e222d7fda Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 5 Feb 2026 23:06:27 +0000 Subject: [PATCH 03/10] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 49 +-------------------------- 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index b3fcbb0ab..bf4cda475 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -388,51 +388,7 @@ "algorithm": "max", } -NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "step_size": 0.25, - "start_multiplier": 0.25, - "stop_multiplier": 2.0, - }, -} - -NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "fp8_scale_sweep": True, - }, -} - - -NVFP4_LOCAL_HESSIAN_CFG = { +NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), @@ -1134,9 +1090,6 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig): - ``dw = weight - quantized_weight`` (weight reconstruction error per block) - ``H = X @ X.T`` is the local Hessian computed from input activations X - This method is particularly effective for NVFP4 weight-only quantization where - activation information helps select better per-block scales. - """ method: Literal["local_hessian"] = ModeloptField("local_hessian") From 21aa442feeb7789247531c604c402717bfb6c9ff Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 5 Feb 2026 23:10:49 +0000 Subject: [PATCH 04/10] unit test Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- tests/gpu/torch/quantization/test_quantize_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index b5aca034a..f4d333c1e 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -87,6 +87,7 @@ mtq.NVFP4_AWQ_LITE_CFG, mtq.NVFP4_AWQ_CLIP_CFG, mtq.NVFP4_AWQ_FULL_CFG, + mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG, mtq.MXFP8_DEFAULT_CFG, mtq.MXFP6_DEFAULT_CFG, mtq.MXFP4_DEFAULT_CFG, @@ -113,6 +114,7 @@ def test_quantize(model_cls, config): mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, NVFP4_WEIGHT_ACT_MSE_CFG, NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, + mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") From 2931f61c1472c908c98751d65fad1156d6da4dd5 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:03:21 +0000 Subject: [PATCH 05/10] add reviewers feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 7 +++- modelopt/torch/quantization/model_calib.py | 45 ++++++++-------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index bf4cda475..857724d26 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -388,7 +388,7 @@ "algorithm": "max", } -NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG = { +NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), @@ -397,7 +397,10 @@ "enable": True, }, "*input_quantizer": { - "enable": False, + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, }, **_default_disabled_quantizer_cfg, }, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 701ac7d6f..310c7382f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -387,10 +387,12 @@ def local_hessian_calibrate( ): """Calibrate the model using local Hessian-weighted MSE search. - This calibration method collects input activations during forward pass, computes - per-block local Hessian matrices (H = X @ X.T), and uses them to weight the - MSE loss for scale selection. This minimizes output reconstruction error rather - than weight reconstruction error. + Instead of minimizing weight error ||W - Wq||², this minimizes Hessian-weighted error: + loss = (W - Wq)ᵀ H (W - Wq) + where H = X @ X.T approximates output reconstruction error ||WX - WqX||². + + Per-block Hessians of shape (cin // block_size, block_size, block_size) are accumulated + during forward pass and used to weight the MSE loss during scale search. Args: model: Model to be calibrated. @@ -512,6 +514,11 @@ def forward(self, input, *args, **kwargs): return self._forward_no_local_hessian(input, *args, **kwargs) + # First, run max_calibrate on the whole model to get initial amax for all quantizers + # This calibrates both weight_quantizer and input_quantizer with max calibration + print_rank_0("local_hessian: Running max calibration for all quantizers...") + max_calibrate(model, forward_loop, distributed_sync) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -531,12 +538,6 @@ def forward(self, input, *args, **kwargs): # TODO(fridah-nv): Sync Hessian across distributed processes if needed - # Get initial amax using max calibration on weights - print_rank_0("local_hessian: Computing initial amax with max calibration...") - for name, module in weight_quantizers_info: - with enable_weight_access_and_writeback(module, model, name_to_module): - max_calibrate(module, lambda m: m.weight_quantizer(m.weight), distributed_sync) - # Replace calibrators with MseCalibrator using local Hessian error function print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") for name, module in weight_quantizers_info: @@ -608,34 +609,18 @@ def quant_func(x, amax, quantizer=weight_quantizer): if weight_quantizer._calibrator is None: continue - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() - + # Enable calibration mode for the weight quantizer + enable_stats_collection(module) with enable_weight_access_and_writeback(module, model, name_to_module): weight = module.weight weight_quantizer(weight) - - # Compute optimal amax and load it - for name, module in weight_quantizers_info: - weight_quantizer = module.weight_quantizer - if weight_quantizer._calibrator is None: - continue - - cal = weight_quantizer._calibrator - if cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() - - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + finish_stats_collection(module, method="mse") + weight_quantizer._calibrator.reset() # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in weight_quantizers_info: module.local_hessian.cleanup() - if hasattr(module.weight_quantizer, "_calibrator"): - cal = module.weight_quantizer._calibrator - if hasattr(cal, "clear"): - cal.clear() print_rank_0("local_hessian: Calibration complete.") From e391ea1a43ebad947843d30aa4e17f6d420e221d Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:02:12 +0000 Subject: [PATCH 06/10] add rabbit feedback Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 17 +++++++++-------- .../torch/quantization/test_quantize_cuda.py | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 310c7382f..df622ae3f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -486,13 +486,12 @@ def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor] def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: """Compute local Hessian-weighted error.""" original_shape = x.shape - dw = (x - xq).view(-1, 1, bs) # (num_blocks, 1, block_size) - # Repeat hessian for each output channel - hessian_expanded = hessian.repeat( - cout, 1, 1 - ) # (num_blocks, block_size, block_size) - # Per-block loss: (num_blocks,) - block_loss = (dw @ hessian_expanded @ dw.transpose(-1, -2)).squeeze(-1).squeeze(-1) + # Reshape to (cout, num_blocks_per_cin, block_size) + dw = (x - xq).view(cout, -1, bs) + # Use einsum to avoid materializing cout-repeated Hessian + # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) + block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw) + block_loss = block_loss.reshape(-1) error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) return error @@ -522,12 +521,14 @@ def forward(self, input, *args, **kwargs): # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] + all_patched_modules = [] # Track all modules for cleanup (including disabled ones) for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): module.local_hessian = LocalHessianHelper(module, name) module.local_hessian.setup() + all_patched_modules.append((name, module)) if module.local_hessian.is_enabled: weight_quantizers_info.append((name, module)) @@ -619,7 +620,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False - for name, module in weight_quantizers_info: + for name, module in all_patched_modules: module.local_hessian.cleanup() print_rank_0("local_hessian: Calibration complete.") diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index f4d333c1e..3e9ff4256 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -87,7 +87,7 @@ mtq.NVFP4_AWQ_LITE_CFG, mtq.NVFP4_AWQ_CLIP_CFG, mtq.NVFP4_AWQ_FULL_CFG, - mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG, + mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, mtq.MXFP8_DEFAULT_CFG, mtq.MXFP6_DEFAULT_CFG, mtq.MXFP4_DEFAULT_CFG, @@ -114,7 +114,7 @@ def test_quantize(model_cls, config): mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, NVFP4_WEIGHT_ACT_MSE_CFG, NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, - mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG, + mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") From 8e40de0ce27748ff0335715687ea3949e907fdb3 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:10:36 +0000 Subject: [PATCH 07/10] fix doc build Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index df622ae3f..241005edf 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -387,11 +387,11 @@ def local_hessian_calibrate( ): """Calibrate the model using local Hessian-weighted MSE search. - Instead of minimizing weight error ||W - Wq||², this minimizes Hessian-weighted error: - loss = (W - Wq)ᵀ H (W - Wq) - where H = X @ X.T approximates output reconstruction error ||WX - WqX||². + Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error + ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction + error ``||WX - WqX||²``. - Per-block Hessians of shape (cin // block_size, block_size, block_size) are accumulated + Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are accumulated during forward pass and used to weight the MSE loss during scale search. Args: From 35a6aea5e4c9aacaaf80d9ffe04cfb36f797cb21 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:12:30 -0800 Subject: [PATCH 08/10] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 241005edf..8ad546973 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -455,8 +455,8 @@ def cleanup(self): """Clean up the forward hook.""" unpatch_forward_method(self.module, "_forward_no_local_hessian") if not debug: - if hasattr(self.module, "local_hessian"): - delattr(self.module, "local_hessian") + if hasattr(self.module, "hessian_helper"): + delattr(self.module, "hessian_helper") def accumulate_hessian(self, input_tensor: torch.Tensor): """Accumulate local Hessian from input activations. @@ -499,10 +499,10 @@ def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: def forward(self, input, *args, **kwargs): """Custom forward that collects activations in cache mode.""" - if LocalHessianHelper.cache_mode and self.local_hessian.is_enabled: + if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled: # Get local tensor from DTensor if applicable input_local = input.to_local() if hasattr(input, "to_local") else input - self.local_hessian.accumulate_hessian(input_local) + self.hessian_helper.accumulate_hessian(input_local) # Forward without quantization during caching if LocalHessianHelper.cache_mode: @@ -526,10 +526,10 @@ def forward(self, input, *args, **kwargs): for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model, name_to_module): - module.local_hessian = LocalHessianHelper(module, name) - module.local_hessian.setup() + module.hessian_helper = LocalHessianHelper(module, name) + module.hessian_helper.setup() all_patched_modules.append((name, module)) - if module.local_hessian.is_enabled: + if module.hessian_helper.is_enabled: weight_quantizers_info.append((name, module)) # Cache activations by running forward loop @@ -543,7 +543,7 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") for name, module in weight_quantizers_info: weight_quantizer = module.weight_quantizer - helper = module.local_hessian + helper = module.hessian_helper if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: continue @@ -621,7 +621,7 @@ def quant_func(x, amax, quantizer=weight_quantizer): # Cleanup and free memory LocalHessianHelper.cache_mode = False for name, module in all_patched_modules: - module.local_hessian.cleanup() + module.hessian_helper.cleanup() print_rank_0("local_hessian: Calibration complete.") From d686ac9d9a3270b53d5fed9848bf151f154814a4 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:29:51 -0800 Subject: [PATCH 09/10] Memory optimization Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 85 ++++++++++++++++++---- 1 file changed, 71 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8ad546973..9f04f8403 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -359,16 +359,42 @@ def mse_calibrate( weight_quantizers.append((parent_module, weight_name, weight_quantizer)) seen_modules.add(parent_module) - # Step 3: Calibrate weight quantizers once with MSE calibration - # This ensures weights are only calibrated once, not during every forward pass - for parent_module, weight_name, weight_quantizer in weight_quantizers: + # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation + # This prevents massive memory accumulation seen in large models + for idx, (parent_module, weight_name, weight_quantizer) in enumerate(weight_quantizers): # Enable calibration mode for the weight quantizer - enable_stats_collection(parent_module) + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() with enable_weight_access_and_writeback(parent_module, model): weight = getattr(parent_module, weight_name) weight_quantizer(weight) - finish_stats_collection(parent_module, method="mse") - weight_quantizer._calibrator.reset() + + # IMMEDIATELY compute amax and reset calibrator to free memory + cal = getattr(weight_quantizer, "_calibrator", None) + if cal is not None and cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Synchronize ALL CUDA devices before resetting to ensure all async operations complete + # This is critical for multi-GPU setups where tensors may be on different devices + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + + if cal is not None and hasattr(cal, "reset"): + cal.reset() + + if (idx + 1) % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() # TODO: Sync amax across distributed processes @@ -604,19 +630,50 @@ def quant_func(x, amax, quantizer=weight_quantizer): error_func=error_func, ) - # Calibrate weights with local Hessian MSE - for name, module in weight_quantizers_info: + # Free cached memory before heavy calibration + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Process weights ONE AT A TIME with immediate amax computation and cleanup + weight_list = [ + (name, module) + for name, module in weight_quantizers_info + if module.weight_quantizer._calibrator is not None + ] + + for idx, (name, module) in enumerate(weight_list): weight_quantizer = module.weight_quantizer - if weight_quantizer._calibrator is None: - continue + cal = weight_quantizer._calibrator - # Enable calibration mode for the weight quantizer - enable_stats_collection(module) + # Step 1: Calibrate this weight + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() with enable_weight_access_and_writeback(module, model, name_to_module): weight = module.weight weight_quantizer(weight) - finish_stats_collection(module, method="mse") - weight_quantizer._calibrator.reset() + + # Step 2: IMMEDIATELY compute amax (before calibration data grows) + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Step 3: Sync all devices and reset calibrator for next weight + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + + if hasattr(cal, "reset"): + cal.reset() + + if (idx + 1) % 10 == 0 and torch.cuda.is_available(): + torch.cuda.empty_cache() + + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() # Cleanup and free memory LocalHessianHelper.cache_mode = False From 7fd24b4852b67511ea77e0eec8e3d78067bf9e41 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:36:37 +0000 Subject: [PATCH 10/10] add progress bar for weight MSE calibration Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9f04f8403..1ef0dbca3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -361,7 +361,9 @@ def mse_calibrate( # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate(weight_quantizers): + for idx, (parent_module, weight_name, weight_quantizer) in enumerate( + tqdm(weight_quantizers, desc="MSE weight calibration") + ): # Enable calibration mode for the weight quantizer weight_quantizer.disable_quant() weight_quantizer.enable_calib()