diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 76b2f494d5b2..49cfc5aa04c3 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -153,20 +153,27 @@ async def test_failed_abort(tmp_socket): await client.check_health() # Trigger an abort on the client side. - # This request ID does not exist, and will cause the engine to error - await client.abort(request_id="foo") + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") - # Future generation requests will now fail + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( inputs="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), + sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass assert "KeyError" in repr(execinfo.value) assert client.errored + await abort_task + # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 165e6cc2146c..700332864d17 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -43,6 +43,10 @@ class RPCAbortRequest: request_id: str +class RPCHealthRequest: + pass + + class RPCStartupRequest(Enum): IS_SERVER_READY = 1 @@ -52,7 +56,8 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, + RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7e397cf408fb..aa9dbbd448af 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -20,8 +20,9 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, + RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptInputs @@ -94,9 +95,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - # IPC path for acking heartbeats. - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -123,28 +124,34 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_heartbeat_loop(self, timeout: int): - """Background loop that continually listens to the RPCServer for - heartbeats. + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). """ + try: while True: - if await self.heartbeat_socket.poll(timeout=timeout) == 0: - # No heartbeat was received. Set error and exit the loop - self._set_errored( - TimeoutError("No heartbeat received " - "from MQLLMEngine")) - logger.debug("Shutting down MQLLMEngineClient check " - "health loop due to timeout") - break - + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) else: - # Heartbeat received- check the message + # Server sent a health status message unprompted. await self._check_success( - error_message="Heartbeat failed.", - socket=self.heartbeat_socket) + error_message="Health check failed.", + socket=self.health_socket) - logger.debug("Heartbeat successful.") + logger.debug("Health probe successful.") except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") @@ -227,7 +234,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) def close(self): """Destroy the ZeroMQ Context.""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index b1dd9915cbbf..485db0bab129 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,7 +1,5 @@ import pickle import signal -import threading -import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -17,10 +15,10 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, + RPCStartupResponse) # yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -93,9 +91,9 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send heartbeats back to client. - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -103,20 +101,6 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None - # Heartbeat thread - self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, - daemon=True) - self._heartbeat_stop_event = threading.Event() - # The heartbeat needs to be faster than what the client will wait for - # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds - self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 - - self._last_alive_time = time.time() - # The heartbeats can tolerate a long period of the engine chugging - # away at a generation request. - # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds - self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 - @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -147,8 +131,6 @@ def start(self): try: logger.debug("Starting Startup Loop.") self.run_startup_loop() - logger.debug("Starting heartbeat thread") - self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -162,7 +144,6 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. - self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -201,11 +182,9 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: - self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - self._alive() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -221,6 +200,7 @@ def run_engine_loop(self): def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" + try: return self.engine.step() except SystemExit: @@ -249,9 +229,10 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() else: - raise ValueError("Unknown RPCRequest Type: " - f"{type(request)}") + raise ValueError("Unknown RPCRequest Type: {request}") except Exception as e: self._set_errored(e) @@ -298,32 +279,13 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _heartbeat_loop(self): - while not self._heartbeat_stop_event.wait( - timeout=self.heartbeat_interval_seconds): - # Loops until the stop event is set - self._heartbeat() - - logger.debug("Exiting MQLLMEngine heartbeat thread") - - def _heartbeat(self): - # Send unhealthy if engine has already errored + def _handle_health_request(self): if self._errored_with is not None: self._send_unhealthy(self._errored_with) - # Check for life of the main loop - elif time.time() - self._last_alive_time > self.last_alive_threshold: - self._send_unhealthy(RuntimeError("Engine loop has died")) - - else: - # Otherwise- check health of the engine - # self.engine.check_health() raises on unhealthy - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -333,14 +295,12 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - error_bytes = pickle.dumps(error) - self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -353,9 +313,6 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e - def _alive(self): - self._last_alive_time = time.time() - def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str):