Skip to content

Comments

Add FlashInfer allreduce RMSNorm Quant fusion (#21069)#32

Open
MitchLewis930 wants to merge 1 commit intoROCM_bug_beforefrom
ROCM_bug_after
Open

Add FlashInfer allreduce RMSNorm Quant fusion (#21069)#32
MitchLewis930 wants to merge 1 commit intoROCM_bug_beforefrom
ROCM_bug_after

Conversation

@MitchLewis930
Copy link
Collaborator

@MitchLewis930 MitchLewis930 commented Jan 24, 2026

test

Summary by CodeRabbit

Release Notes

  • Tests

    • Expanded test coverage for all-reduce fusion with static FP8/FP4 quantization variants.
    • Added new test models for fused RMS norm with integrated quantization paths.
  • Features

    • Enhanced all-reduce fusion to support static FP8/FP4 quantization operations.
    • Extended fused RMS norm functionality with quantization support.
  • Chores

    • Increased default token limit configuration for all-reduce fusion optimization.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
@coderabbitai
Copy link

coderabbitai bot commented Jan 24, 2026

📝 Walkthrough

Walkthrough

The pull request extends all-reduce fusion capabilities to support static FP8 and FP4 quantization paths combined with RMS normalization. It introduces new pattern classes for quantized fusion operations, adds test infrastructure for quantization variants, wires new compilation passes, and adjusts token limits for fusion optimization.

Changes

Cohort / File(s) Summary
Build Configuration
.buildkite/test-pipeline.yaml
Adds compile/test_fusion_all_reduce.py to PyTorch Compilation Unit Tests group.
Test Infrastructure
tests/compile/test_fusion_all_reduce.py
Introduces TestAllReduceFusedAddRMSNormStaticQuantFP8Model and TestAllReduceFusedAddRMSNormStaticQuantFP4Model test classes. Updates existing test models to accept token_num parameter. Wires FixFunctionalizationPass and NoOpEliminationPass into backend. Extends test coverage for FP8/FP4 quantization paths and adjusts tensor shapes to use token-based sizing. Updates skip conditions for hardware capability checks.
Test Utilities
tests/utils.py
Adds has_module_attribute() helper function for dynamic module attribute checking with fallback error handling.
Quantization Fusion Patterns
vllm/compilation/collective_fusion.py
Introduces AllReduceFusedRMSNormStaticQuantFP8Pattern, AllReduceFusedAddRMSNormStaticQuantFP8Pattern, and corresponding NVFP4 variants. Adds STATIC_FP8_QUANT_OP and STATIC_FP4_QUANT_OP constants. Extends call_trtllm_fused_allreduce_norm() signature with quantization output parameters. Updates AllReduceFusionPass to dynamically compute max_num_token, register multiple pattern variants across epsilons, and clear inductor cache after registration. Renames AllReduceRMSNORMPattern to AllReduceRMSNormPattern.
Configuration
vllm/config.py
Updates PassConfig.fi_allreduce_fusion_max_token_num default from 1024 to 16384.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Runner
    participant Pass as AllReduceFusionPass
    participant PatternMatcher as Pattern Matcher
    participant Backend as TestBackend
    participant Runtime as TRTLLM Runtime

    Test->>Pass: Initialize with vllm_config
    Pass->>Pass: Compute max_num_token from model/token config
    Pass->>Pass: Register AllReduceRMSNormPattern (eps1, eps2)
    Pass->>Pass: Register AllReduceFusedRMSNormStaticQuantFP8Pattern (eps1, eps2)
    Pass->>Pass: Register AllReduceFusedAddRMSNormStaticQuantFP8Pattern (eps1, eps2)
    alt Device supports NVFP4
        Pass->>Pass: Register NVFP4 pattern variants (eps1, eps2)
    end
    Pass->>PatternMatcher: Clear inductor cache after each epsilon
    Test->>Backend: Compile model with fusion passes
    Backend->>PatternMatcher: Match patterns in graph
    PatternMatcher->>Backend: Return matched fusion opportunities
    Backend->>Runtime: Execute fused all-reduce+norm+quant operations
    Runtime->>Test: Return fused computation results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 New patterns emerge, FP8 and FP4 too,
Quantized fusion paths now shining through,
All-reduce embraces the norm and the quant,
With token-wise sizing, we've given you what you want! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.92% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly reflects the main changes: adding FlashInfer allreduce RMSNorm quantization fusion support across multiple files and test cases.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/compile/test_fusion_all_reduce.py (1)

29-34: Silence unused token_num args to keep Ruff clean.

Ruff ARG002 flags these constructor parameters as unused. Either store them or rename to _token_num to keep lint green.

✅ Suggested fix
 class TestAllReduceRMSNormModel(torch.nn.Module):

     def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
         super().__init__()
         self.hidden_size = hidden_size
         self.eps = eps
         self.norm = RMSNorm(hidden_size, eps)
+        self.token_num = token_num
@@
 class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):

     def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
         super().__init__()
         self.hidden_size = hidden_size
         self.eps = eps
         self.norm = RMSNorm(hidden_size, eps)
+        self.token_num = token_num

Also applies to: 50-55

🤖 Fix all issues with AI agents
In `@tests/compile/test_fusion_all_reduce.py`:
- Around line 112-116: The inline lambda named round_up should be replaced with
the shared helper to avoid defining lambdas in expressions; locate the lambda
used to compute rounded_m/rounded_n (the round_up = lambda x, y: ... and its
uses before self.output_scale = torch.empty(...)) and replace it with a call to
vllm.utils.round_up (or define a local def round_up(x, y): ...) so the same
rounding logic is reused and the lambda assignment is removed, then update the
calls computing rounded_m and rounded_n to use the helper.

In `@vllm/compilation/collective_fusion.py`:
- Around line 499-503: Fix the typo in the comment ("outpput" -> "output") and
avoid the subtle logic bug by tracking whether norm_out was originally provided
before it gets reassigned: capture a boolean like original_norm_out_provided =
(norm_out is not None) immediately before any internal assignment to norm_out
(around where norm_out is assigned at/near the code referencing line 488), then
replace the condition `if scale_factor is None or norm_out is not None:` with
`if scale_factor is None or original_norm_out_provided:` so the copy_
(allreduce_in.copy_(allreduce_out)) only runs when scale_factor is None or
norm_out was passed in by the caller; update the comment to read "output" and
briefly explain the check uses the original-provided flag.
🧹 Nitpick comments (3)
vllm/compilation/collective_fusion.py (3)

704-780: Consider consistency in get_inputs placement.

The get_inputs() function is defined inside register() rather than as a class method like in AllReduceRMSNormPattern and other base patterns. While functional, this is inconsistent with the established pattern style in this file.

The pattern logic and return indices are correct.


1084-1088: Add bounds checking for max_num_token.

If hidden_dim * tp_size is very large, the integer division could result in 0, which would likely cause issues downstream. Consider adding a minimum bound or guard.

Proposed fix
         max_num_token = min(
             _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
             (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
             config.compilation_config.pass_config.
             fi_allreduce_fusion_max_token_num)
+        if max_num_token <= 0:
+            logger.warning(
+                "Computed max_num_token is %d, skipping allreduce fusion pass",
+                max_num_token)
+            return

1149-1151: Add version guard and document PyTorch compatibility for _seen_patterns cache hack.

The torch._inductor.pattern_matcher._seen_patterns is a private internal API not covered by PyTorch's stability guarantees and may change across versions. While pinned to torch == 2.7.1 in the repository, consider wrapping this in a version check (e.g., if torch.__version__.startswith("2.7")) and documenting the minimum PyTorch version tested for this specific pattern, consistent with the version guard pattern used elsewhere in vllm/compilation/compiler_interface.py.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2dff2e2 and 6e672da.

📒 Files selected for processing (5)
  • .buildkite/test-pipeline.yaml
  • tests/compile/test_fusion_all_reduce.py
  • tests/utils.py
  • vllm/compilation/collective_fusion.py
  • vllm/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/compile/test_fusion_all_reduce.py (9)
vllm/compilation/fix_functionalization.py (1)
  • FixFunctionalizationPass (19-191)
vllm/compilation/noop_elimination.py (1)
  • NoOpEliminationPass (18-165)
vllm/distributed/communication_op.py (1)
  • tensor_model_parallel_all_reduce (12-14)
vllm/model_executor/layers/layernorm.py (1)
  • RMSNorm (89-200)
vllm/model_executor/layers/quantization/utils/quant_utils.py (1)
  • GroupShape (24-32)
vllm/model_executor/layers/quantization/input_quant_fp8.py (1)
  • QuantFP8 (24-103)
tests/utils.py (2)
  • has_module_attribute (980-988)
  • multi_gpu_test (893-906)
vllm/utils/__init__.py (1)
  • round_up (981-982)
vllm/model_executor/layers/quantization/rtn.py (1)
  • shape (104-111)
vllm/compilation/collective_fusion.py (2)
vllm/distributed/communication_op.py (1)
  • tensor_model_parallel_all_reduce (12-14)
vllm/distributed/parallel_state.py (2)
  • all_reduce (105-110)
  • all_reduce (341-364)
🪛 Ruff (0.14.13)
tests/compile/test_fusion_all_reduce.py

29-29: Unused method argument: token_num

(ARG002)


50-50: Unused method argument: token_num

(ARG002)


112-112: Do not assign a lambda expression, use a def

Rewrite round_up as a def

(E731)

🔇 Additional comments (11)
.buildkite/test-pipeline.yaml (1)

356-356: LGTM!

The new test file compile/test_fusion_all_reduce.py is appropriately added to the PyTorch Compilation Unit Tests section, which aligns with the fusion pass functionality being tested.

vllm/compilation/collective_fusion.py (10)

40-41: LGTM!

The new quantization operation constants follow the established naming convention and are appropriately defined at module scope for use in the pattern classes.


524-536: LGTM!

The custom op registration correctly declares all mutated arguments including the new quant_out and scale_out parameters.


541-571: LGTM!

The FlashInferFusedAllReduceParams class is cleanly extended with the new fuse_rms_quant parameter while maintaining backward compatibility through the default value.


574-636: LGTM!

The class naming is now consistent (PascalCase for "Norm"), and the pattern replacement correctly uses the fused operation with proper return tuple indices matching the mutates_args declaration.


639-701: LGTM!

The pattern correctly handles the fused add RMS norm case where norm_out=None indicates in-place operation, and the return indices properly extract allreduce_in and residual from the result tuple.


783-868: LGTM!

The FP8 quantization pattern with residual addition is correctly implemented. The return indices properly extract quant_out and the updated residual.


871-960: LGTM!

The NVFP4 quantization pattern correctly handles the three-output case (quantized output, allreduce output, and output scale) with proper return indices.


963-1052: LGTM!

The NVFP4 pattern with residual addition is correctly implemented with appropriate return values for the three outputs.


1123-1135: LGTM on device capability gating.

The NVFP4 patterns are appropriately gated behind has_device_capability(100) for Blackwell GPUs.


1165-1169: LGTM!

The cleanup logic properly destroys IPC workspace resources when the pass is deleted.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +112 to +116
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace the inline round_up lambda (Ruff E731).

Use the existing vllm.utils.round_up helper (or a local def) instead of assigning a lambda.

✅ Suggested fix
-from vllm.utils import update_environment_variables
+from vllm.utils import round_up, update_environment_variables
@@
-        round_up = lambda x, y: (x + y - 1) // y * y
         rounded_m = round_up(token_num, 128)
         scale_n = hidden_size // 16
         rounded_n = round_up(scale_n, 4)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
🧰 Tools
🪛 Ruff (0.14.13)

112-112: Do not assign a lambda expression, use a def

Rewrite round_up as a def

(E731)

🤖 Prompt for AI Agents
In `@tests/compile/test_fusion_all_reduce.py` around lines 112 - 116, The inline
lambda named round_up should be replaced with the shared helper to avoid
defining lambdas in expressions; locate the lambda used to compute
rounded_m/rounded_n (the round_up = lambda x, y: ... and its uses before
self.output_scale = torch.empty(...)) and replace it with a call to
vllm.utils.round_up (or define a local def round_up(x, y): ...) so the same
rounding logic is reused and the lambda assignment is removed, then update the
calls computing rounded_m and rounded_n to use the helper.

Comment on lines +499 to +503
if scale_factor is None or norm_out is not None:
# we need to return allreduce outpput
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo in comment and potential logic clarity issue.

Line 500 has a typo: "outpput" should be "output".

Also, the condition norm_out is not None on line 499 will now always be true in the else branch because norm_out is assigned at line 488 when it was originally None. This means the copy_ always executes in the non-fused path when scale_factor is None. Consider clarifying with a local variable to track whether norm_out was originally provided vs assigned internally.

Proposed fix for typo
-            if scale_factor is None or norm_out is not None:
-                # we need to return allreduce outpput
+            if scale_factor is None or norm_out is not None:
+                # we need to return allreduce output
🤖 Prompt for AI Agents
In `@vllm/compilation/collective_fusion.py` around lines 499 - 503, Fix the typo
in the comment ("outpput" -> "output") and avoid the subtle logic bug by
tracking whether norm_out was originally provided before it gets reassigned:
capture a boolean like original_norm_out_provided = (norm_out is not None)
immediately before any internal assignment to norm_out (around where norm_out is
assigned at/near the code referencing line 488), then replace the condition `if
scale_factor is None or norm_out is not None:` with `if scale_factor is None or
original_norm_out_provided:` so the copy_ (allreduce_in.copy_(allreduce_out))
only runs when scale_factor is None or norm_out was passed in by the caller;
update the comment to read "output" and briefly explain the check uses the
original-provided flag.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants