Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,11 @@ def __post_init__(self) -> None:
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
):
if not ray_found:
raise ValueError(
"Unable to load Ray: "
f"{ray_utils.ray_import_err}. Ray is "
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`."
)
backend = "ray"
gpu_count = cuda_device_count_stateless()
raise ValueError(
f"Tensor parallel size ({self.world_size}) cannot be "
f"larger than the number of available GPUs ({gpu_count})."
)
elif self.data_parallel_backend == "ray":
logger.info(
"Using ray distributed inference because "
Expand Down
50 changes: 44 additions & 6 deletions vllm/v1/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,33 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` and `ray list nodes` to make sure the cluster has "
"enough resources."
) from None
# Provide more helpful error message when GPU count is exceeded
total_gpu_required = sum(spec.get("GPU", 0) for spec in placement_group_specs)
# If more than one GPU is required for the placement group, provide a
# more specific error message.
# We use >1 here because multi-GPU (tensor parallel) jobs are more
# likely to fail due to insufficient cluster resources, and users may
# need to adjust tensor_parallel_size to fit available GPUs.
if total_gpu_required > 1:
raise ValueError(
f"Cannot provide a placement group requiring "
f"{total_gpu_required} GPUs "
f"(placement_group_specs={placement_group_specs}) within "
f"{PG_WAIT_TIMEOUT} seconds.\n"
f"Tensor parallel size may exceed available GPUs in your "
f"cluster. Check resources with `ray status` and "
f"`ray list nodes`.\n"
f"If running on K8s with limited GPUs, consider reducing "
f"--tensor-parallel-size to match available GPU resources."
) from None
else:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within "
f"{PG_WAIT_TIMEOUT} seconds. See "
"`ray status` and `ray list nodes` to make sure the cluster "
"has enough resources."
) from None


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
Expand Down Expand Up @@ -299,6 +320,23 @@ def initialize_ray_cluster(
assert_ray_available()
from vllm.platforms import current_platform

# Prevalidate GPU requirements before Ray processing
if current_platform.is_cuda() and parallel_config.world_size > 1:
from vllm.utils import cuda_device_count_stateless

available_gpus = cuda_device_count_stateless()
if parallel_config.world_size > available_gpus:
logger.warning(
"Tensor parallel size (%d) exceeds available GPUs (%d). "
"This may result in Ray placement group allocation failures. "
"Consider reducing tensor_parallel_size to %d or less, "
"or ensure your Ray cluster has %d GPUs available.",
parallel_config.world_size,
available_gpus,
available_gpus,
parallel_config.world_size,
)
Comment on lines +323 to +338
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Check if cuda_device_count_stateless is exported from vllm.utils
rg -n "cuda_device_count_stateless" --type py vllm/utils/ | head -20

Repository: Signal65/vllm-code-review

Length of output: 349


🏁 Script executed:

# Check the parallel_config definition to understand world_size
fd -e py -p "parallel_config" vllm/ | head -5

Repository: Signal65/vllm-code-review

Length of output: 51


🏁 Script executed:

# Search for world_size definition in parallel_config
rg -n "world_size" --type py vllm/ -A 2 | grep -E "(def|world_size|=)" | head -30

Repository: Signal65/vllm-code-review

Length of output: 2699


🏁 Script executed:

# Check the specific file mentioned in the review
cat -n vllm/v1/executor/ray_utils.py | sed -n '320,340p'

Repository: Signal65/vllm-code-review

Length of output: 1108


🏁 Script executed:

# Check if cuda_device_count_stateless is re-exported from vllm/utils/__init__.py
rg -n "cuda_device_count_stateless" vllm/utils/__init__.py

Repository: Signal65/vllm-code-review

Length of output: 51


🏁 Script executed:

# Find the ParallelConfig class definition
fd -e py "parallel_config" vllm/ -x grep -l "class.*ParallelConfig"

Repository: Signal65/vllm-code-review

Length of output: 51


🏁 Script executed:

# Search for world_size property or attribute definition in parallel config
rg -n "def world_size|world_size\s*=" --type py vllm/ -B 2 -A 2 | head -50

Repository: Signal65/vllm-code-review

Length of output: 3976


🏁 Script executed:

# Search for ParallelConfig class definition more broadly
rg -n "class ParallelConfig" --type py vllm/

Repository: Signal65/vllm-code-review

Length of output: 117


🏁 Script executed:

# Check if cuda_device_count_stateless is available via direct import from torch_utils
cat vllm/utils/torch_utils.py | grep -n "def cuda_device_count_stateless" -A 10

Repository: Signal65/vllm-code-review

Length of output: 628


🏁 Script executed:

# Search vllm/utils/__init__.py to see what is exported
head -100 vllm/utils/__init__.py

Repository: Signal65/vllm-code-review

Length of output: 3202


🏁 Script executed:

# Check how parallel_config.world_size is used/defined
rg -n "parallel_config\.world_size" --type py vllm/ -B 1 -A 1 | head -40

Repository: Signal65/vllm-code-review

Length of output: 2649


🏁 Script executed:

# Check the ParallelConfig class definition to see world_size
sed -n '68,200p' vllm/config/parallel.py

Repository: Signal65/vllm-code-review

Length of output: 6774


🏁 Script executed:

# Verify the import path - try to find if cuda_device_count_stateless is actually accessible via vllm.utils
rg -n "from vllm\.utils import cuda_device_count_stateless" --type py vllm/

Repository: Signal65/vllm-code-review

Length of output: 161


🏁 Script executed:

# Check what's actually exported from vllm/utils or torch_utils
rg -n "__all__" vllm/utils/torch_utils.py

Repository: Signal65/vllm-code-review

Length of output: 51


🏁 Script executed:

# Search for world_size as a property or method in ParallelConfig
rg -n "def world_size|@property" vllm/config/parallel.py -A 3 | head -50

Repository: Signal65/vllm-code-review

Length of output: 637


🏁 Script executed:

# Check if there's a __getattr__ or similar in ParallelConfig
rg -n "__getattr__|world_size" vllm/config/parallel.py

Repository: Signal65/vllm-code-review

Length of output: 982


🏁 Script executed:

# Look at the full ParallelConfig class to find world_size
cat vllm/config/parallel.py | grep -n "world_size" -B 2 -A 2

Repository: Signal65/vllm-code-review

Length of output: 2298


🏁 Script executed:

# Check if cuda_device_count_stateless might be re-exported dynamically or if the import works
python3 << 'EOF'
import sys
try:
    from vllm.utils import cuda_device_count_stateless
    print("✓ Import SUCCESS: cuda_device_count_stateless is accessible from vllm.utils")
except ImportError as e:
    print(f"✗ Import FAILED: {e}")
except Exception as e:
    print(f"✗ Unexpected error: {e}")
EOF

Repository: Signal65/vllm-code-review

Length of output: 155


🏁 Script executed:

# Check the imports at the top of vllm/config/parallel.py to see how cuda_device_count_stateless is imported
head -50 vllm/config/parallel.py | grep -E "^from|^import"

Repository: Signal65/vllm-code-review

Length of output: 647


Fix incorrect import path and misleading warning message.

Two issues:

  1. Import path is incorrect (line 325): The import from vllm.utils import cuda_device_count_stateless will fail. Should be from vllm.utils.torch_utils import cuda_device_count_stateless to match how it's imported elsewhere (e.g., vllm/config/parallel.py).

  2. Warning message is misleading: parallel_config.world_size is defined as tensor_parallel_size × pipeline_parallel_size, not just tensor parallel. The warning says "Tensor parallel size" but should clarify it's the world_size.

Suggested fix
    if current_platform.is_cuda() and parallel_config.world_size > 1:
-       from vllm.utils import cuda_device_count_stateless
+       from vllm.utils.torch_utils import cuda_device_count_stateless
 
         available_gpus = cuda_device_count_stateless()
         if parallel_config.world_size > available_gpus:
             logger.warning(
-                "Tensor parallel size (%d) exceeds available GPUs (%d). "
+                "Required GPU count (world_size=%d) exceeds available GPUs (%d). "
                 "This may result in Ray placement group allocation failures. "
-                "Consider reducing tensor_parallel_size to %d or less, "
+                "Consider reducing tensor_parallel_size or pipeline_parallel_size "
+                "so that their product is %d or less, "
                 "or ensure your Ray cluster has %d GPUs available.",
                 parallel_config.world_size,
                 available_gpus,
                 available_gpus,
                 parallel_config.world_size,
             )
🤖 Prompt for AI Agents
In `@vllm/v1/executor/ray_utils.py` around lines 323 - 338, Update the incorrect
import and clarify the warning: change the import of cuda_device_count_stateless
to come from vllm.utils.torch_utils (replace "from vllm.utils import
cuda_device_count_stateless" with the correct module import), and update the
logger.warning message that references parallel_config.world_size to state
"world_size (tensor_parallel_size × pipeline_parallel_size)" instead of "Tensor
parallel size" so it accurately describes parallel_config.world_size when
logging available_gpus vs parallel_config.world_size in the prevalidation block
that uses cuda_device_count_stateless and parallel_config.world_size.


if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
Expand Down