Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d4597f3
add implementation
maxyanghu Jan 14, 2026
7cbf291
add impl
maxyanghu Jan 19, 2026
8713291
add flashinfer
maxyanghu Jan 20, 2026
f9362fb
fix tp
maxyanghu Jan 21, 2026
d48087f
Temporary change for ViT
Anerudhan Jan 20, 2026
71eeda2
fix workspace_buffer device.
b-mu Jan 22, 2026
392b3ac
change max_seqlen to 128k.
b-mu Jan 22, 2026
772a17b
remove duplicate multiplier.
b-mu Jan 22, 2026
c38e8c4
fix accuracy and refactor
maxyanghu Jan 23, 2026
19d5ffa
more fix
maxyanghu Jan 23, 2026
47af3e1
change dockerfile
maxyanghu Jan 26, 2026
a09a785
format
maxyanghu Jan 26, 2026
bfd41ec
fix version
maxyanghu Jan 26, 2026
5599eb4
change python version
maxyanghu Jan 26, 2026
76b1482
remove qwen25 transformer support
maxyanghu Jan 26, 2026
fec4833
change dockerfile
maxyanghu Jan 26, 2026
9a8c2d5
add build versions
maxyanghu Jan 26, 2026
f6a2ee7
chagne version
maxyanghu Jan 26, 2026
4b9aa2a
change version
maxyanghu Jan 26, 2026
f782e97
change
maxyanghu Jan 26, 2026
56868a9
change
maxyanghu Jan 26, 2026
c2ca450
change
maxyanghu Jan 26, 2026
1d8b7ec
change
maxyanghu Jan 26, 2026
413260e
change
maxyanghu Jan 26, 2026
7a2ac66
build image
maxyanghu Jan 26, 2026
e8d34b7
change back
maxyanghu Jan 26, 2026
5adb294
change to 10.0f
maxyanghu Jan 26, 2026
bc90e8f
fix fi import
maxyanghu Jan 26, 2026
2d1286d
change to build in dev image
maxyanghu Jan 26, 2026
42858c6
change location
maxyanghu Jan 26, 2026
c9a8f9b
change location
maxyanghu Jan 26, 2026
89703a4
change
maxyanghu Jan 26, 2026
9431a61
change cubin and jitcache to wheels
maxyanghu Jan 27, 2026
0e0f19e
change
maxyanghu Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \
rm /tmp/requirements-cuda.txt /tmp/common.txt

# Install FlashInfer pre-compiled kernel cache and binaries
# This is ~1.1GB and only changes when FlashInfer version bumps
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
ARG FLASHINFER_VERSION=0.5.3
# Install FlashInfer from CentML fork
# https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0
ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git
ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0
ARG FLASHINFER_CUBIN_VERSION=0.5.3
ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \
&& uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
&& flashinfer show-config
# Clone and build FlashInfer
git clone --recursive -b ${FLASHINFER_BRANCH} ${FLASHINFER_REPO} /tmp/flashinfer && \
cd /tmp/flashinfer && \
uv pip install --system --no-build-isolation -v . && \
uv pip install --system flashinfer-cubin==${FLASHINFER_CUBIN_VERSION} && \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_JIT_CACHE_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \
rm -rf /tmp/flashinfer && \
flashinfer show-config

