diff --git a/temporalio/activity.py b/temporalio/activity.py index ff46bdea8..6926161a5 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -104,16 +104,25 @@ class Info: heartbeat_details: Sequence[Any] heartbeat_timeout: timedelta | None is_local: bool + namespace: str schedule_to_close_timeout: timedelta | None scheduled_time: datetime start_to_close_timeout: timedelta | None started_time: datetime task_queue: str task_token: bytes - workflow_id: str - workflow_namespace: str - workflow_run_id: str - workflow_type: str + workflow_id: str | None + """ID of the workflow. None if the activity was not started by a workflow.""" + workflow_namespace: str | None + """Namespace of the workflow. None if the activity was not started by a workflow. + + .. deprecated:: + Use :py:attr:`namespace` instead. + """ + workflow_run_id: str | None + """Run ID of the workflow. None if the activity was not started by a workflow.""" + workflow_type: str | None + """Type of the workflow. None if the activity was not started by a workflow.""" priority: temporalio.common.Priority retry_policy: temporalio.common.RetryPolicy | None """The retry policy of this activity. @@ -122,6 +131,14 @@ class Info: If the value is None, it means the server didn't send information about retry policy (e.g. due to old server version), but it may still be defined server-side.""" + activity_run_id: str | None = None + """Run ID of this activity. None for workflow activities.""" + + @property + def in_workflow(self) -> bool: + """Was this activity started by a workflow?""" + return self.workflow_id is not None + # TODO(cretz): Consider putting identity on here for "worker_id" for logger? def _logger_details(self) -> Mapping[str, Any]: @@ -129,7 +146,7 @@ def _logger_details(self) -> Mapping[str, Any]: "activity_id": self.activity_id, "activity_type": self.activity_type, "attempt": self.attempt, - "namespace": self.workflow_namespace, + "namespace": self.namespace, "task_queue": self.task_queue, "workflow_id": self.workflow_id, "workflow_run_id": self.workflow_run_id, @@ -238,7 +255,7 @@ def metric_meter(self) -> temporalio.common.MetricMeter: info = self.info() self._metric_meter = self.runtime_metric_meter.with_additional_attributes( { - "namespace": info.workflow_namespace, + "namespace": info.namespace, "task_queue": info.task_queue, "activity_type": info.activity_type, } @@ -577,6 +594,20 @@ def must_from_callable(fn: Callable) -> _Definition: f"Activity {fn_name} missing attributes, was it decorated with @activity.defn?" ) + @classmethod + def get_name_and_result_type( + cls, name_or_run_fn: str | Callable[..., Any] + ) -> tuple[str, type | None]: + if isinstance(name_or_run_fn, str): + return name_or_run_fn, None + elif callable(name_or_run_fn): + defn = cls.must_from_callable(name_or_run_fn) + if not defn.name: + raise ValueError(f"Activity {name_or_run_fn} definition has no name") + return defn.name, defn.ret_type + else: + raise TypeError("Activity must be a string or callable") # type:ignore[reportUnreachable] + @staticmethod def _apply_to_callable( fn: Callable, diff --git a/temporalio/client.py b/temporalio/client.py index b4d5af0fa..5c18424f2 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -39,6 +39,8 @@ from google.protobuf.internal.containers import MessageMap from typing_extensions import Required, Self, TypedDict +import temporalio.activity +import temporalio.api.activity.v1 import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.api.errordetails.v1 @@ -60,6 +62,7 @@ import temporalio.workflow from temporalio.activity import ActivityCancellationDetails from temporalio.converter import ( + ActivitySerializationContext, DataConverter, SerializationContext, WithSerializationContext, @@ -79,6 +82,10 @@ from .common import HeaderCodecBehavior from .types import ( AnyType, + CallableAsyncNoParam, + CallableAsyncSingleParam, + CallableSyncNoParam, + CallableSyncSingleParam, LocalReturnType, MethodAsyncNoParam, MethodAsyncSingleParam, @@ -1266,6 +1273,1258 @@ async def count_workflows( ) ) + # async no-param + @overload + async def start_activity( + self, + activity: CallableAsyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync no-param + @overload + async def start_activity( + self, + activity: CallableSyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async single-param + @overload + async def start_activity( + self, + activity: CallableAsyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync single-param + @overload + async def start_activity( + self, + activity: CallableSyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async multi-param + @overload + async def start_activity( + self, + activity: Callable[..., Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync multi-param + @overload + async def start_activity( + self, + activity: Callable[..., ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # string name + @overload + async def start_activity( + self, + activity: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: ... + + async def start_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + # Either schedule_to_close_timeout or start_to_close_timeout must be present + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: + """Start an activity and return its handle. + + .. warning:: + This API is experimental. + + Args: + activity: String name or callable activity function to execute. + arg: Single argument to the activity. + args: Multiple arguments to the activity. Cannot be set if arg is. + id: Unique identifier for the activity. Required. + task_queue: Task queue to send the activity to. + result_type: For string name activities, optional type to deserialize result into. + schedule_to_close_timeout: Total time allowed for the activity from schedule to completion. + schedule_to_start_timeout: Time allowed for the activity to sit in the task queue. + start_to_close_timeout: Time allowed for a single execution attempt. + heartbeat_timeout: Time between heartbeats before the activity is considered failed. + id_reuse_policy: How to handle reusing activity IDs from closed activities. + Default is ALLOW_DUPLICATE. + id_conflict_policy: How to handle activity ID conflicts with running activities. + Default is FAIL. + retry_policy: Retry policy for the activity. + search_attributes: Search attributes for the activity. + summary: A single-line fixed summary for this activity that may appear + in the UI/CLI. This can be in single-line Temporal markdown format. + priority: Priority of the activity execution. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + A handle to the started activity. + """ + name, result_type_from_type_annotation = ( + temporalio.activity._Definition.get_name_and_result_type(activity) + ) + return await self._impl.start_activity( + StartActivityInput( + activity_type=name, + args=temporalio.common._arg_or_args(arg, args), + id=id, + task_queue=task_queue, + result_type=result_type or result_type_from_type_annotation, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + headers={}, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + priority=priority, + ) + ) + + # async no-param + @overload + async def execute_activity( + self, + activity: CallableAsyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync no-param + @overload + async def execute_activity( + self, + activity: CallableSyncNoParam[ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async single-param + @overload + async def execute_activity( + self, + activity: CallableAsyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync single-param + @overload + async def execute_activity( + self, + activity: CallableSyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async multi-param + @overload + async def execute_activity( + self, + activity: Callable[..., Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync multi-param + @overload + async def execute_activity( + self, + activity: Callable[..., ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # string name + @overload + async def execute_activity( + self, + activity: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: ... + + async def execute_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + # Either schedule_to_close_timeout or start_to_close_timeout must be present + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: + """Start an activity, wait for it to complete, and return its result. + + .. warning:: + This API is experimental. + + This is a convenience method that combines :py:meth:`start_activity` and + :py:meth:`ActivityHandle.result`. + + Returns: + The result of the activity. + + Raises: + ActivityFailedError: If the activity completed with a failure. + """ + handle: ActivityHandle[ReturnType] = await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + return await handle.result() + + # async no-param + @overload + async def start_activity_class( + self, + activity: type[CallableAsyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync no-param + @overload + async def start_activity_class( + self, + activity: type[CallableSyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async single-param + @overload + async def start_activity_class( + self, + activity: type[CallableAsyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync single-param + @overload + async def start_activity_class( + self, + activity: type[CallableSyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async multi-param + @overload + async def start_activity_class( + self, + activity: type[Callable[..., Awaitable[ReturnType]]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync multi-param + @overload + async def start_activity_class( + self, + activity: type[Callable[..., ReturnType]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + async def start_activity_class( + self, + activity: type[Callable], # type: ignore[reportInvalidTypeForm] + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: + """Start an activity from a callable class. + + .. warning:: + This API is experimental. + + See :py:meth:`start_activity` for parameter and return details. + """ + return await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + # async no-param + @overload + async def execute_activity_class( + self, + activity: type[CallableAsyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync no-param + @overload + async def execute_activity_class( + self, + activity: type[CallableSyncNoParam[ReturnType]], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async single-param + @overload + async def execute_activity_class( + self, + activity: type[CallableAsyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync single-param + @overload + async def execute_activity_class( + self, + activity: type[CallableSyncSingleParam[ParamType, ReturnType]], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async multi-param + @overload + async def execute_activity_class( + self, + activity: type[Callable[..., Awaitable[ReturnType]]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync multi-param + @overload + async def execute_activity_class( + self, + activity: type[Callable[..., ReturnType]], # type: ignore[reportInvalidTypeForm] + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + async def execute_activity_class( + self, + activity: type[Callable], # type: ignore[reportInvalidTypeForm] + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: + """Start an activity from a callable class and wait for completion. + + .. warning:: + This API is experimental. + + This is a shortcut for ``await`` :py:meth:`start_activity_class`. + """ + return await self.execute_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + # async no-param + @overload + async def start_activity_method( + self, + activity: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async single-param + @overload + async def start_activity_method( + self, + activity: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # async multi-param + @overload + async def start_activity_method( + self, + activity: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + # sync multi-param + @overload + async def start_activity_method( + self, + activity: Callable[Concatenate[SelfType, MultiParamSpec], ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[ReturnType]: ... + + async def start_activity_method( + self, + activity: Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityHandle[Any]: + """Start an activity from a method. + + .. warning:: + This API is experimental. + + See :py:meth:`start_activity` for parameter and return details. + """ + return await self.start_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + # async no-param + @overload + async def execute_activity_method( + self, + activity: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async single-param + @overload + async def execute_activity_method( + self, + activity: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # async multi-param + @overload + async def execute_activity_method( + self, + activity: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + # sync multi-param + @overload + async def execute_activity_method( + self, + activity: Callable[Concatenate[SelfType, MultiParamSpec], ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: ... + + async def execute_activity_method( + self, + activity: Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str, + result_type: type | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> Any: + """Start an activity from a method and wait for completion. + + .. warning:: + This API is experimental. + + This is a shortcut for ``await`` :py:meth:`start_activity_method`. + """ + return await self.execute_activity( + cast(Any, activity), + arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + + def list_activities( + self, + query: str | None = None, + *, + limit: int | None = None, + page_size: int = 1000, + next_page_token: bytes | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityExecutionAsyncIterator: + """List activities not started by a workflow. + + .. warning:: + This API is experimental. + + This does not make a request until the first iteration is attempted. + Therefore any errors will not occur until then. + + Args: + query: A Temporal visibility list filter for activities. + limit: Maximum number of activities to return. If unset, all + activities are returned. Only applies if using the + returned :py:class:`ActivityExecutionAsyncIterator` + as an async iterator. + page_size: Maximum number of results for each page. + next_page_token: A previously obtained next page token if doing + pagination. Usually not needed as the iterator automatically + starts from the beginning. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + + Returns: + An async iterator that can be used with ``async for``. + """ + return self._impl.list_activities( + ListActivitiesInput( + query=query, + page_size=page_size, + next_page_token=next_page_token, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + limit=limit, + ) + ) + + async def count_activities( + self, + query: str | None = None, + *, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityExecutionCount: + """Count activities not started by a workflow. + + .. warning:: + This API is experimental. + + Args: + query: A Temporal visibility filter for activities. + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + Count of activities. + """ + return await self._impl.count_activities( + CountActivitiesInput( + query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout + ) + ) + + @overload + def get_activity_handle( + self, + activity_id: str, + *, + activity_run_id: str | None = None, + ) -> ActivityHandle[Any]: ... + + @overload + def get_activity_handle( + self, + activity_id: str, + *, + result_type: type[ReturnType], + activity_run_id: str | None = None, + ) -> ActivityHandle[ReturnType]: ... + + def get_activity_handle( + self, + activity_id: str, + *, + result_type: type | None = None, + activity_run_id: str | None = None, + ) -> ActivityHandle[Any]: + """Get a handle to an existing activity, as the caller of that activity. + + The activity must not have been started by a workflow. + + .. warning:: + This API is experimental. + + To get a handle to an activity execution that you control for manual completion and + heartbeating, see :py:meth:`Client.get_async_activity_handle`. + + Args: + activity_id: The activity ID. + result_type: The result type to deserialize into. + activity_run_id: The activity run ID. If not provided, targets the + latest run. + + Returns: + A handle to the activity. + """ + return ActivityHandle( + self, + activity_id, + activity_run_id=activity_run_id, + result_type=result_type, + ) + + @overload + def get_async_activity_handle( + self, *, activity_id: str, run_id: str | None = None + ) -> AsyncActivityHandle: + pass + @overload def get_async_activity_handle( self, *, workflow_id: str, run_id: str | None, activity_id: str @@ -1276,6 +2535,7 @@ def get_async_activity_handle( def get_async_activity_handle(self, *, task_token: bytes) -> AsyncActivityHandle: pass + # TODO(dan): add typed API get_async_activity_handle_for? def get_async_activity_handle( self, *, @@ -1284,22 +2544,31 @@ def get_async_activity_handle( activity_id: str | None = None, task_token: bytes | None = None, ) -> AsyncActivityHandle: - """Get an async activity handle. + """Get a handle to an activity execution that you control, for manual completion and heartbeating. + + To get a handle to an activity execution as the caller of that activity, see + :py:meth:`Client.get_activity_handle`. + + This function may be used to get a handle to an activity started by a client, or + an activity started by a workflow. + + To get a handle to an activity started by a workflow, use one of the following two calls: + - Supply ``workflow_id``, ``run_id``, and ``activity_id`` + - Supply the activity ``task_token`` alone + + To get a handle to an activity not started by a workflow, supply ``activity_id`` and + ``run_id`` - Either the workflow_id, run_id, and activity_id can be provided, or a - singular task_token can be provided. Args: - workflow_id: Workflow ID for the activity. Cannot be set if - task_token is set. - run_id: Run ID for the activity. Cannot be set if task_token is set. - activity_id: ID for the activity. Cannot be set if task_token is - set. - task_token: Task token for the activity. Cannot be set if any of the - id parameters are set. + workflow_id: Workflow ID for the activity. + run_id: Run ID for the activity. Cannot be + set if task_token is set. + activity_id: ID for the activity. + task_token: Task token for the activity. Returns: - A handle that can be used for completion or heartbeat. + A handle that can be used for completion or heartbeating. """ if task_token is not None: if workflow_id is not None or run_id is not None or activity_id is not None: @@ -1316,7 +2585,18 @@ def get_async_activity_handle( workflow_id=workflow_id, run_id=run_id, activity_id=activity_id ), ) - raise ValueError("Task token or workflow/run/activity ID must be present") + elif activity_id is not None: + return AsyncActivityHandle( + self, + AsyncActivityIDReference( + activity_id=activity_id, + run_id=run_id, + workflow_id=None, + ), + ) + raise ValueError( + "Require task token, or workflow_id & run_id & activity_id, or activity_id & run_id" + ) async def create_schedule( self, @@ -1592,7 +2872,7 @@ def _data_converter(self) -> temporalio.converter.DataConverter: @property def id(self) -> str: - """ID for the workflow.""" + """ID of the workflow.""" return self._id @property @@ -2647,59 +3927,515 @@ def __init__( ) -> None: """Create a WithStartWorkflowOperation. - See :py:meth:`temporalio.client.Client.start_workflow` for documentation of the - arguments. - """ - temporalio.common._warn_on_deprecated_search_attributes( - search_attributes, stack_level=stack_level - ) - name, result_type_from_run_fn = ( - temporalio.workflow._Definition.get_name_and_result_type(workflow) - ) - if id_conflict_policy == temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: - raise ValueError("WorkflowIDConflictPolicy is required") + See :py:meth:`temporalio.client.Client.start_workflow` for documentation of the + arguments. + """ + temporalio.common._warn_on_deprecated_search_attributes( + search_attributes, stack_level=stack_level + ) + name, result_type_from_run_fn = ( + temporalio.workflow._Definition.get_name_and_result_type(workflow) + ) + if id_conflict_policy == temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED: + raise ValueError("WorkflowIDConflictPolicy is required") + + self._start_workflow_input = UpdateWithStartStartWorkflowInput( + workflow=name, + args=temporalio.common._arg_or_args(arg, args), + id=id, + task_queue=task_queue, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + headers={}, + ret_type=result_type or result_type_from_run_fn, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + priority=priority, + versioning_override=versioning_override, + ) + self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future() + self._used = False + + async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]: + """Wait until workflow is running and return a WorkflowHandle.""" + return await self._workflow_handle + + +class ActivityExecutionAsyncIterator: + """Asynchronous iterator for activity execution values. + + You should typically use ``async for`` on this iterator and not call any of its methods. + + .. warning:: + This API is experimental. + """ + + def __init__( + self, + client: Client, + input: ListActivitiesInput, + ) -> None: + """Create an asynchronous iterator for the given input. + + Users should not create this directly, but rather use + :py:meth:`Client.list_activities`. + """ + self._client = client + self._input = input + self._next_page_token = input.next_page_token + self._current_page: Sequence[ActivityExecution] | None = None + self._current_page_index = 0 + self._limit = input.limit + self._yielded = 0 + + @property + def current_page_index(self) -> int: + """Index of the entry in the current page that will be returned from + the next :py:meth:`__anext__` call. + """ + return self._current_page_index + + @property + def current_page(self) -> Sequence[ActivityExecution] | None: + """Current page, if it has been fetched yet.""" + return self._current_page + + @property + def next_page_token(self) -> bytes | None: + """Token for the next page request if any.""" + return self._next_page_token + + async def fetch_next_page(self, *, page_size: int | None = None) -> None: + """Fetch the next page of results. + + Args: + page_size: Override the page size this iterator was originally + created with. + """ + page_size = page_size or self._input.page_size + if self._limit is not None and self._limit - self._yielded < page_size: + page_size = self._limit - self._yielded + + resp = await self._client.workflow_service.list_activity_executions( + temporalio.api.workflowservice.v1.ListActivityExecutionsRequest( + namespace=self._client.namespace, + page_size=page_size, + next_page_token=self._next_page_token or b"", + query=self._input.query or "", + ), + retry=True, + metadata=self._input.rpc_metadata, + timeout=self._input.rpc_timeout, + ) + + self._current_page = [ + ActivityExecution._from_raw_info( + v, self._client.namespace, self._client.data_converter + ) + for v in resp.executions + ] + self._current_page_index = 0 + self._next_page_token = resp.next_page_token or None + + def __aiter__(self) -> ActivityExecutionAsyncIterator: + """Return self as the iterator.""" + return self + + async def __anext__(self) -> ActivityExecution: + """Get the next execution on this iterator, fetching next page if + necessary. + """ + if self._limit is not None and self._yielded >= self._limit: + raise StopAsyncIteration + while True: + # No page? fetch and continue + if self._current_page is None: + await self.fetch_next_page() + continue + # No more left in page? + if self._current_page_index >= len(self._current_page): + # If there is a next page token, try to get another page and try + # again + if self._next_page_token is not None: + await self.fetch_next_page() + continue + # No more pages means we're done + raise StopAsyncIteration + # Get current, increment page index, and return + ret = self._current_page[self._current_page_index] + self._current_page_index += 1 + self._yielded += 1 + return ret + + +@dataclass(frozen=True) +class ActivityExecution: + """Info for an activity execution not started by a workflow, from list response. + + .. warning:: + This API is experimental. + """ + + activity_id: str + """Activity ID.""" + + activity_run_id: str | None + """Run ID of the activity.""" + + activity_type: str + """Type name of the activity.""" + + close_time: datetime | None + """Time the activity reached a terminal status, if closed.""" + + execution_duration: timedelta | None + """Duration from scheduled to close time, only populated if closed.""" + + namespace: str + """Namespace of the activity (copied from calling client).""" + + raw_info: ( + temporalio.api.activity.v1.ActivityExecutionListInfo + | temporalio.api.activity.v1.ActivityExecutionInfo + ) + """Underlying protobuf info.""" + + scheduled_time: datetime + """Time the activity was originally scheduled.""" + + search_attributes: temporalio.common.SearchAttributes + """Search attributes from the start request.""" + + state_transition_count: int | None + """Number of state transitions, if available.""" + + status: temporalio.common.ActivityExecutionStatus + """Current status of the activity.""" + + task_queue: str + """Task queue the activity was scheduled on.""" + + @classmethod + def _from_raw_info( + cls, + info: temporalio.api.activity.v1.ActivityExecutionListInfo, + namespace: str, + _converter: temporalio.converter.DataConverter, + ) -> Self: + """Create from raw proto activity list info.""" + return cls( + activity_id=info.activity_id, + activity_run_id=info.run_id or None, + activity_type=( + info.activity_type.name if info.HasField("activity_type") else "" + ), + close_time=( + info.close_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("close_time") + else None + ), + execution_duration=( + info.execution_duration.ToTimedelta() + if info.HasField("execution_duration") + else None + ), + namespace=namespace, + raw_info=info, + scheduled_time=( + info.schedule_time.ToDatetime().replace(tzinfo=timezone.utc) + if info.HasField("schedule_time") + else datetime.min + ), + search_attributes=temporalio.converter.decode_search_attributes( + info.search_attributes + ), + state_transition_count=( + info.state_transition_count if info.state_transition_count else None + ), + status=( + temporalio.common.ActivityExecutionStatus(info.status) + if info.status + else temporalio.common.ActivityExecutionStatus.RUNNING + ), + task_queue=info.task_queue, + ) + + +@dataclass(frozen=True) +class ActivityExecutionCountAggregationGroup: + """A single aggregation group from a count activities call. + + .. warning:: + This API is experimental. + """ + + count: int + """Count for this group.""" + + group_values: Sequence[temporalio.common.SearchAttributeValue] + """Values that define this group.""" + + +@dataclass(frozen=True) +class ActivityExecutionCount: + """Representation of a count from a count activities call. + + .. warning:: + This API is experimental. + """ + + count: int + """Total count matching the filter, if any.""" + + groups: Sequence[ActivityExecutionCountAggregationGroup] + """Aggregation groups if requested.""" + + @staticmethod + def _from_raw( + resp: temporalio.api.workflowservice.v1.CountActivityExecutionsResponse, + ) -> ActivityExecutionCount: + """Create from raw proto response.""" + return ActivityExecutionCount( + count=resp.count, + groups=[ + ActivityExecutionCountAggregationGroup( + count=g.count, + group_values=[ + temporalio.converter._decode_search_attribute_value(v) + for v in g.group_values + ], + ) + for g in resp.groups + ], + ) + + +@dataclass(frozen=True) +class ActivityExecutionOutcome: + """Outcome of a completed activity execution. + + .. warning:: + This API is experimental. + """ + + result: Sequence[Any] | None + """The decoded result if the activity completed successfully.""" + + failure: BaseException | None + """The failure if the activity completed unsuccessfully.""" + + +@dataclass(frozen=True) +class ActivityExecutionDescription(ActivityExecution): + """Detailed information about an activity execution not started by a workflow. + + .. warning:: + This API is experimental. + """ + + attempt: int + """Current attempt number.""" + + canceled_reason: str | None + """Reason for cancellation, if cancel was requested.""" + + current_retry_interval: timedelta | None + """Time until the next retry, if applicable.""" + + eager_execution_requested: bool + """Whether eager execution was requested for this activity.""" + + expiration_time: datetime + """Scheduled time plus schedule_to_close_timeout.""" - self._start_workflow_input = UpdateWithStartStartWorkflowInput( - workflow=name, - args=temporalio.common._arg_or_args(arg, args), - id=id, - task_queue=task_queue, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - headers={}, - ret_type=result_type or result_type_from_run_fn, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - priority=priority, - versioning_override=versioning_override, - ) - self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future() - self._used = False + heartbeat_details: Sequence[Any] + """Details from the last heartbeat.""" - async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]: - """Wait until workflow is running and return a WorkflowHandle.""" - return await self._workflow_handle + input: Sequence[Any] | None + """Serialized activity input. None if include_input was False.""" + + last_attempt_complete_time: datetime | None + """Time when the last attempt completed.""" + + last_failure: Exception | None + """Failure from the last failed attempt, if any.""" + + last_heartbeat_time: datetime | None + """Time of the last heartbeat.""" + + last_started_time: datetime | None + """Time the last attempt was started.""" + + last_worker_identity: str + """Identity of the last worker that processed the activity.""" + + next_attempt_schedule_time: datetime | None + """Time when the next attempt will be scheduled.""" + + paused: bool + """Whether the activity is paused.""" + + retry_policy: temporalio.common.RetryPolicy | None + """Retry policy for the activity.""" + + run_state: temporalio.common.PendingActivityState | None + """More detailed breakdown if status is RUNNING.""" + + outcome: ActivityExecutionOutcome | None + """Outcome of the activity if completed and include_outcome was True.""" + + long_poll_token: bytes | None + """Token for follow-on long-poll requests. None if the activity is complete.""" + + @classmethod + async def _from_execution_info( + cls, + info: temporalio.api.activity.v1.ActivityExecutionInfo, + input: temporalio.api.common.v1.Payloads | None, + outcome: temporalio.api.activity.v1.ActivityExecutionOutcome | None, + long_poll_token: bytes | None, + namespace: str, + data_converter: temporalio.converter.DataConverter, + ) -> Self: + """Create from raw proto activity execution info.""" + # Decode outcome if present + decoded_outcome: ActivityExecutionOutcome | None = None + if outcome is not None: + if outcome.HasField("result"): + decoded_outcome = ActivityExecutionOutcome( + result=await data_converter.decode(outcome.result.payloads), + failure=None, + ) + elif outcome.HasField("failure"): + decoded_outcome = ActivityExecutionOutcome( + result=None, + failure=await data_converter.decode_failure(outcome.failure), + ) + + return cls( + activity_id=info.activity_id, + activity_run_id=info.run_id or None, + activity_type=( + info.activity_type.name if info.HasField("activity_type") else "" + ), + attempt=info.attempt, + canceled_reason=info.canceled_reason or None, + close_time=( + info.close_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("close_time") + else None + ), + current_retry_interval=( + info.current_retry_interval.ToTimedelta() + if info.HasField("current_retry_interval") + else None + ), + eager_execution_requested=getattr(info, "eager_execution_requested", False), + execution_duration=( + info.execution_duration.ToTimedelta() + if info.HasField("execution_duration") + else None + ), + expiration_time=( + info.expiration_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("expiration_time") + else datetime.min + ), + heartbeat_details=( + await data_converter.decode(info.heartbeat_details.payloads) + if info.HasField("heartbeat_details") + else [] + ), + input=( + await data_converter.decode(input.payloads) + if input is not None + else None + ), + last_attempt_complete_time=( + info.last_attempt_complete_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_attempt_complete_time") + else None + ), + last_failure=( + cast( + Exception | None, + await data_converter.decode_failure(info.last_failure), + ) + if info.HasField("last_failure") + else None + ), + last_heartbeat_time=( + info.last_heartbeat_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_heartbeat_time") + else None + ), + last_started_time=( + info.last_started_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("last_started_time") + else None + ), + last_worker_identity=info.last_worker_identity, + long_poll_token=long_poll_token or None, + namespace=namespace, + next_attempt_schedule_time=( + info.next_attempt_schedule_time.ToDatetime(tzinfo=timezone.utc) + if info.HasField("next_attempt_schedule_time") + else None + ), + outcome=decoded_outcome, + paused=getattr(info, "paused", False), + raw_info=info, + retry_policy=temporalio.common.RetryPolicy.from_proto(info.retry_policy) + if info.HasField("retry_policy") + else None, + run_state=( + temporalio.common.PendingActivityState(info.run_state) + if info.run_state + else None + ), + scheduled_time=(info.schedule_time.ToDatetime(tzinfo=timezone.utc)), + search_attributes=temporalio.converter.decode_search_attributes( + info.search_attributes + ), + state_transition_count=( + info.state_transition_count if info.state_transition_count else None + ), + status=( + temporalio.common.ActivityExecutionStatus(info.status) + if info.status + else temporalio.common.ActivityExecutionStatus.RUNNING + ), + task_queue=info.task_queue, + ) @dataclass(frozen=True) -class AsyncActivityIDReference: - """Reference to an async activity by its qualified ID.""" +class ActivityIDReference: + """Information identifying an activity execution. - workflow_id: str + .. warning:: + This API is experimental. + """ + + workflow_id: str | None run_id: str | None activity_id: str +# Deprecated alias +AsyncActivityIDReference = ActivityIDReference + + class AsyncActivityHandle(WithSerializationContext): """Handle representing an external activity for completion and heartbeat.""" @@ -2815,27 +4551,280 @@ async def report_cancellation( ), ) - def with_context(self, context: SerializationContext) -> Self: - """Create a new AsyncActivityHandle with a different serialization context. + def with_context(self, context: SerializationContext) -> Self: + """Create a new AsyncActivityHandle with a different serialization context. + + Payloads received by the activity will be decoded and deserialized using a data converter + with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data + converter that makes use of this context then you can use this method to supply matching + context data to the data converter used to serialize and encode the outbound payloads. + """ + data_converter = self._client.data_converter.with_context(context) + if data_converter is self._client.data_converter: + return self + cls = type(self) + if cls.__init__ is not AsyncActivityHandle.__init__: + raise TypeError( + "If you have subclassed AsyncActivityHandle and overridden the __init__ method " + "then you must override with_context to return an instance of your class." + ) + return cls( + self._client, + self._id_or_token, + data_converter, + ) + + +# TODO: in the future when messages can be sent to activities, we will want the activity handle to +# be generic in the activity type in addition to the return type (as WorkflowHandle), to support +# static type inference for signal/query/update. +class ActivityHandle(Generic[ReturnType]): + """Handle representing an activity execution not started by a workflow. + + .. warning:: + This API is experimental. + """ + + def __init__( + self, + client: Client, + activity_id: str, + *, + activity_run_id: str | None = None, + result_type: type | None = None, + data_converter_override: DataConverter | None = None, + ) -> None: + """Create activity handle.""" + self._client = client + self._activity_id = activity_id + self._activity_run_id = activity_run_id + self._result_type = result_type + self._data_converter_override = data_converter_override + self._known_outcome: ( + temporalio.api.activity.v1.ActivityExecutionOutcome | None + ) = None + self._cached_result: ReturnType | None = None + self._result_fetched: bool = False + + @property + def activity_id(self) -> str: + """ID of the activity.""" + return self._activity_id + + @property + def activity_run_id(self) -> str | None: + """Run ID of the activity.""" + return self._activity_run_id + + def with_context(self, context: SerializationContext) -> Self: + """Create a new ActivityHandle with a different serialization context. + + Payloads received by the activity will be decoded and deserialized using a data converter + with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data + converter that makes use of this context then you can use this method to supply matching + context data to the data converter used to serialize and encode the outbound payloads. + """ + data_converter = self._client.data_converter.with_context(context) + if data_converter is self._client.data_converter: + return self + cls = type(self) + if cls.__init__ is not ActivityHandle.__init__: + raise TypeError( + "If you have subclassed ActivityHandle and overridden the __init__ method " + "then you must override with_context to return an instance of your class." + ) + return cls( + self._client, + activity_id=self._activity_id, + activity_run_id=self._activity_run_id, + result_type=self._result_type, + data_converter_override=data_converter, + ) + + async def result( + self, + *, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ReturnType: + """Wait for result of the activity. + + .. warning:: + This API is experimental. + + The result may already be known if this method has been called before, + in which case no network call is made. Otherwise the result will be + polled for until it is available. + + Args: + rpc_metadata: Headers used on the RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. Note: + this is the timeout for each RPC call while polling, not a + timeout for the function as a whole. If an individual RPC + times out, it will be retried until the result is available. + + Returns: + The result of the activity. + + Raises: + ActivityFailedError: If the activity completed with a failure. + RPCError: Activity result could not be fetched for some reason. + """ + if self._result_fetched: + return cast(ReturnType, self._cached_result) + + result = await self._client._impl.get_activity_result( + GetActivityResultInput[ReturnType]( + activity_id=self._activity_id, + activity_run_id=self._activity_run_id, + result_type=self._result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + self._cached_result = result + self._result_fetched = True + return result + + async def _poll_until_outcome( + self, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Poll for activity result until it's available.""" + if self._known_outcome: + return + + req = temporalio.api.workflowservice.v1.PollActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=self._activity_id, + run_id=self._activity_run_id or "", + ) + + # Continue polling as long as we have no outcome + while True: + try: + res = await self._client.workflow_service.poll_activity_execution( + req, + retry=True, + metadata=rpc_metadata, + timeout=rpc_timeout, + ) + if res.HasField("outcome"): + self._known_outcome = res.outcome + return + except RPCError as err: + if err.status == RPCStatusCode.DEADLINE_EXCEEDED: + # Deadline exceeded is expected with long polling; retry + continue + elif err.status == RPCStatusCode.CANCELLED: + raise asyncio.CancelledError() from err + else: + raise + except asyncio.CancelledError: + raise + + async def cancel( + self, + *, + reason: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Request cancellation of the activity. + + .. warning:: + This API is experimental. + + Requesting cancellation of an activity does not automatically transition the activity to + canceled status. If the activity is heartbeating, a :py:class:`exceptions.CancelledError` + exception will be raised when receiving the heartbeat response; if the activity allows this + exception to bubble out, the activity will transition to canceled status. If the activity it + is not heartbeating, this method will have no effect on activity status. + + Args: + reason: Reason for the cancellation. Recorded and available via describe. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + """ + await self._client._impl.cancel_activity( + CancelActivityInput( + activity_id=self._activity_id, + activity_run_id=self._activity_run_id, + reason=reason, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def terminate( + self, + *, + reason: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> None: + """Terminate the activity execution immediately. + + .. warning:: + This API is experimental. + + Termination does not reach the worker and the activity code cannot react to it. + A terminated activity may have a running attempt and will be requested to be + canceled by the server when it heartbeats. + + Args: + reason: Reason for the termination. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + """ + await self._client._impl.terminate_activity( + TerminateActivityInput( + activity_id=self._activity_id, + activity_run_id=self._activity_run_id, + reason=reason, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def describe( + self, + *, + include_input: bool = True, + include_outcome: bool = False, + long_poll_token: bytes | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> ActivityExecutionDescription: + """Describe the activity execution. - Payloads received by the activity will be decoded and deserialized using a data converter - with :py:class:`ActivitySerializationContext` set as context. If you are using a custom data - converter that makes use of this context then you can use this method to supply matching - context data to the data converter used to serialize and encode the outbound payloads. + .. warning:: + This API is experimental. + + Args: + include_input: If True, include the activity input in the response. + include_outcome: If True, include the outcome (result/failure) for + completed activities. + long_poll_token: Token from a previous describe response. If provided, + the request will long-poll until the activity state changes. + rpc_metadata: Headers used on the RPC call. + rpc_timeout: Optional RPC deadline to set for the RPC call. + + Returns: + Activity execution description. """ - data_converter = self._client.data_converter.with_context(context) - if data_converter is self._client.data_converter: - return self - cls = type(self) - if cls.__init__ is not AsyncActivityHandle.__init__: - raise TypeError( - "If you have subclassed AsyncActivityHandle and overridden the __init__ method " - "then you must override with_context to return an instance of your class." + return await self._client._impl.describe_activity( + DescribeActivityInput( + activity_id=self._activity_id, + activity_run_id=self._activity_run_id, + include_input=include_input, + include_outcome=include_outcome, + long_poll_token=long_poll_token, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, ) - return cls( - self._client, - self._id_or_token, - data_converter, ) @@ -3011,7 +5000,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -4555,7 +6544,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -4804,7 +6793,7 @@ async def memo_value( key: Key to get memo value for. default: Default to use if key is not present. If unset, a :py:class:`KeyError` is raised when the key does not exist. - type_hint: Type hint to use when converting. + type_hint: type hint to use when converting. Returns: Memo value, converted with the type hint if present. @@ -5263,6 +7252,25 @@ def __init__(self) -> None: super().__init__("Timeout or cancellation waiting for update") +class ActivityFailedError(temporalio.exceptions.TemporalError): + """Error that occurs when an activity is unsuccessful. + + .. warning:: + This API is experimental. + """ + + def __init__(self, *, cause: BaseException) -> None: + """Create activity failure error.""" + super().__init__("Activity execution failed") + self.__cause__ = cause + + @property + def cause(self) -> BaseException: + """Cause of the activity failure.""" + assert self.__cause__ + return self.__cause__ + + class AsyncActivityCancelledError(temporalio.exceptions.TemporalError): """Error that occurs when async activity attempted heartbeat but was cancelled.""" @@ -5417,6 +7425,125 @@ class TerminateWorkflowInput: rpc_timeout: timedelta | None +@dataclass +class StartActivityInput: + """Input for :py:meth:`OutboundInterceptor.start_activity`. + + .. warning:: + This API is experimental. + """ + + activity_type: str + args: Sequence[Any] + id: str + task_queue: str + result_type: type | None + schedule_to_close_timeout: timedelta | None + start_to_close_timeout: timedelta | None + schedule_to_start_timeout: timedelta | None + heartbeat_timeout: timedelta | None + id_reuse_policy: temporalio.common.ActivityIDReusePolicy + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy + retry_policy: temporalio.common.RetryPolicy | None + priority: temporalio.common.Priority + search_attributes: temporalio.common.TypedSearchAttributes | None + summary: str | None + headers: Mapping[str, temporalio.api.common.v1.Payload] + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class CancelActivityInput: + """Input for :py:meth:`OutboundInterceptor.cancel_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + reason: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class TerminateActivityInput: + """Input for :py:meth:`OutboundInterceptor.terminate_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + reason: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class DescribeActivityInput: + """Input for :py:meth:`OutboundInterceptor.describe_activity`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + include_input: bool + include_outcome: bool + long_poll_token: bytes | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class GetActivityResultInput(Generic[ReturnType]): + """Input for :py:meth:`OutboundInterceptor.get_activity_result`. + + .. warning:: + This API is experimental. + """ + + activity_id: str + activity_run_id: str | None + result_type: type[ReturnType] | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + +@dataclass +class ListActivitiesInput: + """Input for :py:meth:`OutboundInterceptor.list_activities`. + + .. warning:: + This API is experimental. + """ + + query: str | None + page_size: int + next_page_token: bytes | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + limit: int | None + + +@dataclass +class CountActivitiesInput: + """Input for :py:meth:`OutboundInterceptor.count_activities`. + + .. warning:: + This API is experimental. + """ + + query: str | None + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None + + @dataclass class StartWorkflowUpdateInput: """Input for :py:meth:`OutboundInterceptor.start_workflow_update`.""" @@ -5498,7 +7625,7 @@ class StartWorkflowUpdateWithStartInput: class HeartbeatAsyncActivityInput: """Input for :py:meth:`OutboundInterceptor.heartbeat_async_activity`.""" - id_or_token: AsyncActivityIDReference | bytes + id_or_token: ActivityIDReference | bytes details: Sequence[Any] rpc_metadata: Mapping[str, str | bytes] rpc_timeout: timedelta | None @@ -5509,7 +7636,7 @@ class HeartbeatAsyncActivityInput: class CompleteAsyncActivityInput: """Input for :py:meth:`OutboundInterceptor.complete_async_activity`.""" - id_or_token: AsyncActivityIDReference | bytes + id_or_token: ActivityIDReference | bytes result: Any | None rpc_metadata: Mapping[str, str | bytes] rpc_timeout: timedelta | None @@ -5520,7 +7647,7 @@ class CompleteAsyncActivityInput: class FailAsyncActivityInput: """Input for :py:meth:`OutboundInterceptor.fail_async_activity`.""" - id_or_token: AsyncActivityIDReference | bytes + id_or_token: ActivityIDReference | bytes error: Exception last_heartbeat_details: Sequence[Any] rpc_metadata: Mapping[str, str | bytes] @@ -5532,7 +7659,7 @@ class FailAsyncActivityInput: class ReportCancellationAsyncActivityInput: """Input for :py:meth:`OutboundInterceptor.report_cancellation_async_activity`.""" - id_or_token: AsyncActivityIDReference | bytes + id_or_token: ActivityIDReference | bytes details: Sequence[Any] rpc_metadata: Mapping[str, str | bytes] rpc_timeout: timedelta | None @@ -5751,6 +7878,72 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: """Called for every :py:meth:`WorkflowHandle.terminate` call.""" await self.next.terminate_workflow(input) + ### Activity calls + + async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: + """Called for every :py:meth:`Client.start_activity` call. + + .. warning:: + This API is experimental. + """ + return await self.next.start_activity(input) + + async def cancel_activity(self, input: CancelActivityInput) -> None: + """Called for every :py:meth:`ActivityHandle.cancel` call. + + .. warning:: + This API is experimental. + """ + await self.next.cancel_activity(input) + + async def terminate_activity(self, input: TerminateActivityInput) -> None: + """Called for every :py:meth:`ActivityHandle.terminate` call. + + .. warning:: + This API is experimental. + """ + await self.next.terminate_activity(input) + + async def describe_activity( + self, input: DescribeActivityInput + ) -> ActivityExecutionDescription: + """Called for every :py:meth:`ActivityHandle.describe` call. + + .. warning:: + This API is experimental. + """ + return await self.next.describe_activity(input) + + async def get_activity_result( + self, input: GetActivityResultInput[ReturnType] + ) -> ReturnType: + """Called for every :py:meth:`ActivityHandle.result` call. + + .. warning:: + This API is experimental. + """ + return await self.next.get_activity_result(input) + + def list_activities( + self, input: ListActivitiesInput + ) -> ActivityExecutionAsyncIterator: + """Called for every :py:meth:`Client.list_activities` call. + + .. warning:: + This API is experimental. + """ + return self.next.list_activities(input) + + async def count_activities( + self, input: CountActivitiesInput + ) -> ActivityExecutionCount: + """Called for every :py:meth:`Client.count_activities` call. + + .. warning:: + This API is experimental. + """ + return await self.next.count_activities(input) + async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: @@ -6202,6 +8395,230 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) + async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: + """Start an activity and return a handle to it.""" + if not (input.start_to_close_timeout or input.schedule_to_close_timeout): + raise ValueError( + "Activity must have start_to_close_timeout or schedule_to_close_timeout" + ) + req = await self._build_start_activity_execution_request(input) + + # TODO(dan): any counterpart of WorkflowExecutionAlreadyStartedFailure? + # If RPCError with err.status == RPCStatusCode.ALREADY_EXISTS + + resp = await self._client.workflow_service.start_activity_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + return ActivityHandle( + self._client, + activity_id=input.id, + activity_run_id=resp.run_id, + result_type=input.result_type, + ) + + async def _build_start_activity_execution_request( + self, input: StartActivityInput + ) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest: + """Build StartActivityExecutionRequest from input.""" + data_converter = self._client.data_converter.with_context( + ActivitySerializationContext( + namespace=self._client.namespace, + activity_id=input.id, + activity_type=input.activity_type, + activity_task_queue=input.task_queue, + is_local=False, + workflow_id=None, + workflow_type=None, + ) + ) + + req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( + namespace=self._client.namespace, + identity=self._client.identity, + activity_id=input.id, + activity_type=temporalio.api.common.v1.ActivityType( + name=input.activity_type + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), + id_reuse_policy=cast( + "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ), + id_conflict_policy=cast( + "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ), + ) + + if input.schedule_to_close_timeout is not None: + req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout) + if input.start_to_close_timeout is not None: + req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) + if input.schedule_to_start_timeout is not None: + req.schedule_to_start_timeout.FromTimedelta(input.schedule_to_start_timeout) + if input.heartbeat_timeout is not None: + req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) + + # Set input payloads + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + + # Set search attributes + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes + ) + + # Set user metadata + metadata = await _encode_user_metadata(data_converter, input.summary, None) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) + + # Set headers + if input.headers: + await self._apply_headers(input.headers, req.header.fields) + + # Set priority + req.priority.CopyFrom(input.priority._to_proto()) + + return req + + async def cancel_activity(self, input: CancelActivityInput) -> None: + """Cancel an activity.""" + await self._client.workflow_service.request_cancel_activity_execution( + temporalio.api.workflowservice.v1.RequestCancelActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + identity=self._client.identity, + request_id=str(uuid.uuid4()), + reason=input.reason or "", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + + async def terminate_activity(self, input: TerminateActivityInput) -> None: + """Terminate an activity.""" + await self._client.workflow_service.terminate_activity_execution( + temporalio.api.workflowservice.v1.TerminateActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + reason=input.reason or "", + identity=self._client.identity, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + + async def describe_activity( + self, input: DescribeActivityInput + ) -> ActivityExecutionDescription: + """Describe an activity.""" + resp = await self._client.workflow_service.describe_activity_execution( + temporalio.api.workflowservice.v1.DescribeActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + include_input=input.include_input, + include_outcome=input.include_outcome, + long_poll_token=input.long_poll_token or b"", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + return await ActivityExecutionDescription._from_execution_info( + info=resp.info, + input=resp.input if resp.HasField("input") else None, + outcome=resp.outcome if resp.HasField("outcome") else None, + long_poll_token=resp.long_poll_token or None, + namespace=self._client.namespace, + data_converter=self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.activity_id, # Using activity_id as workflow_id for activities not started by a workflow + ) + ), + ) + + async def get_activity_result( + self, input: GetActivityResultInput[ReturnType] + ) -> ReturnType: + """Get the result of an activity.""" + req = temporalio.api.workflowservice.v1.PollActivityExecutionRequest( + namespace=self._client.namespace, + activity_id=input.activity_id, + run_id=input.activity_run_id or "", + ) + + # Poll until we have an outcome + outcome: temporalio.api.activity.v1.ActivityExecutionOutcome | None = None + while outcome is None: + try: + res = await self._client.workflow_service.poll_activity_execution( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + if res.HasField("outcome"): + outcome = res.outcome + except RPCError as err: + if err.status == RPCStatusCode.DEADLINE_EXCEEDED: + # Deadline exceeded is expected with long polling; retry + continue + elif err.status == RPCStatusCode.CANCELLED: + raise asyncio.CancelledError() from err + else: + raise + except asyncio.CancelledError: + raise + + # Decode the outcome + data_converter = self._client.data_converter + if outcome.HasField("failure"): + raise ActivityFailedError( + cause=await data_converter.decode_failure(outcome.failure), + ) + + # Decode result + type_hints: list[type] | None = ( + [input.result_type] if input.result_type else None + ) + results = await data_converter.decode(outcome.result.payloads, type_hints) + if not results: + return cast(ReturnType, None) + return cast(ReturnType, results[0]) + + def list_activities( + self, input: ListActivitiesInput + ) -> ActivityExecutionAsyncIterator: + return ActivityExecutionAsyncIterator(self._client, input) + + async def count_activities( + self, input: CountActivitiesInput + ) -> ActivityExecutionCount: + return ActivityExecutionCount._from_raw( + await self._client.workflow_service.count_activity_executions( + temporalio.api.workflowservice.v1.CountActivityExecutionsRequest( + namespace=self._client.namespace, + query=input.query or "", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + ) + async def start_workflow_update( self, input: StartWorkflowUpdateInput ) -> WorkflowUpdateHandle[Any]: @@ -6445,7 +8862,7 @@ async def heartbeat_async_activity( if isinstance(input.id_or_token, AsyncActivityIDReference): resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6500,7 +8917,7 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_completed_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6537,7 +8954,7 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_failed_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, @@ -6575,7 +8992,7 @@ async def report_cancellation_async_activity( if isinstance(input.id_or_token, AsyncActivityIDReference): await self._client.workflow_service.respond_activity_task_canceled_by_id( temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( - workflow_id=input.id_or_token.workflow_id, + workflow_id=input.id_or_token.workflow_id or "", run_id=input.id_or_token.run_id or "", activity_id=input.id_or_token.activity_id, namespace=self._client.namespace, diff --git a/temporalio/common.py b/temporalio/common.py index b6dd67a4e..05a428ac4 100644 --- a/temporalio/common.py +++ b/temporalio/common.py @@ -146,6 +146,110 @@ class WorkflowIDConflictPolicy(IntEnum): ) +class ActivityIDReusePolicy(IntEnum): + """How already-closed activity IDs are handled on start. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityIdReusePolicy`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_UNSPECIFIED + ) + ALLOW_DUPLICATE = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_ALLOW_DUPLICATE + ) + ALLOW_DUPLICATE_FAILED_ONLY = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY + ) + REJECT_DUPLICATE = int( + temporalio.api.enums.v1.ActivityIdReusePolicy.ACTIVITY_ID_REUSE_POLICY_REJECT_DUPLICATE + ) + + +class ActivityIDConflictPolicy(IntEnum): + """How already-running activity IDs are handled on start. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityIdConflictPolicy`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_UNSPECIFIED + ) + FAIL = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_FAIL + ) + USE_EXISTING = int( + temporalio.api.enums.v1.ActivityIdConflictPolicy.ACTIVITY_ID_CONFLICT_POLICY_USE_EXISTING + ) + + +class ActivityExecutionStatus(IntEnum): + """Status of an activity execution. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.ActivityExecutionStatus`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_UNSPECIFIED + ) + RUNNING = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_RUNNING + ) + COMPLETED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_COMPLETED + ) + FAILED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_FAILED + ) + CANCELED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_CANCELED + ) + TERMINATED = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_TERMINATED + ) + TIMED_OUT = int( + temporalio.api.enums.v1.ActivityExecutionStatus.ACTIVITY_EXECUTION_STATUS_TIMED_OUT + ) + + +class PendingActivityState(IntEnum): + """Detailed state of an activity execution that is in ACTIVITY_EXECUTION_STATUS_RUNNING. + + .. warning:: + This API is experimental. + + See :py:class:`temporalio.api.enums.v1.PendingActivityState`. + """ + + UNSPECIFIED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_UNSPECIFIED + ) + SCHEDULED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_SCHEDULED + ) + STARTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_STARTED + ) + CANCEL_REQUESTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_CANCEL_REQUESTED + ) + PAUSED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_PAUSED + ) + PAUSE_REQUESTED = int( + temporalio.api.enums.v1.PendingActivityState.PENDING_ACTIVITY_STATE_PAUSE_REQUESTED + ) + + class QueryRejectCondition(IntEnum): """Whether a query should be rejected in certain conditions. diff --git a/temporalio/contrib/openai_agents/_mcp.py b/temporalio/contrib/openai_agents/_mcp.py index c9d1f87ea..8d6a9464a 100644 --- a/temporalio/contrib/openai_agents/_mcp.py +++ b/temporalio/contrib/openai_agents/_mcp.py @@ -445,7 +445,7 @@ def name(self) -> str: def _get_activities(self) -> Sequence[Callable]: def _server_id(): - return self.name + "@" + activity.info().workflow_run_id + return self.name + "@" + (activity.info().workflow_run_id or "") @activity.defn(name=self.name + "-list-tools") async def list_tools() -> list[MCPTool]: @@ -491,7 +491,7 @@ async def connect( ) -> None: heartbeat_task = asyncio.create_task(heartbeat_every(30)) - server_id = self.name + "@" + activity.info().workflow_run_id + server_id = self.name + "@" + (activity.info().workflow_run_id or "") if server_id in self._servers: raise ApplicationError( "Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow." diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index ef1e52bb2..3e08ea68d 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -355,14 +355,15 @@ async def execute_activity( self, input: temporalio.worker.ExecuteActivityInput ) -> Any: info = temporalio.activity.info() + attributes: dict[str, str] = {"temporalActivityID": info.activity_id} + if info.workflow_id: + attributes["temporalWorkflowID"] = info.workflow_id + if info.workflow_run_id: + attributes["temporalRunID"] = info.workflow_run_id with self.root._start_as_current_span( f"RunActivity:{info.activity_type}", context=self.root._context_from_headers(input.headers), - attributes={ - "temporalWorkflowID": info.workflow_id, - "temporalRunID": info.workflow_run_id, - "temporalActivityID": info.activity_id, - }, + attributes=attributes, kind=opentelemetry.trace.SpanKind.SERVER, ): return await super().execute_activity(input) diff --git a/temporalio/converter.py b/temporalio/converter.py index 3849a47f4..f91a54f2c 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -82,15 +82,7 @@ class SerializationContext(ABC): @dataclass(frozen=True) -class BaseWorkflowSerializationContext(SerializationContext): - """Base serialization context shared by workflow and activity serialization contexts.""" - - namespace: str - workflow_id: str - - -@dataclass(frozen=True) -class WorkflowSerializationContext(BaseWorkflowSerializationContext): +class WorkflowSerializationContext(SerializationContext): """Serialization context for workflows. See :py:class:`SerializationContext` for more details. @@ -103,30 +95,51 @@ class WorkflowSerializationContext(BaseWorkflowSerializationContext): when the workflow is created by the schedule. """ - pass + namespace: str + """Namespace.""" + + workflow_id: str | None + """Workflow ID.""" @dataclass(frozen=True) -class ActivitySerializationContext(BaseWorkflowSerializationContext): +class ActivitySerializationContext(SerializationContext): """Serialization context for activities. See :py:class:`SerializationContext` for more details. Attributes: namespace: Workflow/activity namespace. - workflow_id: Workflow ID. Note, when creating/describing schedules, + activity_id: Activity ID. Optional if this is an activity started from a workflow. + activity_type: Activity type. + activity_task_queue: Activity task queue. + workflow_id: Workflow ID. Only set if this is an activity started from a workflow. Note, when creating/describing schedules, this may be the workflow ID prefix as configured, not the final workflow ID when the workflow is created by the schedule. - workflow_type: Workflow Type. - activity_type: Activity Type. - activity_task_queue: Activity task queue. - is_local: Whether the activity is a local activity. + workflow_type: Workflow Type. Only set if this is an activity started from a workflow. + is_local: Whether the activity is a local activity. False if the activity was not started by a workflow. """ - workflow_type: str + namespace: str + """Namespace.""" + + activity_id: str | None + """Activity ID. Optional if this is an activity started from a workflow.""" + activity_type: str + """Activity type.""" + activity_task_queue: str + """Activity task queue.""" + + workflow_id: str | None + """Workflow ID if this is an activity started from a workflow.""" + + workflow_type: str | None + """Workflow type if this is an activity started from a workflow.""" + is_local: bool + """Whether the activity is a local activity started from a workflow.""" class WithSerializationContext(ABC): diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..e6f2dc49c 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -247,6 +247,10 @@ class RetryState(IntEnum): ) +# TODO: This error class has required history event fields. I propose we retain it as +# workflow-specific and introduce client.ActivityFailedError for an error in an activity not +# started by a workflow. We could deprecate this name and introduce WorkflowActivityError as a +# preferred-going-forwards alias. class ActivityError(FailureError): """Error raised on activity failure.""" @@ -362,6 +366,8 @@ def retry_state(self) -> RetryState | None: return self._retry_state +# TODO: This error class has required history event fields. Would we retain it as workflow-specific +# and introduce client.NexusOperationFailureError? See related note on ActivityError above. class NexusOperationError(FailureError): """Error raised on Nexus operation failure inside a Workflow.""" diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 0098a91e1..ae7d2a38b 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -31,6 +31,7 @@ heartbeat_details=[], heartbeat_timeout=None, is_local=False, + namespace="default", schedule_to_close_timeout=timedelta(seconds=1), scheduled_time=_utc_zero, start_to_close_timeout=timedelta(seconds=1), @@ -43,6 +44,7 @@ workflow_type="test", priority=temporalio.common.Priority.default, retry_policy=None, + activity_run_id=None, ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 93249fad5..6368999ec 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -250,10 +250,11 @@ async def _heartbeat_async( data_converter = self._data_converter if activity.info: context = temporalio.converter.ActivitySerializationContext( - namespace=activity.info.workflow_namespace, + namespace=activity.info.namespace, workflow_id=activity.info.workflow_id, workflow_type=activity.info.workflow_type, activity_type=activity.info.activity_type, + activity_id=activity.info.activity_id, activity_task_queue=self._task_queue, is_local=activity.info.is_local, ) @@ -302,10 +303,11 @@ async def _handle_start_activity_task( ) # Create serialization context for the activity context = temporalio.converter.ActivitySerializationContext( - namespace=start.workflow_namespace, + namespace=start.workflow_namespace or self._client.namespace, workflow_id=start.workflow_execution.workflow_id, workflow_type=start.workflow_type, activity_type=start.activity_type, + activity_id=start.activity_id, activity_task_queue=self._task_queue, is_local=start.is_local, ) @@ -545,6 +547,7 @@ async def _execute_activity( ) from err # Build info + started_by_workflow = bool(start.workflow_execution.workflow_id) info = temporalio.activity.Info( activity_id=start.activity_id, activity_type=start.activity_type, @@ -557,6 +560,7 @@ async def _execute_activity( if start.HasField("heartbeat_timeout") else None, is_local=start.is_local, + namespace=start.workflow_namespace or self._client.namespace, schedule_to_close_timeout=_proto_to_non_zero_timedelta( start.schedule_to_close_timeout ) @@ -571,14 +575,17 @@ async def _execute_activity( started_time=_proto_to_datetime(start.started_time), task_queue=self._task_queue, task_token=task_token, - workflow_id=start.workflow_execution.workflow_id, - workflow_namespace=start.workflow_namespace, - workflow_run_id=start.workflow_execution.run_id, - workflow_type=start.workflow_type, + workflow_id=start.workflow_execution.workflow_id or None, + workflow_namespace=start.workflow_namespace or None, + workflow_run_id=start.workflow_execution.run_id or None, + workflow_type=start.workflow_type or None, priority=temporalio.common.Priority._from_proto(start.priority), retry_policy=temporalio.common.RetryPolicy.from_proto(start.retry_policy) if start.HasField("retry_policy") else None, + activity_run_id=getattr(start, "run_id", None) + if not started_by_workflow + else None, ) if self._encode_headers and data_converter.payload_codec is not None: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 10fd594fd..a8785d6fb 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -790,6 +790,7 @@ def _apply_resolve_activity( workflow_id=self._info.workflow_id, workflow_type=self._info.workflow_type, activity_type=handle._input.activity, + activity_id=handle._input.activity_id, activity_task_queue=( handle._input.task_queue or self._info.task_queue if isinstance(handle._input, StartActivityInput) @@ -2123,6 +2124,7 @@ def get_serialization_context( workflow_id=self._info.workflow_id, workflow_type=self._info.workflow_type, activity_type=activity_handle._input.activity, + activity_id=activity_handle._input.activity_id, activity_task_queue=( activity_handle._input.task_queue if isinstance(activity_handle._input, StartActivityInput) @@ -2918,6 +2920,7 @@ def __init__( workflow_id=self._instance._info.workflow_id, workflow_type=self._instance._info.workflow_type, activity_type=self._input.activity, + activity_id=self._input.activity_id, activity_task_queue=( self._input.task_queue or self._instance._info.task_queue if isinstance(self._input, StartActivityInput) diff --git a/tests/test_activity.py b/tests/test_activity.py new file mode 100644 index 000000000..e31216ac8 --- /dev/null +++ b/tests/test_activity.py @@ -0,0 +1,1178 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +from temporalio import activity, workflow +from temporalio.client import ( + ActivityExecutionCount, + ActivityExecutionCountAggregationGroup, + ActivityExecutionDescription, + ActivityFailedError, + ActivityHandle, + CancelActivityInput, + Client, + CountActivitiesInput, + DescribeActivityInput, + GetActivityResultInput, + Interceptor, + ListActivitiesInput, + OutboundInterceptor, + StartActivityInput, + TerminateActivityInput, +) +from temporalio.common import ActivityExecutionStatus, PendingActivityState +from temporalio.exceptions import ApplicationError, CancelledError +from temporalio.service import RPCError, RPCStatusCode +from temporalio.worker import Worker +from tests.helpers import assert_eq_eventually + + +@activity.defn +async def increment(input: int) -> int: + return input + 1 + + +# Activity classes for testing start_activity_class / execute_activity_class +@activity.defn +class IncrementClass: + """Async callable class activity with a parameter.""" + + async def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class NoParamClass: + """Async callable class activity with no parameters.""" + + async def __call__(self) -> str: + return "no-param-result" + + +@activity.defn +class SyncIncrementClass: + """Sync callable class activity with a parameter.""" + + def __call__(self, x: int) -> int: + return x + 1 + + +# Activity holder for testing start_activity_method / execute_activity_method +class ActivityHolder: + """Class holding activity methods.""" + + @activity.defn + async def async_increment(self, x: int) -> int: + return x + 1 + + @activity.defn + async def async_no_param(self) -> str: + return "async-method-result" + + @activity.defn + def sync_increment(self, x: int) -> int: + return x + 1 + + +class TestDescribe: + @pytest.fixture + async def activity_handle(self, client: Client): + id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + yield await client.start_activity( + increment, + args=(42,), + id=id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(hours=1), + ) + + async def test_describe(self, client: Client, activity_handle: ActivityHandle): + desc = await activity_handle.describe() + # From ActivityExecution (base class) + assert desc.activity_id == activity_handle.activity_id + assert desc.activity_run_id == activity_handle.activity_run_id + assert desc.activity_type == "increment" + assert desc.close_time is None # not closed yet + assert desc.execution_duration is None # not closed yet + assert desc.namespace == client.namespace + assert desc.raw_info is not None + assert desc.scheduled_time is not None + assert desc.search_attributes == {} + assert desc.state_transition_count is not None + assert desc.status == ActivityExecutionStatus.RUNNING + assert desc.task_queue + # From ActivityExecutionDescription + assert desc.attempt == 1 + assert desc.canceled_reason is None + assert desc.current_retry_interval is None + assert desc.eager_execution_requested is False + assert desc.expiration_time is not None + assert desc.heartbeat_details == [] + assert desc.input == [42] + assert desc.outcome is None + assert desc.run_state == PendingActivityState.SCHEDULED + assert desc.last_attempt_complete_time is None + assert desc.last_failure is None + assert desc.last_heartbeat_time is None + assert desc.last_started_time is None + assert desc.last_worker_identity == "" + assert desc.long_poll_token is not None + assert desc.next_attempt_schedule_time is None + assert desc.paused is False + assert desc.retry_policy is not None + + assert (await activity_handle.describe(include_input=False)).input is None + + async def test_describe_include_outcome(self, activity_handle: ActivityHandle): + desc = await activity_handle.describe() + assert desc.outcome is None + assert (await activity_handle.describe(include_outcome=True)).outcome is None + async with Worker( + activity_handle._client, + task_queue=desc.task_queue, + activities=[increment], + ): + await activity_handle.result() + desc = await activity_handle.describe(include_outcome=True) + assert desc.status == ActivityExecutionStatus.COMPLETED + assert desc.run_state is None + assert desc.outcome and desc.outcome.result == [43] + + async def test_describe_long_poll(self, activity_handle: ActivityHandle): + desc1 = await activity_handle.describe() + assert desc1.long_poll_token + desc2_task = asyncio.create_task( + activity_handle.describe(long_poll_token=desc1.long_poll_token) + ) + # Worker poll causes a transition to Started which notifies the waiting long-poll. + async with Worker( + activity_handle._client, + task_queue=desc1.task_queue, + activities=[increment], + ): + desc2 = await desc2_task + assert desc2.state_transition_count and desc1.state_transition_count + assert desc2.state_transition_count > desc1.state_transition_count + + +class ActivityTracingInterceptor(Interceptor): + """Test interceptor that tracks all activity interceptor calls.""" + + def __init__(self) -> None: + super().__init__() + self.start_activity_calls: list[StartActivityInput] = [] + self.get_activity_result_calls: list[GetActivityResultInput] = [] + self.describe_activity_calls: list[DescribeActivityInput] = [] + self.cancel_activity_calls: list[CancelActivityInput] = [] + self.terminate_activity_calls: list[TerminateActivityInput] = [] + self.list_activities_calls: list[ListActivitiesInput] = [] + self.count_activities_calls: list[CountActivitiesInput] = [] + + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return ActivityTracingOutboundInterceptor(self, next) + + +class ActivityTracingOutboundInterceptor(OutboundInterceptor): + def __init__( + self, + parent: ActivityTracingInterceptor, + next: OutboundInterceptor, + ) -> None: + super().__init__(next) + self._parent = parent + + async def start_activity(self, input: StartActivityInput): + assert isinstance(input, StartActivityInput) + self._parent.start_activity_calls.append(input) + return await super().start_activity(input) + + async def get_activity_result(self, input: GetActivityResultInput): + assert isinstance(input, GetActivityResultInput) + self._parent.get_activity_result_calls.append(input) + return await super().get_activity_result(input) + + async def describe_activity(self, input: DescribeActivityInput): + assert isinstance(input, DescribeActivityInput) + self._parent.describe_activity_calls.append(input) + return await super().describe_activity(input) + + async def cancel_activity(self, input: CancelActivityInput): + assert isinstance(input, CancelActivityInput) + self._parent.cancel_activity_calls.append(input) + return await super().cancel_activity(input) + + async def terminate_activity(self, input: TerminateActivityInput): + assert isinstance(input, TerminateActivityInput) + self._parent.terminate_activity_calls.append(input) + return await super().terminate_activity(input) + + def list_activities(self, input: ListActivitiesInput): + assert isinstance(input, ListActivitiesInput) + self._parent.list_activities_calls.append(input) + return super().list_activities(input) + + async def count_activities(self, input: CountActivitiesInput): + assert isinstance(input, CountActivitiesInput) + self._parent.count_activities_calls.append(input) + return await super().count_activities(input) + + +async def test_start_activity_calls_interceptor(client: Client): + """Client.start_activity() should call the start_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + assert len(interceptor.start_activity_calls) == 1 + call = interceptor.start_activity_calls[0] + assert call.id == activity_id + assert call.task_queue == task_queue + assert call.activity_type == "increment" + + +async def test_get_activity_result_calls_interceptor(client: Client): + """ActivityHandle.result() should call the get_activity_result interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + intercepted_client, + task_queue=task_queue, + activities=[increment], + ): + result = await activity_handle.result() + assert result == 2 + + assert len(interceptor.get_activity_result_calls) == 1 + call = interceptor.get_activity_result_calls[0] + assert call.activity_id == activity_id + + +async def test_describe_activity_calls_interceptor(client: Client): + """ActivityHandle.describe() should call the describe_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + desc = await activity_handle.describe() + assert isinstance(desc, ActivityExecutionDescription) + + assert len(interceptor.describe_activity_calls) == 1 + call = interceptor.describe_activity_calls[0] + assert call.activity_id == activity_id + + +async def test_cancel_activity_calls_interceptor(client: Client): + """ActivityHandle.cancel() should call the cancel_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + await activity_handle.cancel(reason="test cancellation") + + assert len(interceptor.cancel_activity_calls) == 1 + call = interceptor.cancel_activity_calls[0] + assert call.activity_id == activity_id + assert call.reason == "test cancellation" + + +async def test_terminate_activity_calls_interceptor(client: Client): + """ActivityHandle.terminate() should call the terminate_activity interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + await activity_handle.terminate(reason="test termination") + + assert len(interceptor.terminate_activity_calls) == 1 + call = interceptor.terminate_activity_calls[0] + assert call.activity_id == activity_id + assert call.reason == "test termination" + + +async def test_list_activities_calls_interceptor(client: Client): + """Client.list_activities() should call the list_activities interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + query = f'ActivityId = "{activity_id}"' + async for _ in intercepted_client.list_activities(query): + pass + + assert len(interceptor.list_activities_calls) >= 1 + call = interceptor.list_activities_calls[0] + assert call.query == query + + +async def test_count_activities_calls_interceptor(client: Client): + """Client.count_activities() should call the count_activities interceptor.""" + interceptor = ActivityTracingInterceptor() + intercepted_client = Client( + service_client=client.service_client, + namespace=client.namespace, + interceptors=[interceptor], + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await intercepted_client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + query = f'ActivityId = "{activity_id}"' + count = await intercepted_client.count_activities(query) + assert isinstance(count, ActivityExecutionCount) + + assert len(interceptor.count_activities_calls) == 1 + call = interceptor.count_activities_calls[0] + assert call.query == query + + +async def test_get_result(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + result_via_execute_activity = client.execute_activity( + increment, + args=(1,), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + assert await activity_handle.result() == 2 + assert await result_via_execute_activity == 2 + + +async def test_get_activity_handle(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + handle_by_id = client.get_activity_handle(activity_id) + assert handle_by_id.activity_id == activity_id + assert handle_by_id.activity_run_id is None + + handle_by_id_and_run_id = client.get_activity_handle( + activity_id, + activity_run_id=activity_handle.activity_run_id, + ) + assert handle_by_id_and_run_id.activity_id == activity_id + assert handle_by_id_and_run_id.activity_run_id == activity_handle.activity_run_id + + handle_with_result_type = client.get_activity_handle( + activity_id, + result_type=int, + activity_run_id=activity_handle.activity_run_id, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + assert await handle_by_id.result() == 2 + assert await handle_by_id_and_run_id.result() == 2 + assert await handle_with_result_type.result() == 2 + + +async def test_list_activities(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + executions = [ + e async for e in client.list_activities(f'ActivityId = "{activity_id}"') + ] + assert len(executions) == 1 + execution = executions[0] + assert execution.activity_id == activity_id + assert execution.activity_type == "increment" + assert execution.task_queue == task_queue + assert execution.status == ActivityExecutionStatus.RUNNING + # TODO: not being set by server? + # assert isinstance(execution.state_transition_count, int) + + +async def test_count_activities(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async def fetch_count(): + return await client.count_activities(f'ActivityId = "{activity_id}"') + + await assert_eq_eventually( + ActivityExecutionCount(count=1, groups=[]), + fetch_count, + ) + + +async def test_count_activities_group_by(client: Client): + from temporalio.client import ActivityExecutionCount + + task_queue = str(uuid.uuid4()) + activity_ids = [] + + for _ in range(3): + activity_id = str(uuid.uuid4()) + activity_ids.append(activity_id) + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + ) + + ids_filter = " OR ".join([f'ActivityId = "{aid}"' for aid in activity_ids]) + + async def fetch_count() -> ActivityExecutionCount: + return await client.count_activities(f"({ids_filter}) GROUP BY ExecutionStatus") + + await assert_eq_eventually( + ActivityExecutionCount( + count=3, + groups=[ + ActivityExecutionCountAggregationGroup( + count=3, group_values=["Running"] + ), + ], + ), + fetch_count, + ) + + +@dataclass +class ActivityInput: + event_workflow_id: str + wait_for_activity_start_workflow_id: str | None = None + + +@activity.defn +async def async_activity(input: ActivityInput) -> int: + # Notify test that the activity has started and is ready to be completed manually + await ( + activity.client() + .get_workflow_handle(input.event_workflow_id) + .signal(EventWorkflow.set) + ) + activity.raise_complete_async() + + +async def test_manual_completion(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + args=(ActivityInput(event_workflow_id=event_workflow_id),), # TODO: overloads + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + # Wait for activity to start + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + # Complete activity manually + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.activity_run_id, + ) + await async_activity_handle.complete(7) + assert await activity_handle.result() == 7 + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.COMPLETED + + +async def test_manual_cancellation(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + args=(ActivityInput(event_workflow_id=event_workflow_id),), # TODO: overloads + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + # Wait for activity to start + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.activity_run_id, + ) + + # report_cancellation fails if activity is not in CANCELLATION_REQUESTED state + with pytest.raises(RPCError) as err: + await async_activity_handle.report_cancellation("Test cancellation") + assert err.value.status == RPCStatusCode.FAILED_PRECONDITION + assert "invalid transition from Started" in str(err.value) + + # Request cancellation to transition activity to CANCELLATION_REQUESTED state + await activity_handle.cancel() + + # Now report_cancellation succeeds + await async_activity_handle.report_cancellation("Test cancellation") + + with pytest.raises(ActivityFailedError) as exc_info: + await activity_handle.result() + assert isinstance(exc_info.value.cause, CancelledError) + assert list(exc_info.value.cause.details) == ["Test cancellation"] + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.CANCELED + + +async def test_manual_failure(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + args=(ActivityInput(event_workflow_id=event_workflow_id),), # TODO: overloads + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.activity_run_id, + ) + await async_activity_handle.fail( + ApplicationError("Test failure", non_retryable=True) + ) + with pytest.raises(ActivityFailedError) as err: + await activity_handle.result() + assert isinstance(err.value.cause, ApplicationError) + assert str(err.value.cause) == "Test failure" + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.FAILED + + +@activity.defn +async def activity_for_testing_heartbeat(input: ActivityInput) -> str: + info = activity.info() + if info.attempt == 1: + # Signal that activity has started (only on first attempt) + if input.wait_for_activity_start_workflow_id: + await ( + activity.client() + .get_workflow_handle( + workflow_id=input.wait_for_activity_start_workflow_id, + ) + .signal(EventWorkflow.set) + ) + wait_for_heartbeat_wf_handle = await activity.client().start_workflow( + EventWorkflow.wait, + id=input.event_workflow_id, + task_queue=activity.info().task_queue, + ) + # Wait for test to notify that it has sent heartbeat + await wait_for_heartbeat_wf_handle.result() + raise Exception("Intentional error to force retry") + elif info.attempt == 2: + [heartbeat_data] = info.heartbeat_details + assert isinstance(heartbeat_data, str) + return heartbeat_data + else: + raise AssertionError(f"Unexpected attempt number: {info.attempt}") + + +async def test_manual_heartbeat(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + wait_for_activity_start_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + activity_for_testing_heartbeat, + args=( + ActivityInput( + event_workflow_id=event_workflow_id, + wait_for_activity_start_workflow_id=wait_for_activity_start_workflow_id, + ), + ), # TODO: overloads + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + wait_for_activity_start_wf_handle = await client.start_workflow( + EventWorkflow.wait, + id=wait_for_activity_start_workflow_id, + task_queue=task_queue, + ) + async with Worker( + client, + task_queue=task_queue, + activities=[activity_for_testing_heartbeat], + workflows=[EventWorkflow], + ): + async_activity_handle = client.get_async_activity_handle( + activity_id=activity_id, + run_id=activity_handle.activity_run_id, + ) + await wait_for_activity_start_wf_handle.result() + await async_activity_handle.heartbeat("Test heartbeat details") + await client.get_workflow_handle( + workflow_id=event_workflow_id, + ).signal(EventWorkflow.set) + assert await activity_handle.result() == "Test heartbeat details" + + +async def test_id_conflict_policy_fail(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDConflictPolicy + + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.FAIL, + ) + + with pytest.raises(RPCError) as err: + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.FAIL, + ) + assert err.value.status == RPCStatusCode.ALREADY_EXISTS + + +async def test_id_conflict_policy_use_existing(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDConflictPolicy + + handle1 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.USE_EXISTING, + ) + + handle2 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + id_conflict_policy=ActivityIDConflictPolicy.USE_EXISTING, + ) + + assert handle1.activity_id == handle2.activity_id + assert handle1.activity_run_id == handle2.activity_run_id + + +async def test_id_reuse_policy_reject_duplicate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDReusePolicy + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.REJECT_DUPLICATE, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + await handle.result() + + with pytest.raises(RPCError) as err: + await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.REJECT_DUPLICATE, + ) + assert err.value.status == RPCStatusCode.ALREADY_EXISTS + + +async def test_id_reuse_policy_allow_duplicate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + from temporalio.common import ActivityIDReusePolicy + + handle1 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.ALLOW_DUPLICATE, + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[increment], + ): + await handle1.result() + + handle2 = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + id_reuse_policy=ActivityIDReusePolicy.ALLOW_DUPLICATE, + ) + + assert handle1.activity_id == handle2.activity_id + assert handle1.activity_run_id != handle2.activity_run_id + + +async def test_search_attributes(client: Client): + from temporalio.common import ( + SearchAttributeKey, + SearchAttributePair, + TypedSearchAttributes, + ) + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + temporal_change_version_key = SearchAttributeKey.for_keyword_list( + "TemporalChangeVersion" + ) + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + schedule_to_close_timeout=timedelta(seconds=60), + search_attributes=TypedSearchAttributes( + [SearchAttributePair(temporal_change_version_key, ["test-1", "test-2"])] + ), + ) + + desc = await handle.describe() + assert desc.search_attributes is not None + assert desc.search_attributes["TemporalChangeVersion"] == ["test-1", "test-2"] + + +async def test_retry_policy(client: Client): + from temporalio.common import RetryPolicy + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity( + increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=1), + maximum_interval=timedelta(seconds=10), + backoff_coefficient=2.0, + maximum_attempts=3, + ), + ) + + desc = await handle.describe() + assert desc.retry_policy is not None + assert desc.retry_policy.initial_interval == timedelta(seconds=1) + assert desc.retry_policy.maximum_interval == timedelta(seconds=10) + assert desc.retry_policy.backoff_coefficient == 2.0 + assert desc.retry_policy.maximum_attempts == 3 + + +async def test_terminate(client: Client): + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + event_workflow_id = str(uuid.uuid4()) + + activity_handle = await client.start_activity( + async_activity, + args=(ActivityInput(event_workflow_id=event_workflow_id),), + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[async_activity], + workflows=[EventWorkflow], + ): + await client.execute_workflow( + EventWorkflow.wait, + id=event_workflow_id, + task_queue=task_queue, + ) + + await activity_handle.terminate(reason="Test termination") + + with pytest.raises(ActivityFailedError): + await activity_handle.result() + + desc = await activity_handle.describe() + assert desc.status == ActivityExecutionStatus.TERMINATED + + +# Tests for start_activity_class / execute_activity_class + + +async def test_start_activity_class_async(client: Client): + """Test start_activity_class with an async callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + IncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[IncrementClass()], + ): + result = await handle.result() + assert result == 2 + + +async def test_execute_activity_class_async(client: Client): + """Test execute_activity_class with an async callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + async with Worker( + client, + task_queue=task_queue, + activities=[IncrementClass()], + ): + result = await client.execute_activity_class( + IncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + assert result == 2 + + +async def test_start_activity_class_no_param(client: Client): + """Test start_activity_class with a no-param callable class.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + NoParamClass, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[NoParamClass()], + ): + result = await handle.result() + assert result == "no-param-result" + + +async def test_start_activity_class_sync(client: Client): + """Test start_activity_class with a sync callable class.""" + import concurrent.futures + + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + handle = await client.start_activity_class( + SyncIncrementClass, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=task_queue, + activities=[SyncIncrementClass()], + activity_executor=executor, + ): + result = await handle.result() + assert result == 2 + + +# Tests for start_activity_method / execute_activity_method + + +async def test_start_activity_method_async(client: Client): + """Test start_activity_method with an async method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + handle = await client.start_activity_method( + ActivityHolder.async_increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_increment], + ): + result = await handle.result() + assert result == 2 + + +async def test_execute_activity_method_async(client: Client): + """Test execute_activity_method with an async method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_increment], + ): + result = await client.execute_activity_method( + ActivityHolder.async_increment, + 1, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + assert result == 2 + + +async def test_start_activity_method_no_param(client: Client): + """Test start_activity_method with a no-param method.""" + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + + holder = ActivityHolder() + handle = await client.start_activity_method( + ActivityHolder.async_no_param, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + + async with Worker( + client, + task_queue=task_queue, + activities=[holder.async_no_param], + ): + result = await handle.result() + assert result == "async-method-result" + + +# Utilities + + +@workflow.defn +class EventWorkflow: + """ + A workflow version of asyncio.Event() + """ + + def __init__(self) -> None: + self.signal_received = asyncio.Event() + + @workflow.run + async def wait(self) -> None: + await self.signal_received.wait() + + @workflow.signal + def set(self) -> None: + self.signal_received.set() diff --git a/tests/test_activity_type_errors.py b/tests/test_activity_type_errors.py new file mode 100644 index 000000000..fadf7d14b --- /dev/null +++ b/tests/test_activity_type_errors.py @@ -0,0 +1,491 @@ +""" +This file exists to test for type-checker false positives and false negatives +for the activity client API. + +It doesn't contain any test functions - it uses the machinery in test_type_errors.py +to verify that pyright produces the expected errors. +""" + +from datetime import timedelta +from unittest.mock import Mock + +from temporalio import activity +from temporalio.client import ActivityHandle, Client +from temporalio.service import ServiceClient + + +@activity.defn +async def increment(x: int) -> int: + return x + 1 + + +@activity.defn +async def greet(name: str) -> str: + return f"Hello, {name}" + + +@activity.defn +async def no_return(_: int) -> None: + pass + + +@activity.defn +async def no_param_async() -> str: + return "done" + + +@activity.defn +def increment_sync(x: int) -> int: + return x + 1 + + +@activity.defn +def no_param_sync() -> str: + return "done" + + +@activity.defn +class IncrementClass: + """Async activity defined as a callable class.""" + + async def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class NoParamClass: + """Async activity class with no parameters.""" + + async def __call__(self) -> str: + return "done" + + +@activity.defn +class SyncIncrementClass: + """Sync activity defined as a callable class.""" + + def __call__(self, x: int) -> int: + return x + 1 + + +@activity.defn +class SyncNoParamClass: + """Sync activity class with no parameters.""" + + def __call__(self) -> str: + return "done" + + +class ActivityHolder: + """Class holding activity methods.""" + + @activity.defn + async def increment_method(self, x: int) -> int: + return x + 1 + + @activity.defn + async def no_param_method(self) -> str: + return "done" + + +async def _test_start_activity_typed_callable_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + _result: int = await _handle.result() + + +async def _test_execute_activity_typed_callable_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_positional_arg_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_positional_arg_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_string_name_with_result_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle = await client.start_activity( + "increment", + args=[1], + id="activity-id", + task_queue="tq", + result_type=int, + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_no_param_async_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_async, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_no_param_async_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity( + no_param_async, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_no_param_sync_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_no_param_sync_happy_path() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_wrong_arg_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment, + # assert-type-error-pyright: 'cannot be assigned to parameter' + "wrong type", # type: ignore + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_wrong_arg_type() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment, + # assert-type-error-pyright: 'cannot be assigned to parameter' + "wrong type", # type: ignore + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_wrong_result_type_assignment() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + handle = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + # assert-type-error-pyright: 'Type "int" is not assignable to declared type "str"' + _wrong: str = await handle.result() # type: ignore + + +async def _test_execute_activity_wrong_result_type_assignment() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'Type "int" is not assignable to declared type "str"' + _wrong: str = await client.execute_activity( # type: ignore + increment, # type: ignore[arg-type] + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_missing_required_params() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'No overloads for "start_activity" match' + await client.start_activity( # type: ignore + increment, + args=[1], + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + # assert-type-error-pyright: 'No overloads for "start_activity" match' + await client.start_activity( # type: ignore + increment, + args=[1], + id="activity-id", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_activity_handle_typed_correctly() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + handle_int: ActivityHandle[int] = await client.start_activity( + increment, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _int_result: int = await handle_int.result() + + handle_str: ActivityHandle[str] = await client.start_activity( + greet, + args=["world"], + id="activity-id-2", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _str_result: str = await handle_str.result() + + handle_none: ActivityHandle[None] = await client.start_activity( + no_return, + args=[1], + id="activity-id-3", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + _none_result: None = await handle_none.result() # type: ignore[func-returns-value] + + +async def _test_activity_handle_wrong_type_parameter() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # assert-type-error-pyright: 'Type "ActivityHandle\[int\]" is not assignable to declared type "ActivityHandle\[str\]"' + _handle: ActivityHandle[str] = await client.start_activity( # type: ignore + increment, # type: ignore[arg-type] + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_sync_activity() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity( + increment_sync, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_sync_activity() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity( + increment_sync, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_sync_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity( + no_param_sync, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for start_activity_class and execute_activity_class +# Note: Type inference for callable classes is limited; use args= form + + +async def _test_start_activity_class_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity_class( + IncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity_class( + IncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_class_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity_class( + NoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: str = await client.execute_activity_class( + NoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for sync callable classes + + +async def _test_start_activity_class_sync_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[int] = await client.start_activity_class( + SyncIncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_class_sync_single_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _result: int = await client.execute_activity_class( + SyncIncrementClass, + 1, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_class_sync_no_param() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + _handle: ActivityHandle[str] = await client.start_activity_class( + SyncNoParamClass, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +# Tests for start_activity_method and execute_activity_method +# Note: The _method variants work best with unbound methods (class references). +# For bound methods accessed via instance, use start_activity directly. + + +async def _test_start_activity_method_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _handle: ActivityHandle[int] = await client.start_activity_method( + ActivityHolder.increment_method, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_method_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _result: int = await client.execute_activity_method( + ActivityHolder.increment_method, + args=[1], + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_start_activity_method_no_param_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _handle: ActivityHandle[str] = await client.start_activity_method( + ActivityHolder.no_param_method, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) + + +async def _test_execute_activity_method_no_param_unbound() -> None: # type:ignore[reportUnusedFunction] + client = Client(service_client=Mock(spec=ServiceClient)) + + # Using unbound method reference + _result: str = await client.execute_activity_method( + ActivityHolder.no_param_method, + id="activity-id", + task_queue="tq", + start_to_close_timeout=timedelta(seconds=5), + ) diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index 4e217861b..c346c46bb 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -180,6 +180,7 @@ async def run(self, data: TraceData) -> TraceData: data, start_to_close_timeout=timedelta(seconds=10), heartbeat_timeout=timedelta(seconds=2), + activity_id="activity-id", ) data = await workflow.execute_child_workflow( EchoWorkflow.run, data, id=f"{workflow.info().workflow_id}_child" @@ -232,6 +233,7 @@ async def test_payload_conversion_calls_follow_expected_sequence_and_contexts( workflow_id=workflow_id, workflow_type=PayloadConversionWorkflow.__name__, activity_type=passthrough_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -329,6 +331,7 @@ async def run(self) -> TraceData: initial_interval=timedelta(milliseconds=100), maximum_attempts=2, ), + activity_id="activity-id", ) @@ -371,6 +374,7 @@ async def test_heartbeat_details_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=HeartbeatDetailsSerializationContextTestWorkflow.__name__, activity_type=activity_with_heartbeat_details.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -420,6 +424,7 @@ async def run(self, data: TraceData) -> TraceData: local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -460,6 +465,7 @@ async def test_local_activity_payload_conversion(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityWorkflow.__name__, activity_type=local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -505,7 +511,7 @@ async def test_local_activity_payload_conversion(client: Client): @workflow.defn -class EventWorkflow: +class WaitForSignalWorkflow: # Like a global asyncio.Event() def __init__(self) -> None: @@ -522,10 +528,11 @@ def signal(self) -> None: @activity.defn async def async_activity() -> TraceData: + # Notify test that the activity has started and is ready to be completed manually await ( activity.client() .get_workflow_handle("activity-started-wf-id") - .signal(EventWorkflow.signal) + .signal(WaitForSignalWorkflow.signal) ) activity.raise_complete_async() @@ -559,7 +566,7 @@ async def test_async_activity_completion_payload_conversion( task_queue=task_queue, workflows=[ AsyncActivityCompletionSerializationContextTestWorkflow, - EventWorkflow, + WaitForSignalWorkflow, ], activities=[async_activity], workflow_runner=UnsandboxedWorkflowRunner(), # so that we can use isinstance @@ -573,12 +580,13 @@ async def test_async_activity_completion_payload_conversion( workflow_id=workflow_id, workflow_type=AsyncActivityCompletionSerializationContextTestWorkflow.__name__, activity_type=async_activity.__name__, + activity_id="async-activity-id", activity_task_queue=task_queue, is_local=False, ) act_started_wf_handle = await client.start_workflow( - EventWorkflow.run, + WaitForSignalWorkflow.run, id="activity-started-wf-id", task_queue=task_queue, ) @@ -645,6 +653,7 @@ def test_subclassed_async_activity_handle(client: Client): workflow_id="workflow-id", workflow_type="workflow-type", activity_type="activity-type", + activity_id="activity-id", activity_task_queue="activity-task-queue", is_local=False, ) @@ -1059,11 +1068,12 @@ async def run(self) -> Never: failing_activity, start_to_close_timeout=timedelta(seconds=10), retry_policy=RetryPolicy(maximum_attempts=1), + activity_id="activity-id", ) raise Exception("Unreachable") -test_traces: dict[str, list[TraceItem]] = defaultdict(list) +test_traces: dict[str | None, list[TraceItem]] = defaultdict(list) class FailureConverterWithContext(DefaultFailureConverter, WithSerializationContext): @@ -1155,6 +1165,7 @@ async def test_failure_converter_with_context(client: Client): workflow_id=workflow_id, workflow_type=FailureConverterTestWorkflow.__name__, activity_type=failing_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=False, ) @@ -1323,6 +1334,7 @@ async def run(self, data: str) -> str: codec_test_local_activity, data, start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ) @@ -1361,6 +1373,7 @@ async def test_local_activity_codec_with_context(client: Client): workflow_id=workflow_id, workflow_type=LocalActivityCodecTestWorkflow.__name__, activity_type=codec_test_local_activity.__name__, + activity_id="activity-id", activity_task_queue=task_queue, is_local=True, ) @@ -1594,6 +1607,7 @@ async def run(self, _data: str) -> str: payload_encryption_activity, "outbound", start_to_close_timeout=timedelta(seconds=10), + activity_id="activity-id", ), workflow.execute_child_workflow( PayloadEncryptionChildWorkflow.run, diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index e66a42dc0..f811fd1b5 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -1259,6 +1259,9 @@ def async_handle(self, client: Client, use_task_token: bool) -> AsyncActivityHan assert self._info if use_task_token: return client.get_async_activity_handle(task_token=self._info.task_token) + assert ( + self._info.workflow_id + ) # These tests are for workflow-triggered activities return client.get_async_activity_handle( workflow_id=self._info.workflow_id, run_id=self._info.workflow_run_id, @@ -1739,8 +1742,8 @@ async def wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, ) @@ -1759,8 +1762,8 @@ def sync_wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, ) @@ -1811,8 +1814,8 @@ async def wait_cancel() -> str: req = temporalio.api.workflowservice.v1.ResetActivityRequest( namespace=client.namespace, execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=activity.info().workflow_id, - run_id=activity.info().workflow_run_id, + workflow_id=activity.info().workflow_id or "", + run_id=activity.info().workflow_run_id or "", ), id=activity.info().activity_id, )