From 87977ef55ddff8583ce16548bd01e2293a374623 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Mon, 26 Jan 2026 17:03:50 -0600 Subject: [PATCH 01/16] feat: Propagate exceptions to AfterToolCallEvent for decorated tools (#1565) --- src/strands/tools/decorator.py | 9 +- src/strands/tools/executors/_executor.py | 7 +- src/strands/types/_events.py | 26 ++++-- tests/strands/agent/hooks/test_events.py | 2 - .../strands/tools/executors/test_executor.py | 92 +++++++++++++++++++ tests/strands/tools/test_decorator.py | 77 ++++++++++++++++ 6 files changed, 201 insertions(+), 12 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index f72a8ccf1..de7b968f9 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -613,6 +613,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_msg}"}], }, + exception=e, ) except Exception as e: # Return error result with exception details for any other error @@ -625,14 +626,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], }, + exception=e, ) - def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | None = None) -> ToolResultEvent: # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_d - return ToolResultEvent(cast(ToolResult, result)) + return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format # Always include at least one content item for consistency @@ -641,7 +643,8 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: "toolUseId": tool_use_d, "status": "success", "content": [{"text": str(result)}], - } + }, + exception=exception, ) @property diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index ef000fbd6..0da6b5715 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -215,6 +215,9 @@ async def _stream( return if structured_output_context.is_enabled: kwargs["structured_output_context"] = structured_output_context + + exception: Exception | None = None + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. @@ -227,6 +230,8 @@ async def _stream( return if isinstance(event, ToolResultEvent): + # Preserve exception from decorated tools before extracting tool_result + exception = event.exception # below the last "event" must point to the tool_result event = event.tool_result break @@ -239,7 +244,7 @@ async def _stream( result = cast(ToolResult, event) after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, selected_tool, tool_use, invocation_state, result, exception=exception ) # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 0896d48e1..9f3f0c4e3 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -274,15 +274,29 @@ def prepare(self, invocation_state: dict) -> None: class ToolResultEvent(TypedEvent): - """Event emitted when a tool execution completes.""" + """Event emitted when a tool execution completes. - def __init__(self, tool_result: ToolResult) -> None: - """Initialize with the completed tool result. + Stores the full Exception object as an instance attribute for debugging while + keeping the event dict JSON-serializable. The exception can be accessed via + the `exception` property for re-raising or type-based error handling in hooks. - Args: - tool_result: Final result from the tool execution - """ + Parameters: + tool_result: Final result from the tool execution. + exception: Optional exception that occurred during tool execution. + """ + + def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None: + """Initialize tool result event.""" super().__init__({"type": "tool_result", "tool_result": tool_result}) + self._exception = exception + + @property + def exception(self) -> Exception | None: + """The original exception that occurred, if any. + + Can be used for re-raising or type-based error handling. + """ + return self._exception @property def tool_use_id(self) -> str: diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 762b77452..de551d137 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -206,8 +206,6 @@ def test_invocation_state_is_available_in_model_call_events(agent): assert after_event.invocation_state["request_id"] == "req-456" - - def test_before_invocation_event_messages_default_none(agent): """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" event = BeforeInvocationEvent(agent=agent) diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 78e35c2aa..4a5479503 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -482,6 +482,98 @@ async def test_executor_stream_updates_invocation_state_with_agent( assert empty_invocation_state["agent"] is agent +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_exception_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exceptions from @tool-decorated functions reach AfterToolCallEvent.""" + exception = ValueError("decorated tool error") + + @strands.tool(name="decorated_error_tool") + def failing_tool(): + """A tool that raises an exception.""" + raise exception + + agent.tool_registry.register_tool(failing_tool) + tool_use = {"name": "decorated_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_runtime_error_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that RuntimeError from @tool-decorated functions reach AfterToolCallEvent.""" + exception = RuntimeError("runtime error from decorated tool") + + @strands.tool(name="runtime_error_tool") + def runtime_error_tool(): + """A tool that raises a RuntimeError.""" + raise exception + + agent.tool_registry.register_tool(runtime_error_tool) + tool_use = {"name": "runtime_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_no_exception_on_success( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when decorated tool succeeds.""" + + @strands.tool(name="success_decorated_tool") + def success_tool(): + """A tool that succeeds.""" + return "success" + + agent.tool_registry.register_tool(success_tool) + tool_use = {"name": "success_decorated_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_error_result_without_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exception is None when a tool returns an error result without throwing.""" + + @strands.tool(name="error_result_tool") + def error_result_tool(): + """A tool that returns an error result dict without raising.""" + return {"status": "error", "content": [{"text": "something went wrong"}]} + + agent.tool_registry.register_tool(error_result_tool) + tool_use = {"name": "error_result_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "error" + + @pytest.mark.asyncio async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): """Test default behavior when retry is not set - tool executes once.""" diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 4757e5587..bb1431c9d 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1823,3 +1823,80 @@ def test_tool_decorator_annotated_field_with_inner_default(): @strands.tool def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: return f"{name} is at level {level}" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_runtime_error(alist): + """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + + @strands.tool + def error_tool(): + """Tool that raises a RuntimeError.""" + raise RuntimeError("test runtime error") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, RuntimeError) + assert str(result_event.exception) == "test runtime error" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_value_error(alist): + """Test that ToolResultEvent carries exception when tool raises ValueError.""" + + @strands.tool + def validation_error_tool(): + """Tool that raises a ValueError.""" + raise ValueError("validation failed") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(validation_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, ValueError) + assert str(result_event.exception) == "validation failed" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_no_exception_on_success(alist): + """Test that ToolResultEvent.exception is None when tool succeeds.""" + + @strands.tool + def success_tool(): + """Tool that succeeds.""" + return "success" + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(success_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.exception is None + assert result_event.tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_assertion_error(alist): + """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + + @strands.tool + def assertion_error_tool(): + """Tool that raises an AssertionError.""" + raise AssertionError("unexpected assertion failure") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(assertion_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert isinstance(result_event.exception, AssertionError) + assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.tool_result["status"] == "error" From 9f0e5340a822d71d8e64454714381d2bce7316e2 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 12:16:38 -0500 Subject: [PATCH 02/16] Increase pytest timeout to 45 seconds (#1586) --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a16132881..7f816880d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.15.0", "tenacity>=9.0.0,<10.0.0", @@ -146,6 +147,7 @@ dependencies = [ "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", "moto>=5.1.0,<6.0.0", ] @@ -239,6 +241,7 @@ convention = "google" testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" +timeout = 45 [tool.coverage.run] From 623f40e69e602ea52de08d70c8e436a52b939eb4 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 15:07:17 -0500 Subject: [PATCH 03/16] Publish integ tests results to cloudwatch (#1587) --- .github/scripts/upload-integ-test-metrics.py | 147 +++++++++++++++++++ .github/workflows/integration-test.yml | 7 + pyproject.toml | 5 +- 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 .github/scripts/upload-integ-test-metrics.py diff --git a/.github/scripts/upload-integ-test-metrics.py b/.github/scripts/upload-integ-test-metrics.py new file mode 100644 index 000000000..28595d647 --- /dev/null +++ b/.github/scripts/upload-integ-test-metrics.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +import sys +import xml.etree.ElementTree as ET +from datetime import datetime +from dataclasses import dataclass +from typing import Any, Literal, TypedDict +import os +import boto3 + +STRANDS_METRIC_NAMESPACE = 'Strands/Tests' + + + +class Dimension(TypedDict): + Name: str + Value: str + + +class MetricDatum(TypedDict): + MetricName: str + Dimensions: list[Dimension] + Value: float + Unit: str + Timestamp: datetime + + +@dataclass +class TestResult: + name: str + classname: str + duration: float + outcome: Literal['failed', 'skipped', 'passed'] + + +def parse_junit_xml(xml_file_path: str) -> list[TestResult]: + try: + tree = ET.parse(xml_file_path) + except FileNotFoundError: + print(f"Warning: XML file not found: {xml_file_path}") + return [] + except ET.ParseError as e: + print(f"Warning: Failed to parse XML: {e}") + return [] + + results = [] + root = tree.getroot() + + for testcase in root.iter('testcase'): + name = testcase.get('name') + classname = testcase.get('classname') + duration = float(testcase.get('time', 0.0)) + + if not name or not classname: + continue + + if testcase.find('failure') is not None or testcase.find('error') is not None: + outcome = 'failed' + elif testcase.find('skipped') is not None: + outcome = 'skipped' + else: + outcome = 'passed' + + results.append(TestResult(name, classname, duration, outcome)) + + return results + + +def build_metric_data(test_results: list[TestResult], repository: str) -> list[MetricDatum]: + metrics: list[MetricDatum] = [] + timestamp = datetime.utcnow() + + for test in test_results: + test_name = f"{test.classname}.{test.name}" + dimensions: list[Dimension] = [ + Dimension(Name='TestName', Value=test_name), + Dimension(Name='Repository', Value=repository) + ] + + metrics.append(MetricDatum( + MetricName='TestPassed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'passed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestFailed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'failed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestSkipped', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'skipped' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestDuration', + Dimensions=dimensions, + Value=test.duration, + Unit='Seconds', + Timestamp=timestamp + )) + + return metrics + + +def publish_metrics(metric_data: list[dict[str, Any]], region: str): + cloudwatch = boto3.client('cloudwatch', region_name=region) + + batch_size = 1000 + for i in range(0, len(metric_data), batch_size): + batch = metric_data[i:i + batch_size] + try: + cloudwatch.put_metric_data(Namespace=STRANDS_METRIC_NAMESPACE, MetricData=batch) + print(f"Published {len(batch)} metrics to CloudWatch") + except Exception as e: + print(f"Warning: Failed to publish metrics batch: {e}") + + +def main(): + if len(sys.argv) != 3: + print("Usage: python upload-integ-test-metrics.py ") + sys.exit(0) + + xml_file = sys.argv[1] + repository = sys.argv[2] + region = os.environ.get('AWS_REGION', 'us-east-1') + + test_results = parse_junit_xml(xml_file) + if not test_results: + print("No test results found") + sys.exit(1) + + print(f"Found {len(test_results)} test results") + metric_data = build_metric_data(test_results, repository) + publish_metrics(metric_data, region) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 65c785f30..bbcdfde25 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -37,6 +37,7 @@ jobs: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 mask-aws-account-id: true + - name: Checkout head commit uses: actions/checkout@v6 with: @@ -57,3 +58,9 @@ jobs: id: tests run: | hatch test tests_integ + + - name: Publish test metrics to CloudWatch + if: always() + run: | + pip install --no-cache-dir boto3 + python .github/scripts/upload-integ-test-metrics.py ./build/test-results.xml ${{ github.event.repository.name }} diff --git a/pyproject.toml b/pyproject.toml index 7f816880d..ba635cc48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ dependencies = [ "pytest-asyncio>=1.0.0,<1.4.0", "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", + "pytest-timeout>=2.0.0,<3.0.0", "moto>=5.1.0,<6.0.0", ] @@ -240,7 +241,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" -addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" +addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi --junit-xml=build/test-results.xml" timeout = 45 @@ -298,7 +299,7 @@ prepare = [ "hatch run bidi-test:test-cov", ] -[tools.hatch.envs.bidi-lint] +[tool.hatch.envs.bidi-lint] template = "bidi" [tool.hatch.envs.bidi-lint.scripts] From c007cc0a456c06adde62106b3d4a306be3479c21 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 29 Jan 2026 16:30:09 -0500 Subject: [PATCH 04/16] feat(a2a): add A2AAgent class (#1441) Co-authored-by: Arron Bailiss --- AGENTS.md | 7 +- src/strands/agent/__init__.py | 11 + src/strands/agent/a2a_agent.py | 262 +++++++++++ src/strands/multiagent/a2a/_converters.py | 130 ++++++ src/strands/types/a2a.py | 38 ++ tests/strands/agent/test_a2a_agent.py | 414 ++++++++++++++++++ .../strands/multiagent/a2a/test_converters.py | 205 +++++++++ tests_integ/a2a/__init__.py | 0 tests_integ/a2a/a2a_server.py | 15 + tests_integ/a2a/test_multiagent_a2a.py | 72 +++ 10 files changed, 1153 insertions(+), 1 deletion(-) create mode 100644 src/strands/agent/a2a_agent.py create mode 100644 src/strands/multiagent/a2a/_converters.py create mode 100644 src/strands/types/a2a.py create mode 100644 tests/strands/agent/test_a2a_agent.py create mode 100644 tests/strands/multiagent/a2a/test_converters.py create mode 100644 tests_integ/a2a/__init__.py create mode 100644 tests_integ/a2a/a2a_server.py create mode 100644 tests_integ/a2a/test_multiagent_a2a.py diff --git a/AGENTS.md b/AGENTS.md index 71e83835d..a57286941 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,8 @@ strands-agents/ │ ├── agent/ # Core agent implementation │ │ ├── agent.py # Main Agent class │ │ ├── agent_result.py # Agent execution results +│ │ ├── base.py # AgentBase protocol (agent interface) +│ │ ├── a2a_agent.py # A2AAgent client for remote A2A agents │ │ ├── state.py # Agent state management │ │ └── conversation_manager/ # Message history strategies │ │ ├── conversation_manager.py # Base conversation manager @@ -82,7 +84,8 @@ strands-agents/ │ │ ├── swarm.py # Swarm pattern │ │ ├── a2a/ # Agent-to-agent protocol │ │ │ ├── executor.py # A2A executor -│ │ │ └── server.py # A2A server +│ │ │ ├── server.py # A2A server +│ │ │ └── converters.py # Strands/A2A type converters │ │ └── nodes/ # Graph node implementations │ │ │ ├── types/ # Type definitions @@ -102,6 +105,7 @@ strands-agents/ │ │ ├── json_dict.py # JSON dict utilities │ │ ├── collections.py # Collection types │ │ ├── _events.py # Internal event types +│ │ ├── a2a.py # A2A protocol types │ │ └── models/ # Model-specific types │ │ │ ├── session/ # Session management @@ -188,6 +192,7 @@ strands-agents/ │ ├── interrupts/ # Interrupt tests │ ├── steering/ # Steering tests │ ├── bidi/ # Bidirectional streaming tests +│ ├── a2a/ # A2A agent integration tests │ ├── test_multiagent_graph.py │ ├── test_multiagent_swarm.py │ ├── test_stream_agent.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 2e40866a9..c901e800f 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -7,6 +7,8 @@ - Retry Strategies: Configurable retry behavior for model calls """ +from typing import Any + from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult @@ -28,3 +30,12 @@ "SummarizingConversationManager", "ModelRetryStrategy", ] + + +def __getattr__(name: str) -> Any: + """Lazy load A2AAgent to defer import of optional a2a dependency.""" + if name == "A2AAgent": + from .a2a_agent import A2AAgent + + return A2AAgent + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py new file mode 100644 index 000000000..e18da2f4a --- /dev/null +++ b/src/strands/agent/a2a_agent.py @@ -0,0 +1,262 @@ +"""A2A Agent client for Strands Agents. + +This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, +allowing them to be used standalone or as part of multi-agent patterns. + +A2AAgent can be used to get the Agent Card and interact with the agent. +""" + +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent + +from .._async import run_async +from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result +from ..types._events import AgentResultEvent +from ..types.a2a import A2AResponse, A2AStreamEvent +from ..types.agent import AgentInput +from .agent_result import AgentResult +from .base import AgentBase + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 300 + + +class A2AAgent(AgentBase): + """Client wrapper for remote A2A agents.""" + + def __init__( + self, + endpoint: str, + *, + name: str | None = None, + description: str | None = None, + timeout: int = _DEFAULT_TIMEOUT, + a2a_client_factory: ClientFactory | None = None, + ): + """Initialize A2A agent. + + Args: + endpoint: The base URL of the remote A2A agent. + name: Agent name. If not provided, will be populated from agent card. + description: Agent description. If not provided, will be populated from agent card. + timeout: Timeout for HTTP operations in seconds (defaults to 300). + a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, + it will be used to create the A2A client after discovering the agent card. + Note: When providing a custom factory, you are responsible for managing + the lifecycle of any httpx client it uses. + """ + self.endpoint = endpoint + self.name = name + self.description = description + self.timeout = timeout + self._agent_card: AgentCard | None = None + self._a2a_client_factory: ClientFactory | None = a2a_client_factory + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + return run_async(lambda: self.invoke_async(prompt, **kwargs)) + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + result: AgentResult | None = None + async for event in self.stream_async(prompt, **kwargs): + if "result" in event: + result = event["result"] + + if result is None: + raise RuntimeError("No response received from A2A agent") + + return result + + async def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream remote agent execution asynchronously. + + This method provides an asynchronous interface for streaming A2A protocol events. + Unlike Agent.stream_async() which yields text deltas and tool events, this method + yields raw A2A protocol events wrapped in A2AStreamEvent dictionaries. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Yields: + An async iterator that yields events. Each event is a dictionary: + - A2AStreamEvent: {"type": "a2a_stream", "event": } + where the A2A object can be a Message, or a tuple of + (Task, TaskStatusUpdateEvent) or (Task, TaskArtifactUpdateEvent). + - AgentResultEvent: {"result": AgentResult} - always emitted last. + + Raises: + ValueError: If prompt is None. + + Example: + ```python + async for event in a2a_agent.stream_async("Hello"): + if event.get("type") == "a2a_stream": + print(f"A2A event: {event['event']}") + elif "result" in event: + print(f"Final result: {event['result'].message}") + ``` + """ + last_event = None + last_complete_event = None + + async for event in self._send_message(prompt): + last_event = event + if self._is_complete_event(event): + last_complete_event = event + yield A2AStreamEvent(event) + + # Use the last complete event if available, otherwise fall back to last event + final_event = last_complete_event or last_event + + if final_event is not None: + result = convert_response_to_agent_result(final_event) + yield AgentResultEvent(result) + + async def get_agent_card(self) -> AgentCard: + """Fetch and return the remote agent's card. + + This method eagerly fetches the agent card from the remote endpoint, + populating name and description if not already set. The card is cached + after the first fetch. + + Returns: + The remote agent's AgentCard containing name, description, capabilities, skills, etc. + """ + if self._agent_card is not None: + return self._agent_card + + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + + # Populate name from card if not set + if self.name is None and self._agent_card.name: + self.name = self._agent_card.name + + # Populate description from card if not set + if self.description is None and self._agent_card.description: + self.description = self._agent_card.description + + logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) + return self._agent_card + + @asynccontextmanager + async def _get_a2a_client(self) -> AsyncIterator[Any]: + """Get A2A client for sending messages. + + If a custom factory was provided, uses that (caller manages httpx lifecycle). + Otherwise creates a per-call httpx client with proper cleanup. + + Yields: + Configured A2A client instance. + """ + agent_card = await self.get_agent_card() + + if self._a2a_client_factory is not None: + yield self._a2a_client_factory.create(agent_card) + return + + async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: + config = ClientConfig(httpx_client=httpx_client, streaming=True) + yield ClientFactory(config).create(agent_card) + + async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: + """Send message to A2A agent. + + Args: + prompt: Input to send to the agent. + + Yields: + A2A response events. + + Raises: + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("prompt is required for A2AAgent") + + message = convert_input_to_message(prompt) + logger.debug("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) + + async with self._get_a2a_client() as client: + async for event in client.send_message(message): + yield event + + def _is_complete_event(self, event: A2AResponse) -> bool: + """Check if an A2A event represents a complete response. + + Args: + event: A2A event. + + Returns: + True if the event represents a complete response. + """ + # Direct Message is always complete + if isinstance(event, Message): + return True + + # Handle tuple responses (Task, UpdateEvent | None) + if isinstance(event, tuple) and len(event) == 2: + task, update_event = event + + # Initial task response (no update event) + if update_event is None: + return True + + # Artifact update with last_chunk flag + if isinstance(update_event, TaskArtifactUpdateEvent): + if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: + return update_event.last_chunk + return False + + # Status update with completed state + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state == TaskState.completed + + return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py new file mode 100644 index 000000000..b818c824b --- /dev/null +++ b/src/strands/multiagent/a2a/_converters.py @@ -0,0 +1,130 @@ +"""Conversion functions between Strands and A2A types.""" + +from typing import cast +from uuid import uuid4 + +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from ...agent.agent_result import AgentResult +from ...telemetry.metrics import EventLoopMetrics +from ...types.a2a import A2AResponse +from ...types.agent import AgentInput +from ...types.content import ContentBlock, Message + + +def convert_input_to_message(prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + # Check for interrupt responses - not supported in A2A + if "interruptResponse" in prompt[0]: + raise ValueError("InterruptResponseContent is not supported for A2AAgent") + + if "role" in prompt[0]: + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + +def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + +def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: + """Convert A2A response to AgentResult. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + + # Handle artifact updates + if isinstance(update_event, TaskArtifactUpdateEvent): + if update_event.artifact and hasattr(update_event.artifact, "parts"): + for part in update_event.artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle status updates with messages + elif isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + for part in update_event.status.message.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle initial task or task without update event + elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + for artifact in task.artifacts: + if hasattr(artifact, "parts"): + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + return AgentResult( + stop_reason="end_turn", + message=message, + metrics=EventLoopMetrics(), + state={}, + ) diff --git a/src/strands/types/a2a.py b/src/strands/types/a2a.py new file mode 100644 index 000000000..2ca444cb0 --- /dev/null +++ b/src/strands/types/a2a.py @@ -0,0 +1,38 @@ +"""Additional A2A types.""" + +from typing import Any, TypeAlias + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from ._events import TypedEvent + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message | Any + + +class A2AStreamEvent(TypedEvent): + """Event emitted for every update received from the remote A2A server. + + This event wraps all A2A response types during streaming, including: + - Partial task updates (TaskArtifactUpdateEvent) + - Status updates (TaskStatusUpdateEvent) + - Complete messages (Message) + - Final task completions + + The event is emitted for EVERY update from the server, regardless of whether + it represents a complete or partial response. When streaming completes, an + AgentResultEvent containing the final AgentResult is also emitted after all + A2AStreamEvents. + """ + + def __init__(self, a2a_event: A2AResponse) -> None: + """Initialize with A2A event. + + Args: + a2a_event: The original A2A event (Task tuple or Message) + """ + super().__init__( + { + "type": "a2a_stream", + "event": a2a_event, # Nest A2A event to avoid field conflicts + } + ) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py new file mode 100644 index 000000000..26a34476d --- /dev/null +++ b/tests/strands/agent/test_a2a_agent.py @@ -0,0 +1,414 @@ +"""Tests for A2AAgent class.""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from a2a.types import AgentCard, Message, Part, Role, TextPart + +from strands.agent.a2a_agent import A2AAgent +from strands.agent.agent_result import AgentResult + + +@pytest.fixture +def mock_agent_card(): + """Mock AgentCard for testing.""" + return AgentCard( + name="test-agent", + description="Test agent", + url="http://localhost:8000", + version="1.0.0", + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + +@pytest.fixture +def a2a_agent(): + """Create A2AAgent instance for testing.""" + return A2AAgent(endpoint="http://localhost:8000") + + +@pytest.fixture +def mock_httpx_client(): + """Create a mock httpx.AsyncClient that works as async context manager.""" + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + return mock_client + + +@asynccontextmanager +async def mock_a2a_client_context(send_message_func): + """Helper to create mock A2A client setup for _send_message tests.""" + mock_client = MagicMock() + mock_client.send_message = send_message_func + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + yield mock_httpx_class, mock_factory_class + + +def test_init_with_defaults(): + """Test initialization with default parameters.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert agent.endpoint == "http://localhost:8000" + assert agent.timeout == 300 + assert agent._agent_card is None + assert agent.name is None + assert agent.description is None + + +def test_init_with_name_and_description(): + """Test initialization with custom name and description.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="my-agent", description="My custom agent") + assert agent.name == "my-agent" + assert agent.description == "My custom agent" + + +def test_init_with_custom_timeout(): + """Test initialization with custom timeout.""" + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + assert agent.timeout == 600 + + +def test_init_with_external_a2a_client_factory(): + """Test initialization with external A2A client factory.""" + external_factory = MagicMock() + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + + +@pytest.mark.asyncio +async def test_get_agent_card(a2a_agent, mock_agent_card, mock_httpx_client): + """Test agent card discovery.""" + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + assert a2a_agent._agent_card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_cached(a2a_agent, mock_agent_card): + """Test that agent card is cached after first discovery.""" + a2a_agent._agent_card = mock_agent_card + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_populates_name_and_description(mock_agent_card, mock_httpx_client): + """Test that agent card populates name and description if not set.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == mock_agent_card.name + assert agent.description == mock_agent_card.description + + +@pytest.mark.asyncio +async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_card, mock_httpx_client): + """Test that custom name and description are not overridden by agent card.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="custom-name", description="Custom description") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == "custom-name" + assert agent.description == "Custom description" + + +@pytest.mark.asyncio +async def test_invoke_async_success(a2a_agent, mock_agent_card): + """Test successful async invocation.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + result = await a2a_agent.invoke_async("Hello") + + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_invoke_async_no_prompt(a2a_agent): + """Test that invoke_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + await a2a_agent.invoke_async(None) + + +@pytest.mark.asyncio +async def test_invoke_async_no_response(a2a_agent, mock_agent_card): + """Test that invoke_async raises RuntimeError when no response received.""" + + async def mock_send_message(*args, **kwargs): + return + yield # Make it an async generator + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") + + +def test_call_sync(a2a_agent): + """Test synchronous call method.""" + mock_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=MagicMock(), + state={}, + ) + + with patch("strands.agent.a2a_agent.run_async") as mock_run_async: + mock_run_async.return_value = mock_result + + result = a2a_agent("Hello") + + assert result == mock_result + mock_run_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_async_success(a2a_agent, mock_agent_card): + """Test successful async streaming.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + assert len(events) == 2 + # First event is A2A stream event + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + # Final event is AgentResult + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_stream_async_no_prompt(a2a_agent): + """Test that stream_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + async for _ in a2a_agent.stream_async(None): + pass + + +@pytest.mark.asyncio +async def test_send_message_uses_provided_factory(mock_agent_card): + """Test _send_message uses provided factory instead of creating per-call client.""" + external_factory = MagicMock() + mock_a2a_client = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send_message + external_factory.create.return_value = mock_a2a_client + + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + # Consume the async iterator + async for _ in agent._send_message("Hello"): + pass + + external_factory.create.assert_called_once_with(mock_agent_card) + + +@pytest.mark.asyncio +async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): + """Test _send_message creates a fresh httpx client for each call when no factory provided.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): + # Consume the async iterator + async for _ in a2a_agent._send_message("Hello"): + pass + + # Verify httpx client was created with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +def test_is_complete_event_message(a2a_agent): + """Test _is_complete_event returns True for Message.""" + mock_message = MagicMock(spec=Message) + + assert a2a_agent._is_complete_event(mock_message) is True + + +def test_is_complete_event_tuple_with_none_update(a2a_agent): + """Test _is_complete_event returns True for tuple with None update event.""" + mock_task = MagicMock() + + assert a2a_agent._is_complete_event((mock_task, None)) is True + + +def test_is_complete_event_artifact_last_chunk(a2a_agent): + """Test _is_complete_event handles TaskArtifactUpdateEvent last_chunk flag.""" + from a2a.types import TaskArtifactUpdateEvent + + mock_task = MagicMock() + + # last_chunk=True -> complete + event_complete = MagicMock(spec=TaskArtifactUpdateEvent) + event_complete.last_chunk = True + assert a2a_agent._is_complete_event((mock_task, event_complete)) is True + + # last_chunk=False -> not complete + event_incomplete = MagicMock(spec=TaskArtifactUpdateEvent) + event_incomplete.last_chunk = False + assert a2a_agent._is_complete_event((mock_task, event_incomplete)) is False + + # last_chunk=None -> not complete + event_none = MagicMock(spec=TaskArtifactUpdateEvent) + event_none.last_chunk = None + assert a2a_agent._is_complete_event((mock_task, event_none)) is False + + +def test_is_complete_event_status_update(a2a_agent): + """Test _is_complete_event handles TaskStatusUpdateEvent state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + + # completed state -> complete + event_completed = MagicMock(spec=TaskStatusUpdateEvent) + event_completed.status = MagicMock() + event_completed.status.state = TaskState.completed + assert a2a_agent._is_complete_event((mock_task, event_completed)) is True + + # working state -> not complete + event_working = MagicMock(spec=TaskStatusUpdateEvent) + event_working.status = MagicMock() + event_working.status.state = TaskState.working + assert a2a_agent._is_complete_event((mock_task, event_working)) is False + + # no status -> not complete + event_no_status = MagicMock(spec=TaskStatusUpdateEvent) + event_no_status.status = None + assert a2a_agent._is_complete_event((mock_task, event_no_status)) is False + + +def test_is_complete_event_unknown_type(a2a_agent): + """Test _is_complete_event returns False for unknown event types.""" + assert a2a_agent._is_complete_event("unknown") is False + + +@pytest.mark.asyncio +async def test_stream_async_tracks_complete_events(a2a_agent, mock_agent_card): + """Test stream_async uses last complete event for final result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + # First event: incomplete + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + # Second event: complete + complete_event = MagicMock(spec=TaskStatusUpdateEvent) + complete_event.status = MagicMock() + complete_event.status.state = TaskState.completed + complete_event.status.message = MagicMock() + complete_event.status.message.parts = [] + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + yield (mock_task, complete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 2 stream events + 1 result event + assert len(events) == 3 + assert "result" in events[2] + + +@pytest.mark.asyncio +async def test_stream_async_falls_back_to_last_event(a2a_agent, mock_agent_card): + """Test stream_async falls back to last event when no complete event.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 1 stream event + 1 result event (falls back to last) + assert len(events) == 2 + assert "result" in events[1] diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py new file mode 100644 index 000000000..002ebf6a6 --- /dev/null +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -0,0 +1,205 @@ +"""Tests for A2A converter functions.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a._converters import ( + convert_content_blocks_to_parts, + convert_input_to_message, + convert_response_to_agent_result, +) + + +def test_convert_string_input(): + """Test converting string input to A2A message.""" + message = convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + convert_input_to_message(123) + + +def test_convert_interrupt_response_raises_error(): + """Test that InterruptResponseContent raises explicit error.""" + interrupt_responses = [{"interruptResponse": {"interruptId": "123", "response": "A"}}] + + with pytest.raises(ValueError, match="InterruptResponseContent is not supported for A2AAgent"): + convert_input_to_message(interrupt_responses) + + +def test_convert_content_blocks_to_parts(): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" + + +# --- New tests for coverage --- + + +def test_convert_message_list_finds_last_user_message(): + """Test that message list conversion finds the last user message.""" + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + ] + + message = convert_input_to_message(messages) + + assert message.parts[0].root.text == "Second" + + +def test_convert_content_blocks_skips_non_text(): + """Test that non-text content blocks are skipped.""" + content_blocks = [{"text": "Hello"}, {"image": "data"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + + +def test_convert_task_artifact_update_event(): + """Test converting TaskArtifactUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Streamed artifact" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = mock_artifact + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Streamed artifact" + + +def test_convert_task_status_update_event(): + """Test converting TaskStatusUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Status message" + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_status = MagicMock() + mock_status.message = mock_message + + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = mock_status + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Status message" + + +def test_convert_response_handles_missing_data(): + """Test that response conversion handles missing/malformed data gracefully.""" + # TaskArtifactUpdateEvent with no artifact + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # TaskStatusUpdateEvent with no status + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # Task artifact without parts attribute + mock_task = MagicMock() + mock_artifact = MagicMock(spec=[]) + del mock_artifact.parts + mock_task.artifacts = [mock_artifact] + result = convert_response_to_agent_result((mock_task, None)) + assert len(result.message["content"]) == 0 diff --git a/tests_integ/a2a/__init__.py b/tests_integ/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/a2a/a2a_server.py b/tests_integ/a2a/a2a_server.py new file mode 100644 index 000000000..047edc3ba --- /dev/null +++ b/tests_integ/a2a/a2a_server.py @@ -0,0 +1,15 @@ +from strands import Agent +from strands.multiagent.a2a import A2AServer + +# Create an agent and serve it over A2A +agent = Agent( + name="Test agent", + description="Test description here", + callback_handler=None, +) +a2a_server = A2AServer( + agent=agent, + host="localhost", + port=9000, +) +a2a_server.serve() diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py new file mode 100644 index 000000000..60cbc9ce5 --- /dev/null +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -0,0 +1,72 @@ +import os +import subprocess +import time + +import httpx +import pytest +from a2a.client import ClientConfig, ClientFactory + +from strands.agent.a2a_agent import A2AAgent + + +@pytest.fixture +def a2a_server(): + """Start A2A server as subprocess fixture.""" + server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") + process = subprocess.Popen(["python", server_path]) + time.sleep(5) # Wait for A2A server to start + + yield "http://localhost:9000" + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + +def test_a2a_agent_invoke_sync(a2a_server): + """Test synchronous invocation via __call__.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = a2a_agent("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_invoke_async(a2a_server): + """Test async invocation.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_stream_async(a2a_server): + """Test async streaming.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + + events = [] + async for event in a2a_agent.stream_async("Hello there!"): + events.append(event) + + # Should have at least one A2A stream event and one final result event + assert len(events) >= 2 + assert events[0]["type"] == "a2a_stream" + assert "result" in events[-1] + assert events[-1]["result"].stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_with_non_streaming_client_config(a2a_server): + """Test with streaming=False client configuration (non-default).""" + httpx_client = httpx.AsyncClient(timeout=300) + config = ClientConfig(httpx_client=httpx_client, streaming=False) + factory = ClientFactory(config) + + try: + a2a_agent = A2AAgent(endpoint=a2a_server, a2a_client_factory=factory) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + finally: + await httpx_client.aclose() From db735570b4cfb52eb83497348ebec45b4a75a6df Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Thu, 29 Jan 2026 15:40:44 -0600 Subject: [PATCH 05/16] fix(tools): preserve nullable semantics for required Union[T, None] params (#1584) Co-authored-by: Dean Schmigelski --- src/strands/tools/decorator.py | 11 +++- tests/strands/tools/test_decorator.py | 78 +++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index de7b968f9..70552d6ba 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -326,13 +326,20 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: del schema[key] # Process properties to clean up anyOf and similar structures + required_fields = schema.get("required", []) if "properties" in schema: - for _prop_name, prop_schema in schema["properties"].items(): + for prop_name, prop_schema in schema["properties"].items(): # Handle anyOf constructs (common for Optional types) if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): + # Only simplify when the field is not required; required nullable + # fields need anyOf preserved so the model can pass null. + if ( + prop_name not in required_fields + and len(any_of) == 2 + and any(item.get("type") == "null" for item in any_of) + ): # Find the non-null type for item in any_of: if item.get("type") != "null": diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index bb1431c9d..f3d6eda02 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1900,3 +1900,81 @@ def assertion_error_tool(): assert isinstance(result_event.exception, AssertionError) assert "unexpected assertion failure" in str(result_event.exception) assert result_event.tool_result["status"] == "error" + + +def test_tool_nullable_required_field_preserves_anyof(): + """Test that a required nullable field preserves anyOf so the model can pass null. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1525 + """ + from enum import Enum + + class Priority(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + @strands.tool + def prioritized_task(description: str, priority: Priority | None) -> str: + """Create a task with optional priority. + + Args: + description: Task description + priority: Optional priority level + """ + return f"{description}: {priority}" + + spec = prioritized_task.tool_spec + schema = spec["inputSchema"]["json"] + + expected_schema = { + "$defs": { + "Priority": { + "enum": ["high", "medium", "low"], + "title": "Priority", + "type": "string", + }, + }, + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Task description", + }, + "priority": { + "anyOf": [ + {"$ref": "#/$defs/Priority"}, + {"type": "null"}, + ], + "description": "Optional priority level", + }, + }, + "required": ["description", "priority"], + } + + assert schema == expected_schema + + +def test_tool_nullable_optional_field_simplifies_anyof(): + """Test that a non-required nullable field still gets anyOf simplified.""" + + @strands.tool + def my_tool(name: str, tag: str | None = None) -> str: + """A tool. + + Args: + name: The name + tag: An optional tag + """ + return f"{name}: {tag}" + + spec = my_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # tag has a default, so it should NOT be required + assert "name" in schema["required"] + assert "tag" not in schema["required"] + + # Since tag is not required, anyOf should be simplified away + assert "anyOf" not in schema["properties"]["tag"] + assert schema["properties"]["tag"]["type"] == "string" From 8c0cb43ec036c8649b71a3b1a2f1f19173c4534e Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 18:47:14 -0500 Subject: [PATCH 06/16] Feature: Allow s3Location as Document, Image, and Video location source (#1572) --- src/strands/models/bedrock.py | 61 +++++-- src/strands/types/media.py | 54 ++++++- tests/strands/models/test_bedrock.py | 151 +++++++++++++++++- tests/strands/tools/mcp/test_mcp_client.py | 2 +- tests/strands/types/test_media.py | 99 ++++++++++++ tests_integ/conftest.py | 11 +- tests_integ/mcp/echo_server.py | 2 +- tests_integ/mcp/test_mcp_client.py | 2 +- tests_integ/resources/blue.mp4 | Bin 0 -> 5200 bytes tests_integ/{ => resources}/letter.pdf | Bin tests_integ/{ => resources}/yellow.png | Bin tests_integ/test_a2a_executor.py | 4 +- tests_integ/test_bedrock_s3_location.py | 177 +++++++++++++++++++++ 13 files changed, 531 insertions(+), 32 deletions(-) create mode 100644 tests/strands/types/test_media.py create mode 100644 tests_integ/resources/blue.mp4 rename tests_integ/{ => resources}/letter.pdf (100%) rename tests_integ/{ => resources}/yellow.png (100%) create mode 100644 tests_integ/test_bedrock_s3_location.py diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index a3cea7cfe..b053b70fb 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -17,6 +17,8 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from strands.types.media import S3Location, SourceLocation + from .._exception_notes import add_exception_note from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec @@ -407,6 +409,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Format content blocks for Bedrock API compatibility formatted_content = self._format_request_message_content(content_block) + if formatted_content is None: + continue # Wrap text or image content in guardrailContent if this is the last user message if ( @@ -459,7 +463,19 @@ def _should_include_tool_result_status(self) -> bool: else: # "auto" return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None: + """Convert location content block to Bedrock format if its an S3Location.""" + if location["type"] == "s3": + s3_location = cast(S3Location, location) + formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]} + if "bucketOwner" in s3_location: + formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"] + return {"s3Location": formatted_document_s3} + else: + logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block") + return None + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None: """Format a Bedrock content block. Bedrock strictly validates content blocks and throws exceptions for unknown fields. @@ -489,9 +505,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source + # Handle source - supports bytes or location if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + source = document["source"] + formatted_document_source: dict[str, Any] | None + if "location" in source: + formatted_document_source = self._handle_location(source["location"]) + if formatted_document_source is None: + return None + elif "bytes" in source: + formatted_document_source = {"bytes": source["bytes"]} + result["source"] = formatted_document_source # Handle optional fields if "citations" in document and document["citations"] is not None: @@ -512,10 +536,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "image" in content: image = content["image"] source = image["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": image["format"], "source": formatted_source} + formatted_image_source: dict[str, Any] | None + if "location" in source: + formatted_image_source = self._handle_location(source["location"]) + if formatted_image_source is None: + return None + elif "bytes" in source: + formatted_image_source = {"bytes": source["bytes"]} + result = {"format": image["format"], "source": formatted_image_source} return {"image": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html @@ -550,9 +578,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # Handle json field since not in ContentBlock but valid in ToolResultContent formatted_content.append({"json": tool_result_content["json"]}) else: - formatted_content.append( - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + formatted_message_content = self._format_request_message_content( + cast(ContentBlock, tool_result_content) ) + if formatted_message_content is None: + continue + formatted_content.append(formatted_message_content) result = { "content": formatted_content, @@ -577,10 +608,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "video" in content: video = content["video"] source = video["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": video["format"], "source": formatted_source} + formatted_video_source: dict[str, Any] | None + if "location" in source: + formatted_video_source = self._handle_location(source["location"]) + if formatted_video_source is None: + return None + elif "bytes" in source: + formatted_video_source = {"bytes": source["bytes"]} + result = {"format": video["format"], "source": formatted_video_source} return {"video": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 462d8af34..b1240dffb 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,9 +5,9 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Literal, TypeAlias -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict from .citations import CitationsConfig @@ -15,14 +15,50 @@ """Supported document formats.""" -class DocumentSource(TypedDict): +class Location(TypedDict, total=False): + """A location for a document. + + This type is a generic location for a document. Its usage is determined by the underlying model provider. + """ + + type: Required[str] + + +class S3Location(Location, total=False): + """A storage location in an Amazon S3 bucket. + + Used by Bedrock to reference media files stored in S3 instead of passing raw bytes. + + - Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html + + Attributes: + type: s3 + uri: An object URI starting with `s3://`. Required. + bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional. + """ + + # mypy doesn't like overriding this field since its a subclass, but since its just a literal string, this is fine. + + type: Literal["s3"] # type: ignore[misc] + uri: Required[str] + bucketOwner: str + + +SourceLocation: TypeAlias = Location | S3Location + + +class DocumentSource(TypedDict, total=False): """Contains the content of a document. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the document. + location: Location of the document. """ bytes: bytes + location: SourceLocation class DocumentContent(TypedDict, total=False): @@ -45,14 +81,18 @@ class DocumentContent(TypedDict, total=False): """Supported image formats.""" -class ImageSource(TypedDict): +class ImageSource(TypedDict, total=False): """Contains the content of an image. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the image. + location: Location of the image. """ bytes: bytes + location: SourceLocation class ImageContent(TypedDict): @@ -71,14 +111,18 @@ class ImageContent(TypedDict): """Supported video formats.""" -class VideoSource(TypedDict): +class VideoSource(TypedDict, total=False): """Contains the content of a video. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the video. + location: Location of the video. """ bytes: bytes + location: SourceLocation class VideoContent(TypedDict): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e92018f35..761434258 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,5 @@ +import copy +import logging import os import sys import traceback @@ -1519,7 +1521,6 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model @pytest.mark.asyncio async def test_stream_logging(bedrock_client, model, messages, caplog, alist): """Test that stream method logs debug messages at the expected stages.""" - import logging # Set the logger to debug level to capture debug messages caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") @@ -1787,8 +1788,8 @@ def test_format_request_filters_image_content_blocks(model, model_id): assert "metadata" not in image_block -def test_format_request_filters_nested_image_s3_fields(model, model_id): - """Test that s3Location is filtered out and only bytes source is preserved.""" +def test_format_request_image_s3_location_only(model, model_id): + """Test that image with only s3Location is properly formatted.""" messages = [ { "role": "user", @@ -1797,8 +1798,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): "image": { "format": "png", "source": { - "bytes": b"image_data", - "s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"}, + "location": {"type": "s3", "uri": "s3://my-bucket/image.png"}, }, } } @@ -1809,8 +1809,146 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): formatted_request = model._format_request(messages) image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}} + + +def test_format_request_image_bytes_only(model, model_id): + """Test that image with only bytes source is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"image_data"}, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + assert image_source == {"bytes": b"image_data"} - assert "s3Location" not in image_source + + +def test_format_request_document_s3_location(model, model_id): + """Test that document with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}, + }, + } + }, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "s3", + "uri": "s3://my-bucket/report.pdf", + "bucketOwner": "123456789012", + }, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + document = formatted_request["messages"][0]["content"][0]["document"] + document_with_bucket_owner = formatted_request["messages"][0]["content"][1]["document"] + + assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf"}} + + assert document_with_bucket_owner["source"] == { + "s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"} + } + + +def test_format_request_unsupported_location(model, caplog): + """Test that document with s3Location is properly formatted.""" + + caplog.set_level(logging.WARNING, logger="strands.models.bedrock") + + messages = [ + { + "role": "user", + "content": [ + {"text": "Hello!"}, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "video": { + "format": "mp4", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "image": { + "format": "png", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + assert len(formatted_request["messages"][0]["content"]) == 1 + assert "Non s3 location sources are not supported by Bedrock, skipping content block" in caplog.text + + +def test_format_request_video_s3_location(model, model_id): + """Test that video with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/video.mp4"}, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + video_source = formatted_request["messages"][0]["content"][0]["video"]["source"] + + assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}} def test_format_request_filters_document_content_blocks(model, model_id): @@ -2310,7 +2448,6 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client): def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): """Test that _format_bedrock_messages does not mutate original messages.""" - import copy model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f784da414..a2ef369ea 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -632,7 +632,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): """EmbeddedResource.resource (blob with image MIME) should map to image content.""" # Read yellow.png file - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() payload = base64.b64encode(png_data).decode() diff --git a/tests/strands/types/test_media.py b/tests/strands/types/test_media.py new file mode 100644 index 000000000..2fa8c3621 --- /dev/null +++ b/tests/strands/types/test_media.py @@ -0,0 +1,99 @@ +"""Tests for media type definitions.""" + +from strands.types.media import ( + DocumentSource, + ImageSource, + S3Location, + VideoSource, +) + + +class TestS3Location: + """Tests for S3Location TypedDict.""" + + def test_s3_location_with_uri_only(self): + """Test S3Location with only uri field.""" + s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"} + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert "bucketOwner" not in s3_loc + + def test_s3_location_with_bucket_owner(self): + """Test S3Location with both uri and bucketOwner fields.""" + s3_loc: S3Location = { + "uri": "s3://my-bucket/path/to/file.pdf", + "bucketOwner": "123456789012", + } + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert s3_loc["bucketOwner"] == "123456789012" + + +class TestDocumentSource: + """Tests for DocumentSource TypedDict.""" + + def test_document_source_with_bytes(self): + """Test DocumentSource with bytes content.""" + doc_source: DocumentSource = {"bytes": b"document content"} + + assert doc_source["bytes"] == b"document content" + assert "s3Location" not in doc_source + + def test_document_source_with_s3_location(self): + """Test DocumentSource with s3Location.""" + doc_source: DocumentSource = { + "s3Location": { + "uri": "s3://my-bucket/docs/report.pdf", + "bucketOwner": "123456789012", + } + } + + assert "bytes" not in doc_source + assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf" + assert doc_source["s3Location"]["bucketOwner"] == "123456789012" + + +class TestImageSource: + """Tests for ImageSource TypedDict.""" + + def test_image_source_with_bytes(self): + """Test ImageSource with bytes content.""" + img_source: ImageSource = {"bytes": b"image content"} + + assert img_source["bytes"] == b"image content" + assert "s3Location" not in img_source + + def test_image_source_with_s3_location(self): + """Test ImageSource with s3Location.""" + img_source: ImageSource = { + "s3Location": { + "uri": "s3://my-bucket/images/photo.png", + } + } + + assert "bytes" not in img_source + assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png" + + +class TestVideoSource: + """Tests for VideoSource TypedDict.""" + + def test_video_source_with_bytes(self): + """Test VideoSource with bytes content.""" + vid_source: VideoSource = {"bytes": b"video content"} + + assert vid_source["bytes"] == b"video content" + assert "s3Location" not in vid_source + + def test_video_source_with_s3_location(self): + """Test VideoSource with s3Location.""" + vid_source: VideoSource = { + "s3Location": { + "uri": "s3://my-bucket/videos/clip.mp4", + "bucketOwner": "987654321098", + } + } + + assert "bytes" not in vid_source + assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4" + assert vid_source["s3Location"]["bucketOwner"] == "987654321098" diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 9de00089b..dbe25d685 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -133,14 +133,21 @@ def pytest_sessionstart(session): @pytest.fixture def yellow_img(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/yellow.png" + path = pytestconfig.rootdir / "tests_integ/resources/yellow.png" with open(path, "rb") as fp: return fp.read() @pytest.fixture def letter_pdf(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/letter.pdf" + path = pytestconfig.rootdir / "tests_integ/resources/letter.pdf" + with open(path, "rb") as fp: + return fp.read() + + +@pytest.fixture +def blue_video(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/resources/blue.mp4" with open(path, "rb") as fp: return fp.read() diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 8fa1fb2b2..363c588ee 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -90,7 +90,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): ] elif location.lower() == "tokyo": # Read yellow.png file for weather icon - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() return [ EmbeddedResource( diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 298272df5..4e192c935 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -43,7 +43,7 @@ def calculator(x: int, y: int) -> int: @mcp.tool(description="Generates a custom image") def generate_custom_image() -> MCPImageContent: try: - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: diff --git a/tests_integ/resources/blue.mp4 b/tests_integ/resources/blue.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5989bb4b02d85ad96d9985acdfafd0125096acb5 GIT binary patch literal 5200 zcmeHL|7#pY6rXFfNo!luCao2MjG7-5lFRO1a;?$DTtaGuh*h*j5Vmu>b9ZZYZ#TP> zJFIi~ccr>QIVoEM2+u^@D5wuCHU~UYGeUP>chg{=D#_Mc%#&GPd2m2oiui`W(Rx z2(2IHg^9ry_3v8OZ~B5C>LE!1-5k-Dj3U|tETA2GxE`JL3D*I)M`wTBA>YRUoCSK2 zu^?x`c@Ta78+yQYII;E3-sbczx|FCl@3$g@i>*UW7NwiibC_5 zv8*)4z%Y{rhmoiEPd_<4N^=LMz|-J57^WPzYVm@giX>%*6-gNbWl0Ekd}L&4X(^3& z498;SwBr>=aFldO*cSLWt}valKTdU)XSym=xJRfNYVf?}=yR$(E{#i+m6=ubxhhpM z<5ESIGt}m4iC3tj9#-xbV;Nk}&^>tq6`hrkLB@EMJxTGHUOVHiZwHwn#yQizV zSD-fBpEynn1XanTB|49jQKfViSQmi<$|`F1QBe4TyXq)4T}Tpa2*@E|v3bZpW|JI+ z9h~DQkCDfkjZ1H@_g~omyqNv`;l}az^=H2Af8F}CZ_mKHjqW33hsY3H{_Bx1W?$R) zU22Gs9c$m5kXBZ|eEizAgU}0MCRF2nY~o0m6zw ztc4I?P48@jxZD=Sl{Sd035YO?`nGn6`fxmo`bZq&^k@PijH3Qr0%ATMMcr?Ms3ahw zD3)6gM`2={Q}qwpqBs{q;L7ynPJf($h@$wu1raV_{hziduD3z_lz<4MsNLU!2&1T} zafsRzafp?{1Vk7`ZL$RsM)AM1^8yOd9e-Xp z&BmEyg#3sfHzf78>&TAW=~f*nHXF5G(p3x*ZhKn*LaU4%Y&P~zkgfQK5yWtNRpdXN CcM9hK literal 0 HcmV?d00001 diff --git a/tests_integ/letter.pdf b/tests_integ/resources/letter.pdf similarity index 100% rename from tests_integ/letter.pdf rename to tests_integ/resources/letter.pdf diff --git a/tests_integ/yellow.png b/tests_integ/resources/yellow.png similarity index 100% rename from tests_integ/yellow.png rename to tests_integ/resources/yellow.png diff --git a/tests_integ/test_a2a_executor.py b/tests_integ/test_a2a_executor.py index ddca0bfa6..43a6026bf 100644 --- a/tests_integ/test_a2a_executor.py +++ b/tests_integ/test_a2a_executor.py @@ -17,7 +17,7 @@ async def test_a2a_executor_with_real_image(): """Test A2A server processes a real image file correctly via HTTP.""" # Read the test image file - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_image_bytes = f.read() @@ -80,7 +80,7 @@ async def test_a2a_executor_with_real_image(): def test_a2a_executor_image_roundtrip(): """Test that image data survives the A2A base64 encoding/decoding roundtrip.""" # Read the test image - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_bytes = f.read() diff --git a/tests_integ/test_bedrock_s3_location.py b/tests_integ/test_bedrock_s3_location.py new file mode 100644 index 000000000..9b28e88be --- /dev/null +++ b/tests_integ/test_bedrock_s3_location.py @@ -0,0 +1,177 @@ +"""Integration tests for S3 location support in media content types.""" + +import time + +import boto3 +import pytest + +from strands import Agent +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def boto_session(): + """Create a boto3 session for testing.""" + return boto3.Session(region_name="us-west-2") + + +@pytest.fixture +def account_id(boto_session): + """Get the current AWS account ID.""" + sts_client = boto_session.client("sts") + return sts_client.get_caller_identity()["Account"] + + +@pytest.fixture +def s3_client(boto_session): + """Create an S3 client.""" + return boto_session.client("s3") + + +@pytest.fixture +def test_bucket(s3_client, account_id): + """Create a test S3 bucket for the tests. + + Creates a bucket with account-specific name and cleans it up after tests. + """ + bucket_name = f"strands-integ-tests-resources-{account_id}" + + # Create the bucket if it doesn't exist + try: + s3_client.head_bucket(Bucket=bucket_name) + print(f"Bucket {bucket_name} already exists") + except s3_client.exceptions.ClientError: + try: + s3_client.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + print(f"Created test bucket: {bucket_name}") + # Wait for bucket to be available + time.sleep(2) + except s3_client.exceptions.BucketAlreadyOwnedByYou: + print(f"Bucket {bucket_name} already exists") + + yield bucket_name + + # Note: We don't delete the bucket to allow reuse across test runs + # Objects will be overwritten on subsequent runs + + +@pytest.fixture +def s3_document(s3_client, test_bucket, letter_pdf): + """Upload a test document to S3 and return its URI.""" + document_key = "test-documents/letter.pdf" + + # Upload the document using existing letter_pdf fixture + s3_client.put_object( + Bucket=test_bucket, + Key=document_key, + Body=letter_pdf, + ContentType="application/pdf", + ) + print(f"Uploaded test document to s3://{test_bucket}/{document_key}") + + return f"s3://{test_bucket}/{document_key}" + + +@pytest.fixture +def s3_image(s3_client, test_bucket, yellow_img): + """Upload a test image to S3 and return its URI.""" + image_key = "test-images/yellow.png" + + # Upload the image using existing yellow_img fixture + s3_client.put_object( + Bucket=test_bucket, + Key=image_key, + Body=yellow_img, + ContentType="image/png", + ) + print(f"Uploaded test image to s3://{test_bucket}/{image_key}") + + return f"s3://{test_bucket}/{image_key}" + + +@pytest.fixture +def s3_video(s3_client, test_bucket, blue_video): + """Upload a test video to S3 and return its URI.""" + video_key = "test-videos/blue.mp4" + + # Upload the video using existing blue_video fixture + s3_client.put_object( + Bucket=test_bucket, + Key=video_key, + Body=blue_video, + ContentType="video/mp4", + ) + print(f"Uploaded test video to s3://{test_bucket}/{video_key}") + + return f"s3://{test_bucket}/{video_key}" + + +def test_document_s3_location(s3_document, account_id): + """Test that Bedrock correctly formats a document with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this document?"}, + { + "document": { + "format": "pdf", + "name": "letter", + "source": {"location": {"type": "s3", "uri": s3_document, "bucketOwner": account_id}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_image_s3_location(s3_image): + """Test that Bedrock correctly formats an image with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this image?"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": s3_image}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_video_s3_location(s3_video): + """Test that Bedrock correctly formats a video with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe the colors is in this video?"}, + {"video": {"format": "mp4", "source": {"location": {"type": "s3", "uri": s3_video}}}}, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-pro-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 From 522b7d29efee937bff48a9262e6fc6c38faee8f4 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Fri, 30 Jan 2026 06:02:42 -0800 Subject: [PATCH 07/16] fix: LedgerProvider handles parallel tool calls (#1559) --- .../context_providers/ledger_provider.py | 26 ++- .../steering/handlers/llm/mappers.py | 16 +- .../context_providers/test_ledger_provider.py | 191 +++++++++++++++++- 3 files changed, 221 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py index 0e7bde529..43f56717a 100644 --- a/src/strands/experimental/steering/context_providers/ledger_provider.py +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -46,6 +46,7 @@ def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext tool_call_entry = { "timestamp": datetime.now().isoformat(), + "tool_use_id": event.tool_use.get("toolUseId"), "tool_name": event.tool_use.get("name"), "tool_args": event.tool_use.get("input", {}), "status": "pending", @@ -62,16 +63,21 @@ def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, ledger = steering_context.data.get("ledger") or {} if ledger.get("tool_calls"): - last_call = ledger["tool_calls"][-1] - last_call.update( - { - "completion_timestamp": datetime.now().isoformat(), - "status": event.result["status"], - "result": event.result["content"], - "error": str(event.exception) if event.exception else None, - } - ) - steering_context.data.set("ledger", ledger) + tool_use_id = event.tool_use.get("toolUseId") + + # Search for the matching tool call in the ledger to update it + for call in reversed(ledger["tool_calls"]): + if call.get("tool_use_id") == tool_use_id and call.get("status") == "pending": + call.update( + { + "completion_timestamp": datetime.now().isoformat(), + "status": event.result["status"], + "result": event.result["content"], + "error": str(event.exception) if event.exception else None, + } + ) + steering_context.data.set("ledger", ledger) + break class LedgerProvider(SteeringContextProvider): diff --git a/src/strands/experimental/steering/handlers/llm/mappers.py b/src/strands/experimental/steering/handlers/llm/mappers.py index 9901da7d4..ade018d32 100644 --- a/src/strands/experimental/steering/handlers/llm/mappers.py +++ b/src/strands/experimental/steering/handlers/llm/mappers.py @@ -23,7 +23,7 @@ **CRITICAL CONSTRAINTS:** - Base decisions ONLY on the context data provided below -- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT use external knowledge about domains, URLs, or tool purposes - Do NOT make assumptions about what tools "should" or "shouldn't" do - Focus ONLY on patterns in the context data @@ -31,6 +31,20 @@ {context_str} +### Understanding Ledger Tool States + +If the context includes a ledger with tool_calls, the "status" field indicates: + +- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). +This is NOT a duplicate call - it's the tool you're deciding whether to approve. +The tool has NOT started executing yet. +- **"success"**: The tool completed successfully in a previous turn +- **"error"**: The tool failed or was cancelled in a previous turn + +**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, +that IS the current tool being evaluated. +It is NOT already executing or a duplicate. + ## Event to Evaluate {event_description} diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py index 1d280f7c1..c3cde475b 100644 --- a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py +++ b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py @@ -87,11 +87,19 @@ def test_ledger_after_tool_call_success(mock_datetime): # Set up existing ledger with pending call existing_ledger = { - "tool_calls": [{"tool_name": "test_tool", "status": "pending", "timestamp": "2024-01-01T12:00:00"}] + "tool_calls": [ + { + "tool_use_id": "test-id", + "tool_name": "test_tool", + "status": "pending", + "timestamp": "2024-01-01T12:00:00", + } + ] } steering_context.data.set("ledger", existing_ledger) event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "test-id"} event.result = {"status": "success", "content": ["success_result"]} event.exception = None @@ -133,3 +141,184 @@ def test_session_start_persistence(): callback = LedgerBeforeToolCall() assert callback.session_start == "2024-01-01T10:00:00" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_all_pending(mock_datetime): + """Test multiple tool calls added as pending before any execute.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + # Add three tool calls in sequence (simulating parallel proposal) + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert len(ledger["tool_calls"]) == 3 + assert all(call["status"] == "pending" for call in ledger["tool_calls"]) + assert [call["tool_name"] for call in ledger["tool_calls"]] == ["tool_a", "tool_b", "tool_c"] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_by_id(mock_datetime): + """Test tool calls complete in any order by matching toolUseId.""" + # Need timestamps for: session_start + 3 tool calls + 1 completion + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_a + "2024-01-01T12:01:00", # tool_b + "2024-01-01T12:02:00", # tool_c + "2024-01-01T12:03:00", # completion + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + before_callback(event, steering_context) + + # Complete middle tool first (out of order) + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "success", "content": ["result_b"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert ledger["tool_calls"][1]["status"] == "success" + assert ledger["tool_calls"][1]["result"] == ["result_b"] + assert ledger["tool_calls"][2]["status"] == "pending" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_all_out_of_order(mock_datetime): + """Test all parallel tool calls complete in reverse order.""" + # Need timestamps for: session_start + 3 tool calls + 3 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # tool_2 + "2024-01-01T12:03:00", # completion tool_2 + "2024-01-01T12:04:00", # completion tool_1 + "2024-01-01T12:05:00", # completion tool_0 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i in range(3): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # Complete in reverse order: 2, 1, 0 + for i in [2, 1, 0]: + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}"} + event.result = {"status": "success", "content": [f"result_{i}"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert all(call["status"] == "success" for call in ledger["tool_calls"]) + assert ledger["tool_calls"][0]["result"] == ["result_0"] + assert ledger["tool_calls"][1]["result"] == ["result_1"] + assert ledger["tool_calls"][2]["result"] == ["result_2"] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_with_failure(mock_datetime): + """Test parallel tool calls where one fails.""" + # Need timestamps for: session_start + 2 tool calls + 2 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # completion tool_0 + "2024-01-01T12:03:00", # completion tool_1 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add two pending tool calls + for i in range(2): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # First succeeds + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_0"} + event.result = {"status": "success", "content": ["result_0"]} + event.exception = None + after_callback(event, steering_context) + + # Second fails + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "error", "content": []} + event.exception = ValueError("test error") + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "success" + assert ledger["tool_calls"][0]["error"] is None + assert ledger["tool_calls"][1]["status"] == "error" + assert ledger["tool_calls"][1]["error"] == "test error" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_after_tool_call_no_matching_id(mock_datetime): + """Test AfterToolCallEvent when tool_use_id doesn't match any pending call.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add a pending tool call + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "id_1", "name": "tool_1", "input": {}} + before_callback(event, steering_context) + + # Try to complete a different tool_use_id that doesn't exist + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_999"} + event.result = {"status": "success", "content": ["result"]} + event.exception = None + after_callback(event, steering_context) + + # Original tool should still be pending (no match found) + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert "completion_timestamp" not in ledger["tool_calls"][0] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_tool_use_id_stored_in_ledger(mock_datetime): + """Test that toolUseId is stored in ledger entries.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "test-id-123", "name": "test_tool", "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["tool_use_id"] == "test-id-123" From 353c15c2f0c324408b629662b03d1c891e096d2f Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 30 Jan 2026 10:22:46 -0500 Subject: [PATCH 08/16] Clone main metrics upload script for integ tests (#1600) --- .github/workflows/integration-test.yml | 38 ++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index bbcdfde25..00fda1262 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -59,8 +59,42 @@ jobs: run: | hatch test tests_integ - - name: Publish test metrics to CloudWatch + - name: Upload test results if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results + path: ./build/test-results.xml + + upload-metrics: + runs-on: ubuntu-latest + needs: check-access-and-checkout + if: always() + permissions: + id-token: write + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v5 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + + - name: Checkout main + uses: actions/checkout@v6 + with: + ref: main + sparse-checkout: | + .github/scripts + persist-credentials: false + + - name: Download test results + uses: actions/download-artifact@v4 + with: + name: test-results + + - name: Publish test metrics to CloudWatch run: | pip install --no-cache-dir boto3 - python .github/scripts/upload-integ-test-metrics.py ./build/test-results.xml ${{ github.event.repository.name }} + python .github/scripts/upload-integ-test-metrics.py test-results.xml ${{ github.event.repository.name }} From 7a4de6b36e04558058099f96d006737c9c43e428 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 30 Jan 2026 11:40:42 -0500 Subject: [PATCH 09/16] Skip location for non bedrock model providers (#1602) --- src/strands/models/_validation.py | 21 ++++++++ src/strands/models/anthropic.py | 7 ++- src/strands/models/bedrock.py | 2 +- src/strands/models/gemini.py | 27 ++++++---- src/strands/models/llamaapi.py | 18 ++++--- src/strands/models/llamacpp.py | 18 ++++--- src/strands/models/mistral.py | 7 ++- src/strands/models/ollama.py | 18 ++++--- src/strands/models/openai.py | 18 ++++--- src/strands/models/writer.py | 16 ++++-- tests/strands/models/test__validation.py | 67 +++++++++++++++++++++++ tests/strands/models/test_anthropic.py | 67 +++++++++++++++++++++++ tests/strands/models/test_bedrock.py | 2 +- tests/strands/models/test_gemini.py | 64 ++++++++++++++++++++++ tests/strands/models/test_llamaapi.py | 67 +++++++++++++++++++++++ tests/strands/models/test_llamacpp.py | 69 ++++++++++++++++++++++++ tests/strands/models/test_mistral.py | 63 ++++++++++++++++++++++ tests/strands/models/test_ollama.py | 66 +++++++++++++++++++++++ tests/strands/models/test_openai.py | 65 ++++++++++++++++++++++ tests/strands/models/test_writer.py | 67 +++++++++++++++++++++++ 20 files changed, 708 insertions(+), 41 deletions(-) create mode 100644 tests/strands/models/test__validation.py diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 1e82bca73..9d4d8b178 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -6,6 +6,7 @@ from typing_extensions import get_type_hints +from ..types.content import ContentBlock from ..types.tools import ToolChoice @@ -41,3 +42,23 @@ def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None: "A ToolChoice was provided to this provider but is not supported and will be ignored", stacklevel=4, ) + + +def _has_location_source(content: ContentBlock) -> bool: + """Check if a content block contains a location source. + + Providers need to explicitly define an implementation to support content locations. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an location source, False otherwise. + """ + if "image" in content: + return "location" in content["image"].get("source", {}) + if "document" in content: + return "location" in content["document"].get("source", {}) + if "video" in content: + return "location" in content["video"].get("source", {}) + return False diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 535c820ee..b5f6fcf91 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -20,7 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -189,6 +189,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} continue + # Check for location sources in image, document, or video content + if _has_location_source(content): + logger.warning("Location sources are not supported by Anthropic | skipping content block") + continue + formatted_contents.append(self._format_request_message_content(content)) if formatted_contents: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b053b70fb..596936e6f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -472,7 +472,7 @@ def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None: formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"] return {"s3Location": formatted_document_s3} else: - logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block") + logger.warning("Non s3 location sources are not supported by Bedrock | skipping content block") return None def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 192a363d3..6a6535999 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -18,7 +18,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -229,15 +229,24 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten # available in tool result blocks, hence the mapping. tool_use_id_to_name: dict[str, str] = {} - return [ - genai.types.Content( - parts=[ - self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"] - ], - role="user" if message["role"] == "user" else "model", + contents = [] + for message in messages: + parts = [] + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Gemini | skipping content block") + continue + parts.append(self._format_request_content_part(content, tool_use_id_to_name)) + + contents.append( + genai.types.Content( + parts=parts, + role="user" if message["role"] == "user" else "model", + ) ) - for message in messages - ] + + return contents def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index ce0367bf5..b1ed4563a 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -20,7 +20,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -176,12 +176,18 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None for message in messages: contents = message["content"] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by LlamaAPI | skipping content block") + continue + filtered_contents.append(content) + formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + formatted_contents = [self._format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index ca838f3d7..c52509816 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -30,7 +30,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -299,11 +299,17 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None) for message in messages: contents = message["content"] - formatted_contents = [ - self._format_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by llama.cpp | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [self._format_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_tool_call( { diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 4ec77ccfe..504e81c92 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -17,7 +17,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -212,6 +212,11 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None tool_messages: list[dict[str, Any]] = [] for content in contents: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Mistral | skipping content block") + continue + if "text" in content: formatted_content = self._format_request_message_content(content) if isinstance(formatted_content, str): diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 8d72aa534..68aba59d4 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -15,7 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -160,12 +160,16 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None """ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - return system_message + [ - formatted_message - for message in messages - for content in message["content"] - for formatted_message in self._format_request_message_contents(message["role"], content) - ] + formatted_messages = [] + for message in messages: + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Ollama | skipping content block") + continue + formatted_messages.extend(self._format_request_message_contents(message["role"], content)) + + return system_message + formatted_messages def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index d9266212b..51e98c8c2 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -20,7 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -338,11 +338,17 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." ) - formatted_contents = [ - cls.format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) - ] + # Filter out content blocks that shouldn't be formatted + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by OpenAI | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [cls.format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content ] diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f306d649b..94774b363 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -18,7 +18,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -218,11 +218,21 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None for message in messages: contents = message["content"] + # Filter out location sources + filtered_contents = [] + for content in contents: + if _has_location_source(content): + logger.warning("Location sources are not supported by Writer | skipping content block") + continue + filtered_contents.append(content) + # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) + formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision( + filtered_contents + ) else: - formatted_contents = self._format_request_message_contents(contents) + formatted_contents = self._format_request_message_contents(filtered_contents) formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) diff --git a/tests/strands/models/test__validation.py b/tests/strands/models/test__validation.py new file mode 100644 index 000000000..e8a451494 --- /dev/null +++ b/tests/strands/models/test__validation.py @@ -0,0 +1,67 @@ +"""Tests for model validation helper functions.""" + +from strands.models._validation import _has_location_source + + +class TestHasLocationSource: + """Tests for _has_location_source helper function.""" + + def test_image_with_location_source(self): + """Test detection of location source in image content.""" + content = {"image": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_image_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"image": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_document_with_location_source(self): + """Test detection of location source in document content.""" + content = {"document": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_document_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"document": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_video_with_location_source(self): + """Test detection of location source in video content.""" + content = {"video": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_video_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"video": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_text_content(self): + """Test that text content is not detected as location source.""" + content = {"text": "hello"} + assert not _has_location_source(content) + + def test_tool_use_content(self): + """Test that toolUse content is not detected as location source.""" + content = {"toolUse": {"name": "test", "input": {}, "toolUseId": "123"}} + assert not _has_location_source(content) + + def test_tool_result_content(self): + """Test that toolResult content is not detected as location source.""" + content = {"toolResult": {"toolUseId": "123", "content": [{"text": "result"}]}} + assert not _has_location_source(content) + + def test_image_without_source(self): + """Test that image without source is not detected as location.""" + content = {"image": {"format": "png"}} + assert not _has_location_source(content) + + def test_document_without_source(self): + """Test that document without source is not detected as location.""" + content = {"document": {"format": "pdf", "name": "test.pdf"}} + assert not _has_location_source(content) + + def test_video_without_source(self): + """Test that video without source is not detected as location.""" + content = {"video": {"format": "mp4"}} + assert not _has_location_source(content) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 74bbb8d45..c5aff8062 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import anthropic @@ -866,3 +867,69 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, model_id, max_tokens, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "look at this image"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text + + +def test_format_request_filters_location_source_document(model, model_id, max_tokens, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "analyze this document"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 761434258..aac791214 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1924,7 +1924,7 @@ def test_format_request_unsupported_location(model, caplog): formatted_request = model._format_request(messages) assert len(formatted_request["messages"][0]["content"]) == 1 - assert "Non s3 location sources are not supported by Bedrock, skipping content block" in caplog.text + assert "Non s3 location sources are not supported by Bedrock | skipping content block" in caplog.text def test_format_request_video_s3_location(model, model_id): diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 86ab2fea5..d62c5a7c8 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -934,3 +934,67 @@ def test_init_with_both_client_and_client_args_raises_error(): with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index a6bbf5673..2bf12d055 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import logging import unittest.mock import pytest @@ -414,3 +415,69 @@ async def test_tool_choice_none_no_warning(model, messages, captured_warnings, a await alist(response) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index e5b2614c0..fa784de5c 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -2,6 +2,7 @@ import base64 import json +import logging from unittest.mock import AsyncMock, patch import httpx @@ -637,3 +638,71 @@ def test_format_messages_with_mixed_content() -> None: assert result[0]["content"][2]["type"] == "image_url" assert "image_url" in result[0]["content"][2] assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") + + +def test_format_request_filters_s3_source_image(caplog) -> None: + """Test that images with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text + + +def test_format_request_filters_location_source_document(caplog) -> None: + """Test that documents with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 7808336f2..ad74bae89 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import pydantic @@ -592,3 +593,65 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Image with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "look at this image" + assert "Location sources are not supported by Mistral" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Document with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "analyze this document" + assert "Location sources are not supported by Mistral" in caplog.text diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 14db63a24..d17894028 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -1,4 +1,5 @@ import json +import logging import unittest.mock import pydantic @@ -559,3 +560,68 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "look at this image" + assert "images" not in user_message or user_message.get("images") == [] + assert "Location sources are not supported by Ollama" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "analyze this document" + assert "Location sources are not supported by Ollama" in caplog.text diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7c1d18998..6eeb477d9 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import openai @@ -1246,3 +1247,67 @@ def test_init_with_both_client_and_client_args_raises_error(): with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 963904002..81745f412 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,3 +1,4 @@ +import logging import unittest.mock from typing import Any @@ -435,3 +436,69 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text From a82b9059b42a33f4888017beb93aeaf5cce4b74b Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Feb 2026 12:41:18 -0500 Subject: [PATCH 10/16] Add conditional execution for finalize step (#1605) --- .github/workflows/strands-command.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index 6c3328192..6cd43c5c0 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -79,6 +79,7 @@ jobs: write_permission: 'false' finalize: + if: always() needs: [setup-and-process, execute-readonly-agent] permissions: contents: write From 0669bf25e14e2b28148d296d2a96face42877f14 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 2 Feb 2026 16:51:18 -0500 Subject: [PATCH 11/16] interrupts - graph - multiagent nodes (#1606) --- src/strands/multiagent/graph.py | 70 +++++++------ tests/strands/multiagent/test_graph.py | 97 ++++++++++++++++++- .../{test_agent.py => test_node.py} | 12 ++- .../interrupts/multiagent/test_session.py | 28 ++++-- 4 files changed, 166 insertions(+), 41 deletions(-) rename tests_integ/interrupts/multiagent/{test_agent.py => test_node.py} (93%) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d296753c0..6b135d1a7 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -603,17 +603,20 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + def _activate_interrupt( + self, node: GraphNode, interrupts: list[Interrupt], from_hook: bool = False + ) -> MultiAgentNodeInterruptEvent: """Activate the interrupt state. Args: node: The interrupted node. interrupts: The interrupts raised by the user. + from_hook: Whether the interrupt originated from a hook (e.g., BeforeNodeCallEvent). Returns: MultiAgentNodeInterruptEvent """ - logger.debug("node=<%s> | node interrupted", node.node_id) + logger.debug("node=<%s>, from_hook=<%s> | node interrupted", node.node_id, from_hook) node.execution_status = Status.INTERRUPTED @@ -622,13 +625,20 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() + + self._interrupt_state.context[node.node_id] = { + "from_hook": from_hook, + "interrupt_ids": [interrupt.id for interrupt in interrupts], + } + if isinstance(node.executor, Agent): - self._interrupt_state.context[node.node_id] = { - "activated": node.executor._interrupt_state.activated, - "interrupt_state": node.executor._interrupt_state.to_dict(), - "state": node.executor.state.get(), - "messages": node.executor.messages, - } + self._interrupt_state.context[node.node_id].update( + { + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + ) return MultiAgentNodeInterruptEvent(node.node_id, interrupts) @@ -866,7 +876,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) start_time = time.time() try: if interrupts: - yield self._activate_interrupt(node, interrupts) + yield self._activate_interrupt(node, interrupts, from_hook=True) return if before_event.cancel_node: @@ -896,20 +906,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - if multi_agent_result.status == Status.INTERRUPTED: - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from a multi agent node" - ) - node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, + status=multi_agent_result.status, accumulated_usage=multi_agent_result.accumulated_usage, accumulated_metrics=multi_agent_result.accumulated_metrics, execution_count=multi_agent_result.execution_count, + interrupts=multi_agent_result.interrupts, ) elif isinstance(node.executor, Agent): @@ -1040,18 +1044,26 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """ if self._interrupt_state.activated: context = self._interrupt_state.context - if node.node_id in context and context[node.node_id]["activated"]: - agent_context = context[node.node_id] - agent = cast(Agent, node.executor) - agent.messages = agent_context["messages"] - agent.state = AgentState(agent_context["state"]) - agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"]) - - responses = context["responses"] - interrupts = agent._interrupt_state.interrupts - return [ - response for response in responses if response["interruptResponse"]["interruptId"] in interrupts - ] + if node.node_id in context: + node_context = context[node.node_id] + + # Only route responses if the interrupt originated from the node's execution + if not node_context["from_hook"]: + # Filter responses to only those for this node's interrupts + node_responses = [ + response + for response in context["responses"] + if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"] + ] + + if isinstance(node.executor, MultiAgentBase): + return node_responses + + agent = node.executor + agent.messages = node_context["messages"] + agent.state = AgentState(node_context["state"]) + agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) + return node_responses # Get satisfied dependencies dependency_results = {} diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c511328d4..0fbb102a4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2228,7 +2228,8 @@ def test_graph_interrupt_on_agent(agenerator): ], ) graph._interrupt_state.context["test_agent"] = { - "activated": True, + "from_hook": False, + "interrupt_ids": [interrupt.id], "interrupt_state": { "activated": True, "context": {}, @@ -2259,3 +2260,97 @@ def test_graph_interrupt_on_agent(agenerator): assert len(multiagent_result.results) == 1 agent.stream_async.assert_called_once_with(responses, invocation_state={}) + + +def test_graph_interrupt_on_multiagent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + multiagent = create_mock_multi_agent("test_multiagent", "Multi-agent completed") + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={}, + status=Status.INTERRUPTED, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(multiagent, "test_multiagent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_multiagent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={ + "inner_node": NodeResult( + result=AgentResult( + message={"role": "assistant", "content": [{"text": "Inner completed"}]}, + stop_reason="end_turn", + state={}, + metrics={}, + ) + ) + }, + status=Status.COMPLETED, + ), + }, + ], + ) + graph._interrupt_state.context["test_multiagent"] = { + "from_hook": False, + "interrupt_ids": [interrupt.id], + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + multiagent.stream_async.assert_called_once_with(responses, {}) diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_node.py similarity index 93% rename from tests_integ/interrupts/multiagent/test_agent.py rename to tests_integ/interrupts/multiagent/test_node.py index 1a6ad87c6..23e7a62bc 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_node.py @@ -65,13 +65,13 @@ def swarm(weather_agent): @pytest.fixture -def graph(info_agent, day_agent, time_agent, weather_agent): +def graph(info_agent, day_agent, time_agent, swarm): builder = GraphBuilder() builder.add_node(info_agent, "info") builder.add_node(day_agent, "day") builder.add_node(time_agent, "time") - builder.add_node(weather_agent, "weather") + builder.add_node(swarm, "weather") builder.add_edge("info", "day") builder.add_edge("info", "time") @@ -82,7 +82,7 @@ def graph(info_agent, day_agent, time_agent, weather_agent): return builder.build() -def test_swarm_interrupt_agent(swarm): +def test_swarm_interrupt_node(swarm): multiagent_result = swarm("What is the weather?") tru_status = multiagent_result.status @@ -122,7 +122,7 @@ def test_swarm_interrupt_agent(swarm): assert "sunny" in weather_message -def test_graph_interrupt_agent(graph): +def test_graph_interrupt_node(graph): multiagent_result = graph("What is the day, time, and weather?") tru_result_status = multiagent_result.status @@ -180,7 +180,9 @@ def test_graph_interrupt_agent(graph): day_message = json.dumps(multiagent_result.results["day"].result.message).lower() time_message = json.dumps(multiagent_result.results["time"].result.message).lower() - weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() assert "monday" in day_message assert "12:01" in time_message + + nested_multiagent_result = multiagent_result.results["weather"].result + weather_message = json.dumps(nested_multiagent_result.results["weather"].result.message).lower() assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index 96b9844bf..8a5979d63 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -72,15 +72,23 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): def test_graph_interrupt_session(weather_tool, tmpdir): + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() multiagent_result = graph("Can you check the weather and then summarize the results?") @@ -105,15 +113,23 @@ def test_graph_interrupt_session(weather_tool, tmpdir): interrupt = multiagent_result.interrupts[0] + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() responses = [ From cbff9c62d141e008fc1b10cd4b053d1faaa2627f Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Feb 2026 16:58:30 -0500 Subject: [PATCH 12/16] fix various test warnings (#1613) --- .../strands/agent/hooks/test_agent_events.py | 7 +- tests/strands/agent/test_agent_hooks.py | 3 + tests/strands/event_loop/test_streaming.py | 1 + .../experimental/hooks/test_hook_aliases.py | 16 ++-- .../tools/test_tool_provider_alias.py | 6 +- tests/strands/models/test_bedrock.py | 87 +++++++++---------- tests/strands/models/test_llamacpp.py | 4 +- tests/strands/multiagent/a2a/test_executor.py | 3 + tests/strands/multiagent/test_base.py | 1 + tests/strands/tools/test_loader.py | 8 ++ tests/strands/tools/test_registry.py | 1 + 11 files changed, 81 insertions(+), 56 deletions(-) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index f511c7019..02c367ccc 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -84,7 +84,7 @@ async def test_stream_e2e_success(alist): mock_callback = unittest.mock.Mock() agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) tool_config = { "toolChoice": {"auto": {}}, @@ -344,7 +344,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): mock_callback = unittest.mock.Mock() agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) # Base object with common properties throttle_props = { @@ -492,7 +492,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Because we're throwing an exception, we manually collect the items here tru_events = [] - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) async for event in stream: tru_events.append(event) @@ -525,6 +525,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( assert typed_events == [] +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_structured_output(agenerator): # we use bedrock here as it uses the tool implementation diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 8ff81295a..4397b9628 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -284,6 +284,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert len(agent.messages) == 4 +@pytest.mark.filterwarnings("ignore:Agent.structured_output method is deprecated:DeprecationWarning") def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output.""" @@ -300,6 +301,7 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert len(agent.messages) == 0 # no new messages added +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output_async.""" @@ -667,6 +669,7 @@ async def overwrite_input_hook(event: BeforeInvocationEvent): assert agent.messages[0]["content"][0]["text"] == "GOODBYE" +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_before_invocation_event_messages_none_in_structured_output(agenerator): """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index b2cc152cb..0fe04f4b2 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -48,6 +48,7 @@ def moto_autouse(moto_env, moto_mock_aws): ), ], ) +@pytest.mark.filterwarnings("ignore:remove_blank_messages_content_text is deprecated:DeprecationWarning") def test_remove_blank_messages_content_text(messages, exp_result): tru_result = strands.event_loop.streaming.remove_blank_messages_content_text(messages) diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index b229c1c2d..ed7adba8a 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -7,16 +7,20 @@ import importlib import sys +import warnings from unittest.mock import Mock import pytest -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, -) +# Suppress deprecation warnings from imports since we're testing the aliases themselves +with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, + ) from strands.hooks import ( AfterModelCallEvent, AfterToolCallEvent, diff --git a/tests/strands/experimental/tools/test_tool_provider_alias.py b/tests/strands/experimental/tools/test_tool_provider_alias.py index 58a2b9e20..3b3055bc6 100644 --- a/tests/strands/experimental/tools/test_tool_provider_alias.py +++ b/tests/strands/experimental/tools/test_tool_provider_alias.py @@ -6,6 +6,7 @@ """ import sys +import warnings import pytest @@ -14,7 +15,10 @@ def test_experimental_alias_is_same_type(): """Verify that experimental ToolProvider alias is identical to the actual type.""" - from strands.experimental.tools import ToolProvider as ExperimentalToolProvider + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.tools import ToolProvider as ExperimentalToolProvider assert ExperimentalToolProvider is ToolProvider diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index aac791214..1410e129b 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -31,7 +31,9 @@ def session_cls(): # Mock the creation of a Session so that we don't depend on environment variables or profiles with unittest.mock.patch.object(strands.models.bedrock.boto3, "Session") as mock_session_cls: - mock_session_cls.return_value.region_name = None + mock_session = unittest.mock.Mock() + mock_session.region_name = None + mock_session_cls.return_value = mock_session yield mock_session_cls @@ -216,66 +218,63 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) -def test__init__default_user_agent(bedrock_client): +def test__init__default_user_agent(session_cls, bedrock_client): """Set user agent when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() + _ = BedrockModel() - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__default_read_timeout(bedrock_client): +def test__init__default_read_timeout(session_cls, bedrock_client): """Set default read timeout when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() - # Verify the client was created with the correct read timeout - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + _ = BedrockModel() + # Verify the client was created with the correct read timeout + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + +def test__init__with_custom_boto_client_config_no_user_agent(session_cls, bedrock_client): """Set user agent when boto_client_config is provided without user_agent_extra.""" custom_config = BotocoreConfig(read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) + _ = BedrockModel(boto_client_config=custom_config) - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == 900 + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 -def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_with_user_agent(session_cls, bedrock_client): """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) - - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" - assert kwargs["config"].read_timeout == 900 + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 def test__init__model_config(bedrock_client): diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index fa784de5c..3e023dfce 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -3,7 +3,7 @@ import base64 import json import logging -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -248,7 +248,7 @@ async def mock_aiter_lines(): mock_response = AsyncMock() mock_response.aiter_lines = mock_aiter_lines - mock_response.raise_for_status = AsyncMock() + mock_response.raise_for_status = MagicMock() with patch.object(model.client, "post", return_value=mock_response): messages = [{"role": "user", "content": [{"text": "Hi"}]}] diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 73ade574e..bb039bdce 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -11,6 +11,9 @@ from strands.multiagent.a2a.executor import StrandsA2AExecutor from strands.types.content import ContentBlock +# Suppress A2A compliance warnings for legacy streaming mode tests +pytestmark = pytest.mark.filterwarnings("ignore:The default A2A response stream.*:UserWarning") + # Test data constants VALID_PNG_BYTES = b"fake_png_data" VALID_MP4_BYTES = b"fake_mp4_data" diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..2fb2cc617 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -156,6 +156,7 @@ def deserialize_state(self, payload: dict) -> None: assert isinstance(agent, MultiAgentBase) +@pytest.mark.filterwarnings("ignore:`\\*\\*kwargs` parameter is deprecating:UserWarning") def test_multi_agent_base_call_method(): """Test that __call__ method properly delegates to invoke_async.""" diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 1c665b42a..121ebed2d 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -10,6 +10,14 @@ from strands.tools.loader import _TOOL_MODULE_PREFIX, ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool +# Suppress deprecation warnings for deprecated ToolLoader methods being tested +pytestmark = pytest.mark.filterwarnings( + "ignore:ToolLoader.load_python_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_python_tools is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tools is deprecated:DeprecationWarning", +) + @pytest.fixture def tool_path(request, tmp_path, monkeypatch): diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ed96f2b6a..73141beb6 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -13,6 +13,7 @@ from strands.tools.registry import ToolRegistry +@pytest.mark.filterwarnings("ignore:load_tool_from_filepath is deprecated:DeprecationWarning") def test_load_tool_from_filepath_failure(): """Test error handling when load_tool fails.""" tool_registry = ToolRegistry() From 7c8279e1e52e33baa25f25f3752bcd3a16afc3a0 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 3 Feb 2026 10:06:04 -0500 Subject: [PATCH 13/16] Fix bedrock file warnings (#1603) From 5a1224630f6a2c1f35c51a652101e8a590bf46b8 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Mon, 26 Jan 2026 17:03:50 -0600 Subject: [PATCH 14/16] feat: Propagate exceptions to AfterToolCallEvent for decorated tools (#1565) --- tests/strands/tools/test_decorator.py | 77 +++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index f3d6eda02..5e39047b6 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1978,3 +1978,80 @@ def my_tool(name: str, tag: str | None = None) -> str: # Since tag is not required, anyOf should be simplified away assert "anyOf" not in schema["properties"]["tag"] assert schema["properties"]["tag"]["type"] == "string" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_runtime_error(alist): + """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + + @strands.tool + def error_tool(): + """Tool that raises a RuntimeError.""" + raise RuntimeError("test runtime error") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, RuntimeError) + assert str(result_event.exception) == "test runtime error" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_value_error(alist): + """Test that ToolResultEvent carries exception when tool raises ValueError.""" + + @strands.tool + def validation_error_tool(): + """Tool that raises a ValueError.""" + raise ValueError("validation failed") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(validation_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, ValueError) + assert str(result_event.exception) == "validation failed" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_no_exception_on_success(alist): + """Test that ToolResultEvent.exception is None when tool succeeds.""" + + @strands.tool + def success_tool(): + """Tool that succeeds.""" + return "success" + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(success_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.exception is None + assert result_event.tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_assertion_error(alist): + """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + + @strands.tool + def assertion_error_tool(): + """Tool that raises an AssertionError.""" + raise AssertionError("unexpected assertion failure") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(assertion_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert isinstance(result_event.exception, AssertionError) + assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.tool_result["status"] == "error" From 60dad8a4bf2fc0565e55a14273ef35fb3e9a56d1 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 3 Feb 2026 14:42:07 -0500 Subject: [PATCH 15/16] address test comments --- tests/strands/tools/test_decorator.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 5e39047b6..382c568a6 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1983,40 +1983,38 @@ def my_tool(name: str, tag: str | None = None) -> str: @pytest.mark.asyncio async def test_tool_result_event_carries_exception_runtime_error(alist): """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + exception = RuntimeError("test runtime error") @strands.tool def error_tool(): """Tool that raises a RuntimeError.""" - raise RuntimeError("test runtime error") + raise exception tool_use = {"toolUseId": "test-id", "input": {}} events = await alist(error_tool.stream(tool_use, {})) result_event = events[-1] assert isinstance(result_event, ToolResultEvent) - assert hasattr(result_event, "exception") - assert isinstance(result_event.exception, RuntimeError) - assert str(result_event.exception) == "test runtime error" + assert result_event.exception is exception assert result_event.tool_result["status"] == "error" @pytest.mark.asyncio async def test_tool_result_event_carries_exception_value_error(alist): """Test that ToolResultEvent carries exception when tool raises ValueError.""" + exception = ValueError("validation failed") @strands.tool def validation_error_tool(): """Tool that raises a ValueError.""" - raise ValueError("validation failed") + raise exception tool_use = {"toolUseId": "test-id", "input": {}} events = await alist(validation_error_tool.stream(tool_use, {})) result_event = events[-1] assert isinstance(result_event, ToolResultEvent) - assert hasattr(result_event, "exception") - assert isinstance(result_event.exception, ValueError) - assert str(result_event.exception) == "validation failed" + assert result_event.exception is exception assert result_event.tool_result["status"] == "error" @@ -2041,17 +2039,17 @@ def success_tool(): @pytest.mark.asyncio async def test_tool_result_event_carries_exception_assertion_error(alist): """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + exception = AssertionError("unexpected assertion failure") @strands.tool def assertion_error_tool(): """Tool that raises an AssertionError.""" - raise AssertionError("unexpected assertion failure") + raise exception tool_use = {"toolUseId": "test-id", "input": {}} events = await alist(assertion_error_tool.stream(tool_use, {})) result_event = events[-1] assert isinstance(result_event, ToolResultEvent) - assert isinstance(result_event.exception, AssertionError) - assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.exception is exception assert result_event.tool_result["status"] == "error" From 827407170419c8364dc90abb8f407ee28347c63b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 4 Feb 2026 10:10:47 -0500 Subject: [PATCH 16/16] Remove implementation details from ToolResultEvent docstring --- src/strands/types/_events.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 9f3f0c4e3..5b0ae78f6 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -274,16 +274,7 @@ def prepare(self, invocation_state: dict) -> None: class ToolResultEvent(TypedEvent): - """Event emitted when a tool execution completes. - - Stores the full Exception object as an instance attribute for debugging while - keeping the event dict JSON-serializable. The exception can be accessed via - the `exception` property for re-raising or type-based error handling in hooks. - - Parameters: - tool_result: Final result from the tool execution. - exception: Optional exception that occurred during tool execution. - """ + """Event emitted when a tool execution completes.""" def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None: """Initialize tool result event."""