# ============================================================
# OPENAI API SERVER DEPENDENCIES
Expand Down
11 changes: 10 additions & 1 deletion docker/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@
"RUN_WHEEL_CHECK": {
"default": "true"
},
"FLASHINFER_VERSION": {
"FLASHINFER_REPO": {
"default": "https://github.com/CentML/flashinfer.git"
},
"FLASHINFER_BRANCH": {
"default": "mlperf-inf-mm-q3vl-v6.0"
},
"FLASHINFER_CUBIN_VERSION": {
"default": "0.5.3"
},
"FLASHINFER_JIT_CACHE_VERSION": {
"default": "0.5.3"
},
"GDRCOPY_CUDA_VERSION": {
Expand Down
39 changes: 37 additions & 2 deletions vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_fa4_flash_attn_wrapper,
vit_flash_attn_wrapper,
vit_flashinfer_wrapper,
vit_torch_sdpa_wrapper,
)

Expand All @@ -34,6 +35,7 @@ def __init__(
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
workspace_buffer: torch.Tensor | None = None, # Only used for FlashInfer
) -> None:
"""
Args:
Expand All @@ -49,10 +51,10 @@ def __init__(

self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.scale = 1.0 / (head_size**0.5) if scale is None else scale

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this default scale factor based on?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on the MHA formula, scale = rsqrt(head_dim)

self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix

self.workspace_buffer = workspace_buffer
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
Expand Down Expand Up @@ -185,6 +187,27 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _forward_flashinfer(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return vit_flashinfer_wrapper(
q=query,
k=key,
v=value,
scale=self.scale,
workspace_buffer=self.workspace_buffer,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

def _forward_fa4(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -226,6 +249,8 @@ def forward_native(
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

Expand All @@ -236,11 +261,17 @@ def forward_cuda(
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.is_fa4_backend:
return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen)
elif self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
return self._forward_flashinfer(
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
Expand All @@ -256,6 +287,8 @@ def forward_cpu(
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

Expand All @@ -266,6 +299,8 @@ def forward_xpu(
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
workspace_buffer: torch.Tensor | None = None, # Only used for FlashInfer
) -> None:
super().__init__()
# Per attention head and per partition values.
Expand Down Expand Up @@ -355,6 +356,7 @@ def __init__(
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
multimodal_config=multimodal_config,
workspace_buffer=workspace_buffer,
)

self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
Expand All @@ -366,6 +368,7 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
Expand Down Expand Up @@ -406,6 +409,7 @@ def forward(
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

context_layer = einops.rearrange(
Expand Down
92 changes: 89 additions & 3 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.distributed import get_pp_group, parallel_state
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
Expand Down Expand Up @@ -131,6 +131,10 @@
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576

# Batch buckets for cuDNN graph caching - graphs are cached per bucket size
# This avoids creating a new graph for each unique batch size at runtime
BATCH_BUCKETS = [8, 16, 32, 64]


class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
Expand Down Expand Up @@ -214,6 +218,7 @@ def __init__(
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
workspace_buffer: torch.Tensor | None = None,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -227,6 +232,7 @@ def __init__(
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
workspace_buffer=workspace_buffer,
)
self.mlp = Qwen3_VisionMLP(
dim,
Expand All @@ -245,13 +251,15 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

x = x + self.mlp(self.norm2(x))
Expand Down Expand Up @@ -335,6 +343,17 @@ def __init__(
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.num_grid_per_side = int(self.num_position_embeddings**0.5)

use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)

# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
Expand Down Expand Up @@ -399,10 +418,18 @@ def __init__(
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASHINFER,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)

workspace_buffer = (
None
if self.attn_backend != AttentionBackendEnum.FLASHINFER
else torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
)

self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
Expand All @@ -414,6 +441,7 @@ def __init__(
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
workspace_buffer=workspace_buffer,
)
for layer_idx in range(vision_config.depth)
]
Expand Down Expand Up @@ -540,11 +568,55 @@ def compute_attn_mask_seqlen(
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE
or self.attn_backend == AttentionBackendEnum.FLASHINFER
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def add_padding_to_fi_seqlens(
self, seq: np.ndarray, batch_size: int, padding_value: int
) -> np.ndarray:
batch_size_padded = next(
(b for b in BATCH_BUCKETS if b >= batch_size), BATCH_BUCKETS[-1]
)
if batch_size_padded == batch_size:
return seq
return np.concatenate(
[
seq,
np.full(
(batch_size_padded - batch_size,), padding_value, dtype=seq.dtype
),
]
)

def compute_flashinfer_cu_seqlens(
self,
cu_seqlens: np.ndarray,
rotary_pos_emb_cos: torch.Tensor | None = None,
rotary_pos_emb_sin: torch.Tensor | None = None,
) -> np.ndarray:
batch_size = len(cu_seqlens) - 1
scale = self.hidden_size // self.tp_size
cu_seqlens = cu_seqlens * scale
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
cu_seqlens_qk = cu_seqlens * 2
else:
cu_seqlens_qk = cu_seqlens * 3
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_o = cu_seqlens
cu_seqlens_qk = self.add_padding_to_fi_seqlens(
cu_seqlens_qk, batch_size, cu_seqlens_qk[-1]
)
cu_seqlens_v = self.add_padding_to_fi_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
cu_seqlens_o = self.add_padding_to_fi_seqlens(
cu_seqlens_o, batch_size, cu_seqlens_o[-1]
)
return np.concatenate([cu_seqlens_qk, cu_seqlens_v, cu_seqlens_o])

def forward(
self,
x: torch.Tensor,
Expand All @@ -568,11 +640,24 @@ def forward(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
sequence_lengths = self.add_padding_to_fi_seqlens(
sequence_lengths, len(sequence_lengths), 0
)
cu_seqlens = self.compute_flashinfer_cu_seqlens(
cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin
)
cu_seqlens = torch.from_numpy(cu_seqlens)

sequence_lengths = torch.from_numpy(sequence_lengths)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = (
torch.tensor(128 * 1024, device=self.device)
if self.attn_backend == AttentionBackendEnum.FLASHINFER
else self.compute_attn_mask_seqlen(cu_seqlens)
)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
sequence_lengths = sequence_lengths.to(self.device, non_blocking=True)

deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
Expand All @@ -582,6 +667,7 @@ def forward(
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
]

@classmethod
Expand Down
Loading