Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ steps:
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py

- label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental]
Expand Down
126 changes: 104 additions & 22 deletions tests/compile/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@

import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

from ..utils import multi_gpu_test
from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend


class TestAllReduceRMSNormModel(torch.nn.Module):

def __init__(self, hidden_size=16, eps=1e-6):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
Expand All @@ -43,7 +47,7 @@ def ops_in_model_after(self):

class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):

def __init__(self, hidden_size=16, eps=1e-6):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
Expand All @@ -62,24 +66,101 @@ def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)

def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output

def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)

round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
Comment on lines +112 to +116
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace the inline round_up lambda (Ruff E731).

Use the existing vllm.utils.round_up helper (or a local def) instead of assigning a lambda.

✅ Suggested fix
-from vllm.utils import update_environment_variables
+from vllm.utils import round_up, update_environment_variables
@@
-        round_up = lambda x, y: (x + y - 1) // y * y
         rounded_m = round_up(token_num, 128)
         scale_n = hidden_size // 16
         rounded_n = round_up(scale_n, 4)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
🧰 Tools
🪛 Ruff (0.14.13)

112-112: Do not assign a lambda expression, use a def

Rewrite round_up as a def

(E731)

🤖 Prompt for AI Agents
In `@tests/compile/test_fusion_all_reduce.py` around lines 112 - 116, The inline
lambda named round_up should be replaced with the shared helper to avoid
defining lambdas in expressions; locate the lambda used to compute
rounded_m/rounded_n (the round_up = lambda x, y: ... and its uses before
self.output_scale = torch.empty(...)) and replace it with a call to
vllm.utils.round_up (or define a local def round_up(x, y): ...) so the same
rounding logic is reused and the lambda assignment is removed, then update the
calls computing rounded_m and rounded_n to use the helper.

dtype=torch.int32)

def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
self.output_scale, self.scale)
return self.output, residual_output, self.output_scale

def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
]


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model",
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
@pytest.mark.parametrize("test_model", [
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
@pytest.mark.skipif(not find_spec("flashinfer"),
reason="flashinfer is not installed")
@pytest.mark.skipif(not current_platform.is_device_capability(100),
reason="Only test on SM100")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")

def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
Expand Down Expand Up @@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)

vllm_config = VllmConfig(
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
compile_sizes=[2, 4, 8]))
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True)
enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

# this is a fake model name to construct the model config
Expand All @@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
seed=42)

all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)

model = test_model_cls(hidden_size)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

hidden_states = torch.randn((batch_size * seq_len, hidden_size),
requires_grad=False)
residual = torch.randn((batch_size * seq_len, hidden_size),
requires_grad=False)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
Expand Down
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import copy
import functools
import importlib
import os
import signal
import subprocess
Expand Down Expand Up @@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]


def has_module_attribute(module_name, attribute_name):
"""
Helper function to check if a module has a specific attribute.
"""
try:
module = importlib.import_module(module_name)
return hasattr(module, attribute_name)
except ImportError:
return False
Loading