diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b04811..72943ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,14 +3,15 @@ name: CI on: pull_request: push: - branches: [ main ] + branches: + - main jobs: - lint-test-docs: + lint-test: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ['3.10', '3.13', '3.14'] + python-version: ['3.10', '3.14'] os: [ubuntu-latest, macos-latest, windows-latest] steps: @@ -38,35 +39,3 @@ jobs: with: name: coverage-html-report-${{ matrix.os }}-${{ matrix.python-version }} path: coverage_html_report/ - - deploy-docs: - runs-on: ubuntu-latest - needs: lint-test-docs - # TODO(preview): deploy on releases only - permissions: - contents: read - pages: write - id-token: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v6 - with: - python-version: '3.10' - - - name: Install dependencies - run: uv sync - - - name: Build API docs - run: uv run poe docs - - - name: Upload docs to GitHub Pages - uses: actions/upload-pages-artifact@v3 - with: - path: apidocs - - - name: Deploy to GitHub Pages - uses: actions/deploy-pages@v4 diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000..ddff9d6 --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,37 @@ +name: Deploy Docs + +on: + push: + branches: + - main + +jobs: + deploy-docs: + runs-on: ubuntu-latest + permissions: + contents: read + pages: write + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: '3.10' + + - name: Install dependencies + run: uv sync + + - name: Build API docs + run: uv run poe docs + + - name: Upload docs to GitHub Pages + uses: actions/upload-pages-artifact@v3 + with: + path: apidocs + + - name: Deploy to GitHub Pages + uses: actions/deploy-pages@v4 diff --git a/src/nexusrpc/__init__.py b/src/nexusrpc/__init__.py index bc02763..595ce5c 100644 --- a/src/nexusrpc/__init__.py +++ b/src/nexusrpc/__init__.py @@ -25,7 +25,7 @@ OperationErrorState, OutputT, ) -from ._serializer import Content, LazyValue +from ._serializer import Content, LazyValue, LazyValueT, Serializer from ._service import Operation, OperationDefinition, ServiceDefinition, service from ._util import ( get_operation, @@ -42,12 +42,14 @@ "HandlerErrorType", "InputT", "LazyValue", + "LazyValueT", "Link", "Operation", "OperationDefinition", "OperationError", "OperationErrorState", "OutputT", + "Serializer", "service", "ServiceDefinition", "set_operation", diff --git a/src/nexusrpc/handler/__init__.py b/src/nexusrpc/handler/__init__.py index 9b9fdf2..40567d2 100644 --- a/src/nexusrpc/handler/__init__.py +++ b/src/nexusrpc/handler/__init__.py @@ -18,16 +18,19 @@ StartOperationResultAsync, StartOperationResultSync, ) -from ._core import Handler as Handler +from ._core import Handler, OperationHandlerMiddleware from ._decorators import operation_handler, service_handler, sync_operation -from ._operation_handler import OperationHandler as OperationHandler +from ._operation_handler import MiddlewareSafeOperationHandler, OperationHandler __all__ = [ + "MiddlewareSafeOperationHandler", "CancelOperationContext", "Handler", "OperationContext", "OperationHandler", "OperationTaskCancellation", + "OperationHandlerMiddleware", + "operation_handler", "service_handler", "StartOperationContext", "StartOperationResultAsync", diff --git a/src/nexusrpc/handler/_core.py b/src/nexusrpc/handler/_core.py index c39aa57..eb3b01a 100644 --- a/src/nexusrpc/handler/_core.py +++ b/src/nexusrpc/handler/_core.py @@ -102,7 +102,7 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast from typing_extensions import Self, TypeGuard @@ -113,11 +113,13 @@ from ._common import ( CancelOperationContext, + OperationContext, StartOperationContext, StartOperationResultAsync, StartOperationResultSync, ) from ._operation_handler import ( + MiddlewareSafeOperationHandler, OperationHandler, collect_operation_handler_factories_by_method_name, ) @@ -248,7 +250,9 @@ def __init__( self, user_service_handlers: Sequence[Any], executor: Optional[concurrent.futures.Executor] = None, + middleware: Sequence[OperationHandlerMiddleware] | None = None, ): + self._middleware = cast(Sequence[OperationHandlerMiddleware], middleware or []) super().__init__(user_service_handlers, executor=executor) if not self.executor: self._validate_all_operation_handlers_are_async() @@ -268,17 +272,11 @@ async def start_operation( input: The input to the operation, as a LazyValue. """ service_handler = self._get_service_handler(ctx.service) - op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage] + op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation) + op_defn = service_handler.service.operation_definitions[ctx.operation] deserialized_input = await input.consume(as_type=op_defn.input_type) - # TODO(preview): apply middleware stack - if is_async_callable(op_handler.start): - return await op_handler.start(ctx, deserialized_input) - else: - assert self.executor - return await self.executor.submit_to_event_loop( - op_handler.start, ctx, deserialized_input - ) + return await op_handler.start(ctx, deserialized_input) async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> None: """Handle a Cancel Operation request. @@ -288,12 +286,23 @@ async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> Non token: The operation token. """ service_handler = self._get_service_handler(ctx.service) - op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage] - if is_async_callable(op_handler.cancel): - return await op_handler.cancel(ctx, token) - else: - assert self.executor - return self.executor.submit(op_handler.cancel, ctx, token).result() + op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation) + return await op_handler.cancel(ctx, token) + + def _get_operation_handler( + self, ctx: OperationContext, service_handler: ServiceHandler, operation: str + ) -> MiddlewareSafeOperationHandler: + """ + Get the specified handler for the specified operation from the given service_handler and apply all middleware. + """ + op_handler: MiddlewareSafeOperationHandler = _EnsuredAwaitableOperationHandler( + self.executor, service_handler.get_operation_handler(operation) + ) + + for middleware in reversed(self._middleware): + op_handler = middleware.intercept(ctx, op_handler) + + return op_handler def _validate_all_operation_handlers_are_async(self) -> None: for service_handler in self.service_handlers.values(): @@ -360,7 +369,7 @@ def from_user_instance(cls, user_instance: Any) -> Self: operation_handlers=op_handlers, ) - def _get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]: + def get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]: """Return an operation handler, given the operation name.""" if operation_name not in self.service.operation_definitions: raise HandlerError( @@ -401,3 +410,70 @@ def submit( self, fn: Callable[..., Any], *args: Any ) -> concurrent.futures.Future[Any]: return self._executor.submit(fn, *args) + + +class OperationHandlerMiddleware(ABC): + """ + Middleware for operation handlers. + + This should be extended by any operation handler middelware. + """ + + @abstractmethod + def intercept( + self, + ctx: OperationContext, # type: ignore[reportUnusedParameter] + next: MiddlewareSafeOperationHandler, + ) -> MiddlewareSafeOperationHandler: + """ + Method called for intercepting operation handlers. + + Args: + ctx: The :py:class:`OperationContext` that will be passed to the operation handler. + next: The underlying operation handler that this middleware + should delegate to. + + Returns: + The new middleware that will be used to invoke + :py:attr:`OperationHandler.start` or :py:attr:`OperationHandler.cancel`. + """ + ... + + +class _EnsuredAwaitableOperationHandler(MiddlewareSafeOperationHandler): + """ + An :py:class:`AwaitableOperationHandler` that wraps an :py:class:`OperationHandler` and uses an :py:class:`_Executor` to ensure + that the :py:attr:`start` and :py:attr:`cancel` methods are awaitable. + """ + + def __init__( + self, + executor: _Executor | None, + op_handler: OperationHandler[Any, Any], + ): + self._executor = executor + self._op_handler = op_handler + + async def start( + self, ctx: StartOperationContext, input: Any + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + """ + Start the operation using the wrapped :py:class:`OperationHandler`. + """ + if is_async_callable(self._op_handler.start): + return await self._op_handler.start(ctx, input) + else: + assert self._executor + return await self._executor.submit_to_event_loop( + self._op_handler.start, ctx, input + ) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + """ + Cancel an operation using the wrapped :py:class:`OperationHandler`. + """ + if is_async_callable(self._op_handler.cancel): + return await self._op_handler.cancel(ctx, token) + else: + assert self._executor + return self._executor.submit(self._op_handler.cancel, ctx, token).result() diff --git a/src/nexusrpc/handler/_operation_handler.py b/src/nexusrpc/handler/_operation_handler.py index 02a55b9..5a9c16a 100644 --- a/src/nexusrpc/handler/_operation_handler.py +++ b/src/nexusrpc/handler/_operation_handler.py @@ -3,7 +3,7 @@ import inspect from abc import ABC, abstractmethod from collections.abc import Awaitable -from typing import Any, Callable, Generic, Optional, Union +from typing import Any, Callable, Generic, Optional from nexusrpc._common import InputT, OutputT, ServiceHandlerT from nexusrpc._service import Operation, OperationDefinition, ServiceDefinition @@ -39,12 +39,11 @@ class OperationHandler(ABC, Generic[InputT, OutputT]): @abstractmethod def start( self, ctx: StartOperationContext, input: InputT - ) -> Union[ - StartOperationResultSync[OutputT], - Awaitable[StartOperationResultSync[OutputT]], - StartOperationResultAsync, - Awaitable[StartOperationResultAsync], - ]: + ) -> ( + StartOperationResultSync[OutputT] + | StartOperationResultAsync + | Awaitable[StartOperationResultSync[OutputT] | StartOperationResultAsync] + ): """ Start the operation, completing either synchronously or asynchronously. @@ -54,9 +53,7 @@ def start( ... @abstractmethod - def cancel( - self, ctx: CancelOperationContext, token: str - ) -> Union[None, Awaitable[None]]: + def cancel(self, ctx: CancelOperationContext, token: str) -> None | Awaitable[None]: """ Cancel the operation. """ @@ -104,6 +101,31 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: ) +class MiddlewareSafeOperationHandler(OperationHandler[Any, Any], ABC): + """ + An :py:class:`OperationHandler` where :py:attr:`start` and :py:attr:`cancel` + can be awaited by an async runtime. It can produce a result synchronously by returning + :py:class:`StartOperationResultSync` or asynchronously by returning :py:class:`StartOperationResultAsync` + in the same fashion that :py:class:`OperationHandler` does. + """ + + @abstractmethod + async def start( + self, ctx: StartOperationContext, input: Any + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + """ + Start the operation and return it's result or an async token. + """ + ... + + @abstractmethod + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + """ + Cancel an in progress operation identified by the given token. + """ + ... + + def collect_operation_handler_factories_by_method_name( user_service_cls: type[ServiceHandlerT], service: Optional[ServiceDefinition], diff --git a/tests/handler/test_async_operation.py b/tests/handler/test_async_operation.py index bbb99a8..867de88 100644 --- a/tests/handler/test_async_operation.py +++ b/tests/handler/test_async_operation.py @@ -9,9 +9,9 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, + operation_handler, service_handler, ) -from nexusrpc.handler._decorators import operation_handler from tests.helpers import DummySerializer, TestOperationTaskCancellation _operation_results: dict[str, int] = {} diff --git a/tests/handler/test_middleware.py b/tests/handler/test_middleware.py new file mode 100644 index 0000000..f5636b3 --- /dev/null +++ b/tests/handler/test_middleware.py @@ -0,0 +1,208 @@ +import concurrent.futures +import logging +import uuid +from typing import Any + +import pytest + +from nexusrpc import LazyValue +from nexusrpc.handler import ( + CancelOperationContext, + Handler, + MiddlewareSafeOperationHandler, + OperationContext, + OperationHandler, + OperationHandlerMiddleware, + StartOperationContext, + StartOperationResultAsync, + StartOperationResultSync, + operation_handler, + service_handler, + sync_operation, +) +from tests.helpers import DummySerializer, TestOperationTaskCancellation + +_operation_results: dict[str, int] = {} + +logger = logging.getLogger() + + +class MyAsyncOperationHandler(OperationHandler[int, int]): + async def start( + self, ctx: StartOperationContext, input: int + ) -> StartOperationResultAsync: + token = str(uuid.uuid4()) + _operation_results[token] = input + 1 + return StartOperationResultAsync(token) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + del _operation_results[token] + + +@service_handler +class MyService: + @operation_handler + def incr(self) -> OperationHandler[int, int]: + return MyAsyncOperationHandler() + + +@service_handler +class MyServiceSync: + @sync_operation + def incr(self, ctx: StartOperationContext, input: int) -> int: # type: ignore[reportUnusedParameter] + return input + 1 + + +class CountingMiddleware(OperationHandlerMiddleware): + def __init__(self) -> None: + self.num_start = 0 + self.num_cancel = 0 + + def intercept( + self, ctx: OperationContext, next: MiddlewareSafeOperationHandler + ) -> MiddlewareSafeOperationHandler: + return CountingOperationHandler(next, self) + + +class CountingOperationHandler(MiddlewareSafeOperationHandler): + """ + An :py:class:`AwaitableOperationHandler` that wraps a counting middleware + that counts the number of calls to each handler method. + """ + + def __init__( + self, + next: MiddlewareSafeOperationHandler, + middleware: CountingMiddleware, + ) -> None: + self._next = next + self._middleware = middleware + + async def start( + self, ctx: StartOperationContext, input: Any + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + self._middleware.num_start += 1 + return await self._next.start(ctx, input) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self._middleware.num_cancel += 1 + return await self._next.cancel(ctx, token) + + +class MustBeFirstMiddleware(OperationHandlerMiddleware): + def __init__(self, counter: CountingMiddleware) -> None: + self._counter = counter + + def intercept( + self, ctx: OperationContext, next: MiddlewareSafeOperationHandler + ) -> MiddlewareSafeOperationHandler: + return MustBeFirstOperationHandler(next, self._counter) + + +class MustBeFirstOperationHandler(MiddlewareSafeOperationHandler): + """ + An :py:class:`AwaitableOperationHandler` that wraps a counting middleware + and asserts that the wrapped middleware has a count of 0 for each handler method + """ + + def __init__( + self, + next: MiddlewareSafeOperationHandler, + counter: CountingMiddleware, + ) -> None: + self._next = next + self._counter = counter + + async def start( + self, ctx: StartOperationContext, input: Any + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + assert self._counter.num_start == 0 + logger.info("%s.%s: start operation", ctx.service, ctx.operation) + + result = await self._next.start(ctx, input) + + if isinstance(result, StartOperationResultAsync): + logger.info( + "%s.%s: start operation completed async. token=%s", + ctx.service, + ctx.operation, + result.token, + ) + else: + logger.info( + "%s.%s: start operation completed sync. value=%s", + ctx.service, + ctx.operation, + result.value, + ) + + return result + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + assert self._counter.num_cancel == 0 + logger.info("%s.%s: cancel token=%s", ctx.service, ctx.operation, token) + return await self._next.cancel(ctx, token) + + +@pytest.mark.asyncio +async def test_async_operation_middleware_applied(): + counting_middleware = CountingMiddleware() + handler = Handler( + user_service_handlers=[MyService()], + middleware=[ + MustBeFirstMiddleware(counting_middleware), + counting_middleware, + ], + ) + start_ctx = StartOperationContext( + service="MyService", + operation="incr", + headers={}, + request_id="request_id", + task_cancellation=TestOperationTaskCancellation(), + ) + start_result = await handler.start_operation( + start_ctx, LazyValue(DummySerializer(1), headers={}) + ) + assert isinstance(start_result, StartOperationResultAsync) + assert start_result.token + + cancel_ctx = CancelOperationContext( + service="MyService", + operation="incr", + headers={}, + task_cancellation=TestOperationTaskCancellation(), + ) + await handler.cancel_operation(cancel_ctx, start_result.token) + assert start_result.token not in _operation_results + + assert counting_middleware.num_start == 1 + assert counting_middleware.num_cancel == 1 + + +@pytest.mark.asyncio +async def test_sync_operation_middleware_applied(): + counting_middleware = CountingMiddleware() + handler = Handler( + user_service_handlers=[MyServiceSync()], + executor=concurrent.futures.ThreadPoolExecutor(), + middleware=[ + MustBeFirstMiddleware(counting_middleware), + counting_middleware, + ], + ) + start_ctx = StartOperationContext( + service="MyServiceSync", + operation="incr", + headers={}, + request_id="request_id", + task_cancellation=TestOperationTaskCancellation(), + ) + start_result = await handler.start_operation( + start_ctx, LazyValue(DummySerializer(1), headers={}) + ) + assert isinstance(start_result, StartOperationResultSync) + assert start_result.value == 2 + + assert counting_middleware.num_start == 1 + assert counting_middleware.num_cancel == 0