diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 547c1fd02092..4616f363cc04 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,8 +1,6 @@ -import asyncio from http import HTTPStatus from typing import List -import openai import pytest import pytest_asyncio import requests @@ -105,52 +103,3 @@ async def test_check_health(server: RemoteOpenAIServer): response = requests.get(server.url_for("health")) assert response.status_code == HTTPStatus.OK - - -@pytest.mark.parametrize( - "server_args", - [ - pytest.param(["--max-model-len", "10100"], - id="default-frontend-multiprocessing"), - pytest.param( - ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], - id="disable-frontend-multiprocessing") - ], - indirect=True, -) -@pytest.mark.asyncio -async def test_request_cancellation(server: RemoteOpenAIServer): - # clunky test: send an ungodly amount of load in with short timeouts - # then ensure that it still responds quickly afterwards - - chat_input = [{"role": "user", "content": "Write a long story"}] - client = server.get_async_client(timeout=0.5) - tasks = [] - # Request about 2 million tokens - for _ in range(200): - task = asyncio.create_task( - client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10000, - extra_body={"min_tokens": 10000})) - tasks.append(task) - - done, pending = await asyncio.wait(tasks, - return_when=asyncio.ALL_COMPLETED) - - # Make sure all requests were sent to the server and timed out - # (We don't want to hide other errors like 400s that would invalidate this - # test) - assert len(pending) == 0 - for d in done: - with pytest.raises(openai.APITimeoutError): - d.result() - - # If the server had not cancelled all the other requests, then it would not - # be able to respond to this one within the timeout - client = server.get_async_client(timeout=5) - response = await client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10) - - assert len(response.choices) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a6b0aed66a..0bc9e5bc32a4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import asyncio import os import socket +from functools import partial from typing import AsyncIterator, Tuple import pytest @@ -25,7 +26,10 @@ async def mock_async_iterator(idx: int): print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators) + merged_iterator = merge_async_iterators(*iterators, + is_cancelled=partial(asyncio.sleep, + 0, + result=False)) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/tests/utils.py b/tests/utils.py index bf3d88194e4c..afeb708f3bcd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -163,11 +163,12 @@ def get_client(self): api_key=self.DUMMY_API_KEY, ) - def get_async_client(self, **kwargs): - return openai.AsyncOpenAI(base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - **kwargs) + def get_async_client(self): + return openai.AsyncOpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + ) def _test_completion( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f50e20cf7032..32396fd10188 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1065,20 +1065,16 @@ async def generate( >>> # Process and return the final output >>> ... """ - try: - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ): - yield LLMEngine.validate_output(output, RequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, @@ -1151,19 +1147,15 @@ async def encode( >>> # Process and return the final output >>> ... """ - try: - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 95da1c6e7b9b..ea3c93f73303 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -17,11 +17,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, + random_uuid) from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -47,11 +47,6 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() - return await _generate(request_dict, raw_request=request) - - -@with_cancellation -async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) @@ -59,6 +54,8 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = iterate_with_cancellation( + results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 00e2d1a56f16..14e3a34ce141 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -59,7 +59,6 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, @@ -312,7 +311,6 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") -@with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -327,7 +325,6 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") -@with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -356,7 +353,6 @@ async def show_version(): @router.post("/v1/chat/completions") -@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -377,7 +373,6 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") -@with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -395,7 +390,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") -@with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: @@ -413,7 +407,6 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/score") -@with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: @@ -431,7 +424,6 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post("/v1/score") -@with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 81bce0dd370b..527418c63509 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,6 +32,7 @@ from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls +from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -233,6 +234,10 @@ async def create_chat_completion( assert len(generators) == 1 result_generator, = generators + if raw_request: + result_generator = iterate_with_cancellation( + result_generator, raw_request.is_disconnected) + # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 5cf9df92e296..bd39a4c42e93 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -159,7 +159,8 @@ async def create_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators(*generators) + result_generator = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) model_name = self._get_model_name(lora_request) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 879276646d2b..fd501ad4f833 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -202,7 +202,10 @@ async def create_embedding( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators(*generators) + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 101d170bee4d..6f5cc14ac37c 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -186,7 +186,10 @@ async def create_score( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators(*generators) + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py deleted file mode 100644 index e8a78d216d0f..000000000000 --- a/vllm/entrypoints/utils.py +++ /dev/null @@ -1,57 +0,0 @@ -import asyncio -import functools - -from fastapi import Request - - -async def listen_for_disconnect(request: Request) -> None: - """Returns if a disconnect message is received""" - while True: - message = await request.receive() - if message["type"] == "http.disconnect": - break - - -def with_cancellation(handler_func): - """Decorator that allows a route handler to be cancelled by client - disconnections. - - This does _not_ use request.is_disconnected, which does not work with - middleware. Instead this follows the pattern from - starlette.StreamingResponse, which simultaneously awaits on two tasks- one - to wait for an http disconnect message, and the other to do the work that we - want done. When the first task finishes, the other is cancelled. - - A core assumption of this method is that the body of the request has already - been read. This is a safe assumption to make for fastapi handlers that have - already parsed the body of the request into a pydantic model for us. - This decorator is unsafe to use elsewhere, as it will consume and throw away - all incoming messages for the request while it looks for a disconnect - message. - - In the case where a `StreamingResponse` is returned by the handler, this - wrapper will stop listening for disconnects and instead the response object - will start listening for disconnects. - """ - - # Functools.wraps is required for this wrapper to appear to fastapi as a - # normal route handler, with the correct request type hinting. - @functools.wraps(handler_func) - async def wrapper(*args, **kwargs): - - # The request is either the second positional arg or `raw_request` - request = args[1] if len(args) > 1 else kwargs["raw_request"] - - handler_task = asyncio.create_task(handler_func(*args, **kwargs)) - cancellation_task = asyncio.create_task(listen_for_disconnect(request)) - - done, pending = await asyncio.wait([handler_task, cancellation_task], - return_when=asyncio.FIRST_COMPLETED) - for task in pending: - task.cancel() - - if handler_task in done: - return handler_task.result() - return None - - return wrapper diff --git a/vllm/utils.py b/vllm/utils.py index 38c7dea6d2d3..73d2ae25f15c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import uuid import warnings import weakref -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field @@ -370,23 +370,72 @@ def _next_task(iterator: AsyncGenerator[T, None], return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] +async def iterate_with_cancellation( + iterator: AsyncGenerator[T, None], + is_cancelled: Callable[[], Awaitable[bool]], +) -> AsyncGenerator[T, None]: + """Convert async iterator into one that polls the provided function + at least once per second to check for client cancellation. + """ + + loop = asyncio.get_running_loop() + + awaits: List[Future[T]] = [_next_task(iterator, loop)] + next_cancel_check: float = 0 + while True: + done, pending = await asyncio.wait(awaits, timeout=1.5) + + # Check for cancellation at most once per second + time_now = time.time() + if time_now >= next_cancel_check: + if await is_cancelled(): + with contextlib.suppress(BaseException): + awaits[0].cancel() + await iterator.aclose() + raise asyncio.CancelledError("client cancelled") + next_cancel_check = time_now + 1 + + if done: + try: + item = await awaits[0] + awaits[0] = _next_task(iterator, loop) + yield item + except StopAsyncIteration: + # we are done + return + + async def merge_async_iterators( - *iterators: AsyncGenerator[T, - None], ) -> AsyncGenerator[Tuple[int, T], None]: + *iterators: AsyncGenerator[T, None], + is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, +) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. + + It also optionally polls a provided function at least once per second + to check for client cancellation. """ loop = asyncio.get_running_loop() awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} + timeout = None if is_cancelled is None else 1.5 + next_cancel_check: float = 0 try: while awaits: - done, _ = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED) + done, pending = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED, + timeout=timeout) + if is_cancelled is not None: + # Check for cancellation at most once per second + time_now = time.time() + if time_now >= next_cancel_check: + if await is_cancelled(): + raise asyncio.CancelledError("client cancelled") + next_cancel_check = time_now + 1 for d in done: pair = awaits.pop(d) try: