Skip to content

Conversation

@maxyanghu
Copy link

@maxyanghu maxyanghu commented Jan 23, 2026

This branch adds cuDNN backend support for ViT (Vision Transformer) attention in multimodal models like Qwen2.5-VL and Qwen3-VL.

Key File Changes

1. New FlashInfer/cuDNN Wrapper (vllm/v1/attention/ops/vit_attn_wrappers.py)

Added flashinfer_wrapper and vit_flashinfer_wrapper functions that:

  • Call cudnn_batch_prefill_with_kv_cache from FlashInfer
  • Handle the special cu_seqlens format (3x longer, containing batch_offsets_qk, batch_offsets_v, batch_offsets_o)
  • Support both 3D and 4D tensor inputs (with automatic reshape)
  • Registered as a custom op for torch.compile compatibility

2. MMEncoderAttention Updates (vllm/model_executor/layers/attention/mm_encoder_attention.py)

  • Added workspace_buffer parameter for cuDNN backend
  • New _forward_flashinfer() method
  • Updated forward_cuda() to dispatch to FlashInfer when FLASHINFER backend is selected
  • Added sequence_lengths parameter throughout the call chain

3. Qwen2.5-VL Model Updates (vllm/model_executor/models/qwen2_5_vl.py)

  • Added workspace_buffer allocation (128MB) when using FlashInfer backend
  • Propagated workspace_buffer and sequence_lengths through attention layers
  • Added FlashInfer to supported backends

4. Qwen3-VL Model Updates (vllm/model_executor/models/qwen3_vl.py)

More extensive changes for cuDNN compatibility:

  • New add_padding_to_fi_seqlens() method to pad sequence lengths to batch size of 8
  • New compute_flashinfer_cu_seqlens() method to compute the 3-section cu_seqlens format
  • Fixed max_seqlen to 128K for FlashInfer backend

Technical Details

The cuDNN batch attention API requires:

Parameter Description
cu_seqlens Split into 3 sections: Q/K offsets, V offsets, O offsets
sequence_lengths Padded to batch size 8
max_seqlen Fixed at 128K for FlashInfer
workspace_buffer 128MB pre-allocated buffer

cu_seqlens Format

  • cu_seqlens is 3x the normal length
cu_seqlength = len(cu_seqlens) // 3
batch_offsets_qk = cu_seqlens[:cu_seqlength] # Q/K batch offsets
batch_offsets_v = cu_seqlens[cu_seqlength:cu_seqlength2] # V batch offsets
batch_offsets_o = cu_seqlens[cu_seqlength2:] # O batch offsets

@maxyanghu maxyanghu self-assigned this Jan 23, 2026
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@maxyanghu maxyanghu marked this pull request as draft January 23, 2026 22:19
@maxyanghu maxyanghu force-pushed the vit-attn-cudnn-backend branch from 3372e3a to 28d139b Compare January 26, 2026 00:26
@maxyanghu maxyanghu changed the base branch from main to mlperf-inf-mm-q3vl-v6.0 January 26, 2026 00:26
@maxyanghu maxyanghu marked this pull request as ready for review January 26, 2026 00:27
@maxyanghu maxyanghu force-pushed the vit-attn-cudnn-backend branch from 28d139b to 47af3e1 Compare January 26, 2026 00:34
@wangshangsam wangshangsam added the enhancement New feature or request label Jan 26, 2026
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
zhandaz
zhandaz previously approved these changes Jan 26, 2026
Copy link

@zhandaz zhandaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

The Dockerfile may be polished. Please run a docker image build, and an end-to-end functionality run, then it is good to be merged!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite familiar with the flashinfer build. Please correct me if I'm wrong.

  1. I found there is a script tool/flashinfer-build.sh. While the usage may be different, I feel like we should also set FI_TORCH_CUDA_ARCH_LIST in our case which can reduce our image build time a lot.
  2. python3 are used throughout other places in this docker file. Let's switch to python3 for consistency. And also python may not work.

f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)

workspace_buffer = (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that this is specified through VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE (it was you who introduced this env var)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE controls text model's Flashinfer workspace buffer size. It's different from this one.

Copy link

@wangshangsam wangshangsam Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the ViT FI workspace size be specificable via (a different) env var too? Do you imagine this number to be different on different GPUs?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fprop this is count_of_ragged_tensor (Q,K,V = 3 in this case) * Batch_size * sizeof(TMADescriptor) (128B) + 16 bytes for alignment + 4 bytes for tile size. So, 128 MB should be more than sufficient across architectures.

self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.scale = 1.0 / (head_size**0.5) if scale is None else scale

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this default scale factor based on?

@zhandaz
Copy link

zhandaz commented Jan 26, 2026

@maxyanghu One more thing, flashinfer-python==0.5.3 is specified in requirements/cuda.txt and is installed during the docker build. While I feel it won't have effects since our built version should also be 0.5.3 and thus this may be skipped, please have a double check.

@b-mu
Copy link

b-mu commented Jan 26, 2026

I think we should document that when we have the upcoming cudnn FE release, those harded coded max_seqlen and paddings should be removed.

Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants