From f9fdb883ca266e962221cd6ab7c65a9890990bc4 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 17 Nov 2025 14:44:16 -0800 Subject: [PATCH 01/10] use middleware branch of nexus-rpc --- pyproject.toml | 2 +- uv.lock | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9832c8cf0..b7c0b77d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc==1.2.0", + "nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@interceptors", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", diff --git a/uv.lock b/uv.lock index 8b1ee82a8..09e9d15ad 100644 --- a/uv.lock +++ b/uv.lock @@ -1761,14 +1761,10 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.2.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors#e8806579a7050fc076bb14861aeec7208d521534" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/50/95d7bc91f900da5e22662c82d9bf0f72a4b01f2a552708bf2f43807707a1/nexus_rpc-1.2.0.tar.gz", hash = "sha256:b4ddaffa4d3996aaeadf49b80dfcdfbca48fe4cb616defaf3b3c5c2c8fc61890", size = 74142, upload-time = "2025-11-17T19:17:06.798Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/13/04/eaac430d0e6bf21265ae989427d37e94be5e41dc216879f1fbb6c5339942/nexus_rpc-1.2.0-py3-none-any.whl", hash = "sha256:977876f3af811ad1a09b2961d3d1ac9233bda43ff0febbb0c9906483b9d9f8a3", size = 28166, upload-time = "2025-11-17T19:17:05.64Z" }, -] [[package]] name = "nh3" @@ -3007,7 +3003,7 @@ dev = [ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, - { name = "nexus-rpc", specifier = "==1.2.0" }, + { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.5" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, From 32a26ee0f7d61bcf2b15e19b7ddc78d99bbadd61 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 18 Nov 2025 17:53:16 -0800 Subject: [PATCH 02/10] First draft of nexus interceptors and otel support --- temporalio/contrib/opentelemetry.py | 229 +++++++++++++++++++++++++--- temporalio/worker/__init__.py | 6 + temporalio/worker/_interceptor.py | 108 +++++++++++++ temporalio/worker/_nexus.py | 9 +- tests/contrib/test_opentelemetry.py | 99 ++++++++++++ tests/worker/test_interceptor.py | 83 ++++++++-- 6 files changed, 493 insertions(+), 41 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 7dfd920ef..298c9ffc5 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses from contextlib import contextmanager from dataclasses import dataclass from typing import ( @@ -14,9 +15,11 @@ Optional, Sequence, Type, + TypeVar, cast, ) +import nexusrpc.handler import opentelemetry.baggage.propagation import opentelemetry.context import opentelemetry.context.context @@ -25,6 +28,7 @@ import opentelemetry.trace import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types +from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync from opentelemetry.context import Context from opentelemetry.trace import Status, StatusCode from typing_extensions import Protocol, TypeAlias, TypedDict @@ -135,6 +139,11 @@ def workflow_interceptor_class( ) return TracingWorkflowInboundInterceptor + def intercept_nexus_operation( + self, next: temporalio.worker.NexusOperationInboundInterceptor + ) -> temporalio.worker.NexusOperationInboundInterceptor: + return _TracingNexusOperationInboundInterceptor(next, self) + def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: @@ -201,6 +210,45 @@ def _start_as_current_span( if token and context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) + @contextmanager + def _start_as_current_span_nexus( + self, + name: str, + *, + attributes: opentelemetry.util.types.Attributes, + input_headers: Mapping[str, str], + kind: opentelemetry.trace.SpanKind, + context: Optional[Context] = None, + ) -> Iterator[_CarrierDict]: + token = opentelemetry.context.attach(context) if context else None + try: + with self.tracer.start_as_current_span( + name, + attributes=attributes, + kind=kind, + context=context, + set_status_on_exception=False, + ) as span: + new_headers: _CarrierDict = {**input_headers} + self.text_map_propagator.inject(new_headers) + try: + yield new_headers + except Exception as exc: + if ( + not isinstance(exc, ApplicationError) + or exc.category != ApplicationErrorCategory.BENIGN + ): + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + raise + finally: + if token and context is opentelemetry.context.get_current(): + opentelemetry.context.detach(token) + def _completed_workflow_span( self, params: _CompletedWorkflowSpanParams ) -> Optional[_CarrierDict]: @@ -347,6 +395,74 @@ async def execute_activity( return await super().execute_activity(input) +class _NexusTracing: + _ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext) + + # TODO(amazzeo): not sure what to do if value happens to be a list + # _CarrierDict represents http headers Map[str, List[str] | str] + # but nexus headers are just Map[str, str] + def _carrier_to_nexus_headers( + self, carrier: _CarrierDict, initial: Mapping[str, str] | None = None + ) -> Mapping[str, str]: + out = {**initial} if initial else {} + for k, v in carrier.items(): + if isinstance(v, list): + out[k] = ",".join(v) + else: + out[k] = v + return out + + def _operation_ctx_with_carrier( + self, ctx: _ContextT, carrier: _CarrierDict + ) -> _ContextT: + return dataclasses.replace( + ctx, headers=self._carrier_to_nexus_headers(carrier, ctx.headers) + ) + + +class _TracingNexusOperationInboundInterceptor( + temporalio.worker.NexusOperationInboundInterceptor, _NexusTracing +): + def __init__( + self, + next: temporalio.worker.NexusOperationInboundInterceptor, + root: TracingInterceptor, + ) -> None: + self._next = next + self._root = root + + def _context_from_nexus_headers(self, headers: Mapping[str, str]): + return self._root.text_map_propagator.extract(headers) + + async def start_operation( + self, input: temporalio.worker.NexusOperationStartInput + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + with self._root._start_as_current_span_nexus( + f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", + context=self._context_from_nexus_headers(input.ctx.headers), + attributes={ + "temporalNexusRequestId": input.ctx.request_id, + }, + input_headers=input.ctx.headers, + kind=opentelemetry.trace.SpanKind.SERVER, + ) as new_headers: + input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers) + return await self._next.start_operation(input) + + async def cancel_operation( + self, input: temporalio.worker.NexusOperationCancelInput + ) -> None: + with self._root._start_as_current_span_nexus( + f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", + context=self._context_from_nexus_headers(input.ctx.headers), + attributes={}, + input_headers=input.ctx.headers, + kind=opentelemetry.trace.SpanKind.SERVER, + ) as new_headers: + input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers) + return await self._next.cancel_operation(input) + + class _InputWithHeaders(Protocol): headers: Mapping[str, temporalio.api.common.v1.Payload] @@ -417,7 +533,7 @@ async def execute_workflow( """ with self._top_level_workflow_context(success_is_complete=True): # Entrypoint of workflow should be `server` in OTel - self._completed_span( + self._completed_span_grpc( f"RunWorkflow:{temporalio.workflow.info().workflow_type}", kind=opentelemetry.trace.SpanKind.SERVER, ) @@ -436,7 +552,7 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span( + self._completed_span_grpc( f"HandleSignal:{input.signal}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -468,7 +584,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: token = opentelemetry.context.attach(context) try: # This won't be created if there was no context header - self._completed_span( + self._completed_span_grpc( f"HandleQuery:{input.query}", link_context_carrier=link_context_carrier, # Create even on replay for queries @@ -497,7 +613,7 @@ def handle_update_validator( [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span( + self._completed_span_grpc( f"ValidateUpdate:{input.update}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -517,7 +633,7 @@ async def handle_update_handler( [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span( + self._completed_span_grpc( f"HandleUpdate:{input.update}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -566,7 +682,7 @@ def _top_level_workflow_context( finally: # Create a completed span before detaching context if exception or (success and success_is_complete): - self._completed_span( + self._completed_span_grpc( f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", exception=exception, kind=opentelemetry.trace.SpanKind.INTERNAL, @@ -598,7 +714,7 @@ def _context_carrier_to_headers( } return headers - def _completed_span( + def _completed_span_grpc( self, span_name: str, *, @@ -609,6 +725,56 @@ def _completed_span( exception: Optional[Exception] = None, kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, ) -> None: + updated_context_carrier = self._completed_span( + span_name=span_name, + link_context_carrier=link_context_carrier, + new_span_even_on_replay=new_span_even_on_replay, + additional_attributes=additional_attributes, + exception=exception, + kind=kind, + ) + + # Add to outbound if needed + if add_to_outbound and updated_context_carrier: + add_to_outbound.headers = self._context_carrier_to_headers( + updated_context_carrier, add_to_outbound.headers + ) + + def _completed_span_nexus( + self, + span_name: str, + *, + outbound_headers: Mapping[str, str], + link_context_carrier: Optional[_CarrierDict] = None, + new_span_even_on_replay: bool = False, + additional_attributes: opentelemetry.util.types.Attributes = None, + exception: Optional[Exception] = None, + kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, + ) -> _CarrierDict | None: + new_carrier = self._completed_span( + span_name=span_name, + link_context_carrier=link_context_carrier, + new_span_even_on_replay=new_span_even_on_replay, + additional_attributes=additional_attributes, + exception=exception, + kind=kind, + ) + + if new_carrier: + return {**outbound_headers, **new_carrier} + else: + return {**outbound_headers} + + def _completed_span( + self, + span_name: str, + *, + link_context_carrier: Optional[_CarrierDict] = None, + new_span_even_on_replay: bool = False, + additional_attributes: opentelemetry.util.types.Attributes = None, + exception: Optional[Exception] = None, + kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, + ) -> _CarrierDict | None: # If we are replaying and they don't want a span on replay, no span if temporalio.workflow.unsafe.is_replaying() and not new_span_even_on_replay: return None @@ -616,12 +782,18 @@ def _completed_span( # Create the span. First serialize current context to carrier. new_context_carrier: _CarrierDict = {} self.text_map_propagator.inject(new_context_carrier) + # Invoke - info = temporalio.workflow.info() - attributes: Dict[str, opentelemetry.util.types.AttributeValue] = { - "temporalWorkflowID": info.workflow_id, - "temporalRunID": info.run_id, - } + # TODO(amazzeo): I think this try/except is necessary once non-workflow callers + # are added to Nexus + attributes: Dict[str, opentelemetry.util.types.AttributeValue] = {} + try: + info = temporalio.workflow.info() + attributes["temporalWorkflowID"] = info.workflow_id + attributes["temporalRunID"] = info.run_id + except temporalio.exceptions.TemporalError: + pass + if additional_attributes: attributes.update(additional_attributes) updated_context_carrier = self._extern_functions[ @@ -641,11 +813,7 @@ def _completed_span( ) ) - # Add to outbound if needed - if add_to_outbound and updated_context_carrier: - add_to_outbound.headers = self._context_carrier_to_headers( - updated_context_carrier, add_to_outbound.headers - ) + return updated_context_carrier def _set_on_context( self, context: opentelemetry.context.Context @@ -654,7 +822,7 @@ def _set_on_context( class _TracingWorkflowOutboundInterceptor( - temporalio.worker.WorkflowOutboundInterceptor + temporalio.worker.WorkflowOutboundInterceptor, _NexusTracing ): def __init__( self, @@ -673,7 +841,7 @@ async def signal_child_workflow( self, input: temporalio.worker.SignalChildWorkflowInput ) -> None: # Create new span and put on outbound input - self.root._completed_span( + self.root._completed_span_grpc( f"SignalChildWorkflow:{input.signal}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.SERVER, @@ -684,7 +852,7 @@ async def signal_external_workflow( self, input: temporalio.worker.SignalExternalWorkflowInput ) -> None: # Create new span and put on outbound input - self.root._completed_span( + self.root._completed_span_grpc( f"SignalExternalWorkflow:{input.signal}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -695,7 +863,7 @@ def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: # Create new span and put on outbound input - self.root._completed_span( + self.root._completed_span_grpc( f"StartActivity:{input.activity}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -706,7 +874,7 @@ async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: # Create new span and put on outbound input - self.root._completed_span( + self.root._completed_span_grpc( f"StartChildWorkflow:{input.workflow}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -717,13 +885,26 @@ def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: # Create new span and put on outbound input - self.root._completed_span( + self.root._completed_span_grpc( f"StartActivity:{input.activity}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, ) return super().start_local_activity(input) + async def start_nexus_operation( + self, input: temporalio.worker.StartNexusOperationInput[Any, Any] + ) -> temporalio.workflow.NexusOperationHandle[Any]: + new_carrier = self.root._completed_span_nexus( + f"StartNexusOperation:{input.service}/{input.operation_name}", + kind=opentelemetry.trace.SpanKind.CLIENT, + outbound_headers=input.headers if input.headers else {}, + ) + if new_carrier: + input.headers = self._carrier_to_nexus_headers(new_carrier, input.headers) + + return await super().start_nexus_operation(input) + class workflow: """Contains static methods that are safe to call from within a workflow. @@ -760,6 +941,6 @@ def completed_span( """ interceptor = TracingWorkflowInboundInterceptor._from_context() if interceptor: - interceptor._completed_span( + interceptor._completed_span_grpc( name, additional_attributes=attributes, exception=exception ) diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 1d7b2558e..9d0a3a47e 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -11,6 +11,9 @@ HandleSignalInput, HandleUpdateInput, Interceptor, + NexusOperationCancelInput, + NexusOperationInboundInterceptor, + NexusOperationStartInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, @@ -80,6 +83,7 @@ "ActivityOutboundInterceptor", "WorkflowInboundInterceptor", "WorkflowOutboundInterceptor", + "NexusOperationInboundInterceptor", "Plugin", # Interceptor input "ContinueAsNewInput", @@ -95,6 +99,8 @@ "StartLocalActivityInput", "StartNexusOperationInput", "WorkflowInterceptorClassInput", + "NexusOperationStartInput", + "NexusOperationCancelInput", # Advanced activity classes "SharedStateManager", "SharedHeartbeatSender", diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 7119b0665..beda27535 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Mapping, MutableMapping from dataclasses import dataclass from datetime import timedelta +from functools import reduce from typing import ( Any, Awaitable, @@ -16,6 +17,7 @@ Sequence, Type, Union, + cast, ) import nexusrpc.handler @@ -68,6 +70,112 @@ def workflow_interceptor_class( """ return None + def intercept_nexus_operation( + self, next: NexusOperationInboundInterceptor + ) -> NexusOperationInboundInterceptor: + """Method called for intercepting a Nexus operation. + + Args: + next: The underlying inbound this interceptor + should delegate to. + + Returns: + The new interceptor that should be used for the Nexus operation. + """ + return next + + +@dataclass +class NexusOperationStartInput: + ctx: nexusrpc.handler.StartOperationContext + input: Any + + +@dataclass +class NexusOperationCancelInput: + ctx: nexusrpc.handler.CancelOperationContext + token: str + + +class NexusOperationInboundInterceptor: + def __init__(self, next: NexusOperationInboundInterceptor) -> None: + self.next = next + + async def start_operation( + self, input: NexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self.next.start_operation(input) + + async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + return await self.next.cancel_operation(input) + + +class _NexusOperationHandlerForInterceptor( + nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any] +): + def __init__(self, next_interceptor: NexusOperationInboundInterceptor): + self._next_interceptor = next_interceptor + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: Any + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._next_interceptor.start_operation( + NexusOperationStartInput(ctx, input) + ) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + return await self._next_interceptor.cancel_operation( + NexusOperationCancelInput(ctx, token) + ) + + +class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): + def __init__( + self, + handler: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], + ): + self._handler = handler + + async def start_operation( + self, input: NexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._handler.start(input.ctx, input.input) + + async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + return await self._handler.cancel(input.ctx, input.token) + + +class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): + def __init__(self, interceptors: Sequence[Interceptor]) -> None: + self._interceptors = interceptors + + def intercept( + self, + ctx: nexusrpc.handler.OperationContext, + next: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], + ) -> nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any]: + inbound = reduce( + lambda impl, _next: _next.intercept_nexus_operation(impl), + reversed(self._interceptors), + cast( + NexusOperationInboundInterceptor, + _NexusOperationInboundInterceptorImpl(next), + ), + ) + + return _NexusOperationHandlerForInterceptor(inbound) + @dataclass(frozen=True) class WorkflowInterceptorClassInput: diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 1083cc620..36ec74a6b 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -40,7 +40,7 @@ from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode -from ._interceptor import Interceptor +from ._interceptor import Interceptor, _NexusMiddlewareForInterceptors _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" @@ -73,10 +73,11 @@ def __init__( self._bridge_worker = bridge_worker self._client = client self._task_queue = task_queue - self._handler = Handler(service_handlers, executor) + + middleware = _NexusMiddlewareForInterceptors(interceptors) + + self._handler = Handler(service_handlers, executor, interceptors=[middleware]) self._data_converter = data_converter - # TODO(nexus-preview): interceptors - self._interceptors = interceptors # TODO(nexus-preview): metric_meter self._metric_meter = metric_meter self._running_tasks: dict[bytes, _RunningNexusTask] = {} diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index fb4759be9..27d7cebff 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -12,6 +12,7 @@ from datetime import timedelta from typing import Callable, Dict, Generator, Iterable, List, Optional, cast +import nexusrpc import opentelemetry.context import pytest from opentelemetry import baggage, context @@ -32,6 +33,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @dataclass @@ -62,6 +64,7 @@ class TracingWorkflowAction: wait_until_signal_count: int = 0 wait_and_do_update: bool = False wait_and_do_start_with_update: bool = False + start_and_cancel_nexus_operation: bool = False @dataclass @@ -85,6 +88,25 @@ class TracingWorkflowActionContinueAsNew: param: TracingWorkflowParam +class InterceptedOperationHandler(nexusrpc.handler.OperationHandler[str, str]): + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: str + ) -> nexusrpc.handler.StartOperationResultAsync: + return nexusrpc.handler.StartOperationResultAsync(input) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + pass + + +@nexusrpc.handler.service_handler +class InterceptedNexusService: + @nexusrpc.handler.operation_handler + def intercepted_operation(self) -> nexusrpc.handler.OperationHandler[str, str]: + return InterceptedOperationHandler() + + ready_for_update: asyncio.Semaphore ready_for_update_with_start: asyncio.Semaphore @@ -152,6 +174,24 @@ async def run(self, param: TracingWorkflowParam) -> None: if action.wait_and_do_start_with_update: ready_for_update_with_start.release() await workflow.wait_condition(lambda: self._did_update_with_start) + if action.start_and_cancel_nexus_operation: + nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=InterceptedNexusService, + ) + + nexus_handle = await nexus_client.start_operation( + operation=InterceptedNexusService.intercepted_operation, + input="hello", + ) + nexus_handle.cancel() + + # in order for the cancel to make progress, the handle must be awaited + # but it hangs indefinitely, so I'm using this temp workaround + try: + await asyncio.wait_for(asyncio.shield(nexus_handle), 0.1) + except asyncio.TimeoutError: + pass async def _raise_on_non_replay(self) -> None: replaying = workflow.unsafe.is_replaying() @@ -410,6 +450,65 @@ async def test_opentelemetry_tracing_update_with_start( ] +async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1424" + ) + global ready_for_update_with_start + ready_for_update_with_start = asyncio.Semaphore(0) + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + # Create new client with tracer interceptor + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer)] + client = Client(**client_config) + + task_queue = f"task-queue-{uuid.uuid4()}" + await create_nexus_endpoint(task_queue, client) + async with Worker( + client, + task_queue=task_queue, + workflows=[TracingWorkflow], + activities=[tracing_activity], + nexus_service_handlers=[InterceptedNexusService()], + # Needed so we can wait to send update at the right time + workflow_runner=UnsandboxedWorkflowRunner(), + ): + # Run workflow with various actions + workflow_id = f"workflow_{uuid.uuid4()}" + workflow_params = TracingWorkflowParam( + actions=[ + TracingWorkflowAction(start_and_cancel_nexus_operation=True), + ] + ) + handle = await client.start_workflow( + TracingWorkflow.run, + workflow_params, + id=workflow_id, + task_queue=task_queue, + ) + await handle.result() + + # Dump debug with attributes, but do string assertion test without + logging.debug( + "Spans:\n%s", + "\n".join(dump_spans(exporter.get_finished_spans(), with_attributes=False)), + ) + assert dump_spans(exporter.get_finished_spans(), with_attributes=False) == [ + "StartWorkflow:TracingWorkflow", + " RunWorkflow:TracingWorkflow", + " MyCustomSpan", + " StartNexusOperation:InterceptedNexusService/intercepted_operation", + " RunStartNexusOperationHandler:InterceptedNexusService/intercepted_operation", + " RunCancelNexusOperationHandler:InterceptedNexusService/intercepted_operation", + " CompleteWorkflow:TracingWorkflow", + ] + + def dump_spans( spans: Iterable[ReadableSpan], *, diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 1cb6cd25d..aa9521813 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -3,11 +3,18 @@ from datetime import timedelta from typing import Any, Callable, List, NoReturn, Optional, Tuple, Type +import nexusrpc import pytest +from nexusrpc.handler._common import ( + CancelOperationContext, + StartOperationContext, + StartOperationResultAsync, + StartOperationResultSync, +) from temporalio import activity, workflow from temporalio.client import Client, WorkflowUpdateFailedError -from temporalio.exceptions import ApplicationError, NexusOperationError +from temporalio.exceptions import ApplicationError from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( ActivityInboundInterceptor, @@ -19,6 +26,9 @@ HandleSignalInput, HandleUpdateInput, Interceptor, + NexusOperationCancelInput, + NexusOperationInboundInterceptor, + NexusOperationStartInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, @@ -46,6 +56,11 @@ def workflow_interceptor_class( ) -> Optional[Type[WorkflowInboundInterceptor]]: return TracingWorkflowInboundInterceptor + def intercept_nexus_operation( + self, next: NexusOperationInboundInterceptor + ) -> NexusOperationInboundInterceptor: + return TracingNexusInboundInterceptor(next) + class TracingActivityInboundInterceptor(ActivityInboundInterceptor): def init(self, outbound: ActivityOutboundInterceptor) -> None: @@ -133,6 +148,39 @@ async def start_nexus_operation( return await super().start_nexus_operation(input) +class TracingNexusInboundInterceptor(NexusOperationInboundInterceptor): + async def start_operation( + self, input: NexusOperationStartInput + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + interceptor_traces.append( + (f"nexus.start_operation.{input.ctx.service}.{input.ctx.operation}", input) + ) + return await super().start_operation(input) + + async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + interceptor_traces.append( + (f"nexus.cancel_operation.{input.ctx.service}.{input.ctx.operation}", input) + ) + return await super().cancel_operation(input) + + +class InterceptedOperationHandler(nexusrpc.handler.OperationHandler[str, str]): + async def start( + self, ctx: StartOperationContext, input: str + ) -> StartOperationResultAsync: + return StartOperationResultAsync(input) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + pass + + +@nexusrpc.handler.service_handler +class InterceptedNexusService: + @nexusrpc.handler.operation_handler + def intercepted_operation(self) -> nexusrpc.handler.OperationHandler[str, str]: + return InterceptedOperationHandler() + + @activity.defn async def intercepted_activity(param: str) -> str: if not activity.info().is_local: @@ -176,20 +224,20 @@ async def run(self, style: str) -> None: nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(workflow.info().task_queue), - service="non-existent-nexus-service", + service=InterceptedNexusService, ) + + nexus_handle = await nexus_client.start_operation( + operation=InterceptedNexusService.intercepted_operation, + input="hello", + ) + nexus_handle.cancel() + + # in order for the cancel to make progress, the handle must be awaited + # but it hangs indefinitely, so I'm using this temp workaround try: - await nexus_client.start_operation( - operation="non-existent-nexus-operation", - input={"test": "data"}, - schedule_to_close_timeout=timedelta(microseconds=1), - ) - raise Exception("unreachable") - except NexusOperationError: - # The test requires only that the workflow attempts to schedule the nexus operation. - # Instead of setting up a nexus service, we deliberately schedule a call to a - # non-existent nexus operation with an insufficiently long timeout, and expect this - # error. + await asyncio.wait_for(asyncio.shield(nexus_handle), 0.1) + except asyncio.TimeoutError: pass await self.finish.wait() @@ -232,6 +280,7 @@ async def test_worker_interceptor(client: Client, env: WorkflowEnvironment): workflows=[InterceptedWorkflow], activities=[intercepted_activity], interceptors=[TracingWorkerInterceptor()], + nexus_service_handlers=[InterceptedNexusService()], ): # Run workflow handle = await client.start_workflow( @@ -310,6 +359,14 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any: assert pop_trace( "workflow.update.validator", lambda v: v.args[0] == "reject-me" ) + assert pop_trace( + "nexus.start_operation.InterceptedNexusService.intercepted_operation", + lambda v: v.input == "hello", + ) + assert pop_trace( + "nexus.cancel_operation.InterceptedNexusService.intercepted_operation", + lambda v: v.token == "hello", + ) # Confirm no unexpected traces assert not interceptor_traces From fbb745a05b9f22940a39a83a2f71c778ec993eb1 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Wed, 19 Nov 2025 10:34:34 -0800 Subject: [PATCH 03/10] Add docstrings --- temporalio/contrib/opentelemetry.py | 3 + temporalio/worker/_interceptor.py | 201 +++++++++++++++------------- 2 files changed, 112 insertions(+), 92 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 298c9ffc5..858ab5378 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -142,6 +142,9 @@ def workflow_interceptor_class( def intercept_nexus_operation( self, next: temporalio.worker.NexusOperationInboundInterceptor ) -> temporalio.worker.NexusOperationInboundInterceptor: + """Implementation of + :py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`. + """ return _TracingNexusOperationInboundInterceptor(next, self) def _context_to_headers( diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index beda27535..e35cee454 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -85,98 +85,6 @@ def intercept_nexus_operation( return next -@dataclass -class NexusOperationStartInput: - ctx: nexusrpc.handler.StartOperationContext - input: Any - - -@dataclass -class NexusOperationCancelInput: - ctx: nexusrpc.handler.CancelOperationContext - token: str - - -class NexusOperationInboundInterceptor: - def __init__(self, next: NexusOperationInboundInterceptor) -> None: - self.next = next - - async def start_operation( - self, input: NexusOperationStartInput - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self.next.start_operation(input) - - async def cancel_operation(self, input: NexusOperationCancelInput) -> None: - return await self.next.cancel_operation(input) - - -class _NexusOperationHandlerForInterceptor( - nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any] -): - def __init__(self, next_interceptor: NexusOperationInboundInterceptor): - self._next_interceptor = next_interceptor - - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: Any - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._next_interceptor.start_operation( - NexusOperationStartInput(ctx, input) - ) - - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> None: - return await self._next_interceptor.cancel_operation( - NexusOperationCancelInput(ctx, token) - ) - - -class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): - def __init__( - self, - handler: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], - ): - self._handler = handler - - async def start_operation( - self, input: NexusOperationStartInput - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._handler.start(input.ctx, input.input) - - async def cancel_operation(self, input: NexusOperationCancelInput) -> None: - return await self._handler.cancel(input.ctx, input.token) - - -class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): - def __init__(self, interceptors: Sequence[Interceptor]) -> None: - self._interceptors = interceptors - - def intercept( - self, - ctx: nexusrpc.handler.OperationContext, - next: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], - ) -> nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any]: - inbound = reduce( - lambda impl, _next: _next.intercept_nexus_operation(impl), - reversed(self._interceptors), - cast( - NexusOperationInboundInterceptor, - _NexusOperationInboundInterceptorImpl(next), - ), - ) - - return _NexusOperationHandlerForInterceptor(inbound) - - @dataclass(frozen=True) class WorkflowInterceptorClassInput: """Input for :py:meth:`Interceptor.workflow_interceptor_class`.""" @@ -578,3 +486,112 @@ async def start_nexus_operation( ) -> temporalio.workflow.NexusOperationHandle[OutputT]: """Called for every :py:func:`temporalio.workflow.start_nexus_operation` call.""" return await self.next.start_nexus_operation(input) + + +@dataclass +class NexusOperationStartInput: + """Input for :pyt:meth:`NexusOperationInboundInterceptor.start_operation""" + + ctx: nexusrpc.handler.StartOperationContext + input: Any + + +@dataclass +class NexusOperationCancelInput: + """Input for :pyt:meth:`NexusOperationInboundInterceptor.cancel_operation""" + + ctx: nexusrpc.handler.CancelOperationContext + token: str + + +class NexusOperationInboundInterceptor: + """Inbound interceptor to wrap Nexus operation starting and cancelling. + + This should be extended by any Nexus operation inbound interceptors. + """ + + def __init__(self, next: NexusOperationInboundInterceptor) -> None: + """Create the inbound interceptor. + + Args: + next: The next interceptor in the chain. The default implementation + of all calls is to delegate to the next interceptor. + """ + self.next = next + + async def start_operation( + self, input: NexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + """Called to start a Nexus operation""" + return await self.next.start_operation(input) + + async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + """Called to cancel an in progress Nexus operation""" + return await self.next.cancel_operation(input) + + +class _NexusOperationHandlerForInterceptor( + nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any] +): + def __init__(self, next_interceptor: NexusOperationInboundInterceptor): + self._next_interceptor = next_interceptor + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: Any + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._next_interceptor.start_operation( + NexusOperationStartInput(ctx, input) + ) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + return await self._next_interceptor.cancel_operation( + NexusOperationCancelInput(ctx, token) + ) + + +class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): + def __init__( + self, + handler: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], + ): + self._handler = handler + + async def start_operation( + self, input: NexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._handler.start(input.ctx, input.input) + + async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + return await self._handler.cancel(input.ctx, input.token) + + +class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): + def __init__(self, interceptors: Sequence[Interceptor]) -> None: + self._interceptors = interceptors + + def intercept( + self, + ctx: nexusrpc.handler.OperationContext, + next: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], + ) -> nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any]: + inbound = reduce( + lambda impl, _next: _next.intercept_nexus_operation(impl), + reversed(self._interceptors), + cast( + NexusOperationInboundInterceptor, + _NexusOperationInboundInterceptorImpl(next), + ), + ) + + return _NexusOperationHandlerForInterceptor(inbound) From 690f0fe000d2ebc541780b7b6cf71cbfb23cf147 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Wed, 19 Nov 2025 14:02:17 -0800 Subject: [PATCH 04/10] Update tests to use @workflow_run_operation to avoid hacky cancel awaiting --- tests/contrib/test_opentelemetry.py | 50 +++++++++++++++++------------ tests/worker/test_interceptor.py | 48 ++++++++++++++------------- 2 files changed, 54 insertions(+), 44 deletions(-) diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 27d7cebff..e3165bb20 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -21,7 +21,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode, get_tracer -from temporalio import activity, workflow +from temporalio import activity, nexus, workflow from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy from temporalio.contrib.opentelemetry import ( @@ -29,7 +29,11 @@ TracingWorkflowInboundInterceptor, ) from temporalio.contrib.opentelemetry import workflow as otel_workflow -from temporalio.exceptions import ApplicationError, ApplicationErrorCategory +from temporalio.exceptions import ( + ApplicationError, + ApplicationErrorCategory, + NexusOperationError, +) from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer @@ -88,23 +92,27 @@ class TracingWorkflowActionContinueAsNew: param: TracingWorkflowParam -class InterceptedOperationHandler(nexusrpc.handler.OperationHandler[str, str]): - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: str - ) -> nexusrpc.handler.StartOperationResultAsync: - return nexusrpc.handler.StartOperationResultAsync(input) - - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> None: - pass +@workflow.defn +class ExpectCancelNexusWorkflow: + @workflow.run + async def run(self, input: str): + try: + await asyncio.wait_for(asyncio.Future(), 2) + except asyncio.TimeoutError: + raise ApplicationError("expected cancellation") @nexusrpc.handler.service_handler class InterceptedNexusService: - @nexusrpc.handler.operation_handler - def intercepted_operation(self) -> nexusrpc.handler.OperationHandler[str, str]: - return InterceptedOperationHandler() + @nexus.workflow_run_operation + async def intercepted_operation( + self, ctx: nexus.WorkflowRunOperationContext, input: str + ) -> nexus.WorkflowHandle[None]: + return await ctx.start_workflow( + ExpectCancelNexusWorkflow.run, + input, + id=f"wf-{uuid.uuid4()}-{ctx.request_id}", + ) ready_for_update: asyncio.Semaphore @@ -182,15 +190,13 @@ async def run(self, param: TracingWorkflowParam) -> None: nexus_handle = await nexus_client.start_operation( operation=InterceptedNexusService.intercepted_operation, - input="hello", + input="nexus-workflow", ) nexus_handle.cancel() - # in order for the cancel to make progress, the handle must be awaited - # but it hangs indefinitely, so I'm using this temp workaround try: - await asyncio.wait_for(asyncio.shield(nexus_handle), 0.1) - except asyncio.TimeoutError: + await nexus_handle + except NexusOperationError: pass async def _raise_on_non_replay(self) -> None: @@ -472,7 +478,7 @@ async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironm async with Worker( client, task_queue=task_queue, - workflows=[TracingWorkflow], + workflows=[TracingWorkflow, ExpectCancelNexusWorkflow], activities=[tracing_activity], nexus_service_handlers=[InterceptedNexusService()], # Needed so we can wait to send update at the right time @@ -504,6 +510,8 @@ async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironm " MyCustomSpan", " StartNexusOperation:InterceptedNexusService/intercepted_operation", " RunStartNexusOperationHandler:InterceptedNexusService/intercepted_operation", + " StartWorkflow:ExpectCancelNexusWorkflow", + " RunWorkflow:ExpectCancelNexusWorkflow", " RunCancelNexusOperationHandler:InterceptedNexusService/intercepted_operation", " CompleteWorkflow:TracingWorkflow", ] diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index aa9521813..5d0a95f2a 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -6,15 +6,13 @@ import nexusrpc import pytest from nexusrpc.handler._common import ( - CancelOperationContext, - StartOperationContext, StartOperationResultAsync, StartOperationResultSync, ) -from temporalio import activity, workflow +from temporalio import activity, nexus, workflow from temporalio.client import Client, WorkflowUpdateFailedError -from temporalio.exceptions import ApplicationError +from temporalio.exceptions import ApplicationError, NexusOperationError from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( ActivityInboundInterceptor, @@ -164,21 +162,27 @@ async def cancel_operation(self, input: NexusOperationCancelInput) -> None: return await super().cancel_operation(input) -class InterceptedOperationHandler(nexusrpc.handler.OperationHandler[str, str]): - async def start( - self, ctx: StartOperationContext, input: str - ) -> StartOperationResultAsync: - return StartOperationResultAsync(input) - - async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - pass +@workflow.defn +class ExpectCancelNexusWorkflow: + @workflow.run + async def run(self, input: str): + try: + await asyncio.wait_for(asyncio.Future(), 2) + except asyncio.TimeoutError: + raise ApplicationError("expected cancellation") @nexusrpc.handler.service_handler class InterceptedNexusService: - @nexusrpc.handler.operation_handler - def intercepted_operation(self) -> nexusrpc.handler.OperationHandler[str, str]: - return InterceptedOperationHandler() + @nexus.workflow_run_operation + async def intercepted_operation( + self, ctx: nexus.WorkflowRunOperationContext, input: str + ) -> nexus.WorkflowHandle[None]: + return await ctx.start_workflow( + ExpectCancelNexusWorkflow.run, + input, + id=f"wf-{uuid.uuid4()}-{ctx.request_id}", + ) @activity.defn @@ -229,15 +233,13 @@ async def run(self, style: str) -> None: nexus_handle = await nexus_client.start_operation( operation=InterceptedNexusService.intercepted_operation, - input="hello", + input="nexus-workflow", ) nexus_handle.cancel() - # in order for the cancel to make progress, the handle must be awaited - # but it hangs indefinitely, so I'm using this temp workaround try: - await asyncio.wait_for(asyncio.shield(nexus_handle), 0.1) - except asyncio.TimeoutError: + await nexus_handle + except NexusOperationError: pass await self.finish.wait() @@ -277,7 +279,7 @@ async def test_worker_interceptor(client: Client, env: WorkflowEnvironment): async with Worker( client, task_queue=task_queue, - workflows=[InterceptedWorkflow], + workflows=[InterceptedWorkflow, ExpectCancelNexusWorkflow], activities=[intercepted_activity], interceptors=[TracingWorkerInterceptor()], nexus_service_handlers=[InterceptedNexusService()], @@ -361,11 +363,11 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any: ) assert pop_trace( "nexus.start_operation.InterceptedNexusService.intercepted_operation", - lambda v: v.input == "hello", + lambda v: v.input == "nexus-workflow", ) + assert pop_trace("workflow.execute", lambda v: v.args[0] == "nexus-workflow") assert pop_trace( "nexus.cancel_operation.InterceptedNexusService.intercepted_operation", - lambda v: v.token == "hello", ) # Confirm no unexpected traces From 6210805bd776a23f3682506b6c42c6273286a502 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 20 Nov 2025 11:38:21 -0800 Subject: [PATCH 05/10] PR feedback --- temporalio/contrib/opentelemetry.py | 262 ++++++++++------------------ temporalio/worker/__init__.py | 8 +- temporalio/worker/_interceptor.py | 82 +-------- temporalio/worker/_nexus.py | 72 +++++++- tests/worker/test_interceptor.py | 27 +-- 5 files changed, 192 insertions(+), 259 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 858ab5378..eaec9b1c6 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -9,6 +9,7 @@ Any, Callable, Dict, + Generic, Iterator, Mapping, NoReturn, @@ -28,7 +29,6 @@ import opentelemetry.trace import opentelemetry.trace.propagation.tracecontext import opentelemetry.util.types -from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync from opentelemetry.context import Context from opentelemetry.trace import Status, StatusCode from typing_extensions import Protocol, TypeAlias, TypedDict @@ -60,6 +60,8 @@ _CarrierDict: TypeAlias = Dict[str, opentelemetry.propagators.textmap.CarrierValT] +_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext) + class TracingInterceptor(temporalio.client.Interceptor, temporalio.worker.Interceptor): """Interceptor that supports client and worker OpenTelemetry span creation @@ -180,9 +182,10 @@ def _start_as_current_span( name: str, *, attributes: opentelemetry.util.types.Attributes, - input: Optional[_InputWithHeaders] = None, + input_with_headers: _InputWithHeaders | None = None, + input_with_ctx: _InputWithOperationContext | None = None, kind: opentelemetry.trace.SpanKind, - context: Optional[Context] = None, + context: Context | None = None, ) -> Iterator[None]: token = opentelemetry.context.attach(context) if context else None try: @@ -193,8 +196,19 @@ def _start_as_current_span( context=context, set_status_on_exception=False, ) as span: - if input: - input.headers = self._context_to_headers(input.headers) + if input_with_headers: + input_with_headers.headers = self._context_to_headers( + input_with_headers.headers + ) + if input_with_ctx: + carrier: _CarrierDict = {} + self.text_map_propagator.inject(carrier) + input_with_ctx.ctx = dataclasses.replace( + input_with_ctx.ctx, + headers=_carrier_to_nexus_headers( + carrier, input_with_ctx.ctx.headers + ), + ) try: yield None except Exception as exc: @@ -213,45 +227,6 @@ def _start_as_current_span( if token and context is opentelemetry.context.get_current(): opentelemetry.context.detach(token) - @contextmanager - def _start_as_current_span_nexus( - self, - name: str, - *, - attributes: opentelemetry.util.types.Attributes, - input_headers: Mapping[str, str], - kind: opentelemetry.trace.SpanKind, - context: Optional[Context] = None, - ) -> Iterator[_CarrierDict]: - token = opentelemetry.context.attach(context) if context else None - try: - with self.tracer.start_as_current_span( - name, - attributes=attributes, - kind=kind, - context=context, - set_status_on_exception=False, - ) as span: - new_headers: _CarrierDict = {**input_headers} - self.text_map_propagator.inject(new_headers) - try: - yield new_headers - except Exception as exc: - if ( - not isinstance(exc, ApplicationError) - or exc.category != ApplicationErrorCategory.BENIGN - ): - span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", - ) - ) - raise - finally: - if token and context is opentelemetry.context.get_current(): - opentelemetry.context.detach(token) - def _completed_workflow_span( self, params: _CompletedWorkflowSpanParams ) -> Optional[_CarrierDict]: @@ -311,7 +286,7 @@ async def start_workflow( with self.root._start_as_current_span( f"{prefix}:{input.workflow}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().start_workflow(input) @@ -320,7 +295,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A with self.root._start_as_current_span( f"QueryWorkflow:{input.query}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().query_workflow(input) @@ -331,7 +306,7 @@ async def signal_workflow( with self.root._start_as_current_span( f"SignalWorkflow:{input.signal}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().signal_workflow(input) @@ -342,7 +317,7 @@ async def start_workflow_update( with self.root._start_as_current_span( f"StartWorkflowUpdate:{input.update}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().start_workflow_update(input) @@ -359,7 +334,7 @@ async def start_update_with_start_workflow( with self.root._start_as_current_span( f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}", attributes=attrs, - input=input.start_workflow_input, + input_with_headers=input.start_workflow_input, kind=opentelemetry.trace.SpanKind.CLIENT, ): otel_header = input.start_workflow_input.headers.get(self.root.header_key) @@ -398,33 +373,8 @@ async def execute_activity( return await super().execute_activity(input) -class _NexusTracing: - _ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext) - - # TODO(amazzeo): not sure what to do if value happens to be a list - # _CarrierDict represents http headers Map[str, List[str] | str] - # but nexus headers are just Map[str, str] - def _carrier_to_nexus_headers( - self, carrier: _CarrierDict, initial: Mapping[str, str] | None = None - ) -> Mapping[str, str]: - out = {**initial} if initial else {} - for k, v in carrier.items(): - if isinstance(v, list): - out[k] = ",".join(v) - else: - out[k] = v - return out - - def _operation_ctx_with_carrier( - self, ctx: _ContextT, carrier: _CarrierDict - ) -> _ContextT: - return dataclasses.replace( - ctx, headers=self._carrier_to_nexus_headers(carrier, ctx.headers) - ) - - class _TracingNexusOperationInboundInterceptor( - temporalio.worker.NexusOperationInboundInterceptor, _NexusTracing + temporalio.worker.NexusOperationInboundInterceptor ): def __init__( self, @@ -437,39 +387,48 @@ def __init__( def _context_from_nexus_headers(self, headers: Mapping[str, str]): return self._root.text_map_propagator.extract(headers) - async def start_operation( - self, input: temporalio.worker.NexusOperationStartInput - ) -> StartOperationResultSync[Any] | StartOperationResultAsync: - with self._root._start_as_current_span_nexus( + async def execute_nexus_operation_start( + self, input: temporalio.worker.ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + with self._root._start_as_current_span( f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", context=self._context_from_nexus_headers(input.ctx.headers), attributes={ "temporalNexusRequestId": input.ctx.request_id, }, - input_headers=input.ctx.headers, + input_with_ctx=input, kind=opentelemetry.trace.SpanKind.SERVER, - ) as new_headers: - input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers) - return await self._next.start_operation(input) + ): + return await self._next.execute_nexus_operation_start(input) - async def cancel_operation( - self, input: temporalio.worker.NexusOperationCancelInput + async def execute_nexus_operation_cancel( + self, input: temporalio.worker.ExecuteNexusOperationCancelInput ) -> None: - with self._root._start_as_current_span_nexus( + with self._root._start_as_current_span( f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", context=self._context_from_nexus_headers(input.ctx.headers), attributes={}, - input_headers=input.ctx.headers, + input_with_ctx=input, kind=opentelemetry.trace.SpanKind.SERVER, - ) as new_headers: - input.ctx = self._operation_ctx_with_carrier(input.ctx, new_headers) - return await self._next.cancel_operation(input) + ): + return await self._next.execute_nexus_operation_cancel(input) class _InputWithHeaders(Protocol): headers: Mapping[str, temporalio.api.common.v1.Payload] +class _InputWithStringHeaders(Protocol): + headers: Mapping[str, str] | None + + +class _InputWithOperationContext(Generic[_ContextT], Protocol): + ctx: _ContextT + + class _WorkflowExternFunctions(TypedDict): __temporal_opentelemetry_completed_span: Callable[ [_CompletedWorkflowSpanParams], Optional[_CarrierDict] @@ -536,7 +495,7 @@ async def execute_workflow( """ with self._top_level_workflow_context(success_is_complete=True): # Entrypoint of workflow should be `server` in OTel - self._completed_span_grpc( + self._completed_span( f"RunWorkflow:{temporalio.workflow.info().workflow_type}", kind=opentelemetry.trace.SpanKind.SERVER, ) @@ -555,7 +514,7 @@ async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> Non [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span_grpc( + self._completed_span( f"HandleSignal:{input.signal}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -587,7 +546,7 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: token = opentelemetry.context.attach(context) try: # This won't be created if there was no context header - self._completed_span_grpc( + self._completed_span( f"HandleQuery:{input.query}", link_context_carrier=link_context_carrier, # Create even on replay for queries @@ -616,7 +575,7 @@ def handle_update_validator( [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span_grpc( + self._completed_span( f"ValidateUpdate:{input.update}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -636,7 +595,7 @@ async def handle_update_handler( [link_context_header] )[0] with self._top_level_workflow_context(success_is_complete=False): - self._completed_span_grpc( + self._completed_span( f"HandleUpdate:{input.update}", link_context_carrier=link_context_carrier, kind=opentelemetry.trace.SpanKind.SERVER, @@ -685,7 +644,7 @@ def _top_level_workflow_context( finally: # Create a completed span before detaching context if exception or (success and success_is_complete): - self._completed_span_grpc( + self._completed_span( f"CompleteWorkflow:{temporalio.workflow.info().workflow_type}", exception=exception, kind=opentelemetry.trace.SpanKind.INTERNAL, @@ -717,67 +676,18 @@ def _context_carrier_to_headers( } return headers - def _completed_span_grpc( + def _completed_span( self, span_name: str, *, link_context_carrier: Optional[_CarrierDict] = None, add_to_outbound: Optional[_InputWithHeaders] = None, + add_to_outbound_str: Optional[_InputWithStringHeaders] = None, new_span_even_on_replay: bool = False, additional_attributes: opentelemetry.util.types.Attributes = None, exception: Optional[Exception] = None, kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, ) -> None: - updated_context_carrier = self._completed_span( - span_name=span_name, - link_context_carrier=link_context_carrier, - new_span_even_on_replay=new_span_even_on_replay, - additional_attributes=additional_attributes, - exception=exception, - kind=kind, - ) - - # Add to outbound if needed - if add_to_outbound and updated_context_carrier: - add_to_outbound.headers = self._context_carrier_to_headers( - updated_context_carrier, add_to_outbound.headers - ) - - def _completed_span_nexus( - self, - span_name: str, - *, - outbound_headers: Mapping[str, str], - link_context_carrier: Optional[_CarrierDict] = None, - new_span_even_on_replay: bool = False, - additional_attributes: opentelemetry.util.types.Attributes = None, - exception: Optional[Exception] = None, - kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, - ) -> _CarrierDict | None: - new_carrier = self._completed_span( - span_name=span_name, - link_context_carrier=link_context_carrier, - new_span_even_on_replay=new_span_even_on_replay, - additional_attributes=additional_attributes, - exception=exception, - kind=kind, - ) - - if new_carrier: - return {**outbound_headers, **new_carrier} - else: - return {**outbound_headers} - - def _completed_span( - self, - span_name: str, - *, - link_context_carrier: Optional[_CarrierDict] = None, - new_span_even_on_replay: bool = False, - additional_attributes: opentelemetry.util.types.Attributes = None, - exception: Optional[Exception] = None, - kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, - ) -> _CarrierDict | None: # If we are replaying and they don't want a span on replay, no span if temporalio.workflow.unsafe.is_replaying() and not new_span_even_on_replay: return None @@ -787,15 +697,11 @@ def _completed_span( self.text_map_propagator.inject(new_context_carrier) # Invoke - # TODO(amazzeo): I think this try/except is necessary once non-workflow callers - # are added to Nexus - attributes: Dict[str, opentelemetry.util.types.AttributeValue] = {} - try: - info = temporalio.workflow.info() - attributes["temporalWorkflowID"] = info.workflow_id - attributes["temporalRunID"] = info.run_id - except temporalio.exceptions.TemporalError: - pass + info = temporalio.workflow.info() + attributes: Dict[str, opentelemetry.util.types.AttributeValue] = { + "temporalWorkflowID": info.workflow_id, + "temporalRunID": info.run_id, + } if additional_attributes: attributes.update(additional_attributes) @@ -816,7 +722,17 @@ def _completed_span( ) ) - return updated_context_carrier + # Add to outbound if needed + if updated_context_carrier: + if add_to_outbound: + add_to_outbound.headers = self._context_carrier_to_headers( + updated_context_carrier, add_to_outbound.headers + ) + + if add_to_outbound_str: + add_to_outbound_str.headers = _carrier_to_nexus_headers( + updated_context_carrier, add_to_outbound_str.headers + ) def _set_on_context( self, context: opentelemetry.context.Context @@ -825,7 +741,7 @@ def _set_on_context( class _TracingWorkflowOutboundInterceptor( - temporalio.worker.WorkflowOutboundInterceptor, _NexusTracing + temporalio.worker.WorkflowOutboundInterceptor ): def __init__( self, @@ -844,7 +760,7 @@ async def signal_child_workflow( self, input: temporalio.worker.SignalChildWorkflowInput ) -> None: # Create new span and put on outbound input - self.root._completed_span_grpc( + self.root._completed_span( f"SignalChildWorkflow:{input.signal}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.SERVER, @@ -855,7 +771,7 @@ async def signal_external_workflow( self, input: temporalio.worker.SignalExternalWorkflowInput ) -> None: # Create new span and put on outbound input - self.root._completed_span_grpc( + self.root._completed_span( f"SignalExternalWorkflow:{input.signal}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -866,7 +782,7 @@ def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: # Create new span and put on outbound input - self.root._completed_span_grpc( + self.root._completed_span( f"StartActivity:{input.activity}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -877,7 +793,7 @@ async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: # Create new span and put on outbound input - self.root._completed_span_grpc( + self.root._completed_span( f"StartChildWorkflow:{input.workflow}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -888,7 +804,7 @@ def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: # Create new span and put on outbound input - self.root._completed_span_grpc( + self.root._completed_span( f"StartActivity:{input.activity}", add_to_outbound=input, kind=opentelemetry.trace.SpanKind.CLIENT, @@ -898,17 +814,27 @@ def start_local_activity( async def start_nexus_operation( self, input: temporalio.worker.StartNexusOperationInput[Any, Any] ) -> temporalio.workflow.NexusOperationHandle[Any]: - new_carrier = self.root._completed_span_nexus( + self.root._completed_span( f"StartNexusOperation:{input.service}/{input.operation_name}", kind=opentelemetry.trace.SpanKind.CLIENT, - outbound_headers=input.headers if input.headers else {}, + add_to_outbound_str=input, ) - if new_carrier: - input.headers = self._carrier_to_nexus_headers(new_carrier, input.headers) return await super().start_nexus_operation(input) +def _carrier_to_nexus_headers( + carrier: _CarrierDict, initial: Mapping[str, str] | None = None +) -> Mapping[str, str]: + out = {**initial} if initial else {} + for k, v in carrier.items(): + if isinstance(v, list): + out[k] = ",".join(v) + else: + out[k] = v + return out + + class workflow: """Contains static methods that are safe to call from within a workflow. @@ -944,6 +870,6 @@ def completed_span( """ interceptor = TracingWorkflowInboundInterceptor._from_context() if interceptor: - interceptor._completed_span_grpc( + interceptor._completed_span( name, additional_attributes=attributes, exception=exception ) diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 9d0a3a47e..8388be24d 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -6,14 +6,14 @@ ActivityOutboundInterceptor, ContinueAsNewInput, ExecuteActivityInput, + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, ExecuteWorkflowInput, HandleQueryInput, HandleSignalInput, HandleUpdateInput, Interceptor, - NexusOperationCancelInput, NexusOperationInboundInterceptor, - NexusOperationStartInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, @@ -99,8 +99,8 @@ "StartLocalActivityInput", "StartNexusOperationInput", "WorkflowInterceptorClassInput", - "NexusOperationStartInput", - "NexusOperationCancelInput", + "ExecuteNexusOperationStartInput", + "ExecuteNexusOperationCancelInput", # Advanced activity classes "SharedStateManager", "SharedHeartbeatSender", diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index e35cee454..bcf76ede9 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -6,7 +6,6 @@ from collections.abc import Callable, Mapping, MutableMapping from dataclasses import dataclass from datetime import timedelta -from functools import reduce from typing import ( Any, Awaitable, @@ -17,7 +16,6 @@ Sequence, Type, Union, - cast, ) import nexusrpc.handler @@ -489,7 +487,7 @@ async def start_nexus_operation( @dataclass -class NexusOperationStartInput: +class ExecuteNexusOperationStartInput: """Input for :pyt:meth:`NexusOperationInboundInterceptor.start_operation""" ctx: nexusrpc.handler.StartOperationContext @@ -497,7 +495,7 @@ class NexusOperationStartInput: @dataclass -class NexusOperationCancelInput: +class ExecuteNexusOperationCancelInput: """Input for :pyt:meth:`NexusOperationInboundInterceptor.cancel_operation""" ctx: nexusrpc.handler.CancelOperationContext @@ -519,79 +517,17 @@ def __init__(self, next: NexusOperationInboundInterceptor) -> None: """ self.next = next - async def start_operation( - self, input: NexusOperationStartInput + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput ) -> ( nexusrpc.handler.StartOperationResultSync[Any] | nexusrpc.handler.StartOperationResultAsync ): """Called to start a Nexus operation""" - return await self.next.start_operation(input) + return await self.next.execute_nexus_operation_start(input) - async def cancel_operation(self, input: NexusOperationCancelInput) -> None: - """Called to cancel an in progress Nexus operation""" - return await self.next.cancel_operation(input) - - -class _NexusOperationHandlerForInterceptor( - nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any] -): - def __init__(self, next_interceptor: NexusOperationInboundInterceptor): - self._next_interceptor = next_interceptor - - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: Any - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._next_interceptor.start_operation( - NexusOperationStartInput(ctx, input) - ) - - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput ) -> None: - return await self._next_interceptor.cancel_operation( - NexusOperationCancelInput(ctx, token) - ) - - -class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): - def __init__( - self, - handler: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], - ): - self._handler = handler - - async def start_operation( - self, input: NexusOperationStartInput - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._handler.start(input.ctx, input.input) - - async def cancel_operation(self, input: NexusOperationCancelInput) -> None: - return await self._handler.cancel(input.ctx, input.token) - - -class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): - def __init__(self, interceptors: Sequence[Interceptor]) -> None: - self._interceptors = interceptors - - def intercept( - self, - ctx: nexusrpc.handler.OperationContext, - next: nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any], - ) -> nexusrpc.handler.MiddlewareSafeOperationHandler[Any, Any]: - inbound = reduce( - lambda impl, _next: _next.intercept_nexus_operation(impl), - reversed(self._interceptors), - cast( - NexusOperationInboundInterceptor, - _NexusOperationInboundInterceptorImpl(next), - ), - ) - - return _NexusOperationHandlerForInterceptor(inbound) + """Called to cancel an in progress Nexus operation""" + return await self.next.execute_nexus_operation_cancel(input) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 36ec74a6b..ee3e33032 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -7,6 +7,7 @@ import json import threading from dataclasses import dataclass +from functools import reduce from typing import ( Any, Callable, @@ -16,6 +17,7 @@ Sequence, Type, Union, + cast, ) import google.protobuf.json_format @@ -40,7 +42,12 @@ from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode -from ._interceptor import Interceptor, _NexusMiddlewareForInterceptors +from ._interceptor import ( + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, + Interceptor, + NexusOperationInboundInterceptor, +) _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" @@ -598,3 +605,66 @@ def cancel(self, reason: str) -> bool: self._thread_evt.set() self._async_evt.set() return True + + +class _NexusOperationHandlerForInterceptor( + nexusrpc.handler.MiddlewareSafeOperationHandler +): + def __init__(self, next_interceptor: NexusOperationInboundInterceptor): + self._next_interceptor = next_interceptor + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: Any + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._next_interceptor.execute_nexus_operation_start( + ExecuteNexusOperationStartInput(ctx, input) + ) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + return await self._next_interceptor.execute_nexus_operation_cancel( + ExecuteNexusOperationCancelInput(ctx, token) + ) + + +class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): + def __init__(self, handler: nexusrpc.handler.MiddlewareSafeOperationHandler): + self._handler = handler + + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._handler.start(input.ctx, input.input) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + return await self._handler.cancel(input.ctx, input.token) + + +class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): + def __init__(self, interceptors: Sequence[Interceptor]) -> None: + self._interceptors = interceptors + + def intercept( + self, + ctx: nexusrpc.handler.OperationContext, + next: nexusrpc.handler.MiddlewareSafeOperationHandler, + ) -> nexusrpc.handler.MiddlewareSafeOperationHandler: + inbound = reduce( + lambda impl, _next: _next.intercept_nexus_operation(impl), + reversed(self._interceptors), + cast( + NexusOperationInboundInterceptor, + _NexusOperationInboundInterceptorImpl(next), + ), + ) + + return _NexusOperationHandlerForInterceptor(inbound) diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 5d0a95f2a..fea248022 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -5,10 +5,6 @@ import nexusrpc import pytest -from nexusrpc.handler._common import ( - StartOperationResultAsync, - StartOperationResultSync, -) from temporalio import activity, nexus, workflow from temporalio.client import Client, WorkflowUpdateFailedError @@ -19,25 +15,25 @@ ActivityOutboundInterceptor, ContinueAsNewInput, ExecuteActivityInput, + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, ExecuteWorkflowInput, HandleQueryInput, HandleSignalInput, HandleUpdateInput, Interceptor, - NexusOperationCancelInput, NexusOperationInboundInterceptor, - NexusOperationStartInput, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, + StartNexusOperationInput, Worker, WorkflowInboundInterceptor, WorkflowInterceptorClassInput, WorkflowOutboundInterceptor, ) -from temporalio.worker._interceptor import StartNexusOperationInput from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name interceptor_traces: List[Tuple[str, Any]] = [] @@ -147,19 +143,24 @@ async def start_nexus_operation( class TracingNexusInboundInterceptor(NexusOperationInboundInterceptor): - async def start_operation( - self, input: NexusOperationStartInput - ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): interceptor_traces.append( (f"nexus.start_operation.{input.ctx.service}.{input.ctx.operation}", input) ) - return await super().start_operation(input) + return await super().execute_nexus_operation_start(input) - async def cancel_operation(self, input: NexusOperationCancelInput) -> None: + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: interceptor_traces.append( (f"nexus.cancel_operation.{input.ctx.service}.{input.ctx.operation}", input) ) - return await super().cancel_operation(input) + return await super().execute_nexus_operation_cancel(input) @workflow.defn From d86a4b311feee4088e2d212413878ef4ee838053 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 20 Nov 2025 11:40:55 -0800 Subject: [PATCH 06/10] remove OTel attribute that other SDKs are not sending --- temporalio/contrib/opentelemetry.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index eaec9b1c6..10e7f6fe3 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -396,9 +396,7 @@ async def execute_nexus_operation_start( with self._root._start_as_current_span( f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", context=self._context_from_nexus_headers(input.ctx.headers), - attributes={ - "temporalNexusRequestId": input.ctx.request_id, - }, + attributes={}, input_with_ctx=input, kind=opentelemetry.trace.SpanKind.SERVER, ): From 655fab75fa4b9bb9ca6d5a6919a1dce1090186ea Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 25 Nov 2025 09:45:18 -0800 Subject: [PATCH 07/10] Use uv to update to the head of target nexus-rpc branch. --- pyproject.toml | 5 ++++- temporalio/worker/_nexus.py | 2 +- uv.lock | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b7c0b77d9..db5f7e209 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@interceptors", + "nexus-rpc", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -246,3 +246,6 @@ exclude = ["temporalio/bridge/target/**/*"] [tool.uv] # Prevent uv commands from building the package by default package = false + +[tool.uv.sources] +nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python", rev = "interceptors" } diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index ee3e33032..54dd460ec 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -83,7 +83,7 @@ def __init__( middleware = _NexusMiddlewareForInterceptors(interceptors) - self._handler = Handler(service_handlers, executor, interceptors=[middleware]) + self._handler = Handler(service_handlers, executor, middleware=[middleware]) self._data_converter = data_converter # TODO(nexus-preview): metric_meter self._metric_meter = metric_meter diff --git a/uv.lock b/uv.lock index 09e9d15ad..fa1d25cb1 100644 --- a/uv.lock +++ b/uv.lock @@ -1761,7 +1761,7 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.2.0" -source = { git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors#e8806579a7050fc076bb14861aeec7208d521534" } +source = { git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors#875e2eafc1dcbb39eb2c31c17d89741c58ddb02d" } dependencies = [ { name = "typing-extensions" }, ] From daa5e54907405196f38617db3725bf1e9221eb3a Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 8 Dec 2025 15:07:50 -0800 Subject: [PATCH 08/10] use nexus-rpc 1.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 88623182a..9c4392c1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc", + "nexus-rpc==1.3.0", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", From c2d038aea6adc254824b71b64b33f9936a3d3714 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 8 Dec 2025 15:28:32 -0800 Subject: [PATCH 09/10] fix reference to nexus-rpc 1.3.0 in lock --- pyproject.toml | 3 --- uv.lock | 10 +++++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 47aa1d6ac..d922b1134 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,6 +225,3 @@ exclude = ["temporalio/bridge/target/**/*"] [tool.uv] # Prevent uv commands from building the package by default package = false - -[tool.uv.sources] -nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python", rev = "interceptors" } diff --git a/uv.lock b/uv.lock index 68afb7b14..600e3e8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1760,11 +1760,15 @@ wheels = [ [[package]] name = "nexus-rpc" -version = "1.2.0" -source = { git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors#875e2eafc1dcbb39eb2c31c17d89741c58ddb02d" } +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/2e/f2/d54f5c03d8f4672ccc0875787a385f53dcb61f98a8ae594b5620e85b9cb3/nexus_rpc-1.3.0.tar.gz", hash = "sha256:e56d3b57b60d707ce7a72f83f23f106b86eca1043aa658e44582ab5ff30ab9ad", size = 75650, upload-time = "2025-12-08T22:59:13.002Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d6/74/0afd841de3199c148146c1d43b4bfb5605b2f1dc4c9a9087fe395091ea5a/nexus_rpc-1.3.0-py3-none-any.whl", hash = "sha256:aee0707b4861b22d8124ecb3f27d62dafbe8777dc50c66c91e49c006f971b92d", size = 28873, upload-time = "2025-12-08T22:59:12.024Z" }, +] [[package]] name = "nh3" @@ -3017,7 +3021,7 @@ dev = [ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, - { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python?rev=interceptors" }, + { name = "nexus-rpc", specifier = "==1.3.0" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.5" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, From 1826914060840015aaac774853ff3e4d3d538773 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 9 Dec 2025 14:30:02 -0800 Subject: [PATCH 10/10] fix indentation in docstring --- temporalio/worker/_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 228852722..d3b838679 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -73,7 +73,7 @@ def intercept_nexus_operation( Args: next: The underlying inbound this interceptor - should delegate to. + should delegate to. Returns: The new interceptor that should be used for the Nexus operation.