diff --git a/docker/Dockerfile b/docker/Dockerfile index 227f4a3355c8..ac7cea073e35 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -492,16 +492,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ rm /tmp/requirements-cuda.txt /tmp/common.txt -# Install FlashInfer pre-compiled kernel cache and binaries -# This is ~1.1GB and only changes when FlashInfer version bumps -# https://docs.flashinfer.ai/installation.html -# From versions.json: .flashinfer.version -ARG FLASHINFER_VERSION=0.5.3 +# Install FlashInfer from CentML fork +# https://github.com/CentML/flashinfer/tree/mlperf-inf-mm-q3vl-v6.0 +ARG FLASHINFER_REPO=https://github.com/CentML/flashinfer.git +ARG FLASHINFER_BRANCH=mlperf-inf-mm-q3vl-v6.0 +ARG FLASHINFER_CUBIN_VERSION=0.5.3 +ARG FLASHINFER_JIT_CACHE_VERSION=0.5.3 RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \ - && uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ - --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - && flashinfer show-config + # Clone and build FlashInfer + git clone --recursive -b ${FLASHINFER_BRANCH} ${FLASHINFER_REPO} /tmp/flashinfer && \ + cd /tmp/flashinfer && \ + uv pip install --system --no-build-isolation -v . && \ + uv pip install --system flashinfer-cubin==${FLASHINFER_CUBIN_VERSION} && \ + uv pip install --system flashinfer-jit-cache==${FLASHINFER_JIT_CACHE_VERSION} \ + --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ + rm -rf /tmp/flashinfer && \ + flashinfer show-config # ============================================================ # OPENAI API SERVER DEPENDENCIES diff --git a/docker/versions.json b/docker/versions.json index 045955bc46ce..3bb174eea948 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -67,7 +67,16 @@ "RUN_WHEEL_CHECK": { "default": "true" }, - "FLASHINFER_VERSION": { + "FLASHINFER_REPO": { + "default": "https://github.com/CentML/flashinfer.git" + }, + "FLASHINFER_BRANCH": { + "default": "mlperf-inf-mm-q3vl-v6.0" + }, + "FLASHINFER_CUBIN_VERSION": { + "default": "0.5.3" + }, + "FLASHINFER_JIT_CACHE_VERSION": { "default": "0.5.3" }, "GDRCOPY_CUDA_VERSION": { diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 33e120e7660e..28d83776ebe5 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -13,6 +13,7 @@ from vllm.v1.attention.ops.vit_attn_wrappers import ( vit_fa4_flash_attn_wrapper, vit_flash_attn_wrapper, + vit_flashinfer_wrapper, vit_torch_sdpa_wrapper, ) @@ -34,6 +35,7 @@ def __init__( num_kv_heads: int | None = None, prefix: str = "", multimodal_config: MultiModalConfig | None = None, + workspace_buffer: torch.Tensor | None = None, # Only used for FlashInfer ) -> None: """ Args: @@ -49,10 +51,10 @@ def __init__( self.num_heads = num_heads self.head_size = head_size - self.scale = scale + self.scale = 1.0 / (head_size**0.5) if scale is None else scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix - + self.workspace_buffer = workspace_buffer assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" @@ -185,6 +187,27 @@ def _forward_fa( output = output.reshape(bsz, q_len, -1) return output + def _forward_flashinfer( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend + ) -> torch.Tensor: + return vit_flashinfer_wrapper( + q=query, + k=key, + v=value, + scale=self.scale, + workspace_buffer=self.workspace_buffer, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + def _forward_fa4( self, query: torch.Tensor, @@ -226,6 +249,8 @@ def forward_native( value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -236,11 +261,17 @@ def forward_cuda( value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: if self.is_fa4_backend: return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen) elif self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.FLASHINFER: + return self._forward_flashinfer( + query, key, value, cu_seqlens, max_seqlen, sequence_lengths + ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) else: @@ -256,6 +287,8 @@ def forward_cpu( value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -266,6 +299,8 @@ def forward_xpu( value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: assert self.is_flash_attn_backend, ( "XPU only supports FLASH_ATTN for vision attention." diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 9cfd12a31903..a10e5a4689b6 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -310,6 +310,7 @@ def __init__( quant_config: QuantizationConfig | None = None, multimodal_config: MultiModalConfig | None = None, prefix: str = "", + workspace_buffer: torch.Tensor | None = None, # Only used for FlashInfer ) -> None: super().__init__() # Per attention head and per partition values. @@ -355,6 +356,7 @@ def __init__( head_size=self.hidden_size_per_attention_head, scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, + workspace_buffer=workspace_buffer, ) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) @@ -366,6 +368,7 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -406,6 +409,7 @@ def forward( value=v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) context_layer = einops.rearrange( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 44b45a08dc1e..40099edac6da 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -51,7 +51,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions -from vllm.distributed import get_pp_group +from vllm.distributed import get_pp_group, parallel_state from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.conv import Conv3dLayer @@ -131,6 +131,10 @@ # Official recommended max pixels is 24576 * 32 * 32 _MAX_FRAMES_PER_VIDEO = 24576 +# Batch buckets for cuDNN graph caching - graphs are cached per bucket size +# This avoids creating a new graph for each unique batch size at runtime +BATCH_BUCKETS = [8, 16, 32, 64] + class Qwen3_VisionPatchEmbed(nn.Module): def __init__( @@ -214,6 +218,7 @@ def __init__( multimodal_config: MultiModalConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", + workspace_buffer: torch.Tensor | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -227,6 +232,7 @@ def __init__( quant_config=quant_config, multimodal_config=multimodal_config, prefix=f"{prefix}.attn", + workspace_buffer=workspace_buffer, ) self.mlp = Qwen3_VisionMLP( dim, @@ -245,6 +251,7 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -252,6 +259,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) x = x + self.mlp(self.norm2(x)) @@ -335,6 +343,17 @@ def __init__( self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.num_grid_per_side = int(self.num_position_embeddings**0.5) + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) + # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack self.out_hidden_size = vision_config.out_hidden_size * ( @@ -399,10 +418,18 @@ def __init__( AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.FLASHINFER, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." ) + + workspace_buffer = ( + None + if self.attn_backend != AttentionBackendEnum.FLASHINFER + else torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=self.device) + ) + self.blocks = nn.ModuleList( [ Qwen3_VisionBlock( @@ -414,6 +441,7 @@ def __init__( quant_config=quant_config, multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", + workspace_buffer=workspace_buffer, ) for layer_idx in range(vision_config.depth) ] @@ -540,11 +568,55 @@ def compute_attn_mask_seqlen( if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE + or self.attn_backend == AttentionBackendEnum.FLASHINFER or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen + def add_padding_to_fi_seqlens( + self, seq: np.ndarray, batch_size: int, padding_value: int + ) -> np.ndarray: + batch_size_padded = next( + (b for b in BATCH_BUCKETS if b >= batch_size), BATCH_BUCKETS[-1] + ) + if batch_size_padded == batch_size: + return seq + return np.concatenate( + [ + seq, + np.full( + (batch_size_padded - batch_size,), padding_value, dtype=seq.dtype + ), + ] + ) + + def compute_flashinfer_cu_seqlens( + self, + cu_seqlens: np.ndarray, + rotary_pos_emb_cos: torch.Tensor | None = None, + rotary_pos_emb_sin: torch.Tensor | None = None, + ) -> np.ndarray: + batch_size = len(cu_seqlens) - 1 + scale = self.hidden_size // self.tp_size + cu_seqlens = cu_seqlens * scale + if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + cu_seqlens_qk = cu_seqlens * 2 + else: + cu_seqlens_qk = cu_seqlens * 3 + cu_seqlens_v = cu_seqlens * 3 + cu_seqlens_o = cu_seqlens + cu_seqlens_qk = self.add_padding_to_fi_seqlens( + cu_seqlens_qk, batch_size, cu_seqlens_qk[-1] + ) + cu_seqlens_v = self.add_padding_to_fi_seqlens( + cu_seqlens_v, batch_size, cu_seqlens_v[-1] + ) + cu_seqlens_o = self.add_padding_to_fi_seqlens( + cu_seqlens_o, batch_size, cu_seqlens_o[-1] + ) + return np.concatenate([cu_seqlens_qk, cu_seqlens_v, cu_seqlens_o]) + def forward( self, x: torch.Tensor, @@ -568,11 +640,24 @@ def forward( axis=0, dtype=np.int32 ) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + if self.attn_backend == AttentionBackendEnum.FLASHINFER: + sequence_lengths = self.add_padding_to_fi_seqlens( + sequence_lengths, len(sequence_lengths), 0 + ) + cu_seqlens = self.compute_flashinfer_cu_seqlens( + cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin + ) cu_seqlens = torch.from_numpy(cu_seqlens) - + sequence_lengths = torch.from_numpy(sequence_lengths) hidden_states = hidden_states.unsqueeze(1) - max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen = ( + torch.tensor(128 * 1024, device=self.device) + if self.attn_backend == AttentionBackendEnum.FLASHINFER + else self.compute_attn_mask_seqlen(cu_seqlens) + ) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + sequence_lengths = sequence_lengths.to(self.device, non_blocking=True) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): @@ -582,6 +667,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, ) if layer_num in self.deepstack_visual_indexes: deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8f315881df45..020e948a4a40 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -363,6 +363,7 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.FLASH_ATTN_CUTE, AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, ] @classmethod diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 8fa0a442c3a5..84b1438fb1b0 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -262,3 +262,88 @@ def vit_torch_sdpa_wrapper( cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) + + +def flashinfer_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache + + is_reshaped = q.dim() == 4 + + if is_reshaped: + reshape_batch_size = q.shape[0] + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + assert len(cu_seqlens) % 3 == 0, "cu_seqlens must be divisible by 3" + cu_seqlength = len(cu_seqlens) // 3 + batch_offsets_qk = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1) + batch_offsets_v = cu_seqlens[cu_seqlength : cu_seqlength * 2].view(-1, 1, 1, 1) + batch_offsets_o = cu_seqlens[cu_seqlength * 2 :].view(-1, 1, 1, 1) + sequence_lengths = sequence_lengths.view(-1, 1, 1, 1) + max_seqlen = max_seqlen.item() + + output, _ = cudnn_batch_prefill_with_kv_cache( + q, + k, + v, + scale, + workspace_buffer, + max_token_per_sequence=max_seqlen, + max_sequence_kv=max_seqlen, + actual_seq_lens_q=sequence_lengths, + actual_seq_lens_kv=sequence_lengths, + causal=False, + return_lse=False, + batch_offsets_q=batch_offsets_qk, + batch_offsets_k=batch_offsets_qk, + batch_offsets_v=batch_offsets_v, + batch_offsets_o=batch_offsets_o, + ) + + if is_reshaped: + output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size) + + return output + + +def vit_flashinfer_wrapper_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +direct_register_custom_op( + op_name="flashinfer_wrapper", + op_func=flashinfer_wrapper, + fake_impl=vit_flashinfer_wrapper_fake, +) + + +def vit_flashinfer_wrapper( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + workspace_buffer: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.ops.vllm.flashinfer_wrapper( + q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths + )