Skip to content
44 changes: 32 additions & 12 deletions src/nexusrpc/handler/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Generic, Optional

from nexusrpc._common import Link, OutputT
Expand All @@ -21,17 +22,23 @@ class OperationTaskCancellation(ABC):

@abstractmethod
def is_cancelled(self) -> bool:
"""Return True if the associated task has been cancelled."""
"""
Return True if the associated task has been cancelled.
"""
raise NotImplementedError

@abstractmethod
def cancellation_reason(self) -> Optional[str]:
"""Provide additional context for the cancellation, if available."""
"""
Provide additional context for the cancellation, if available.
"""
raise NotImplementedError

@abstractmethod
def wait_until_cancelled_sync(self, timeout: Optional[float] = None) -> bool:
"""Block until cancellation occurs or the optional timeout elapses. Nexus worker implementations may return `True` for :py:attr:`is_cancelled` before this method returns and therefore may cause a race condition if both are used in tandem."""
"""
Block until cancellation occurs or the optional timeout elapses. Nexus worker implementations may return `True` for :py:attr:`is_cancelled` before this method returns and therefore may cause a race condition if both are used in tandem.
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -40,11 +47,13 @@ async def wait_until_cancelled(self) -> None:
raise NotImplementedError


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class OperationContext(ABC):
"""Context for the execution of the requested operation method.
"""
Context for the execution of the requested operation method.

Includes information from the request."""
Includes information from the request.
"""

def __new__(cls, *args: Any, **kwargs: Any):
if cls is OperationContext:
Expand All @@ -67,17 +76,26 @@ def __new__(cls, *args: Any, **kwargs: Any):
"""
Optional header fields sent by the caller.
"""

task_cancellation: OperationTaskCancellation
"""
Task cancellation information indicating that a running task should be interrupted. This is distinct from operation cancellation.
"""

request_deadline: Optional[datetime] = None
"""
The deadline for the operation handler method. Note that this is the time by which the
current _request_ should complete, not the _operation_'s deadline.
"""

@dataclass(frozen=True)

@dataclass(frozen=True, kw_only=True)
class StartOperationContext(OperationContext):
"""Context for the start method.
"""
Context for the start method.

Includes information from the request."""
Includes information from the request.
"""

request_id: str
"""
Expand Down Expand Up @@ -112,11 +130,13 @@ class StartOperationContext(OperationContext):
"""


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class CancelOperationContext(OperationContext):
"""Context for the cancel method.
"""
Context for the cancel method.

Includes information from the request."""
Includes information from the request.
"""


@dataclass(frozen=True)
Expand Down