Skip to content

Comments

[Disagg][Perf] Use CUDA event sync instead of blocking tolist to av…#31

Open
MitchLewis930 wants to merge 1 commit intosample_token_ids_beforefrom
sample_token_ids_after
Open

[Disagg][Perf] Use CUDA event sync instead of blocking tolist to av…#31
MitchLewis930 wants to merge 1 commit intosample_token_ids_beforefrom
sample_token_ids_after

Conversation

@MitchLewis930
Copy link
Collaborator

@MitchLewis930 MitchLewis930 commented Jan 24, 2026

test

Summary by CodeRabbit

  • Refactor
    • Improved efficiency of GPU-to-CPU memory transfers for token processing through optimized synchronization mechanisms.

✏️ Tip: You can customize this high-level summary in your review settings.

…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
@coderabbitai
Copy link

coderabbitai bot commented Jan 24, 2026

📝 Walkthrough

Walkthrough

Added GPU-to-CPU transfer optimization in GPUModelRunner using pinned memory allocation and CUDA event synchronization. Replaced direct tensor-to-list conversion with a new helper method that uses non-blocking copy for sampled token IDs.

Changes

Cohort / File(s) Summary
GPU-to-CPU Transfer Optimization
vllm/v1/worker/gpu_model_runner.py
Added transfer_event (CUDA event) and sampled_token_ids_pinned_cpu (pinned CPU tensor) instance attributes. Introduced _to_list() helper method that performs non-blocking copy to pinned buffer, records/waits on CUDA event for synchronization, and returns data as Python list. Replaced .tolist() call in execute_model with new _to_list() method.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐰 A bunny's hack for speed so keen,
GPU to CPU, pinned and clean,
Events that sync, no blocking pain,
Async transfers through the lane,
Faster tokens flow like rabbit's mane!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly relates to the main change: replacing a blocking tolist() call with CUDA event synchronization for performance improvement in GPU model runner.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@vllm/v1/worker/gpu_model_runner.py`:
- Around line 319-324: The pinned CPU buffer self.sampled_token_ids_pinned_cpu
is allocated with the wrong first dimension: it uses max_model_len but should be
sized to hold sampled_token_ids which are shaped (num_reqs, 1) and bounded by
max_num_reqs; fix the allocation in GPUModelRunner (where
self.sampled_token_ids_pinned_cpu is created) to use (self.max_num_reqs, 1) (or
max_num_reqs) instead of (self.max_model_len, 1) and keep dtype=torch.int64,
device="cpu", pin_memory=True so the transfer_event-based copy of
sampled_token_ids fits without waste or overflow.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6fad29b and b395b3b.

📒 Files selected for processing (1)
  • vllm/v1/worker/gpu_model_runner.py
🔇 Additional comments (2)
vllm/v1/worker/gpu_model_runner.py (2)

1698-1700: LGTM!

The integration of the new _to_list helper method is correct. It's appropriately used only in the non-speculative decode path where max_gen_len == 1.


3243-3256: Good optimization approach.

The CUDA event synchronization pattern correctly avoids the stream-wide sync that direct .tolist() on a GPU tensor would cause. The comment clearly explains the motivation and references the relevant issue.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +319 to +324
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Incorrect buffer dimension: should use max_num_reqs instead of max_model_len.

The pinned buffer is used to copy sampled_token_ids which has shape (num_reqs, 1) where the first dimension is bounded by max_num_reqs, not max_model_len. Using max_model_len is semantically incorrect and can cause:

  1. Memory waste: max_model_len can be 128K+ while max_num_reqs is typically ~256
  2. Potential buffer overflow if max_num_reqs > max_model_len (edge case with short context models)
🐛 Proposed fix
         self.transfer_event = torch.cuda.Event()
         self.sampled_token_ids_pinned_cpu = torch.empty(
-            (self.max_model_len, 1),
+            (self.max_num_reqs, 1),
             dtype=torch.int64,
             device="cpu",
             pin_memory=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)
🤖 Prompt for AI Agents
In `@vllm/v1/worker/gpu_model_runner.py` around lines 319 - 324, The pinned CPU
buffer self.sampled_token_ids_pinned_cpu is allocated with the wrong first
dimension: it uses max_model_len but should be sized to hold sampled_token_ids
which are shaped (num_reqs, 1) and bounded by max_num_reqs; fix the allocation
in GPUModelRunner (where self.sampled_token_ids_pinned_cpu is created) to use
(self.max_num_reqs, 1) (or max_num_reqs) instead of (self.max_model_len, 1) and
keep dtype=torch.int64, device="cpu", pin_memory=True so the
transfer_event-based copy of sampled_token_ids fits without waste or overflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants