Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis pull request introduces a new local Hessian-based calibration method for PyTorch quantization. It adds a new quantization configuration, calibration function with helper logic, mode descriptor for registry integration, and corresponding GPU tests. Changes
Sequence DiagramsequenceDiagram
participant Model as PyTorch Model
participant LHC as local_hessian_calibrate()
participant Helper as LocalHessianHelper
participant Calib as MSE Calibrator
Model->>LHC: Initialize with forward_loop
LHC->>LHC: Run initial max_calibrate()
LHC->>Helper: Create helpers for quantized modules
LHC->>Model: Register forward hooks
LHC->>LHC: Cache activations via forward passes
Helper->>Helper: Compute per-block Hessians
LHC->>Calib: Replace with local Hessian MSE calibrator
Calib->>Calib: Search amax with Hessian-weighted loss
LHC->>LHC: Cleanup hooks and caches
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
realAsma
left a comment
There was a problem hiding this comment.
overall looks great, have you run any experiments with this?
I'm still get the numbers, will update here when it's ready |
baef63e to
b6fdc75
Compare
b6fdc75 to
4d1380a
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #788 +/- ##
==========================================
- Coverage 73.45% 73.03% -0.43%
==========================================
Files 205 205
Lines 22034 22200 +166
==========================================
+ Hits 16185 16213 +28
- Misses 5849 5987 +138 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
4d1380a to
c7589d1
Compare
c7589d1 to
8b5da94
Compare
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
e35ebb0 to
2931f61
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/config.py`:
- Around line 391-411: The config dict name NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG
does not match the tests which expect mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG;
either rename the dict to NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG or create an alias
assignment (NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG =
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG) and ensure the config name is included in
the module's choices set (add "NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG" to the
choices collection) so tests and any selection logic can find it.
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 620-623: The cleanup loop only iterates weight_quantizers_info and
misses modules that had their forward patched in setup() but were later disabled
(is_enabled set to False); update setup() (the code path that patches
is_quantized_linear modules and attaches module.local_hessian/patches forward)
to record every module whose forward is patched into a new list (e.g.,
patched_modules) or append such modules into weight_quantizers_info regardless
of is_enabled, then change the cleanup block (after setting
LocalHessianHelper.cache_mode = False) to iterate over that recorded list and
call module.local_hessian.cleanup() (and restore/unpatch the forward if needed)
to ensure all patched modules are cleaned up.
- Around line 486-497: The code in local_hessian_error creates a huge temporary
via hessian.repeat(cout,1,1); instead compute the Hessian-weighted quadratic
form without materializing the repeated tensor by leveraging broadcasting or
einsum. Replace the repeat + matrix-mult sequence (hessian_expanded =
hessian.repeat(...); block_loss = (dw @ hessian_expanded @
dw.transpose(...)).squeeze(...)) with a memory-efficient einsum or broadcasted
matmul, e.g. compute block_loss using torch.einsum('nbk,bkl,nbl->n',
dw.squeeze(1), hessian, dw.squeeze(1)) or align dims with unsqueeze on hessian
and rely on broadcasting so no hessian.repeat is created; update use sites for
local_hessian_error, hessian, dw and block_loss accordingly.
In `@tests/gpu/torch/quantization/test_quantize_cuda.py`:
- Line 90: The test references a non-existent/mismatched config name
NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG; update the test to use the correct exported
config NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG (or add an explicit alias/export in
the module where configs are defined) and remove the misleading "WEIGHT_ONLY"
wording if the config enables input quantization; locate occurrences in the test
(e.g., where mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG is used) and replace them
with mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG or add the alias in the config
package to avoid AttributeError during test collection.
🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (2)
552-572:quant_funcduplicates the existing_mse_quant_func(line 240).The closure at lines 552–572 is nearly identical to the top-level
_mse_quant_func. Reusing it viapartial(_mse_quant_func, quantizer=weight_quantizer)(as done inmse_calibrateat line 330) would eliminate the duplication.Reuse _mse_quant_func
- def quant_func(x, amax, quantizer=weight_quantizer): - original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None - quantizer._amax = amax - - with ( - enable_quant(quantizer), - disable_calib(quantizer), - enable_fake_quant(quantizer), - ): - if hasattr(quantizer, "_original_shape"): - x = quantizer._reset_to_original_shape(x) - xq = quantizer(x) - if hasattr(quantizer, "_block_reshape_size"): - xq = xq.reshape(quantizer._block_reshape_size) - - if original_amax is not None: - quantizer._amax = original_amax - else: - delattr(quantizer, "_amax") - - return xq + quant_func = partial(_mse_quant_func, quantizer=weight_quantizer)
417-477:LocalHessianHelperandaccumulate_hessianare well-structured.The pattern of a nested helper class with a class-level
cache_modeflag follows the establishedAWQLiteHelperdesign. Minor note: the matmul at line 475 operates in the input tensor's dtype before converting tofloat32— the existingupdate_hessianfunction (line 1478) converts tofloat()before the matmul. On GPU with TensorCores this is fine, but for consistency you may want to castxtofloat32before the matmul.
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
What does this PR do?
Type of change: new feature
Overview:
Add a new calibration method for weight scale search. It considers activation information by weighing scale candidates with local hessian matrix. Initial experiments with Qwen3 8B NVFP4 shows improvements.
Usage
Use
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFGquantization config for quantization and evaluation.e.g.
Add this line
"nvfp4_local_hessan": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,toQUANT_CFG_CHOICESinexamples/llm_ptq/hf_ptq.pyTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests