diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index f37c76f19..84c3f215c 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -293,21 +293,21 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, - codec: temporalio.converter.PayloadCodec, + data_converter: temporalio.converter.DataConverter, decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(codec.decode), activation) + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, - codec: temporalio.converter.PayloadCodec, + data_converter: temporalio.converter.DataConverter, encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(codec.encode), completion) + ).visit(_Visitor(data_converter._encode_payload_sequence), completion) diff --git a/temporalio/client.py b/temporalio/client.py index b4d5af0fa..849fb953a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -2977,10 +2977,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_info.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_info.memo) @overload async def memo_value( @@ -3019,16 +3016,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_info.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_info.memo, key, default, type_hint + ) @dataclass @@ -4209,18 +4199,9 @@ async def _to_proto( workflow_run_timeout=run_timeout, workflow_task_timeout=task_timeout, retry_policy=retry_policy, - memo=( - temporalio.api.common.v1.Memo( - fields={ - k: v - if isinstance(v, temporalio.api.common.v1.Payload) - else (await data_converter.encode([v]))[0] - for k, v in self.memo.items() - }, - ) - if self.memo - else None - ), + memo=await data_converter._convert_to_memo(self.memo) + if self.memo + else None, user_metadata=await _encode_user_metadata( data_converter, self.static_summary, self.static_details ), @@ -4249,7 +4230,7 @@ async def _to_proto( client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC and not self._from_raw, - client.data_converter.payload_codec, + client.data_converter, ) return action @@ -4521,10 +4502,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_description.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_description.memo) @overload async def memo_value( @@ -4563,16 +4541,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_description.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_description.memo, key, default, type_hint + ) @dataclass @@ -4770,10 +4741,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_entry.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_entry.memo) @overload async def memo_value( @@ -4812,16 +4780,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_entry.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_entry.memo, key, default, type_hint + ) @dataclass @@ -6014,8 +5975,7 @@ async def _populate_start_workflow_execution_request( input.retry_policy.apply_to_proto(req.retry_policy) req.cron_schedule = input.cron_schedule if input.memo is not None: - for k, v in input.memo.items(): - req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0]) + await data_converter._convert_into_memo(input.memo, req.memo) if input.search_attributes is not None: temporalio.converter.encode_search_attributes( input.search_attributes, req.search_attributes @@ -6641,14 +6601,9 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle: initial_patch=initial_patch, identity=self._client.identity, request_id=str(uuid.uuid4()), - memo=None - if not input.memo - else temporalio.api.common.v1.Memo( - fields={ - k: (await self._client.data_converter.encode([v]))[0] - for k, v in input.memo.items() - }, - ), + memo=await self._client.data_converter._convert_to_memo(input.memo) + if input.memo + else None, ) if input.search_attributes: temporalio.converter.encode_search_attributes( @@ -6870,7 +6825,7 @@ async def _apply_headers( dest, self._client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC, - self._client.data_converter.payload_codec, + self._client.data_converter, ) @@ -6878,14 +6833,13 @@ async def _apply_headers( source: Mapping[str, temporalio.api.common.v1.Payload] | None, dest: MessageMap[str, temporalio.api.common.v1.Payload], encode_headers: bool, - codec: temporalio.converter.PayloadCodec | None, + data_converter: DataConverter, ) -> None: if source is None: return - if encode_headers and codec is not None: + if encode_headers: for payload in source.values(): - new_payload = (await codec.encode([payload]))[0] - payload.CopyFrom(new_payload) + payload.CopyFrom(await data_converter._encode_payload(payload)) temporalio.common._apply_headers(source, dest) diff --git a/temporalio/converter.py b/temporalio/converter.py index 3849a47f4..7de095c81 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -824,45 +824,14 @@ async def encode_failure(self, failure: temporalio.api.failure.v1.Failure) -> No It is not guaranteed that all failures will be encoded with this method rather than encoding the underlying payloads. """ - await self._apply_to_failure_payloads(failure, self.encode_wrapper) + await DataConverter._apply_to_failure_payloads(failure, self.encode_wrapper) async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: """Decode payloads of a failure. Intended as a helper method, not for overriding. It is not guaranteed that all failures will be decoded with this method rather than decoding the underlying payloads. """ - await self._apply_to_failure_payloads(failure, self.decode_wrapper) - - async def _apply_to_failure_payloads( - self, - failure: temporalio.api.failure.v1.Failure, - cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], - ) -> None: - if failure.HasField("encoded_attributes"): - # Wrap in payloads and merge back - payloads = temporalio.api.common.v1.Payloads( - payloads=[failure.encoded_attributes] - ) - await cb(payloads) - failure.encoded_attributes.CopyFrom(payloads.payloads[0]) - if failure.HasField( - "application_failure_info" - ) and failure.application_failure_info.HasField("details"): - await cb(failure.application_failure_info.details) - elif failure.HasField( - "timeout_failure_info" - ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): - await cb(failure.timeout_failure_info.last_heartbeat_details) - elif failure.HasField( - "canceled_failure_info" - ) and failure.canceled_failure_info.HasField("details"): - await cb(failure.canceled_failure_info.details) - elif failure.HasField( - "reset_workflow_failure_info" - ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): - await cb(failure.reset_workflow_failure_info.last_heartbeat_details) - if failure.HasField("cause"): - await self._apply_to_failure_payloads(failure.cause, cb) + await DataConverter._apply_to_failure_payloads(failure, self.decode_wrapper) class FailureConverter(ABC): @@ -1238,6 +1207,20 @@ def __init__(self) -> None: super().__init__(encode_common_attributes=True) +@dataclass(frozen=True) +class PayloadLimitsConfig: + """Configuration for when payload sizes exceed limits.""" + + memo_upload_error_limit: int | None = None + """The limit at which a memo size error is created.""" + memo_upload_warning_limit: int = 2 * 1024 + """The limit at which a memo size warning is created.""" + payload_upload_error_limit: int | None = None + """The limit at which a payloads size error is created.""" + payload_upload_warning_limit: int = 512 * 1024 + """The limit at which a payloads size warning is created.""" + + @dataclass(frozen=True) class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. @@ -1261,6 +1244,9 @@ class DataConverter(WithSerializationContext): failure_converter: FailureConverter = dataclasses.field(init=False) """Failure converter created from the :py:attr:`failure_converter_class`.""" + payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() + """Settings for payload size limits.""" + default: ClassVar[DataConverter] """Singleton default data converter.""" @@ -1284,8 +1270,8 @@ async def encode( more than was given. """ payloads = self.payload_converter.to_payloads(values) - if self.payload_codec: - payloads = await self.payload_codec.encode(payloads) + payloads = await self._encode_payload_sequence(payloads) + return payloads async def decode( @@ -1303,8 +1289,7 @@ async def decode( Returns: Decoded and converted values. """ - if self.payload_codec: - payloads = await self.payload_codec.decode(payloads) + payloads = await self._decode_payload_sequence(payloads) return self.payload_converter.from_payloads(payloads, type_hints) async def encode_wrapper( @@ -1332,15 +1317,13 @@ async def encode_failure( ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) - if self.payload_codec: - await self.payload_codec.encode_failure(failure) + await DataConverter._apply_to_failure_payloads(failure, self._encode_payloads) async def decode_failure( self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" - if self.payload_codec: - await self.payload_codec.decode_failure(failure) + await DataConverter._apply_to_failure_payloads(failure, self._decode_payloads) return self.failure_converter.from_failure(failure, self.payload_converter) def with_context(self, context: SerializationContext) -> Self: @@ -1369,6 +1352,158 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "failure_converter", failure_converter) return cloned + async def _convert_from_memo( + self, + source: temporalio.api.common.v1.Memo, + ) -> Mapping[str, Any]: + mapping: dict[str, Any] = {} + for k, v in source.fields.items(): + mapping[k] = (await self.decode([v]))[0] + return mapping + + async def _convert_from_memo_field( + self, + source: temporalio.api.common.v1.Memo, + key: str, + default: Any, + type_hint: type | None, + ) -> dict[str, Any]: + payload = source.fields.get(key) + if not payload: + if default is temporalio.common._arg_unset: + raise KeyError(f"Memo does not have a value for key {key}") + return default + return (await self.decode([payload], [type_hint] if type_hint else None))[0] + + async def _convert_into_memo( + self, source: Mapping[str, Any], memo: temporalio.api.common.v1.Memo + ): + payloads: list[temporalio.api.common.v1.Payload] = [] + for k, v in source.items(): + payload = v + if not isinstance(v, temporalio.api.common.v1.Payload): + payload = (await self.encode([v]))[0] + memo.fields[k].CopyFrom(payload) + payloads.append(payload) + # Memos have their field payloads validated all together in one unit + self._validate_limits( + payloads, + self.payload_limits.memo_upload_error_limit, + self.payload_limits.memo_upload_warning_limit, + "Memo size exceeded the warning limit.", + ) + + async def _convert_to_memo( + self, source: Mapping[str, Any] + ) -> temporalio.api.common.v1.Memo: + memo = temporalio.api.common.v1.Memo() + await self._convert_into_memo(source, memo) + return memo + + async def _encode_payload( + self, payload: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + if self.payload_codec: + payload = (await self.payload_codec.encode([payload]))[0] + self._validate_payload_limits([payload]) + return payload + + async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.payload_codec: + await self.payload_codec.encode_wrapper(payloads) + self._validate_payload_limits(payloads.payloads) + + async def _encode_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded_payloads = list(payloads) + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(encoded_payloads) + self._validate_payload_limits(encoded_payloads) + return encoded_payloads + + async def _decode_payload( + self, payload: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + if self.payload_codec: + payload = (await self.payload_codec.decode([payload]))[0] + return payload + + async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.payload_codec: + await self.payload_codec.decode_wrapper(payloads) + + async def _decode_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded_payloads = list(payloads) + if self.payload_codec: + decoded_payloads = await self.payload_codec.decode(decoded_payloads) + return decoded_payloads + + @staticmethod + async def _apply_to_failure_payloads( + failure: temporalio.api.failure.v1.Failure, + cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], + ) -> None: + if failure.HasField("encoded_attributes"): + # Wrap in payloads and merge back + payloads = temporalio.api.common.v1.Payloads( + payloads=[failure.encoded_attributes] + ) + await cb(payloads) + failure.encoded_attributes.CopyFrom(payloads.payloads[0]) + if failure.HasField( + "application_failure_info" + ) and failure.application_failure_info.HasField("details"): + await cb(failure.application_failure_info.details) + elif failure.HasField( + "timeout_failure_info" + ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): + await cb(failure.timeout_failure_info.last_heartbeat_details) + elif failure.HasField( + "canceled_failure_info" + ) and failure.canceled_failure_info.HasField("details"): + await cb(failure.canceled_failure_info.details) + elif failure.HasField( + "reset_workflow_failure_info" + ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): + await cb(failure.reset_workflow_failure_info.last_heartbeat_details) + if failure.HasField("cause"): + await DataConverter._apply_to_failure_payloads(failure.cause, cb) + + def _validate_payload_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + ): + self._validate_limits( + payloads, + self.payload_limits.payload_upload_error_limit, + self.payload_limits.payload_upload_warning_limit, + "Payloads size exceeded the warning limit.", + ) + + def _validate_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + error_limit: int | None, + warning_limit: int, + warning_message: str, + ): + total_size = sum(payload.ByteSize() for payload in payloads) + + if error_limit and error_limit > 0 and total_size > error_limit: + raise temporalio.exceptions.PayloadSizeError( + size=total_size, + limit=error_limit, + ) + + if warning_limit and warning_limit > 0 and total_size > warning_limit: + # TODO: Use a context aware logger to log extra information about workflow/activity/etc + warnings.warn( + f"{warning_message} Size: {total_size} bytes, Limit: {warning_limit} bytes" + ) + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..96a644a61 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -446,3 +446,28 @@ def is_cancelled_exception(exception: BaseException) -> bool: and isinstance(exception.cause, CancelledError) ) ) + + +class PayloadSizeError(TemporalError): + """Error raised when payloads size exceeds payload size limits.""" + + def __init__(self, size: int, limit: int): + """Initialize a payloads limit error. + + Args: + size: Actual payloads size in bytes. + limit: Payloads size limit in bytes. + """ + super().__init__("Payloads size exceeded the error limit") + self._size = size + self._limit = limit + + @property + def payloads_size(self) -> int: + """Actual payloads size in bytes.""" + return self._size + + @property + def payloads_limit(self) -> int: + """Payloads size limit in bytes.""" + return self._limit diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 23f2ed5cc..6d6126ace 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -380,6 +380,19 @@ async def _handle_start_activity_task( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) + elif isinstance( + err, + temporalio.exceptions.PayloadSizeError, + ): + temporalio.activity.logger.warning( + "Activity task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + extra={"__temporal_error_identifier": "ActivityFailure"}, + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) else: if ( isinstance( @@ -577,10 +590,9 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter.payload_codec is not None: + if self._encode_headers: for payload in start.header_fields.values(): - new_payload = (await data_converter.payload_codec.decode([payload]))[0] - payload.CopyFrom(new_payload) + payload.CopyFrom(await data_converter._decode_payload(payload)) running_activity.info = info input = ExecuteActivityInput( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 16e0de5e8..48891f299 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import dataclasses import logging import os import sys @@ -270,20 +271,21 @@ async def _handle_activation( data_converter = self._data_converter.with_context(workflow_context) if self._data_converter.payload_codec: assert data_converter.payload_codec - if not workflow: - payload_codec = data_converter.payload_codec - else: - payload_codec = _CommandAwarePayloadCodec( - workflow.instance, - context_free_payload_codec=self._data_converter.payload_codec, - workflow_context_payload_codec=data_converter.payload_codec, - workflow_context=workflow_context, + if workflow: + data_converter = dataclasses.replace( + data_converter, + payload_codec=_CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=workflow_context, + ), ) - await temporalio.bridge.worker.decode_activation( - act, - payload_codec, - decode_headers=self._encode_headers, - ) + await temporalio.bridge.worker.decode_activation( + act, + data_converter, + decode_headers=self._encode_headers, + ) if not workflow: assert init_job workflow = _RunningWorkflow( @@ -351,27 +353,46 @@ async def _handle_activation( # Encode completion if self._data_converter.payload_codec and workflow: assert data_converter.payload_codec - payload_codec = _CommandAwarePayloadCodec( - workflow.instance, - context_free_payload_codec=self._data_converter.payload_codec, - workflow_context_payload_codec=data_converter.payload_codec, - workflow_context=temporalio.converter.WorkflowSerializationContext( - namespace=self._namespace, - workflow_id=workflow.workflow_id, - ), - ) - try: - await temporalio.bridge.worker.encode_completion( - completion, - payload_codec, - encode_headers=self._encode_headers, - ) - except Exception as err: - logger.exception( - "Failed encoding completion on workflow with run ID %s", act.run_id + if workflow: + data_converter = dataclasses.replace( + data_converter, + payload_codec=_CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=temporalio.converter.WorkflowSerializationContext( + namespace=self._namespace, + workflow_id=workflow.workflow_id, + ), + ), ) - completion.failed.Clear() - completion.failed.failure.message = f"Failed encoding completion: {err}" + + try: + await temporalio.bridge.worker.encode_completion( + completion, + data_converter, + encode_headers=self._encode_headers, + ) + except temporalio.exceptions.PayloadSizeError as err: + # TODO: Would like to use temporalio.workflow.logger here, but + # that requires being in the workflow event loop. Possibly refactor + # the logger core functionality into shareable class and update + # LoggerAdapter to be a decorator. + logger.warning( + "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + ) + completion.failed.Clear() + await data_converter.encode_failure(err, completion.failed.failure) + # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API + # completion.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE + except Exception as err: + logger.exception( + "Failed encoding completion on workflow with run ID %s", act.run_id + ) + completion.failed.Clear() + completion.failed.failure.message = f"Failed encoding completion: {err}" # Send off completion if LOG_PROTOS: diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index e66a42dc0..b97f08a8a 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -1687,7 +1687,7 @@ async def raise_error(): assert handler._trace_identifiers == 1 finally: - activity.logger.base_logger.removeHandler(CustomLogHandler()) + activity.logger.base_logger.removeHandler(handler) async def test_activity_heartbeat_context( diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 41e6ccad9..5604b8542 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,8 +1,10 @@ +import dataclasses from collections.abc import MutableSequence from google.protobuf.duration_pb2 import Duration import temporalio.bridge.worker +import temporalio.converter from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -228,7 +230,12 @@ async def test_bridge_encoding(): ), ) - await temporalio.bridge.worker.encode_completion(comp, SimpleCodec(), True) + data_converter = dataclasses.replace( + temporalio.converter.default(), + payload_codec=SimpleCodec(), + ) + + await temporalio.bridge.worker.encode_completion(comp, data_converter, True) cmd = comp.successful.commands[0] sa = cmd.schedule_activity diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 4d9299f09..d8401d684 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -15,6 +15,7 @@ import time import typing import uuid +import warnings from abc import ABC, abstractmethod from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -88,6 +89,7 @@ DefaultPayloadConverter, PayloadCodec, PayloadConverter, + PayloadLimitsConfig, ) from temporalio.exceptions import ( ActivityError, @@ -95,18 +97,23 @@ ApplicationErrorCategory, CancelledError, ChildWorkflowError, + PayloadSizeError, TemporalError, TimeoutError, + TimeoutType, WorkflowAlreadyStartedError, ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, BUFFERED_METRIC_KIND_HISTOGRAM, + LogForwardingConfig, + LoggingConfig, MetricBuffer, MetricBufferDurationFormat, PrometheusConfig, Runtime, TelemetryConfig, + TelemetryFilter, ) from temporalio.service import RPCError, RPCStatusCode, __version__ from temporalio.testing import WorkflowEnvironment @@ -8445,7 +8452,7 @@ async def run(self): class CustomLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: - import httpx # type: ignore[reportUnusedImport] + import httpx # type: ignore[reportUnusedImport] # noqa async def test_disable_logger_sandbox( @@ -8483,3 +8490,469 @@ async def test_disable_logger_sandbox( run_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=1), ) + + +@dataclass +class LargePayloadWorkflowInput: + activity_input_data_size: int + activity_output_data_size: int + workflow_output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadWorkflowOutput: + data: list[int] + + +@dataclass +class LargePayloadActivityInput: + output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadActivityOutput: + data: list[int] + + +@activity.defn +async def large_payload_activity( + input: LargePayloadActivityInput, +) -> LargePayloadActivityOutput: + return LargePayloadActivityOutput(data=[0] * input.output_data_size) + + +@workflow.defn +class LargePayloadWorkflow: + @workflow.run + async def run(self, input: LargePayloadWorkflowInput) -> LargePayloadWorkflowOutput: + await workflow.execute_activity( + large_payload_activity, + LargePayloadActivityInput( + output_data_size=input.activity_output_data_size, + data=[0] * input.activity_input_data_size, + ), + schedule_to_close_timeout=timedelta(seconds=5), + ) + return LargePayloadWorkflowOutput(data=[0] * input.workflow_output_data_size) + + +async def test_large_payload_error_workflow_input(client: Client): + config = client.config() + error_limit = 5 * 1024 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with pytest.raises(PayloadSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 6 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + ) + + assert error_limit == err.value.payloads_limit + + +async def test_large_payload_error_workflow_memo(client: Client): + config = client.config() + error_limit = 128 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(memo_upload_error_limit=error_limit), + ) + client = Client(**config) + + with pytest.raises(PayloadSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + memo={"key1": [0] * 256}, + ) + + assert error_limit == err.value.payloads_limit + + +async def test_large_payload_warning_workflow_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 2 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_warning_workflow_memo(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(payload_upload_warning_limit=128), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + memo={"key1": [0] * 256}, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_workflow_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=6 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + assert err.value.cause.type == TimeoutType.START_TO_CLOSE + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + print(f"Justin Record: {record}") + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_workflow_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=2 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_input(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=6 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size exceeded the error limit" in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Workflow task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_activity_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=2 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + +async def test_large_payload_error_activity_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured( + activity.logger.base_logger + ) as activity_logger_capturer, + # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + + def activity_logger_predicate(record: logging.LogRecord) -> bool: + return ( + hasattr(record, "__temporal_error_identifier") + and getattr(record, "__temporal_error_identifier") == "ActivityFailure" + and record.levelname == "WARNING" + and "Activity task failed: payloads size exceeded the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert activity_logger_capturer.find(activity_logger_predicate) + + # Worker logger is not emitting this follow message. Maybe activity completion failures + # are not routed through the log forwarder whereas workflow completion failures are? + # def worker_logger_predicate(record: logging.LogRecord) -> bool: + # return "Payloads size exceeded the error limit" in record.msg + + # assert worker_logger_capturer.find(worker_logger_predicate) + + +async def test_large_payload_warning_activity_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=2 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message)