Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions vllm/model_executor/layers/fused_moe/all2all_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx

Expand Down Expand Up @@ -77,12 +80,17 @@ def maybe_make_prepare_finalize(

prepare_finalize: FusedMoEPrepareAndFinalize | None = None

# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_flashinfer_cutlass_kernels:
assert quant_config is not None
use_deepseek_fp8_block_scale = (
quant_config is not None and quant_config.is_block_quantized
)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)

if moe.use_pplx_kernels:
elif moe.use_pplx_kernels:
assert quant_config is not None

hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def flashinfer_cutlass_moe_fp4(
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(
use_dp=False, use_nvfp4=True, enable_alltoallv=False
),
create_flashinfer_prepare_finalize(use_dp=False),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
Expand Down
10 changes: 1 addition & 9 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
Expand Down Expand Up @@ -150,7 +149,7 @@ def get_fp8_moe_backend(
if block_quant and current_platform.is_device_capability_family(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization on SM100. Please use "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
Expand Down Expand Up @@ -1103,13 +1102,6 @@ def maybe_make_prepare_finalize(
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)

def select_gemm_impl(
Expand Down
15 changes: 0 additions & 15 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
Expand Down Expand Up @@ -751,17 +750,6 @@ def maybe_make_prepare_finalize(
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None

prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)

def select_gemm_impl(
Expand Down Expand Up @@ -1456,9 +1444,6 @@ def maybe_make_prepare_finalize(
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
Expand Down