diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 04c14e452..70552d6ba 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -620,6 +620,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_msg}"}], }, + exception=e, ) except Exception as e: # Return error result with exception details for any other error @@ -632,14 +633,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], }, + exception=e, ) - def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | None = None) -> ToolResultEvent: # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_d - return ToolResultEvent(cast(ToolResult, result)) + return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format # Always include at least one content item for consistency @@ -648,7 +650,8 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: "toolUseId": tool_use_d, "status": "success", "content": [{"text": str(result)}], - } + }, + exception=exception, ) @property diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index ef000fbd6..0da6b5715 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -215,6 +215,9 @@ async def _stream( return if structured_output_context.is_enabled: kwargs["structured_output_context"] = structured_output_context + + exception: Exception | None = None + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. @@ -227,6 +230,8 @@ async def _stream( return if isinstance(event, ToolResultEvent): + # Preserve exception from decorated tools before extracting tool_result + exception = event.exception # below the last "event" must point to the tool_result event = event.tool_result break @@ -239,7 +244,7 @@ async def _stream( result = cast(ToolResult, event) after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, selected_tool, tool_use, invocation_state, result, exception=exception ) # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 0896d48e1..5b0ae78f6 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -276,13 +276,18 @@ def prepare(self, invocation_state: dict) -> None: class ToolResultEvent(TypedEvent): """Event emitted when a tool execution completes.""" - def __init__(self, tool_result: ToolResult) -> None: - """Initialize with the completed tool result. + def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None: + """Initialize tool result event.""" + super().__init__({"type": "tool_result", "tool_result": tool_result}) + self._exception = exception - Args: - tool_result: Final result from the tool execution + @property + def exception(self) -> Exception | None: + """The original exception that occurred, if any. + + Can be used for re-raising or type-based error handling. """ - super().__init__({"type": "tool_result", "tool_result": tool_result}) + return self._exception @property def tool_use_id(self) -> str: diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 78e35c2aa..4a5479503 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -482,6 +482,98 @@ async def test_executor_stream_updates_invocation_state_with_agent( assert empty_invocation_state["agent"] is agent +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_exception_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exceptions from @tool-decorated functions reach AfterToolCallEvent.""" + exception = ValueError("decorated tool error") + + @strands.tool(name="decorated_error_tool") + def failing_tool(): + """A tool that raises an exception.""" + raise exception + + agent.tool_registry.register_tool(failing_tool) + tool_use = {"name": "decorated_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_runtime_error_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that RuntimeError from @tool-decorated functions reach AfterToolCallEvent.""" + exception = RuntimeError("runtime error from decorated tool") + + @strands.tool(name="runtime_error_tool") + def runtime_error_tool(): + """A tool that raises a RuntimeError.""" + raise exception + + agent.tool_registry.register_tool(runtime_error_tool) + tool_use = {"name": "runtime_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_no_exception_on_success( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when decorated tool succeeds.""" + + @strands.tool(name="success_decorated_tool") + def success_tool(): + """A tool that succeeds.""" + return "success" + + agent.tool_registry.register_tool(success_tool) + tool_use = {"name": "success_decorated_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_error_result_without_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exception is None when a tool returns an error result without throwing.""" + + @strands.tool(name="error_result_tool") + def error_result_tool(): + """A tool that returns an error result dict without raising.""" + return {"status": "error", "content": [{"text": "something went wrong"}]} + + agent.tool_registry.register_tool(error_result_tool) + tool_use = {"name": "error_result_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "error" + + @pytest.mark.asyncio async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): """Test default behavior when retry is not set - tool executes once.""" diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 42213fcb8..f3d6eda02 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1825,6 +1825,83 @@ def inner_default_tool(name: str, level: Annotated[int, Field(description="A lev return f"{name} is at level {level}" +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_runtime_error(alist): + """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + + @strands.tool + def error_tool(): + """Tool that raises a RuntimeError.""" + raise RuntimeError("test runtime error") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, RuntimeError) + assert str(result_event.exception) == "test runtime error" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_value_error(alist): + """Test that ToolResultEvent carries exception when tool raises ValueError.""" + + @strands.tool + def validation_error_tool(): + """Tool that raises a ValueError.""" + raise ValueError("validation failed") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(validation_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, ValueError) + assert str(result_event.exception) == "validation failed" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_no_exception_on_success(alist): + """Test that ToolResultEvent.exception is None when tool succeeds.""" + + @strands.tool + def success_tool(): + """Tool that succeeds.""" + return "success" + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(success_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.exception is None + assert result_event.tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_assertion_error(alist): + """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + + @strands.tool + def assertion_error_tool(): + """Tool that raises an AssertionError.""" + raise AssertionError("unexpected assertion failure") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(assertion_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert isinstance(result_event.exception, AssertionError) + assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.tool_result["status"] == "error" + + def test_tool_nullable_required_field_preserves_anyof(): """Test that a required nullable field preserves anyOf so the model can pass null.