Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,27 @@ async def test_failed_abort(tmp_socket):
await client.check_health()

# Trigger an abort on the client side.
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")

# Future generation requests will now fail
# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())

# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored

await abort_task

# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class RPCAbortRequest:
request_id: str


class RPCHealthRequest:
pass


class RPCStartupRequest(Enum):
IS_SERVER_READY = 1

Expand All @@ -52,7 +56,8 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

Expand Down
51 changes: 29 additions & 22 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -94,9 +95,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")

# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
Expand All @@ -123,28 +124,34 @@ def get_data_socket(self) -> Iterator[Socket]:
finally:
socket.close(linger=0)

async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.

The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.

The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
"""

try:
while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break

if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)

# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
else:
# Heartbeat received- check the message
# Server sent a health status message unprompted.
await self._check_success(
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
error_message="Health check failed.",
socket=self.health_socket)

logger.debug("Heartbeat successful.")
logger.debug("Health probe successful.")

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
Expand Down Expand Up @@ -227,7 +234,7 @@ async def setup(self):

# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))

def close(self):
"""Destroy the ZeroMQ Context."""
Expand Down
77 changes: 17 additions & 60 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -17,10 +15,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -93,30 +91,16 @@ def __init__(self,
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

# Error state.
self._errored_with: Optional[BaseException] = None

# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
Expand Down Expand Up @@ -147,8 +131,6 @@ def start(self):
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -162,7 +144,6 @@ def start(self):
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine

Expand Down Expand Up @@ -201,11 +182,9 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")

Expand All @@ -221,6 +200,7 @@ def run_engine_loop(self):

def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""

try:
return self.engine.step()
except SystemExit:
Expand Down Expand Up @@ -249,9 +229,10 @@ def handle_new_input(self):
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
raise ValueError("Unknown RPCRequest Type: {request}")

except Exception as e:
self._set_errored(e)
Expand Down Expand Up @@ -298,32 +279,13 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()

logger.debug("Exiting MQLLMEngine heartbeat thread")

def _heartbeat(self):
# Send unhealthy if engine has already errored
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)

# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))

else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
Comment on lines +282 to +288
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Return after reporting unhealthy state.

When _errored_with is set, the method sends UNHEALTHY but then still runs check_health() and can send HEALTHY, which can mask failures.

🔧 Proposed fix
 def _handle_health_request(self):
     if self._errored_with is not None:
         self._send_unhealthy(self._errored_with)
+        return
 
     # Raises error if unhealthy.
     self.engine.check_health()
     self._send_healthy()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
return
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
🤖 Prompt for AI Agents
In `@vllm/engine/multiprocessing/engine.py` around lines 282 - 288, In
_handle_health_request, when self._errored_with is set you currently call
self._send_unhealthy(self._errored_with) but then continue to run
self.engine.check_health() and self._send_healthy(), which can mask the error;
fix by returning immediately after calling
self._send_unhealthy(self._errored_with) so that _handle_health_request does not
invoke engine.check_health() or _send_healthy in the errored path (update the
_handle_health_request method to short-circuit on self._errored_with).


def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
Expand All @@ -333,14 +295,12 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand All @@ -353,9 +313,6 @@ def _set_errored(self, e: BaseException):
if self._errored_with is None:
self._errored_with = e

def _alive(self):
self._last_alive_time = time.time()


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
Expand Down