From 695ca66541853e7983b9b39c8649801cd7018875 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 5 Jan 2026 19:21:57 +0200 Subject: [PATCH 01/47] docs: update github agent action to reference AGENT_SESSIONS_BUCKET secret (#1418) --- .github/actions/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/README.md b/.github/actions/README.md index a3ec3fa2d..6559462cb 100644 --- a/.github/actions/README.md +++ b/.github/actions/README.md @@ -198,7 +198,7 @@ Your IAM role must have these permissions in order to execute: 3. **Create S3 Bucket** for session storage 4. **Add GitHub Secrets**: - `AWS_ROLE_ARN`: The created role ARN - - `STRANDS_SESSION_BUCKET`: The S3 bucket name + - `AGENT_SESSIONS_BUCKET`: The S3 bucket name ## Security From 50e5e74c53de58c7de2b9ecbfbca488fcdbf456e Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:47:26 -0500 Subject: [PATCH 02/47] feat: provide extra command content as the the prompt to the agent (#1419) Previously triggering the agent would always provide the prompt of "review and continue" to the agent; this meant that if you gave the agent explicit commands in the comment it wouldn't necessarily receive/act on those. For example: /strands you didn't do X, please do it It would not actually receive the extra text; this updates it so that everything after the "strands command" is added as the prompt, defaulting to "review and continue" if non is provided --------- Co-authored-by: Mackenzie Zastrow --- .../actions/strands-agent-runner/action.yml | 2 +- .github/scripts/javascript/process-input.cjs | 20 +++++++++++++++++-- .github/scripts/python/agent_runner.py | 13 ++++++------ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/.github/actions/strands-agent-runner/action.yml b/.github/actions/strands-agent-runner/action.yml index 6d4c2d7fb..d0e93effe 100644 --- a/.github/actions/strands-agent-runner/action.yml +++ b/.github/actions/strands-agent-runner/action.yml @@ -149,7 +149,7 @@ runs: STRANDS_TOOL_CONSOLE_MODE: 'enabled' BYPASS_TOOL_CONSENT: 'true' run: | - uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py "$INPUT_TASK" + uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py - name: Capture repository state shell: bash diff --git a/.github/scripts/javascript/process-input.cjs b/.github/scripts/javascript/process-input.cjs index b7ed29263..395e37b64 100644 --- a/.github/scripts/javascript/process-input.cjs +++ b/.github/scripts/javascript/process-input.cjs @@ -8,9 +8,10 @@ async function getIssueInfo(github, context, inputs) { const issueId = context.eventName === 'workflow_dispatch' ? inputs.issue_id : context.payload.issue.number.toString(); + const commentBody = context.payload.comment?.body || ''; const command = context.eventName === 'workflow_dispatch' ? inputs.command - : (context.payload.comment.body.match(/^\/strands\s*(.*?)$/m)?.[1]?.trim() || ''); + : (commentBody.startsWith('/strands') ? commentBody.slice('/strands'.length).trim() : ''); console.log(`Event: ${context.eventName}, Issue ID: ${issueId}, Command: "${command}"`); @@ -76,10 +77,25 @@ function buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs) const scriptFile = scriptFiles[mode] || scriptFiles['refiner']; const systemPrompt = fs.readFileSync(scriptFile, 'utf8'); + // Extract the user's feedback/instructions after the mode keyword + // e.g., "release-notes Move #123 to Major Features" -> "Move #123 to Major Features" + const modeKeywords = { + 'release-notes': /^(?:release-notes|release notes)\s*/i, + 'implementer': /^implement\s*/i, + 'refiner': /^refine\s*/i + }; + + const modePattern = modeKeywords[mode]; + const userFeedback = modePattern ? command.replace(modePattern, '').trim() : command.trim(); + let prompt = (isPullRequest) ? 'The pull request id is:' : 'The issue id is:'; - prompt += `${issueId}\n${command}\nreview and continue`; + prompt += `${issueId}\n`; + + // If there's any user feedback beyond the command keyword, include it as the main instruction, + // otherwise default to "review and continue" + prompt += userFeedback || 'review and continue'; return { sessionId, systemPrompt, prompt }; } diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py index db10ceadb..9d92c2ac4 100644 --- a/.github/scripts/python/agent_runner.py +++ b/.github/scripts/python/agent_runner.py @@ -142,13 +142,12 @@ def run_agent(query: str): def main() -> None: """Main entry point for the agent runner.""" try: - # Read task from command line arguments - if len(sys.argv) < 2: - raise ValueError("Task argument is required") - - task = " ".join(sys.argv[1:]) - if not task.strip(): - raise ValueError("Task cannot be empty") + # Prefer INPUT_TASK env var (avoids shell escaping issues), fall back to CLI args + task = os.getenv("INPUT_TASK", "").strip() + if not task and len(sys.argv) > 1: + task = " ".join(sys.argv[1:]).strip() + if not task: + raise ValueError("Task is required (via INPUT_TASK env var or CLI argument)") print(f"🤖 Running agent with task: {task}") run_agent(task) From 3bc34acc76373c6465354fa57ced0498342cde32 Mon Sep 17 00:00:00 2001 From: AI Ape Wisdom Date: Wed, 7 Jan 2026 05:12:23 +0800 Subject: [PATCH 03/47] [FEATURE] add MCP resource operations in MCP Tools (#1117) * feat(tools): Add MCP resource operations * feat(tools): Add MCP resource operations * tests: add integ tests for mcp resources * fix: broken merge --------- Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 87 +++++++++++- tests/strands/tools/mcp/test_mcp_client.py | 154 ++++++++++++++++++++- tests_integ/mcp/echo_server.py | 19 +++ tests_integ/mcp/test_mcp_resources.py | 130 +++++++++++++++++ 4 files changed, 388 insertions(+), 2 deletions(-) create mode 100644 tests_integ/mcp/test_mcp_resources.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 6ce591bc5..37b99d021 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -21,11 +21,20 @@ import anyio from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT -from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents +from mcp.types import ( + BlobResourceContents, + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + TextResourceContents, +) from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from pydantic import AnyUrl from typing_extensions import Protocol, TypedDict from ...experimental.tools import ToolProvider @@ -449,6 +458,82 @@ async def _get_prompt_async() -> GetPromptResult: return get_prompt_result + def list_resources_sync(self, pagination_token: Optional[str] = None) -> ListResourcesResult: + """Synchronously retrieves the list of available resources from the MCP server. + + This method calls the asynchronous list_resources method on the MCP session + and returns the raw ListResourcesResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourcesResult: The raw MCP response containing resources and pagination info + """ + self._log_debug_with_thread("listing MCP resources synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resources_async() -> ListResourcesResult: + return await cast(ClientSession, self._background_thread_session).list_resources(cursor=pagination_token) + + list_resources_result: ListResourcesResult = self._invoke_on_background_thread(_list_resources_async()).result() + self._log_debug_with_thread("received %d resources from MCP server", len(list_resources_result.resources)) + + return list_resources_result + + def read_resource_sync(self, uri: AnyUrl | str) -> ReadResourceResult: + """Synchronously reads a resource from the MCP server. + + Args: + uri: The URI of the resource to read + + Returns: + ReadResourceResult: The resource content from the MCP server + """ + self._log_debug_with_thread("reading MCP resource synchronously: %s", uri) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _read_resource_async() -> ReadResourceResult: + # Convert string to AnyUrl if needed + resource_uri = AnyUrl(uri) if isinstance(uri, str) else uri + return await cast(ClientSession, self._background_thread_session).read_resource(resource_uri) + + read_resource_result: ReadResourceResult = self._invoke_on_background_thread(_read_resource_async()).result() + self._log_debug_with_thread("received resource content from MCP server") + + return read_resource_result + + def list_resource_templates_sync(self, pagination_token: Optional[str] = None) -> ListResourceTemplatesResult: + """Synchronously retrieves the list of available resource templates from the MCP server. + + Resource templates define URI patterns that can be used to access resources dynamically. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourceTemplatesResult: The raw MCP response containing resource templates and pagination info + """ + self._log_debug_with_thread("listing MCP resource templates synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resource_templates_async() -> ListResourceTemplatesResult: + return await cast(ClientSession, self._background_thread_session).list_resource_templates( + cursor=pagination_token + ) + + list_resource_templates_result: ListResourceTemplatesResult = self._invoke_on_background_thread( + _list_resource_templates_async() + ).result() + self._log_debug_with_thread( + "received %d resource templates from MCP server", len(list_resource_templates_result.resourceTemplates) + ) + + return list_resource_templates_result + def call_tool_sync( self, tool_use_id: str, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f5040de1b..35f11f47f 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -5,9 +5,21 @@ import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage +from mcp.types import ( + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + Prompt, + PromptMessage, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool +from pydantic import AnyUrl from strands.tools.mcp import MCPClient from strands.tools.mcp.mcp_types import MCPToolResult @@ -772,3 +784,143 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se assert result["metadata"] == metadata assert "structuredContent" in result assert result["structuredContent"] == structured_content + + +# Resource Tests - Sync Methods + + +def test_list_resources_sync(mock_transport, mock_session): + """Test that list_resources_sync correctly retrieves resources.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync() + + mock_session.list_resources.assert_called_once_with(cursor=None) + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert str(result.resources[0].uri) == "file://documents/test.txt" + assert result.nextCursor is None + + +def test_list_resources_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resources_sync correctly passes pagination token and returns next cursor.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource], nextCursor="next_page") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync(pagination_token="current_page") + + mock_session.list_resources.assert_called_once_with(cursor="current_page") + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert result.nextCursor == "next_page" + + +def test_list_resources_sync_session_not_active(): + """Test that list_resources_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resources_sync() + + +def test_read_resource_sync(mock_transport, mock_session): + """Test that read_resource_sync correctly reads a resource.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.read_resource_sync("file://documents/test.txt") + + # Verify the session method was called + mock_session.read_resource.assert_called_once() + # Check the URI argument (it will be wrapped as AnyUrl) + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_with_anyurl(mock_transport, mock_session): + """Test that read_resource_sync correctly handles AnyUrl input.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + uri = AnyUrl("file://documents/test.txt") + result = client.read_resource_sync(uri) + + mock_session.read_resource.assert_called_once() + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_session_not_active(): + """Test that read_resource_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.read_resource_sync("file://documents/test.txt") + + +def test_list_resource_templates_sync(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly retrieves resource templates.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult(resourceTemplates=[mock_template]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync() + + mock_session.list_resource_templates.assert_called_once_with(cursor=None) + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.resourceTemplates[0].uriTemplate == "file://documents/{name}" + assert result.nextCursor is None + + +def test_list_resource_templates_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly passes pagination token and returns next cursor.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult( + resourceTemplates=[mock_template], nextCursor="next_page" + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync(pagination_token="current_page") + + mock_session.list_resource_templates.assert_called_once_with(cursor="current_page") + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.nextCursor == "next_page" + + +def test_list_resource_templates_sync_session_not_active(): + """Test that list_resource_templates_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resource_templates_sync() diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index a23a87b5c..151f913d6 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -16,12 +16,15 @@ """ import base64 +import json from typing import Literal from mcp.server import FastMCP from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents from pydantic import BaseModel +TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + class EchoResponse(BaseModel): """Response model for echo with structured content.""" @@ -102,6 +105,22 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): ) ] + # Resources + @mcp.resource("test://static-text") + def static_text_resource() -> str: + """A static text resource for testing""" + return "This is the content of the static text resource." + + @mcp.resource("test://static-binary") + def static_binary_resource() -> bytes: + """A static binary resource (image) for testing""" + return base64.b64decode(TEST_IMAGE_BASE64) + + @mcp.resource("test://template/{id}/data") + def template_resource(id: str) -> str: + """A resource template with parameter substitution""" + return json.dumps({"id": id, "templateTest": True, "data": f"Data for ID: {id}"}) + mcp.run(transport="stdio") diff --git a/tests_integ/mcp/test_mcp_resources.py b/tests_integ/mcp/test_mcp_resources.py new file mode 100644 index 000000000..dccf3b808 --- /dev/null +++ b/tests_integ/mcp/test_mcp_resources.py @@ -0,0 +1,130 @@ +""" +Integration tests for MCP client resource functionality. + +This module tests the resource-related methods in MCPClient: +- list_resources_sync() +- read_resource_sync() +- list_resource_templates_sync() + +The tests use the echo server which has been extended with resource functionality. +""" + +import base64 +import json + +import pytest +from mcp import StdioServerParameters, stdio_client +from mcp.shared.exceptions import McpError +from mcp.types import BlobResourceContents, TextResourceContents +from pydantic import AnyUrl + +from strands.tools.mcp.mcp_client import MCPClient + + +def test_mcp_resources_list_and_read(): + """Test listing and reading various types of resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resources_sync + resources_result = mcp_client.list_resources_sync() + assert len(resources_result.resources) >= 2 # At least our 2 static resources + + # Verify resource URIs exist (only static resources, not templates) + resource_uris = [str(r.uri) for r in resources_result.resources] + assert "test://static-text" in resource_uris + assert "test://static-binary" in resource_uris + # Template resources are not listed in static resources + + # Test reading text resource + text_resource = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource.contents) == 1 + content = text_resource.contents[0] + assert isinstance(content, TextResourceContents) + assert "This is the content of the static text resource." in content.text + + # Test reading binary resource + binary_resource = mcp_client.read_resource_sync("test://static-binary") + assert len(binary_resource.contents) == 1 + binary_content = binary_resource.contents[0] + assert isinstance(binary_content, BlobResourceContents) + # Verify it's valid base64 encoded data + decoded_data = base64.b64decode(binary_content.blob) + assert len(decoded_data) > 0 + + +def test_mcp_resources_templates(): + """Test listing resource templates and reading from template resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resource_templates_sync + templates_result = mcp_client.list_resource_templates_sync() + assert len(templates_result.resourceTemplates) >= 1 + + # Verify template URIs exist + template_uris = [t.uriTemplate for t in templates_result.resourceTemplates] + assert "test://template/{id}/data" in template_uris + + # Test reading from template resource + template_resource = mcp_client.read_resource_sync("test://template/123/data") + assert len(template_resource.contents) == 1 + template_content = template_resource.contents[0] + assert isinstance(template_content, TextResourceContents) + + # Parse the JSON response + parsed_json = json.loads(template_content.text) + assert parsed_json["id"] == "123" + assert parsed_json["templateTest"] is True + assert "Data for ID: 123" in parsed_json["data"] + + +def test_mcp_resources_pagination(): + """Test pagination support for resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with pagination token (should work even if server doesn't implement pagination) + resources_result = mcp_client.list_resources_sync(pagination_token=None) + assert len(resources_result.resources) >= 0 + + # Test resource templates pagination + templates_result = mcp_client.list_resource_templates_sync(pagination_token=None) + assert len(templates_result.resourceTemplates) >= 0 + + +def test_mcp_resources_error_handling(): + """Test error handling for resource operations.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test reading non-existent resource + with pytest.raises(McpError, match="Unknown resource"): + mcp_client.read_resource_sync("test://nonexistent") + + +def test_mcp_resources_uri_types(): + """Test that both string and AnyUrl types work for read_resource_sync.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with string URI + text_resource_str = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource_str.contents) == 1 + + # Test with AnyUrl URI + text_resource_url = mcp_client.read_resource_sync(AnyUrl("test://static-text")) + assert len(text_resource_url.contents) == 1 + + # Both should return the same content + assert text_resource_str.contents[0].text == text_resource_url.contents[0].text From 514f40243b280f0645f31a0c48a8bac9e2f1354f Mon Sep 17 00:00:00 2001 From: mehtarac Date: Tue, 6 Jan 2026 13:26:31 -0800 Subject: [PATCH 04/47] fix: import errors for models with optional imports (#1384) * fix: import errors for models with optional imports * Addressed comments: added return type, changed error message * Addressed comments: updated imports --- src/strands/models/__init__.py | 57 +++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index ead290a35..d5f88d09a 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,8 +3,63 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ +from typing import Any + from . import bedrock, model from .bedrock import BedrockModel from .model import Model -__all__ = ["bedrock", "model", "BedrockModel", "Model"] +__all__ = [ + "bedrock", + "model", + "BedrockModel", + "Model", +] + + +def __getattr__(name: str) -> Any: + """Lazy load model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "AnthropicModel": + from .anthropic import AnthropicModel + + return AnthropicModel + if name == "GeminiModel": + from .gemini import GeminiModel + + return GeminiModel + if name == "LiteLLMModel": + from .litellm import LiteLLMModel + + return LiteLLMModel + if name == "LlamaAPIModel": + from .llamaapi import LlamaAPIModel + + return LlamaAPIModel + if name == "LlamaCppModel": + from .llamacpp import LlamaCppModel + + return LlamaCppModel + if name == "MistralModel": + from .mistral import MistralModel + + return MistralModel + if name == "OllamaModel": + from .ollama import OllamaModel + + return OllamaModel + if name == "OpenAIModel": + from .openai import OpenAIModel + + return OpenAIModel + if name == "SageMakerAIModel": + from .sagemaker import SageMakerAIModel + + return SageMakerAIModel + if name == "WriterModel": + from .writer import WriterModel + + return WriterModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") From 9fd22d18adbb8f55e0a3cf3450db8d39492d4ea2 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Tue, 6 Jan 2026 13:26:53 -0800 Subject: [PATCH 05/47] add BidiGeminiLiveModel and BidiOpenAIRealtimeModel to the init (#1383) * add BidiGeminiLiveModel and BidiOpenAIRealtimeModel to the init * Address comments - re-word error message, add return type * Addressed comments: updated imports --- .../experimental/bidi/models/__init__.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index cc62c9987..6e5817046 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,5 +1,7 @@ """Bidirectional model interfaces and implementations.""" +from typing import Any + from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel @@ -8,3 +10,22 @@ "BidiModelTimeoutError", "BidiNovaSonicModel", ] + + +def __getattr__(name: str) -> Any: + """ + Lazy load bidi model implementations only when accessed. + + This defers the import of optional dependencies until actually needed: + - BidiGeminiLiveModel requires google-generativeai (lazy loaded) + - BidiOpenAIRealtimeModel requires openai (lazy loaded) + """ + if name == "BidiGeminiLiveModel": + from .gemini_live import BidiGeminiLiveModel + + return BidiGeminiLiveModel + if name == "BidiOpenAIRealtimeModel": + from .openai_realtime import BidiOpenAIRealtimeModel + + return BidiOpenAIRealtimeModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") From 2b1cf6bde77f701f9c6185651ea932e564607e09 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 7 Jan 2026 10:35:47 -0500 Subject: [PATCH 06/47] bidi - async - remove cancelling call (#1357) --- .../experimental/bidi/_async/_task_group.py | 38 +++++++++++-------- .../bidi/_async/test_task_group.py | 16 +++++++- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py index 26c67326d..33cf63dca 100644 --- a/src/strands/experimental/bidi/_async/_task_group.py +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -6,17 +6,17 @@ """ import asyncio -from typing import Any, Coroutine +from typing import Any, Coroutine, cast class _TaskGroup: """Shim of asyncio.TaskGroup for use in Python 3.10. Attributes: - _tasks: List of tasks in group. + _tasks: Set of tasks in group. """ - _tasks: list[asyncio.Task] + _tasks: set[asyncio.Task] def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: """Create an async task and add to group. @@ -25,12 +25,12 @@ def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: The created task. """ task = asyncio.create_task(coro) - self._tasks.append(task) + self._tasks.add(task) return task async def __aenter__(self) -> "_TaskGroup": """Setup self managed task group context.""" - self._tasks = [] + self._tasks = set() return self async def __aexit__(self, *_: Any) -> None: @@ -42,20 +42,28 @@ async def __aexit__(self, *_: Any) -> None: - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. """ try: - await asyncio.gather(*self._tasks) + pending_tasks = self._tasks + while pending_tasks: + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_EXCEPTION) - except (Exception, asyncio.CancelledError) as error: + if any(exception := done_task.exception() for done_task in done_tasks if not done_task.cancelled()): + break + + else: # all tasks completed/cancelled successfully + return + + for pending_task in pending_tasks: + pending_task.cancel() + + await asyncio.gather(*pending_tasks, return_exceptions=True) + raise cast(BaseException, exception) + + except asyncio.CancelledError: # context itself was cancelled for task in self._tasks: task.cancel() await asyncio.gather(*self._tasks, return_exceptions=True) - - if not isinstance(error, asyncio.CancelledError): - raise - - context_task = asyncio.current_task() - if context_task and context_task.cancelling() > 0: # context itself was cancelled - raise + raise finally: - self._tasks = [] + self._tasks = set() diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py index 23ff821f9..b9a30ef5b 100644 --- a/tests/strands/experimental/bidi/_async/test_task_group.py +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -17,7 +17,7 @@ async def test_task_group__aexit__(): @pytest.mark.asyncio -async def test_task_group__aexit__exception(): +async def test_task_group__aexit__task_exception(): wait_event = asyncio.Event() async def wait(): await wait_event.wait() @@ -35,7 +35,19 @@ async def fail(): @pytest.mark.asyncio -async def test_task_group__aexit__cancelled(): +async def test_task_group__aexit__task_cancelled(): + async def wait(): + asyncio.current_task().cancel() + await asyncio.sleep(0) + + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + + assert wait_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__context_cancelled(): wait_event = asyncio.Event() async def wait(): await wait_event.wait() From 08bf5638ee9a5cdf78b0eeabcf40a6a33f6c7659 Mon Sep 17 00:00:00 2001 From: Aleksei Iancheruk <113924163+aiancheruk@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:54:00 +0100 Subject: [PATCH 07/47] feat(bedrock): add guardrail_latest_message option (#1224) * feat(bedrock): add guardrail_last_turn_only option * fix(bedrock): include assistant response in guardrail_last_turn_only context * fix: optimize code * feat: rewrtie the logic, include last user message in guardContent when feature flag is true * fix: remove uncessary integ tests and simplify guardrail logic * fix: rename feature flag, remove uncessary tests,add image to guardcontent block * fix: simplify logic and make tests more reliable --------- Co-authored-by: Aleksei Iancheruk Co-authored-by: Jack Yuan --- src/strands/models/bedrock.py | 22 +++++++++++-- tests/strands/models/test_bedrock.py | 44 +++++++++++++++++++++++++ tests_integ/test_bedrock_guardrails.py | 45 ++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 08d8f400c..8e1558ca7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -82,6 +82,8 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message. guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False. guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message. + guardrail_latest_message: Flag to send only the lastest user message to guardrails. + Defaults to False. max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") include_tool_result_status: Flag to include status field in tool results. @@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_input_message: Optional[str] guardrail_redact_output: Optional[bool] guardrail_redact_output_message: Optional[str] + guardrail_latest_message: Optional[bool] max_tokens: Optional[int] model_id: str include_tool_result_status: Optional[Literal["auto"] | bool] @@ -199,7 +202,6 @@ def _format_request( Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks to provide context to the model. @@ -302,6 +304,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Filtering out SDK_UNKNOWN_MEMBER content blocks - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API + - Optionally wrapping the last user message in guardrailConverseContent blocks Args: messages: List of messages to format @@ -321,7 +324,9 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: filtered_unknown_members = False dropped_deepseek_reasoning_content = False - for message in messages: + guardrail_latest_message = self.config.get("guardrail_latest_message", False) + + for idx, message in enumerate(messages): cleaned_content: list[dict[str, Any]] = [] for content_block in message["content"]: @@ -338,6 +343,19 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Format content blocks for Bedrock API compatibility formatted_content = self._format_request_message_content(content_block) + + # Wrap text or image content in guardrailContent if this is the last user message + if ( + guardrail_latest_message + and idx == len(messages) - 1 + and message["role"] == "user" + and ("text" in formatted_content or "image" in formatted_content) + ): + if "text" in formatted_content: + formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}} + elif "image" in formatted_content: + formatted_content = {"guardContent": {"image": formatted_content["image"]}} + cleaned_content.append(formatted_content) # Create new message with cleaned content (skip if empty) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 33be44b1b..7697c5e03 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2196,3 +2196,47 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " "with the location fields nested inside." ) + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message(model): + """Test that guardrail_latest_message wraps the latest user message with text and image.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + { + "role": "user", + "content": [ + {"text": "Look at this image"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # All messages should be in the request + assert len(formatted_messages) == 3 + + # First user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Assistant message should NOT be wrapped + assert "text" in formatted_messages[1]["content"][0] + assert formatted_messages[1]["content"][0]["text"] == "First response" + + # Latest user message text should be wrapped + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Look at this image" + + # Latest user message image should also be wrapped + assert "guardContent" in formatted_messages[2]["content"][1] + assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 37fa6028c..058597026 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -289,6 +289,51 @@ def list_users() -> str: assert tool_result["content"][0]["text"] == INPUT_REDACT_MESSAGE +def test_guardrail_latest_message(boto_session, bedrock_guardrail, yellow_img): + """Test that guardrail_latest_user_message wraps both text and image in the latest user message.""" + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + guardrail_latest_message=True, + boto_session=boto_session, + ) + + # Create agent with valid content + agent1 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ], + ) + + response = agent1("What do you see?") + assert response.stop_reason != "guardrail_intervened" + + # Create agent with multimodal content in latest user message + agent2 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + { + "role": "user", + "content": [ + {"text": "CACTUS"}, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + ], + }, + ], + ) + + response = agent2("What do you see?") + assert response.stop_reason == "guardrail_intervened" + + def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir): bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, From 1e27d79bd8e6e13f7a77ccd391e4a581415ab90d Mon Sep 17 00:00:00 2001 From: Evan Mattiza Date: Wed, 7 Jan 2026 11:44:28 -0600 Subject: [PATCH 08/47] fix(gemini): Gemini UnboundLocal Exception raised during stream (#1420) --- src/strands/models/gemini.py | 5 ++++- tests/strands/models/test_gemini.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index cf7cc604a..45f7f4e18 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -426,6 +426,8 @@ async def stream( yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_used = False + candidate = None + event = None async for event in response: candidates = event.candidates candidate = candidates[0] if candidates else None @@ -455,7 +457,8 @@ async def stream( "data": "TOOL_USE" if tool_used else (candidate.finish_reason if candidate else "STOP"), } ) - yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) + if event: + yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) except genai.errors.ClientError as error: if not error.message: diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index c552a892a..08be9188d 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -566,6 +566,25 @@ async def test_stream_response_none_candidates(gemini_client, model, messages, a assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_empty_stream(gemini_client, model, messages, agenerator, alist): + """Test that empty stream doesn't raise UnboundLocalError. + + When the stream yields no events, the candidate variable must be initialized + to None to avoid UnboundLocalError when referenced in message_stop chunk. + """ + gemini_client.aio.models.generate_content_stream.return_value = agenerator([]) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_throttled_exception(gemini_client, model, messages): gemini_client.aio.models.generate_content_stream.side_effect = genai.errors.ClientError( From 2f04bc0f9c786e6afa0837819d55accaa68b6896 Mon Sep 17 00:00:00 2001 From: schleidl Date: Thu, 8 Jan 2026 16:50:24 +0100 Subject: [PATCH 09/47] feat(litellm): handle litellm non streaming responses (#512) --------- Co-authored-by: Daniel Schleicher Co-authored-by: Dean Schmigelski --- src/strands/models/litellm.py | 253 ++++++++++++++++------ tests/strands/models/test_litellm.py | 255 ++++++++++++++++++++++- tests_integ/models/test_model_litellm.py | 30 ++- 3 files changed, 461 insertions(+), 77 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..c120b0eda 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -269,75 +269,29 @@ async def stream( ) logger.debug("request=<%s>", request) - logger.debug("invoking model") - try: - if kwargs.get("stream") is False: - raise ValueError("stream parameter cannot be explicitly set to False") - response = await litellm.acompletion(**self.client_args, **request) - except ContextWindowExceededError as e: - logger.warning("litellm client raised context window overflow") - raise ContextWindowOverflowException(e) from e + # Check if streaming is disabled in the params + config = self.get_config() + params = config.get("params") or {} + is_streaming = params.get("stream", True) - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) + litellm_request = {**request} - tool_calls: dict[int, list[Any]] = {} - data_type: str | None = None + litellm_request["stream"] = is_streaming - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] + logger.debug("invoking model with stream=%s", litellm_request.get("stream")) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - chunks, data_type = self._stream_switch_content("reasoning_content", data_type) - for chunk in chunks: + try: + if is_streaming: + async for chunk in self._handle_streaming_response(litellm_request): yield chunk - - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": data_type, - "data": choice.delta.reasoning_content, - } - ) - - if choice.delta.content: - chunks, data_type = self._stream_switch_content("text", data_type) - for chunk in chunks: + else: + async for chunk in self._handle_non_streaming_response(litellm_request): yield chunk + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} - ) - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - if data_type: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) - break - - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event - - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) - - logger.debug("finished streaming response from model") + logger.debug("finished processing response from model") @override async def structured_output( @@ -422,6 +376,181 @@ async def _structured_output_using_tool( except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e + async def _process_choice_content( + self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True + ) -> AsyncGenerator[tuple[str | None, StreamEvent], None]: + """Process content from a choice object (streaming or non-streaming). + + Args: + choice: The choice object from the response. + data_type: Current data type being processed. + tool_calls: Dictionary to collect tool calls. + is_streaming: Whether this is from a streaming response. + + Yields: + Tuples of (updated_data_type, stream_event). + """ + # Get the content source - this is the only difference between streaming/non-streaming + # We use duck typing here: both choice.delta and choice.message have the same interface + # (reasoning_content, content, tool_calls attributes) but different object structures + content_source = choice.delta if is_streaming else choice.message + + # Process reasoning content + if hasattr(content_source, "reasoning_content") and content_source.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": content_source.reasoning_content, + } + ) + yield data_type, chunk + + # Process text content + if hasattr(content_source, "content") and content_source.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": content_source.content, + } + ) + yield data_type, chunk + + # Process tool calls + if hasattr(content_source, "tool_calls") and content_source.tool_calls: + if is_streaming: + # Streaming: tool calls have index attribute for out-of-order delivery + for tool_call in content_source.tool_calls: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + else: + # Non-streaming: tool calls arrive in order, use enumerated index + for i, tool_call in enumerate(content_source.tool_calls): + tool_calls.setdefault(i, []).append(tool_call) + + async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]: + """Process and yield tool call events. + + Args: + tool_calls: Dictionary of tool calls indexed by their position. + + Yields: + Formatted tool call chunks. + """ + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + async def _handle_non_streaming_response( + self, litellm_request: dict[str, Any] + ) -> AsyncGenerator[StreamEvent, None]: + """Handle non-streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got non-streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] + + if hasattr(choice, "message") and choice.message: + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=False + ): + data_type = updated_data_type + yield chunk + + if hasattr(choice, "finish_reason"): + finish_reason = choice.finish_reason + + # Stop the current content block if we have one + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Add usage information if available + if hasattr(response, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) + + async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + """Handle streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + # For streaming, use the streaming API + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=True + ): + data_type = updated_data_type + yield chunk + + if choice.finish_reason: + finish_reason = choice.finish_reason + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + break + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + + logger.debug("finished streaming response from model") + def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 832b5c836..99df22a3f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -285,7 +285,7 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock() + mock_event_3 = unittest.mock.Mock(usage=None) mock_event_4 = unittest.mock.Mock(usage=None) litellm_acompletion.side_effect = unittest.mock.AsyncMock( @@ -408,16 +408,6 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model pass -@pytest.mark.asyncio -async def test_stream_raises_error_when_stream_is_false(model): - """Test that stream raises ValueError when stream parameter is explicitly False.""" - messages = [{"role": "user", "content": [{"text": "test"}]}] - - with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): - async for _ in model.stream(messages, stream=False): - pass - - def test_format_request_messages_with_system_prompt_content(): """Test format_request_messages with system_prompt_content parameter.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] @@ -478,3 +468,246 @@ def test_format_request_messages_cache_point_support(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alist): + """Test LiteLLM model with streaming disabled (stream=False). + + This test verifies that the LiteLLM model works correctly when streaming is disabled, + which was the issue reported in GitHub issue #477. + """ + + mock_function = unittest.mock.Mock() + mock_function.name = "calculator" + mock_function.arguments = '{"expression": "123981723 + 234982734"}' + + mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id="tool_call_id_123") + + mock_message = unittest.mock.Mock() + mock_message.content = "I'll calculate that for you" + mock_message.reasoning_content = "Let me think about this calculation" + mock_message.tool_calls = [mock_tool_call] + + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_choice.finish_reason = "tool_calls" + + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + # Create a more explicit usage mock that doesn't have cache-related attributes + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_usage.total_tokens = 30 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_response.usage = mock_usage + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + model = LiteLLMModel( + client_args={"api_key": api_key}, + model_id=model_id, + params={"stream": False}, # This is the key setting that was causing the #477 isuue + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "What is 123981723 + 234982734?"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Let me think about this calculation"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockStop": {}}, + { + "contentBlockStart": { + "start": {"toolUse": {"name": "calculator", "toolUseId": mock_message.tool_calls[0].id}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "123981723 + 234982734"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30, + }, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + + for i, (tru, exp) in enumerate(zip(tru_events, exp_events, strict=False)): + assert tru == exp, f"Event {i} mismatch: {tru} != {exp}" + + expected_request = { + "api_key": api_key, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "What is 123981723 + 234982734?", "type": "text"}]}], + "stream": False, # Verify that stream=False was passed to litellm + "stream_options": {"include_usage": True}, + "tools": [], + } + litellm_acompletion.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_path_validation(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test that we're taking the correct streaming path and validate stream parameter.""" + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=agenerator([mock_event_1, mock_event_2])) + + messages = [{"role": "user", "content": []}] + response = model.stream(messages) + + # Consume the response + await alist(response) + + # Validate that litellm.acompletion was called with the expected parameters + call_args = litellm_acompletion.call_args + assert call_args is not None, "litellm.acompletion should have been called" + + # Check if stream parameter is being set + called_kwargs = call_args.kwargs + + # Validate we're going down the streaming path (should have stream=True) + assert called_kwargs.get("stream") is True, f"Expected stream=True, got {called_kwargs.get('stream')}" + + +def test_format_request_message_content_reasoning(): + """Test formatting reasoning content.""" + content = {"reasoningContent": {"reasoningText": {"signature": "test_sig", "text": "test_thinking"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"signature": "test_sig", "thinking": "test_thinking", "type": "thinking"} + + assert result == expected + + +def test_format_request_message_content_video(): + """Test formatting video content.""" + content = {"video": {"source": {"bytes": "base64videodata"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"type": "video_url", "video_url": {"detail": "auto", "url": "base64videodata"}} + + assert result == expected + + +def test_apply_proxy_prefix_with_use_litellm_proxy(): + """Test _apply_proxy_prefix when use_litellm_proxy is True.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_already_has_prefix(): + """Test _apply_proxy_prefix when model_id already has prefix.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="litellm_proxy/openai/gpt-4") + + # Should not add another prefix + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_disabled(): + """Test _apply_proxy_prefix when use_litellm_proxy is False.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": False}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "openai/gpt-4" + + +def test_format_chunk_metadata_with_cache_tokens(): + """Test format_chunk for metadata with cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data with cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + # Mock cache-related attributes + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + mock_usage.cache_creation_input_tokens = 10 + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + assert result["metadata"]["usage"]["cacheWriteInputTokens"] == 10 + + +def test_format_chunk_metadata_without_cache_tokens(): + """Test format_chunk for metadata without cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data without cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + assert "cacheWriteInputTokens" not in result["metadata"]["usage"] + + +def test_stream_switch_content_same_type(): + """Test _stream_switch_content when data_type is the same as prev_data_type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "text") + + assert chunks == [] + assert data_type == "text" + + +def test_stream_switch_content_different_type_with_prev(): + """Test _stream_switch_content when switching from one type to another.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "reasoning_content") + + assert len(chunks) == 2 + assert chunks[0]["contentBlockStop"] == {} + assert chunks[1]["contentBlockStart"] == {"start": {}} + assert data_type == "text" + + +def test_stream_switch_content_different_type_no_prev(): + """Test _stream_switch_content when switching to a type with no previous type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", None) + + assert len(chunks) == 1 + assert chunks[0]["contentBlockStart"] == {"start": {}} + assert data_type == "text" diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index d72937641..80e21bdfd 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -14,6 +14,16 @@ def model(): return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") +@pytest.fixture +def streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": True}) + + +@pytest.fixture +def non_streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": False}) + + @pytest.fixture def tools(): @strands.tool @@ -95,15 +105,21 @@ def lower(_, value): return Color(simple_color_name="yellow") -def test_agent_invoke(agent): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_agent_invoke(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -138,14 +154,20 @@ def test_agent_invoke_reasoning(agent, model): assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] -def test_structured_output(agent, weather): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_structured_output(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_structured_output_async(agent, weather): +async def test_agent_structured_output_async(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather From 0ef228878bde27f361f4aa9ecb1d962986ee4cbf Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:19:06 +0100 Subject: [PATCH 10/47] feat(agent): introduce AgentBase Protocol as the interface for agent classes to implement (#1126) --- src/strands/__init__.py | 2 ++ src/strands/agent/__init__.py | 2 ++ src/strands/agent/agent.py | 2 +- src/strands/agent/base.py | 66 +++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 src/strands/agent/base.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..bc17497a0 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,11 +2,13 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.base import AgentBase from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", + "AgentBase", "agent", "models", "tool", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..c00623dc2 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,6 +8,7 @@ from .agent import Agent from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, NullConversationManager, @@ -17,6 +18,7 @@ __all__ = [ "Agent", + "AgentBase", "AgentResult", "ConversationManager", "NullConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9e726ca0b..c4ebc0b54 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -87,7 +87,7 @@ class _DefaultCallbackHandlerSentinel: class Agent: - """Core Agent interface. + """Core Agent implementation. An agent orchestrates the following workflow: diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py new file mode 100644 index 000000000..b35ade8c4 --- /dev/null +++ b/src/strands/agent/base.py @@ -0,0 +1,66 @@ +"""Agent Interface. + +Defines the minimal interface that all agent types must implement. +""" + +from typing import Any, AsyncIterator, Protocol, runtime_checkable + +from ..types.agent import AgentInput +from .agent_result import AgentResult + + +@runtime_checkable +class AgentBase(Protocol): + """Protocol defining the interface for all agent types in Strands. + + This protocol defines the minimal contract that all agent implementations + must satisfy. + """ + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Yields: + Events representing the streaming execution. + """ + ... From 10a8e4a1fdfabebe7341ab5d1584731657e15ddc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 12:04:39 -0500 Subject: [PATCH 11/47] ci: update pytest requirement from <9.0.0,>=8.0.0 to >=8.0.0,<10.0.0 in the dev-dependencies group (#1161) * ci: update pytest requirement in the dev-dependencies group Updates the requirements on [pytest](https://github.com/pytest-dev/pytest) to permit the latest version. Updates `pytest` to 9.0.0 - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/8.0.0...9.0.0) --- updated-dependencies: - dependency-name: pytest dependency-version: 9.0.0 dependency-type: direct:development dependency-group: dev-dependencies ... Signed-off-by: dependabot[bot] * bump pytest version floor --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Dean Schmigelski --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 040babe67..05a385ca9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dev = [ "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.6.0", - "pytest>=8.0.0,<9.0.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", @@ -142,7 +142,7 @@ installer = "uv" features = ["all"] extra-args = ["-n", "auto", "-vv"] dependencies = [ - "pytest>=8.0.0,<9.0.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", From cd6570b197a13957a5dc9d42ce367c20bbd6875d Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 12 Jan 2026 10:27:02 -0800 Subject: [PATCH 12/47] feat(models): pass invocation_state to model providers (#1414) --------- Co-authored-by: Tirth Patel Co-authored-by: Dean Schmigelski --- src/strands/event_loop/event_loop.py | 1 + src/strands/event_loop/streaming.py | 3 +++ src/strands/models/model.py | 2 ++ tests/strands/agent/test_agent.py | 10 +++++++++- tests/strands/event_loop/test_event_loop.py | 1 + tests/strands/event_loop/test_streaming.py | 5 +++++ .../event_loop/test_streaming_structured_output.py | 2 ++ 7 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..231cfa56a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,6 +345,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..7840bfcef 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,6 +425,7 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, + invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -437,6 +438,7 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -453,6 +455,7 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..6b7dd78d7 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -73,6 +73,7 @@ def stream( *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -89,6 +90,7 @@ def stream( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks for advanced features like caching. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..351eadc84 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,7 +36,11 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Skip deep copy of invocation_state which contains non-serializable objects (agent, spans, etc.) + copied_kwargs = { + key: value if key == "invocation_state" else copy.deepcopy(value) for key, value in kwargs.items() + } + result = mock.mock_stream(*copy.deepcopy(args), **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -325,6 +329,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -363,6 +368,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), ], ) @@ -484,6 +490,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -629,6 +636,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..639e60ea0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,6 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..b2cc152cb 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,6 +1117,7 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) @@ -1150,6 +1151,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1183,6 +1185,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1214,6 +1217,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, ) @@ -1245,6 +1249,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..4c4082c00 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) # Verify we get the expected events @@ -131,6 +132,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, ) assert len(tru_events) > 0 From 37d0e470e6dcc796cb1c40a388f3b2ba432cc5f2 Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Mon, 12 Jan 2026 13:30:48 -0500 Subject: [PATCH 13/47] Add Security.md file (#1454) * Create SECURITY.md --- SECURITY.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..b520ee1fb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,20 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 1.x.x | :white_check_mark: | +| < 1.0 | :x: | + +## Reporting Security Issues + +Amazon Web Services (AWS) is dedicated to the responsible disclosure of security vulnerabilities. + +We kindly ask that you **do not** open a public GitHub issue to report security concerns. + +Instead, please submit the issue to the AWS Vulnerability Disclosure Program via [HackerOne](https://hackerone.com/aws_vdp) or send your report via [email](mailto:aws-security@amazon.com). + +For more details, visit the [AWS Vulnerability Reporting Page](http://aws.amazon.com/security/vulnerability-reporting/). + +Thank you in advance for collaborating with us to help protect our customers. From 845c6f76f84975f20d17a32ad02fbb86c3d695d1 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:58:53 -0500 Subject: [PATCH 14/47] chore: Update release notes sop (#1456) Update the release notes SOP that was most effective after several iterations of having it generate the v1.21.0 release notes - final flow is at https://github.com/zastrowm/sdk-python/issues/19. The flow on this was tweaking the SOP and then running the agent, and attempting to have it update in response. The biggest problems to solve was that: - It blindly trusted examples in PRs, resulting in examples that were wrong because of stale PRs - It would not generate useful tests I do not have the step-by-steps results with the different prompt variants as I deleted the older issues to avoid too many references being added to old PRs, but viewing the [final issue with results](https://github.com/zastrowm/sdk-python/issues/19) shows the SOP results with one shot run after verifying which features should go in. --------- Co-authored-by: Mackenzie Zastrow --- .github/agent-sops/task-release-notes.sop.md | 462 +++++++++++++------ 1 file changed, 324 insertions(+), 138 deletions(-) diff --git a/.github/agent-sops/task-release-notes.sop.md b/.github/agent-sops/task-release-notes.sop.md index 5f024da82..e32a0f2eb 100644 --- a/.github/agent-sops/task-release-notes.sop.md +++ b/.github/agent-sops/task-release-notes.sop.md @@ -8,6 +8,22 @@ You analyze merged pull requests between two git references (tags or branches), **Important**: You are executing in an ephemeral environment. Any files you create (test files, notes, etc.) will be discarded after execution. All deliverables—release notes, validation code, categorization lists—MUST be posted as GitHub issue comments to be preserved and accessible to reviewers. +## Key Principles + +These principles apply throughout the entire workflow and are referenced by name in later sections. + +### Principle 1: Ephemeral Environment +You are executing in an ephemeral environment. All deliverables MUST be posted as GitHub issue comments to be preserved. + +### Principle 2: PR Descriptions May Be Stale +PR descriptions are written at PR creation and may become outdated after code review. Reviewers often request structural changes, API modifications, or feature adjustments that are implemented but NOT reflected in the original description. You MUST cross-reference descriptions with review comments and treat merged code as the source of truth. + +### Principle 3: Validation Is Mandatory +You MUST attempt to validate EVERY code example with behavioral tests. The engineer review fallback is only for cases where you have genuinely tried and failed with documented evidence. + +### Principle 4: Never Remove Features +You MUST NOT remove a feature from release notes because validation failed. Always include a code sample—either validated or marked for engineer review. + ## Steps ### 1. Setup and Input Processing @@ -62,10 +78,10 @@ For each PR identified (from release or API query), fetch additional metadata ne - You MUST retrieve additional metadata for PRs being considered for Major Features or Major Bug Fixes: - PR description/body (essential for understanding the change) - PR labels (if any) + - PR review comments and conversation threads (per **Principle 2**) - You SHOULD retrieve for Major Feature candidates: - Files changed in the PR (to find code examples) -- You MAY retrieve: - - PR review comments if helpful for understanding the change +- You MUST retrieve PR review comments for Major Feature and Major Bug Fix candidates to identify post-description changes - You SHOULD minimize API calls by only fetching detailed metadata for PRs that appear significant based on title/prefix - You MUST track this data for use in categorization and release notes generation @@ -89,18 +105,24 @@ Extract categorization signals from PR titles using conventional commit prefixes - You SHOULD record the prefix-based category for each PR - You MAY encounter PRs without conventional commit prefixes -#### 2.2 Analyze PR Descriptions +#### 2.2 Analyze PR Descriptions and Review Comments Use LLM analysis to understand the significance and user impact of each change. **Constraints:** - You MUST read and analyze the PR description for each PR +- Per **Principle 2**, you MUST also review PR comments and review threads to identify changes made after the initial description: + - Look for reviewer comments requesting changes to the implementation + - Look for author responses confirming changes were made + - Look for "LGTM" or approval comments that reference specific modifications + - Pay special attention to comments about API changes, renamed methods, or restructured code +- You MUST treat the actual merged code as the source of truth when descriptions conflict with review feedback - You MUST assess the user-facing impact of the change: - Does it introduce new functionality users will interact with? - Does it fix a bug that users experienced? - Is it purely internal with no user-visible changes? - You MUST identify if the change introduces breaking changes -- You SHOULD identify if the PR includes code examples in its description +- You SHOULD identify if the PR includes code examples in its description (but verify they match the final implementation) - You SHOULD note any links to documentation or related issues - You MAY consider the size and complexity of the change @@ -152,6 +174,10 @@ Present the categorized PRs to the user for review and confirmation. - You MUST wait for user confirmation or recategorization before proceeding - You SHOULD update your categorization based on user feedback - You MAY iterate on categorization if the user requests changes +- When the user promotes a PR to "Major Features" that was not previously in that category: + - You MUST perform Step 3 (Code Snippet Extraction) for the newly promoted PR + - You MUST perform Step 4 (Code Validation) for any code snippets extracted or generated + - You MUST include the validation code for newly promoted features in the Validation Comment (Step 6.1) ### 3. Code Snippet Extraction and Generation @@ -163,12 +189,16 @@ Search merged PRs for existing code that demonstrates the new feature. **Constraints:** - You MUST search each Major Feature PR for existing code examples in: - - Test files (especially integration tests or example tests) + - Test files (especially integration tests or example tests) - these are most reliable as they reflect the final implementation - Example applications or scripts in `examples/` directory - - Code snippets in the PR description + - Code snippets in the PR description (but verify per **Principle 2**) - Documentation updates that include code examples - README updates with usage examples -- You MUST prioritize test files that show real usage of the feature +- You MUST cross-reference any examples from PR descriptions with: + - Review comments that may have requested API changes + - The actual merged code to ensure the example is still accurate + - Test files which reflect the working implementation +- You MUST prioritize test files that show real usage of the feature (these are validated against the final code) - You SHOULD look for the simplest, most focused examples - You SHOULD prefer examples that are already validated (from test files) - You MAY examine multiple PRs if a feature spans several PRs @@ -208,60 +238,178 @@ When existing examples are insufficient, generate new code snippets. ### 4. Code Validation -**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Validation must occur AFTER snippets have been extracted or generated in Step 3. +**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Per **Principle 3**, you MUST attempt validation for every example. -#### 4.1 Create Temporary Test Files +#### 4.1 Validation Requirements -Create temporary test files to validate the code snippets. +Validation tests MUST verify the actual behavior of the feature, not just syntax correctness. A test that only checks whether code parses or imports succeed is NOT valid validation. + +**Available Testing Resources:** +- **Amazon Bedrock**: You have access to Bedrock models for testing. Use Bedrock when a feature requires a real model provider. +- **Project test fixtures**: The project includes mocked model providers and test utilities (commonly in `tests/fixtures/`, `__mocks__/`, or similar) +- **Integration test patterns**: Examine integration test directories (commonly in `tests_integ/` or `test/integ`) for patterns that test real model interactions + +**Features that genuinely cannot be validated (rare):** +- Features requiring paid third-party API credentials with no mock option AND no Bedrock alternative +- Features requiring specific hardware (GPU, TPU) +- Features requiring live network access to specific external services that cannot be mocked **Constraints:** - You MUST create a temporary test file for each code snippet - You MUST place test files in an appropriate test directory based on the project structure - You MUST include all necessary imports and setup code in the test file - You MUST wrap the snippet in a proper test case +- You MUST include assertions that verify the feature's actual behavior: + - Assert that outputs match expected values + - Assert that state changes occur as expected + - Assert that callbacks/hooks are invoked correctly + - Assert that return types and structures are correct +- You MUST NOT write tests that only verify: + - Code parses without syntax errors + - Imports succeed + - Objects can be instantiated without checking behavior + - Functions can be called without checking results - You SHOULD use the project's testing framework -- You MAY need to mock dependencies or setup test fixtures +- You SHOULD mock external dependencies (APIs, databases) but still verify behavior with mocks +- You MAY need to setup test fixtures that enable behavioral verification - You MAY include additional test code that doesn't appear in the release notes -**Example test file structure** (language-specific format will vary): +**Example of GOOD validation** (verifies behavior) - adapt syntax to project language: +```python +def test_structured_output_validation(): + """Verify that structured output actually validates against the schema.""" + from pydantic import BaseModel + + class UserResponse(BaseModel): + name: str + age: int + + agent = Agent(model=mock_model, output_schema=UserResponse) + result = agent("Get user info") + + # Behavioral assertions - verify the feature works + assert isinstance(result.output, UserResponse) + assert hasattr(result.output, 'name') + assert hasattr(result.output, 'age') + assert isinstance(result.output.age, int) ``` -# Test structure depends on the project's testing framework -# Include necessary imports, setup, and the snippet being validated -# Add assertions to verify the code works correctly + +**Example of BAD validation** (only verifies syntax) - adapt syntax to project language: +```python +def test_structured_output_syntax(): + """BAD: This only verifies the code runs without errors.""" + from pydantic import BaseModel + + class UserResponse(BaseModel): + name: str + age: int + + # BAD: No assertions about behavior + agent = Agent(model=mock_model, output_schema=UserResponse) + # BAD: Just calling without checking results proves nothing + agent("Get user info") ``` -#### 4.2 Run Validation Tests +#### 4.2 Validation Workflow -Execute tests to ensure code snippets are valid and functional. +For each Major Feature, follow this workflow in order: + +1. **Write a test file** with behavioral assertions +2. **Run the test** using the project's test framework +3. **If it fails**, try these approaches in order: + - Try using Bedrock instead of other model providers + - Try installing missing dependencies + - Try mocking external services + - Try using project test fixtures (e.g., mocked model providers) + - Try simplifying the example +4. **Document each attempt** and its result in the Validation Comment +5. **Only after documented failures** can you use the engineer review fallback **Constraints:** - You MUST run the appropriate test command for the project (e.g., `npm test`, `pytest`, `go test`) - You MUST verify that the test passes successfully +- You MUST verify that assertions actually executed (not skipped or short-circuited) - You MUST check that the code compiles without errors in compiled languages +- You MUST ensure tests include meaningful assertions about feature behavior - You SHOULD run type checking if applicable (e.g., `npm run type-check`, `mypy`) +- You SHOULD review test output to confirm behavioral assertions passed - You MAY need to adjust imports or setup code if tests fail -- You MAY need to install additional dependencies if required -**Fallback validation** (if test execution fails or is not possible): -- You MUST at minimum validate syntax using the appropriate language tools -- You MUST ensure the code is syntactically correct -- You MUST verify all referenced types and modules exist +**Installing Dependencies:** +- You MUST attempt to install missing dependencies when tests fail due to import errors +- You SHOULD check the project's dependency manifest (`pyproject.toml`, `package.json`, `Cargo.toml`, etc.) for optional dependency groups +- You SHOULD use the project's package manager to install dependencies (e.g., `pip install`, `npm install`, `cargo add`) +- For projects with optional extras, use the appropriate syntax (e.g., `pip install -e ".[extra]"` for Python, `npm install --save-dev` for Node.js) +- You SHOULD only fall back to mocking if the dependency cannot be installed (e.g., requires paid API keys, proprietary software) + +**Example of mocking external dependencies** - adapt syntax to project language: +```python +def test_custom_http_client(): + """Verify custom HTTP client is passed to the provider.""" + from unittest.mock import Mock, patch + + custom_client = Mock() + + with patch('strands.models.openai.OpenAI') as mock_openai: + from strands.models.openai import OpenAIModel + model = OpenAIModel(http_client=custom_client) + + # Verify the custom client was passed + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + assert call_kwargs.get('http_client') == custom_client +``` + +#### 4.3 Engineer Review Fallback -#### 4.3 Handle Validation Failures +When validation genuinely fails after documented attempts, use this fallback. Per **Principle 4**, you MUST still include the feature with a code sample. -Address any validation failures before including snippets in release notes. +**Required proof before using this fallback:** +1. Created an actual test file (show the code in the validation comment) +2. Ran the test and received an actual error (show the error message) +3. Tried at least ONE alternative approach (Bedrock, mocking, simplified example) +4. Documented each attempt and its failure reason **Constraints:** -- You MUST NOT include unvalidated code snippets in release notes -- You MUST revise the code snippet if validation fails -- You MUST re-run validation after making changes -- You SHOULD examine the actual implementation in the PR if generated code fails -- You SHOULD simplify the example if complexity is causing validation issues -- You MAY extract a different example from the PR if the current one cannot be validated -- You MAY seek clarification if you cannot create a valid example -- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.2) +- You MUST NOT mark examples as needing validation without actually attempting validation first +- You MUST NOT use vague reasons like "complex setup required" - be specific about what you tried and what error you got +- You MUST show your test code and error messages in the Validation Comment +- You MUST try Bedrock for any feature that works with multiple model providers before giving up +- You MUST try mocking for provider-specific features before giving up +- You MUST document all validation attempts (successful AND failed) in the Validation Comment +- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.1) +- You MUST note in the validation comment what specific behavior each test verifies - You MAY delete temporary test files after capturing their content, as the environment is ephemeral +**Process when validation genuinely fails:** +1. **Extract a code sample from the PR** - Use code from: + - The PR description's code examples + - Test files added in the PR + - The actual implementation (simplified for readability) + - Documentation updates in the PR +2. **Include the sample in the release notes** with a clear callout that it needs engineer validation +3. **Document the validation attempts and failures** in the Validation Comment (Step 6.1) + +**Format for unvalidated code examples:** +```markdown +### Feature Name - [PR#123](link) + +Description of the feature and its impact. + +\`\`\`python +# ⚠️ NEEDS ENGINEER VALIDATION +# Validation attempted: [describe test created and error received] +# Alternative attempts: [what else you tried and why it failed] + +# Code sample extracted from PR description/tests +from strands import Agent +from strands.models.openai import OpenAIModel + +model = OpenAIModel(http_client=custom_client) +agent = Agent(model=model) +\`\`\` +``` + ### 5. Release Notes Formatting #### 5.1 Format Major Features Section @@ -289,9 +437,16 @@ Create the Major Features section with concise descriptions and code examples. Agents can now validate responses against predefined schemas with configurable retry behavior for non-conforming outputs. -\`\`\`[language] -# Code example in the project's programming language -# Show the feature in action with clear, focused code +\`\`\`python +from strands import Agent +from pydantic import BaseModel + +class Response(BaseModel): + answer: str + +agent = Agent(output_schema=Response) +result = agent("What is 2+2?") +print(result.output.answer) \`\`\` See the [Structured Output docs](https://docs.example.com/structured-output) for configuration options. @@ -336,63 +491,82 @@ Add a horizontal rule to separate your content from GitHub's auto-generated sect - This visually separates your curated content from GitHub's auto-generated "What's Changed" and "New Contributors" sections - You MUST NOT include a "Full Changelog" link—GitHub adds this automatically -**Example format**: -```markdown -## Major Bug Fixes - -- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) - Description of what was fixed. - ---- -``` - ### 6. Output Delivery -**Critical**: You are running in an ephemeral environment. All files created during execution (test files, temporary notes, etc.) will be deleted when the workflow completes. You MUST post all deliverables as GitHub issue comments—this is the only way to preserve your work and make it accessible to reviewers. +Per **Principle 1**, all deliverables must be posted as GitHub issue comments. -**Comment Structure**: Post exactly two comments on the GitHub issue: +**Comment Structure**: Post exactly three comments on the GitHub issue: 1. **Validation Comment** (first): Contains all validation code for all features in one batched comment 2. **Release Notes Comment** (second): Contains the final formatted release notes +3. **Exclusions Comment** (third): Documents any features that were excluded and why + +This ordering allows reviewers to see the validation evidence, review the release notes, and understand any exclusion decisions. -This ordering allows reviewers to see the validation evidence before reviewing the release notes. +**Iteration Comments**: If the user requests changes after the initial comments are posted: +- Post additional validation comments for any re-validated code +- Post updated release notes as new comments (do not edit previous comments) +- This creates an audit trail of changes and validations #### 6.1 Post Validation Code Comment Batch all validation code into a single GitHub issue comment. **Constraints:** -- You MUST post ONE comment containing ALL validation code for ALL features +- You MUST post ONE comment containing validation attempts for ALL Major Features +- You MUST show test code for EVERY feature - both successful and failed attempts - You MUST NOT post separate comments for each feature's validation - You MUST post this comment BEFORE the release notes comment - You MUST include all test files created during validation (Step 4) in this single comment +- You MUST document what specific behavior each test verifies (not just "validates the code works") - You MUST NOT reference local file paths—the ephemeral environment will be destroyed - You MUST clearly label this comment as "Code Validation Tests" -- You MUST include a note explaining that this code was used to validate the snippets in the release notes -- You SHOULD use collapsible `
` sections to organize validation code by feature: - ```markdown - ## Code Validation Tests +- You SHOULD use collapsible `
` sections to organize validation code by feature +- You SHOULD include a brief description of what behavior is being verified for each test - The following test code was used to validate the code examples in the release notes. +**Format:** +```markdown +## Code Validation Tests + +The following test code was used to validate the code examples in the release notes. -
- Validation: Feature Name 1 +
+✅ Validated: Feature Name 1 - \`\`\`typescript - [Full test file for feature 1] - \`\`\` +**Behavior verified:** This test confirms that the new `output_schema` parameter causes the agent to return a validated Pydantic model instance with the correct field types. -
+\`\`\`python +[Full test file for feature 1 with behavioral assertions] +\`\`\` -
- Validation: Feature Name 2 +**Test output:** PASSED - \`\`\`typescript - [Full test file for feature 2] - \`\`\` +
-
- ``` -- This allows reviewers to copy and run the validation code themselves +
+⚠️ Could Not Validate: Feature Name 2 + +**Attempt 1: Direct test with mocked model** +\`\`\`python +[Test code that was attempted] +\`\`\` +**Error received:** +\`\`\` +[Actual error message from running the test] +\`\`\` + +**Attempt 2: Test with Bedrock** +\`\`\`python +[Alternative test code attempted] +\`\`\` +**Error received:** +\`\`\` +[Actual error message] +\`\`\` + +**Conclusion:** Could not validate because [specific reason based on actual errors]. Code sample in release notes extracted from PR description. + +
+``` #### 6.2 Post Release Notes Comment @@ -408,95 +582,117 @@ Post the formatted release notes as a single GitHub issue comment. - You MAY use markdown formatting in the comment - If comment posting is deferred, continue with the workflow and note the deferred status -## Examples +#### 6.3 Post Exclusions Comment -### Example 1: Major Features Section with Code +Document any features with unvalidated code samples and any other notable decisions. +**Constraints:** +- You MUST post this comment as the FINAL comment on the GitHub issue +- You MUST include this comment if ANY of the following occurred: + - A Major Feature has an unvalidated code sample (marked for engineer review) + - A feature's scope or description was significantly different from the PR description + - You relied on review comments rather than the PR description to understand a feature +- You MUST clearly explain the reasoning for each unvalidated sample +- You SHOULD include this comment even if all code samples were validated, with a simple note: "All code samples were successfully validated. No engineer review required." +- You MUST NOT skip this comment—it provides critical transparency for reviewers + +**Format:** ```markdown -## Major Features - -### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) - -MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. +## Release Notes Review Notes -\`\`\`[language] -# Code example in the project's programming language -# Demonstrate the key feature usage -# Keep it focused and concise -\`\`\` +The following items require attention during review: -See the [MCP docs](https://docs.example.com/mcp) for details. +### ⚠️ Features with Unvalidated Code Samples -### Async Streaming for Multi-Agent Systems - [PR#961](https://github.com/org/repo/pull/961) +These features have code samples extracted from PRs but could not be automatically validated. An engineer must verify these examples before publishing: -Multi-agent systems now support async streaming, enabling real-time event streaming from agent teams as they collaborate. +- **PR#123 - Feature Title**: + - Code source: PR description / test files / implementation + - Validation attempted: [what you tried] + - Failure reason: [why it failed, e.g., "requires OpenAI API credentials", "complex multi-service integration"] + - Action needed: Engineer should verify the code sample works as shown -\`\`\`[language] -# Another code example -# Show the feature in action -# Include only essential code -\`\`\` +### Description vs. Implementation Discrepancies +- **PR#101 - Feature Title**: PR description stated [X] but review comments and final implementation show [Y]. Release notes reflect the actual merged behavior. ``` -### Example 2: Major Bug Fixes Section +#### 6.4 Handle User Feedback on Release Notes -```markdown ---- +When the user requests changes to the release notes after they have been posted, re-validate as needed. -## Major Bug Fixes - -- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) - Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. - -- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) - Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. +**Constraints:** +- You MUST re-run validation (Step 4) when the user requests changes that affect code examples: + - Modified code snippets + - New code examples for features that previously had none + - Replacement examples for features +- You MUST perform full extraction (Step 3) and validation (Step 4) when the user requests: + - Adding a new feature to the release notes that wasn't previously included + - Promoting a bug fix to include a code example +- You MUST NOT make changes to code examples without re-validating them +- You MUST post updated validation code as a new comment when re-validation occurs +- You MUST post the revised release notes as a new comment (do not edit previous comments) +- You SHOULD note in the updated release notes comment what changed from the previous version +- You MAY skip re-validation only for changes that do not affect code: + - Wording changes to descriptions + - Fixing typos + - Reordering features + - Removing features (no validation needed for removal) -- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) - Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. -``` +## Examples -### Example 3: Complete Release Notes Structure +### Example 1: Complete Release Notes ```markdown ## Major Features -### Feature Name - [PR#123](https://github.com/owner/repo/pull/123) +### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) -Description of the feature and its impact. +MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. + +\`\`\`python +from strands import Agent +from strands.tools import MCPToolProvider -\`\`\`[language] -# Code example demonstrating the feature +provider = MCPToolProvider(server_config) +agent = Agent(tools=[provider]) +result = agent("Use the MCP tools") \`\`\` ---- +See the [MCP docs](https://docs.example.com/mcp) for details. -## Major Bug Fixes +### Custom HTTP Client Support - [PR#1366](https://github.com/org/repo/pull/1366) -- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) - Description of what was fixed and why it matters. +OpenAI model provider now accepts a custom HTTP client, enabling proxy configuration, custom timeouts, and request logging. ---- -``` +\`\`\`python +# ⚠️ NEEDS ENGINEER VALIDATION +# Validation attempted: mocked OpenAI client, received import error +# Alternative attempts: Bedrock (not applicable - OpenAI-specific) -Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. +from strands.models.openai import OpenAIModel +import httpx -### Example 4: Issue Comment with Release Notes +custom_client = httpx.Client(proxy="http://proxy.example.com:8080") +model = OpenAIModel(client_args={"http_client": custom_client}) +\`\`\` -```markdown -Release notes for v1.15.0: +--- -## Major Features +## Major Bug Fixes -### Managed MCP Connections - [PR#895](https://github.com/strands-agents/sdk-typescript/pull/895) +- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) + Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. -We've introduced MCP Connections via ToolProviders... +- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) + Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. -[... rest of release notes ...] +- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) + Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. --- ``` -When this content is added to the GitHub release, GitHub will automatically append the "What's Changed" and "New Contributors" sections below the separator. +Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. ## Troubleshooting @@ -519,14 +715,7 @@ If you encounter GitHub API rate limit errors: ### Code Validation Failures -If code validation fails for a snippet: -1. Review the test output to understand the failure reason -2. Check if the feature requires additional dependencies or setup -3. Examine the actual implementation in the PR to understand correct usage -4. Try simplifying the example to focus on core functionality -5. Consider using a different example from the PR -6. If unable to validate, note the issue in the release notes comment and skip the code example for that feature -7. Leave a comment on the issue noting which features couldn't include validated code examples +Follow the validation workflow in Section 4.2. If all attempts fail, use the engineer review fallback per Section 4.3. Per **Principle 4**, always include a code sample. ### Large PR Sets (>100 PRs) @@ -561,22 +750,19 @@ When GitHub tools or git operations are deferred (GITHUB_WRITE=false): - The operations will be executed after agent completion - Do not retry or attempt alternative approaches for deferred operations -### Unable to Extract Suitable Code Examples +### Stale PR Descriptions -If no suitable code examples can be found or generated for a feature: -1. Examine the PR description more carefully for usage information -2. Look at related documentation changes -3. Consider whether the feature actually needs a code example (some features are self-explanatory) -4. Generate a minimal example based on the API changes, even if you can't fully validate it -5. Mark the example as "conceptual" if validation isn't possible -6. Consider omitting the code example if it would be misleading +Per **Principle 2**: Review PR comments for context on what changed, examine merged code (especially test files), and use test files as the authoritative source for code examples. ## Desired Outcome * Focused release notes highlighting Major Features and Major Bug Fixes with concise descriptions (2-3 sentences, no bullet points) -* Working, validated code examples for all major features +* Code examples for ALL major features - either validated or marked for engineer review +* Validated code examples have passing behavioral tests +* Unvalidated code examples are clearly marked with the engineer validation warning and extracted from PR sources * Well-formatted markdown that renders properly on GitHub * Release notes posted as a comment on the GitHub issue for review +* Review notes comment documenting any features with unvalidated code samples that need engineer attention **Important**: Your generated release notes will be prepended to GitHub's auto-generated release notes. GitHub automatically generates: - "What's Changed" section listing all PRs with authors and links From 3ffc327071396fa24f805c03524da6b71e5f73cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Mon, 12 Jan 2026 14:08:49 -0500 Subject: [PATCH 15/47] fix(integ): make calculator tool more robust to LLM output variations (#1445) The test_tool_use_with_structured_output test was flaky because the LLM sometimes uses '+' instead of 'add' as the operation string. The calculator tool now accepts both formats for all operations. Changes: - Accept both word and symbol forms: add/+, subtract/-, multiply/*, divide//, power/** - Also accept common abbreviations: sub, mul, div, pow - Normalize input with lower() and strip() - Fix divide operation (was b/a, now a/b) - Improve docstring with Args section This makes the integ tests more resilient to LLM output variations. Co-authored-by: Strands Coder --- .../test_structured_output_agent_loop.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 188f57777..390bd3cff 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -132,16 +132,23 @@ def validate_first_name(cls, value: str) -> str: @tool def calculator(operation: str, a: float, b: float) -> float: - """Simple calculator tool for testing.""" - if operation == "add": + """Simple calculator tool for testing. + + Args: + operation: The operation to perform. One of: add, subtract, multiply, divide, power + a: The first number + b: The second number + """ + op = operation.lower().strip() + if op in ("add", "+"): return a + b - elif operation == "subtract": + elif op in ("subtract", "-", "sub"): return a - b - elif operation == "multiply": + elif op in ("multiply", "*", "mul"): return a * b - elif operation == "divide": - return b / a if a != 0 else 0 - elif operation == "power": + elif op in ("divide", "/", "div"): + return a / b if b != 0 else 0 + elif op in ("power", "**", "pow"): return a**b else: return 0 From 56676c19297b95e0396f74c0fb9c2afc0a96c25e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Mon, 12 Jan 2026 14:16:43 -0500 Subject: [PATCH 16/47] fix(mcp): resolve string formatting error in MCP client error handling (#1446) --- src/strands/tools/mcp/mcp_client.py | 2 +- tests/strands/tools/mcp/test_mcp_client.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 37b99d021..db21b9ef2 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -713,7 +713,7 @@ async def _handle_error_message(self, message: Exception | Any) -> None: if isinstance(message, Exception): error_msg = str(message).lower() if any(pattern in error_msg for pattern in _NON_FATAL_ERROR_PATTERNS): - self._log_debug_with_thread("ignoring non-fatal MCP session error", message) + self._log_debug_with_thread("ignoring non-fatal MCP session error: %s", message) else: raise message await anyio.lowlevel.checkpoint() diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 35f11f47f..f784da414 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -924,3 +924,21 @@ def test_list_resource_templates_sync_session_not_active(): with pytest.raises(MCPClientInitializationError, match="client session is not running"): client.list_resource_templates_sync() + + +@pytest.mark.asyncio +async def test_handle_error_message_with_percent_in_message(): + """Test that _handle_error_message handles messages containing % characters without string formatting errors. + + This is a regression test for issue #1244 where MCP error messages containing '%' characters + (e.g., from URLs like "https://example.com/path?param=value%20encoded") would cause a + TypeError: not all arguments converted during string formatting. + """ + client = MCPClient(MagicMock()) + + # Test with a message that contains % characters (like URL-encoded strings) + # This simulates the error that occurs when MCP servers return messages with % in them + error_with_percent = Exception("unknown request id: abc%20123%30def") + + # This should not raise TypeError and should not raise the exception (since it's non-fatal) + await client._handle_error_message(error_with_percent) From 318573d0618c283f6f16ace1117d5e3d76279568 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 12 Jan 2026 14:34:30 -0500 Subject: [PATCH 17/47] bidi - move 3.12 check to nova sonic module (#1439) --- src/strands/experimental/bidi/__init__.py | 8 --- src/strands/experimental/bidi/agent/agent.py | 15 +++--- .../experimental/bidi/models/__init__.py | 15 +++--- src/strands/experimental/bidi/models/model.py | 3 +- .../experimental/bidi/models/nova_sonic.py | 9 ++++ .../bidi/_async/test_task_group.py | 3 ++ .../experimental/bidi/agent/test_agent.py | 20 +++++--- .../experimental/bidi/agent/test_loop.py | 4 +- .../bidi/models/test_nova_sonic.py | 50 ++++++++++++------- .../bidi/models/test_openai_realtime.py | 3 +- 10 files changed, 75 insertions(+), 55 deletions(-) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 57986062e..1c0e74aae 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,10 +1,5 @@ """Bidirectional streaming package.""" -import sys - -if sys.version_info < (3, 12): - raise ImportError("bidi only supported for >= Python 3.12") - # Main components - Primary user interface # Re-export standard agent events for tool handling from ...types._events import ( @@ -19,7 +14,6 @@ # Model interface (for custom implementations) from .models.model import BidiModel -from .models.nova_sonic import BidiNovaSonicModel # Built-in tools from .tools import stop_conversation @@ -48,8 +42,6 @@ "BidiAgent", # IO channels "BidiAudioIO", - # Model providers - "BidiNovaSonicModel", # Built-in tools "stop_conversation", # Input Event types diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 5ddb181ea..11bea96e5 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -32,7 +32,6 @@ from ...tools import ToolProvider from .._async import _TaskGroup, stop_all from ..models.model import BidiModel -from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, @@ -100,13 +99,13 @@ def __init__( ValueError: If model configuration is invalid or state is invalid type. TypeError: If model type is unsupported. """ - self.model = ( - BidiNovaSonicModel() - if not model - else BidiNovaSonicModel(model_id=model) - if isinstance(model, str) - else model - ) + if isinstance(model, BidiModel): + self.model = model + else: + from ..models.nova_sonic import BidiNovaSonicModel + + self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel() + self.system_prompt = system_prompt self.messages = messages or [] diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 6e5817046..7b87e09fe 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -3,27 +3,26 @@ from typing import Any from .model import BidiModel, BidiModelTimeoutError -from .nova_sonic import BidiNovaSonicModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiNovaSonicModel", ] def __getattr__(name: str) -> Any: - """ - Lazy load bidi model implementations only when accessed. - - This defers the import of optional dependencies until actually needed: - - BidiGeminiLiveModel requires google-generativeai (lazy loaded) - - BidiOpenAIRealtimeModel requires openai (lazy loaded) + """Lazy load bidi model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. """ if name == "BidiGeminiLiveModel": from .gemini_live import BidiGeminiLiveModel return BidiGeminiLiveModel + if name == "BidiNovaSonicModel": + from .nova_sonic import BidiNovaSonicModel + + return BidiNovaSonicModel if name == "BidiOpenAIRealtimeModel": from .openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py index f5e34aa50..5941d7e41 100644 --- a/src/strands/experimental/bidi/models/model.py +++ b/src/strands/experimental/bidi/models/model.py @@ -14,7 +14,7 @@ """ import logging -from typing import Any, AsyncIterable, Protocol +from typing import Any, AsyncIterable, Protocol, runtime_checkable from ....types._events import ToolResultEvent from ....types.content import Messages @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) +@runtime_checkable class BidiModel(Protocol): """Protocol for bidirectional streaming models. diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 6a2477e22..1c946220d 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -11,8 +11,15 @@ - Tool execution with content containers and identifier tracking - 8-minute connection limits with proper cleanup sequences - Interruption detection through stopReason events + +Note, BidiNovaSonicModel is only supported for Python 3.12+ """ +import sys + +if sys.version_info < (3, 12): + raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+") + import asyncio import base64 import json @@ -93,6 +100,8 @@ class BidiNovaSonicModel(BidiModel): Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidiModel interface. + Note, BidiNovaSonicModel is only supported for Python 3.12+. + Attributes: _stream: open bedrock stream to nova sonic. """ diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py index b9a30ef5b..255ead15e 100644 --- a/tests/strands/experimental/bidi/_async/test_task_group.py +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -19,6 +19,7 @@ async def test_task_group__aexit__(): @pytest.mark.asyncio async def test_task_group__aexit__task_exception(): wait_event = asyncio.Event() + async def wait(): await wait_event.wait() @@ -49,12 +50,14 @@ async def wait(): @pytest.mark.asyncio async def test_task_group__aexit__context_cancelled(): wait_event = asyncio.Event() + async def wait(): await wait_event.wait() tasks = [] run_event = asyncio.Event() + async def run(): async with _TaskGroup() as task_group: tasks.append(task_group.create_task(wait())) diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 7b03ab717..50c9afef9 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,13 +1,13 @@ """Unit tests for BidiAgent.""" import asyncio +import sys import unittest.mock from uuid import uuid4 import pytest from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -125,13 +125,6 @@ def test_bidi_agent_init_with_various_configurations(): assert agent_with_config.system_prompt == system_prompt assert agent_with_config.agent_id == "test_agent" - # Test with string model ID - model_id = "amazon.nova-sonic-v1:0" - agent_with_string = BidiAgent(model=model_id) - - assert isinstance(agent_with_string.model, BidiNovaSonicModel) - assert agent_with_string.model.model_id == model_id - # Test model config access config = agent.model.config assert config["audio"]["input_rate"] == 16000 @@ -139,6 +132,17 @@ def test_bidi_agent_init_with_various_configurations(): assert config["audio"]["channels"] == 1 +@pytest.mark.skipif(sys.version_info < (3, 12), reason="BidiNovaSonicModel is only supported for Python 3.12+") +def test_bidi_agent_init_with_model_id(): + from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel + + model_id = "amazon.nova-sonic-v1:0" + agent = BidiAgent(model=model_id) + + assert isinstance(agent.model, BidiNovaSonicModel) + assert agent.model.model_id == model_id + + @pytest.mark.asyncio async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index da8578f55..fac52658e 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -5,7 +5,7 @@ from strands import tool from strands.experimental.bidi import BidiAgent -from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -21,7 +21,7 @@ async def func(): @pytest.fixture def agent(time_tool): - return BidiAgent(model=unittest.mock.AsyncMock(), tools=[time_tool]) + return BidiAgent(model=unittest.mock.AsyncMock(spec=BidiModel), tools=[time_tool]) @pytest_asyncio.fixture diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 933fd2088..7435d4ad2 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -4,10 +4,17 @@ covering connection lifecycle, event conversion, audio streaming, and tool execution. """ +import sys + +if sys.version_info < (3, 12): + import pytest + + pytest.skip(reason="BidiNovaSonicModel is only supported for Python 3.12+", allow_module_level=True) + import asyncio import base64 import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio @@ -39,9 +46,8 @@ def model_id(): @pytest.fixture -def region(): - """AWS region.""" - return "us-east-1" +def boto_session(): + return Mock(region_name="us-east-1") @pytest.fixture @@ -67,11 +73,11 @@ def mock_client(mock_stream): @pytest_asyncio.fixture -def nova_model(model_id, region, mock_client): +def nova_model(model_id, boto_session, mock_client): """Create Nova Sonic model instance.""" _ = mock_client - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) yield model @@ -79,12 +85,12 @@ def nova_model(model_id, region, mock_client): @pytest.mark.asyncio -async def test_model_initialization(model_id, region): +async def test_model_initialization(model_id, boto_session): """Test model initialization with configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.model_id == model_id - assert model.region == region + assert model.region == "us-east-1" assert model._connection_id is None @@ -92,9 +98,9 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio -async def test_audio_config_defaults(model_id, region): +async def test_audio_config_defaults(model_id, boto_session): """Test default audio configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.config["audio"]["input_rate"] == 16000 assert model.config["audio"]["output_rate"] == 16000 @@ -104,10 +110,12 @@ async def test_audio_config_defaults(model_id, region): @pytest.mark.asyncio -async def test_audio_config_partial_override(model_id, region): +async def test_audio_config_partial_override(model_id, boto_session): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 24000, "voice": "ruth"}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 24000 @@ -120,7 +128,7 @@ async def test_audio_config_partial_override(model_id, region): @pytest.mark.asyncio -async def test_audio_config_full_override(model_id, region): +async def test_audio_config_full_override(model_id, boto_session): """Test full audio configuration override.""" provider_config = { "audio": { @@ -131,7 +139,9 @@ async def test_audio_config_full_override(model_id, region): "voice": "stephen", } } - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -527,11 +537,13 @@ async def test_message_history_empty_and_edge_cases(nova_model): @pytest.mark.asyncio -async def test_custom_audio_rates_in_events(model_id, region): +async def test_custom_audio_rates_in_events(model_id, boto_session): """Test that audio events use configured sample rates.""" # Create model with custom audio configuration provider_config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Test audio output event uses custom configuration audio_bytes = b"test audio data" @@ -548,10 +560,10 @@ async def test_custom_audio_rates_in_events(model_id, region): @pytest.mark.asyncio -async def test_default_audio_rates_in_events(model_id, region): +async def test_default_audio_rates_in_events(model_id, boto_session): """Test that audio events use default sample rates when no custom config.""" # Create model without custom audio configuration - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) # Test audio output event uses defaults audio_bytes = b"test audio data" diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 1cabbc92b..09f4c8bc8 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -9,6 +9,7 @@ """ import base64 +import itertools import json import unittest.mock @@ -522,7 +523,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): @unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") @pytest.mark.asyncio async def test_receive_timeout(mock_time, model): - mock_time.side_effect = [1, 2] + mock_time.side_effect = itertools.count() model.timeout_s = 1 await model.start() From 68257a3e9f792b06d4be771677c6e38c10e99da7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:12:19 -0500 Subject: [PATCH 18/47] ci: update sphinx requirement from <9.0.0,>=5.0.0 to >=5.0.0,<10.0.0 (#1426) Updates the requirements on [sphinx](https://github.com/sphinx-doc/sphinx) to permit the latest version. - [Release notes](https://github.com/sphinx-doc/sphinx/releases) - [Changelog](https://github.com/sphinx-doc/sphinx/blob/master/CHANGES.rst) - [Commits](https://github.com/sphinx-doc/sphinx/compare/v5.0.0...v9.1.0) --- updated-dependencies: - dependency-name: sphinx dependency-version: 9.1.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 05a385ca9..62e0e04b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ sagemaker = [ ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ - "sphinx>=5.0.0,<9.0.0", + "sphinx>=5.0.0,<10.0.0", "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<4.0.0", ] From 02738013252a3c39c9dfc3f972ba88f3cfb02afe Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:30:27 -0500 Subject: [PATCH 19/47] fix: add concurrency protection to prevent parallel invocations from corrupting agent state (#1453) When multiple invocations occur concurrently on the same Agent instance the internal agent state can become corrupted, causing subsequent invocations to fail. The most common result is that the number of toolUse blocks end up out of sync with subsequent toolResult blocks, resulting in ValidationExceptions as reported in the bug report (#1176). To block multiple conccurrent agent invocations, we'll raise a new ConcurrencyException before any state modification occurs. --------- Co-authored-by: Strands Agent Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 88 ++++++---- src/strands/tools/_caller.py | 79 +++++---- src/strands/types/exceptions.py | 11 ++ tests/strands/agent/test_agent.py | 260 +++++++++++++++++++++++++++++- 4 files changed, 372 insertions(+), 66 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c4ebc0b54..7126644e6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -10,6 +10,7 @@ """ import logging +import threading import warnings from typing import ( TYPE_CHECKING, @@ -59,7 +60,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -245,6 +246,11 @@ def __init__( self._interrupt_state = _InterruptState() + # Initialize lock for guarding concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads, so asyncio.Lock wouldn't work + self._invocation_lock = threading.Lock() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -554,6 +560,7 @@ async def stream_async( - And other event data provided by the callback handler Raises: + ConcurrencyException: If another invocation is already in progress on this agent instance. Exception: Any exceptions from the agent invocation will be propagated to the caller. Example: @@ -563,50 +570,63 @@ async def stream_async( yield event["data"] ``` """ - self._interrupt_state.resume(prompt) + # Acquire lock to prevent concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads + acquired = self._invocation_lock.acquire(blocking=False) + if not acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) - self.event_loop_metrics.reset_usage_metrics() + try: + self._interrupt_state.resume(prompt) - merged_state = {} - if kwargs: - warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - merged_state.update(kwargs) - if invocation_state is not None: - merged_state["invocation_state"] = invocation_state - else: - if invocation_state is not None: - merged_state = invocation_state + self.event_loop_metrics.reset_usage_metrics() - callback_handler = self.callback_handler - if kwargs: - callback_handler = kwargs.get("callback_handler", self.callback_handler) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state - # Process input and get message to add (if any) - messages = await self._convert_prompt_to_messages(prompt) + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) - self.trace_span = self._start_agent_trace_span(messages) + # Process input and get message to add (if any) + messages = await self._convert_prompt_to_messages(prompt) - with trace_api.use_span(self.trace_span): - try: - events = self._run_loop(messages, merged_state, structured_output_model) + self.trace_span = self._start_agent_trace_span(messages) - async for event in events: - event.prepare(invocation_state=merged_state) + with trace_api.use_span(self.trace_span): + try: + events = self._run_loop(messages, merged_state, structured_output_model) + + async for event in events: + event.prepare(invocation_state=merged_state) - if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield AgentResultEvent(result=result).as_dict() + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield AgentResultEvent(result=result).as_dict() - self._end_agent_trace_span(response=result) + self._end_agent_trace_span(response=result) - except Exception as e: - self._end_agent_trace_span(error=e) - raise + except Exception as e: + self._end_agent_trace_span(error=e) + raise + + finally: + self._invocation_lock.release() async def _run_loop( self, diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 97485d068..bfec5886d 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -15,6 +15,7 @@ from ..tools.executors._executor import ToolExecutor from ..types._events import ToolInterruptEvent from ..types.content import ContentBlock, Message +from ..types.exceptions import ConcurrencyException from ..types.tools import ToolResult, ToolUse if TYPE_CHECKING: @@ -73,46 +74,64 @@ def caller( if self._agent._interrupt_state.activated: raise RuntimeError("cannot directly call tool during interrupt") - normalized_name = self._find_normalized_tool_name(name) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + should_lock = should_record_direct_tool_call - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + from ..agent import Agent # Locally imported to avoid circular reference - tool_result = tool_results[0] + acquired_lock = ( + should_lock + and isinstance(self._agent, Agent) + and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." + ) - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + try: + normalized_name = self._find_normalized_tool_name(name) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs - return tool_result + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + + tool_result = tool_results[0] + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) - tool_result = run_async(acall) + return tool_result - # TODO: https://github.com/strands-agents/sdk-python/issues/1311 - from ..agent import Agent + tool_result = run_async(acall) - if isinstance(self._agent, Agent): - self._agent.conversation_manager.apply_management(self._agent) + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) + + return tool_result - return tool_result + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() return caller diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index b9c5bc769..1d1983abd 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -94,3 +94,14 @@ def __init__(self, message: str): """ self.message = message super().__init__(message) + + +class ConcurrencyException(Exception): + """Exception raised when concurrent invocations are attempted on an agent instance. + + Agent instances maintain internal state that cannot be safely accessed concurrently. + This exception is raised when an invocation is attempted while another invocation + is already in progress on the same agent instance. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 351eadc84..81ce65989 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,17 +1,21 @@ +import asyncio import copy import importlib import json import os import textwrap +import threading +import time import unittest.mock import warnings +from typing import Any, AsyncGenerator from uuid import uuid4 import pytest from pydantic import BaseModel import strands -from strands import Agent +from strands import Agent, ToolContext from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager @@ -24,7 +28,7 @@ from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -189,6 +193,15 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") +class SlowMockedModel(MockedModelProvider): + async def stream( + self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs + ) -> AsyncGenerator[Any, None]: + await asyncio.sleep(0.15) # Add async delay to ensure concurrency + async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): + yield event + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -2190,3 +2203,246 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +# ============================================================================ +# Concurrency Exception Tests +# ============================================================================ + + +def test_agent_concurrent_call_raises_exception(): + """Test that concurrent __call__() calls raise ConcurrencyException.""" + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +def test_agent_concurrent_structured_output_raises_exception(): + """Test that concurrent structured_output() calls raise ConcurrencyException. + + Note: This test validates that the sync invocation path is protected. + The concurrent __call__() test already validates the core functionality. + """ + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + time.sleep(0.05) # Small delay to ensure first thread acquires lock + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +@pytest.mark.asyncio +async def test_agent_sequential_invocations_work(): + """Test that sequential invocations work correctly after lock is released.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + {"role": "assistant", "content": [{"text": "response3"}]}, + ] + ) + agent = Agent(model=model) + + # All sequential calls should succeed + result1 = await agent.invoke_async("test1") + assert result1.message["content"][0]["text"] == "response1" + + result2 = await agent.invoke_async("test2") + assert result2.message["content"][0]["text"] == "response2" + + result3 = await agent.invoke_async("test3") + assert result3.message["content"][0]["text"] == "response3" + + +@pytest.mark.asyncio +async def test_agent_lock_released_on_exception(): + """Test that lock is released when an exception occurs during invocation.""" + + # Create a mock model that raises an explicit error + mock_model = unittest.mock.Mock() + + async def failing_stream(*args, **kwargs): + raise RuntimeError("Simulated model failure") + yield # Make this an async generator + + mock_model.stream = failing_stream + + agent = Agent(model=mock_model) + + # First call will fail due to the simulated error + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + # Lock should be released, so this should not raise ConcurrencyException + # It will still raise RuntimeError, but that's expected + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + +def test_agent_direct_tool_call_during_invocation_raises_exception(tool_decorated): + """Test that direct tool call during agent invocation raises ConcurrencyException.""" + + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=True) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have not succeeded + assert len(tool_calls) == 0 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [ + { + "text": "Error: ConcurrencyException - Direct tool call cannot be made while the agent is " + "in the middle of an invocation. Set record_direct_tool_call=False to allow direct tool " + "calls during agent invocation." + } + ], + "status": "error", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } + + +def test_agent_direct_tool_call_during_invocation_succeeds_with_record_false(tool_decorated): + """Test that direct tool call during agent invocation succeeds when record_direct_tool_call=False.""" + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=False) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have succeeded + assert len(tool_calls) == 1 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "tool result"}], + "status": "success", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } From c098b3df9da3ef848eeb5fee066912e30c3b797d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Tue, 13 Jan 2026 11:52:27 -0500 Subject: [PATCH 20/47] fix(mcp): propagate contextvars to background thread (#1444) Fixes #1440 The MCP client creates a background thread for connection management. Previously, context variables set in the main thread were not accessible in this background thread. This change copies the context from the main thread when starting the background thread, ensuring that contextvars are properly propagated. This is consistent with the fix in PR #1146 which addressed the same issue for tool invocations. Changes: - Add contextvars import - Use contextvars.copy_context() and ctx.run() when creating background thread - Add test to verify context propagation Co-authored-by: Strands Coder --- src/strands/tools/mcp/mcp_client.py | 7 +- .../tools/mcp/test_mcp_client_contextvar.py | 88 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/strands/tools/mcp/test_mcp_client_contextvar.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index db21b9ef2..ea11627b9 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -9,6 +9,7 @@ import asyncio import base64 +import contextvars import logging import threading import uuid @@ -179,7 +180,11 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client session is currently running") self._log_debug_with_thread("entering MCPClient context") - self._background_thread = threading.Thread(target=self._background_task, args=[], daemon=True) + # Copy context vars to propagate to the background thread + # This ensures that context set in the main thread is accessible in the background thread + # See: https://github.com/strands-agents/sdk-python/issues/1440 + ctx = contextvars.copy_context() + self._background_thread = threading.Thread(target=ctx.run, args=(self._background_task,), daemon=True) self._background_thread.start() self._log_debug_with_thread("background thread started, waiting for ready event") try: diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py new file mode 100644 index 000000000..d95929b02 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -0,0 +1,88 @@ +"""Test for MCP client context variable propagation. + +This test verifies that context variables set in the main thread are +properly propagated to the MCP client's background thread. + +Related: https://github.com/strands-agents/sdk-python/issues/1440 +""" + +import contextvars +import threading +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient + + +@pytest.fixture +def mock_transport(): + """Create mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +# Context variable for testing +test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar("test_contextvar", default="default_value") + + +def test_mcp_client_propagates_contextvars_to_background_thread(mock_transport, mock_session): + """Test that context variables are propagated to the MCP client background thread. + + This verifies the fix for https://github.com/strands-agents/sdk-python/issues/1440 + where context variables set in the main thread were not accessible in the + MCP client's background thread. + """ + # Store the value seen in the background thread + background_thread_value = {} + + # Patch _background_task to capture the contextvar value + original_background_task = MCPClient._background_task + + def capturing_background_task(self): + # Capture the contextvar value in the background thread + background_thread_value["contextvar"] = test_contextvar.get() + background_thread_value["thread_id"] = threading.current_thread().ident + # Call the original background task + return original_background_task(self) + + # Set a specific value in the main thread + test_contextvar.set("main_thread_value") + main_thread_id = threading.current_thread().ident + + with patch.object(MCPClient, "_background_task", capturing_background_task): + with MCPClient(mock_transport["transport_callable"]) as client: + # Verify the client started successfully + assert client._background_thread is not None + + # Verify context was propagated to background thread + assert "contextvar" in background_thread_value, "Background task should have run and captured contextvar" + assert background_thread_value["contextvar"] == "main_thread_value", ( + f"Context variable should be propagated to background thread. " + f"Expected 'main_thread_value', got '{background_thread_value['contextvar']}'" + ) + # Verify it was indeed a different thread + assert background_thread_value["thread_id"] != main_thread_id, "Background task should run in a different thread" From 06c32974f914a0f181dcd54c33b8be690a526e9b Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 13 Jan 2026 12:28:21 -0500 Subject: [PATCH 21/47] Update to opus 4.5 (#1471) --- .github/scripts/python/agent_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py index 9d92c2ac4..1f772241c 100644 --- a/.github/scripts/python/agent_runner.py +++ b/.github/scripts/python/agent_runner.py @@ -39,7 +39,7 @@ from str_replace_based_edit_tool import str_replace_based_edit_tool # Strands configuration constants -STRANDS_MODEL_ID = "global.anthropic.claude-sonnet-4-5-20250929-v1:0" +STRANDS_MODEL_ID = "global.anthropic.claude-opus-4-5-20251101-v1:0" STRANDS_MAX_TOKENS = 64000 STRANDS_BUDGET_TOKENS = 8000 STRANDS_REGION = "us-west-2" From c43dfa930e1f87b3a42dc3ba1dc09046321c865a Mon Sep 17 00:00:00 2001 From: Ratish P <114130421+Ratish1@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:02:58 +0530 Subject: [PATCH 22/47] fix(mcp): prevent agent hang by checking session closure state (#1396) --- src/strands/tools/mcp/mcp_client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ea11627b9..c36811c17 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -891,4 +891,10 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> return False def _is_session_active(self) -> bool: - return self._background_thread is not None and self._background_thread.is_alive() + if self._background_thread is None or not self._background_thread.is_alive(): + return False + + if self._close_future is not None and self._close_future.done(): + return False + + return True From 368bb0f719bed57741e77f8e2dc4aeab62bdca5f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:37:11 -0500 Subject: [PATCH 23/47] ci: update sphinx-rtd-theme requirement (#1466) Updates the requirements on [sphinx-rtd-theme](https://github.com/readthedocs/sphinx_rtd_theme) to permit the latest version. - [Changelog](https://github.com/readthedocs/sphinx_rtd_theme/blob/master/docs/changelog.rst) - [Commits](https://github.com/readthedocs/sphinx_rtd_theme/compare/1.0.0...3.1.0) --- updated-dependencies: - dependency-name: sphinx-rtd-theme dependency-version: 3.1.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 62e0e04b3..7a6b02d53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ sagemaker = [ otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<10.0.0", - "sphinx-rtd-theme>=1.0.0,<2.0.0", + "sphinx-rtd-theme>=1.0.0,<4.0.0", "sphinx-autodoc-typehints>=1.12.0,<4.0.0", ] From c0298319ee00ab7c88ce7087b702a544395e1e3a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:39:04 -0500 Subject: [PATCH 24/47] ci: update websockets requirement (#1451) Updates the requirements on [websockets](https://github.com/python-websockets/websockets) to permit the latest version. - [Release notes](https://github.com/python-websockets/websockets/releases) - [Commits](https://github.com/python-websockets/websockets/compare/15.0...16.0) --- updated-dependencies: - dependency-name: websockets dependency-version: '16.0' dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7a6b02d53..aa5f773c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ bidi = [ "smithy-aws-core>=0.0.1; python_version>='3.12'", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] -bidi-openai = ["websockets>=15.0.0,<16.0.0"] +bidi-openai = ["websockets>=15.0.0,<17.0.0"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] From 2546aa0d3e7081b4ca6e45340789343539f81d27 Mon Sep 17 00:00:00 2001 From: Max Rabin <927792+maxrabin@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:51:09 +0200 Subject: [PATCH 25/47] style: update ruff configuration to apply pyupgrade to modernize syntax (#1336) --------- Co-authored-by: Dean Schmigelski --- pyproject.toml | 1 + src/strands/_async.py | 3 +- src/strands/agent/agent.py | 59 ++++++------- src/strands/agent/agent_result.py | 3 +- src/strands/agent/base.py | 3 +- .../conversation_manager.py | 6 +- .../null_conversation_manager.py | 4 +- .../sliding_window_conversation_manager.py | 8 +- .../summarizing_conversation_manager.py | 14 +-- src/strands/event_loop/event_loop.py | 3 +- src/strands/event_loop/streaming.py | 11 +-- src/strands/experimental/agent_config.py | 2 +- .../steering/handlers/__init__.py | 2 +- .../steering/handlers/llm/llm_handler.py | 2 +- .../experimental/tools/tool_provider.py | 3 +- src/strands/hooks/events.py | 12 +-- src/strands/hooks/registry.py | 7 +- src/strands/models/_validation.py | 5 +- src/strands/models/anthropic.py | 19 +++-- src/strands/models/bedrock.py | 77 ++++++++--------- src/strands/models/gemini.py | 29 ++++--- src/strands/models/litellm.py | 29 ++++--- src/strands/models/llamaapi.py | 27 +++--- src/strands/models/llamacpp.py | 36 ++++---- src/strands/models/mistral.py | 29 ++++--- src/strands/models/model.py | 11 +-- src/strands/models/ollama.py | 33 +++---- src/strands/models/openai.py | 29 +++---- src/strands/models/sagemaker.py | 45 +++++----- src/strands/models/writer.py | 27 +++--- src/strands/multiagent/a2a/executor.py | 6 +- src/strands/multiagent/base.py | 5 +- src/strands/multiagent/graph.py | 35 ++++---- src/strands/multiagent/swarm.py | 11 +-- src/strands/session/file_session_manager.py | 16 ++-- .../session/repository_session_manager.py | 4 +- src/strands/session/s3_session_manager.py | 26 +++--- src/strands/session/session_repository.py | 12 +-- src/strands/telemetry/metrics.py | 49 +++++------ src/strands/telemetry/tracer.py | 85 +++++++++---------- src/strands/tools/_caller.py | 3 +- src/strands/tools/decorator.py | 25 +++--- src/strands/tools/executors/_executor.py | 3 +- src/strands/tools/executors/concurrent.py | 3 +- src/strands/tools/executors/sequential.py | 3 +- src/strands/tools/loader.py | 12 +-- src/strands/tools/mcp/mcp_client.py | 22 ++--- src/strands/tools/mcp/mcp_instrumentation.py | 9 +- src/strands/tools/registry.py | 29 ++++--- .../_structured_output_context.py | 10 +-- .../structured_output_tool.py | 10 +-- .../structured_output_utils.py | 28 +++--- src/strands/tools/watcher.py | 6 +- src/strands/types/_events.py | 3 +- src/strands/types/citations.py | 22 ++--- src/strands/types/collections.py | 4 +- src/strands/types/content.py | 14 +-- src/strands/types/guardrails.py | 24 +++--- src/strands/types/media.py | 6 +- src/strands/types/session.py | 4 +- src/strands/types/streaming.py | 22 +++-- src/strands/types/tools.py | 16 +--- src/strands/types/traces.py | 32 +++---- tests/fixtures/mock_hook_provider.py | 7 +- .../fixtures/mock_multiagent_hook_provider.py | 7 +- tests/fixtures/mocked_model_provider.py | 19 +++-- .../strands/agent/hooks/test_hook_registry.py | 3 +- tests/strands/agent/test_agent_result.py | 4 +- .../agent/test_agent_structured_output.py | 3 +- tests/strands/models/test_sagemaker.py | 16 ++-- tests/strands/models/test_writer.py | 6 +- .../session/test_file_session_manager.py | 8 +- .../test_structured_output_context.py | 4 +- .../test_structured_output_tool.py | 5 +- tests/strands/tools/test_decorator.py | 23 ++--- tests/strands/tools/test_structured_output.py | 20 ++--- tests_integ/mcp/echo_server.py | 4 +- tests_integ/mcp/test_mcp_client.py | 6 +- tests_integ/models/providers.py | 4 +- tests_integ/test_function_tools.py | 3 +- tests_integ/test_multiagent_graph.py | 3 +- .../test_structured_output_agent_loop.py | 12 ++- 82 files changed, 626 insertions(+), 629 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa5f773c4..b49c74d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ select = [ "G", # logging format "I", # isort "LOG", # logging + "UP" # pyupgrade ] [tool.ruff.lint.per-file-ignores] diff --git a/src/strands/_async.py b/src/strands/_async.py index 141ca71b7..0ceb038f3 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -2,8 +2,9 @@ import asyncio import contextvars +from collections.abc import Awaitable, Callable from concurrent.futures import ThreadPoolExecutor -from typing import Awaitable, Callable, TypeVar +from typing import TypeVar T = TypeVar("T") diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7126644e6..b58b55f24 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,15 +12,10 @@ import logging import threading import warnings +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, - Callable, - Mapping, - Optional, - Type, TypeVar, Union, cast, @@ -105,26 +100,24 @@ class Agent: def __init__( self, - model: Union[Model, str, None] = None, - messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, - system_prompt: Optional[str | list[SystemContentBlock]] = None, - structured_output_model: Optional[Type[BaseModel]] = None, - callback_handler: Optional[ - Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] - ] = _DEFAULT_CALLBACK_HANDLER, - conversation_manager: Optional[ConversationManager] = None, + model: Model | str | None = None, + messages: Messages | None = None, + tools: list[Union[str, dict[str, str], "ToolProvider", Any]] | None = None, + system_prompt: str | list[SystemContentBlock] | None = None, + structured_output_model: type[BaseModel] | None = None, + callback_handler: Callable[..., Any] | _DefaultCallbackHandlerSentinel | None = _DEFAULT_CALLBACK_HANDLER, + conversation_manager: ConversationManager | None = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, *, - agent_id: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - state: Optional[Union[AgentState, dict]] = None, - hooks: Optional[list[HookProvider]] = None, - session_manager: Optional[SessionManager] = None, - tool_executor: Optional[ToolExecutor] = None, + agent_id: str | None = None, + name: str | None = None, + description: str | None = None, + state: AgentState | dict | None = None, + hooks: list[HookProvider] | None = None, + session_manager: SessionManager | None = None, + tool_executor: ToolExecutor | None = None, ): """Initialize the Agent with the specified configuration. @@ -190,7 +183,7 @@ def __init__( # If not provided, create a new PrintingCallbackHandler instance # If explicitly set to None, use null_callback_handler # Otherwise use the passed callback_handler - self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + self.callback_handler: Callable[..., Any] | PrintingCallbackHandler if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): self.callback_handler = PrintingCallbackHandler() elif callback_handler is None: @@ -227,7 +220,7 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() - self.trace_span: Optional[trace_api.Span] = None + self.trace_span: trace_api.Span | None = None # Initialize agent state management if state is not None: @@ -325,7 +318,7 @@ def __call__( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -366,7 +359,7 @@ async def invoke_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -403,7 +396,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: + def structured_output(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -434,7 +427,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> return run_async(lambda: self.structured_output_async(output_model, prompt)) - async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: + async def structured_output_async(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -529,7 +522,7 @@ async def stream_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -632,7 +625,7 @@ async def _run_loop( self, messages: Messages, invocation_state: dict[str, Any], - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -794,8 +787,8 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index ef8a11029..2ab95e5b5 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -3,8 +3,9 @@ This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. """ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence, cast +from typing import Any, cast from pydantic import BaseModel diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py index b35ade8c4..ae8a14e75 100644 --- a/src/strands/agent/base.py +++ b/src/strands/agent/base.py @@ -3,7 +3,8 @@ Defines the minimal interface that all agent types must implement. """ -from typing import Any, AsyncIterator, Protocol, runtime_checkable +from collections.abc import AsyncIterator +from typing import Any, Protocol, runtime_checkable from ..types.agent import AgentInput from .agent_result import AgentResult diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 47b761abc..690ecbde5 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,7 +1,7 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message @@ -62,7 +62,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """ pass - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. Args: @@ -98,7 +98,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass @abstractmethod - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..11632525d 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,6 +1,6 @@ """Null implementation of conversation management.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -28,7 +28,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Does not reduce context and raises an exception. Args: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index a063e55eb..709c876e7 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,7 @@ """Sliding window conversation history management.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -103,7 +103,7 @@ def get_state(self) -> dict[str, Any]: state["model_call_count"] = self._model_call_count return state - def restore_from_session(self, state: dict[str, Any]) -> Optional[list]: + def restore_from_session(self, state: dict[str, Any]) -> list | None: """Restore the conversation manager's state from a session. Args: @@ -136,7 +136,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: return self.reduce_context(agent) - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -235,7 +235,7 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: return changes_made - def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + def _find_last_message_with_tool_results(self, messages: Messages) -> int | None: """Find the index of the last message containing tool results. This is useful for identifying messages that might need to be truncated to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 12185c286..cc71e4d88 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -1,7 +1,7 @@ """Summarizing conversation history management with configurable options.""" import logging -from typing import TYPE_CHECKING, Any, List, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override @@ -62,7 +62,7 @@ def __init__( summary_ratio: float = 0.3, preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, - summarization_system_prompt: Optional[str] = None, + summarization_system_prompt: str | None = None, ): """Initialize the summarizing conversation manager. @@ -87,10 +87,10 @@ def __init__( self.preserve_recent_messages = preserve_recent_messages self.summarization_agent = summarization_agent self.summarization_system_prompt = summarization_system_prompt - self._summary_message: Optional[Message] = None + self._summary_message: Message | None = None @override - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restores the Summarizing Conversation manager from its previous state in a session. Args: @@ -121,7 +121,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: # No proactive management - summarization only happens on context overflow pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. Args: @@ -173,7 +173,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. Args: @@ -224,7 +224,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry - def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. Uses the same logic as SlidingWindowConversationManager for consistency. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 231cfa56a..99c8f5179 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,7 +11,8 @@ import asyncio import logging import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from opentelemetry import trace as trace_api diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 7840bfcef..954633807 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -4,7 +4,8 @@ import logging import time import warnings -from typing import Any, AsyncGenerator, AsyncIterable, Optional +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from ..models.model import Model from ..tools import InvalidToolUseNameException @@ -419,13 +420,13 @@ async def process_stream( async def stream_messages( model: Model, - system_prompt: Optional[str], + system_prompt: str | None, messages: Messages, tool_specs: list[ToolSpec], *, - tool_choice: Optional[Any] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, - invocation_state: Optional[dict[str, Any]] = None, + tool_choice: Any | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index f65afb57d..e6fb94118 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -98,7 +98,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {file_path}") - with open(config_path, "r") as f: + with open(config_path) as f: config_dict = json.load(f) elif isinstance(config, dict): config_dict = config.copy() diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index 542126ab5..fe364a5a2 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,5 +1,5 @@ """Steering handler implementations.""" -from typing import Sequence +from collections.abc import Sequence __all__: Sequence[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 9d9b34911..4d90f46c9 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction: """Provide contextual guidance for tool usage. Args: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 2c79ceafc..c40d1b572 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -1,7 +1,8 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...types.tools import AgentTool diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 5e11524d1..340b6d3d2 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -116,7 +116,7 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): the tool call and use a default cancel message. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] cancel_tool: bool | str = False @@ -157,11 +157,11 @@ class AfterToolCallEvent(HookEvent): cancel_message: The cancellation message if the user cancelled the tool call. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] result: ToolResult - exception: Optional[Exception] = None + exception: Exception | None = None cancel_message: str | None = None def _can_write(self, name: str) -> bool: @@ -232,8 +232,8 @@ class ModelStopResponse: message: Message stop_reason: StopReason - stop_response: Optional[ModelStopResponse] = None - exception: Optional[Exception] = None + stop_response: ModelStopResponse | None = None + exception: Exception | None = None retry: bool = False def _can_write(self, name: str) -> bool: diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 9edf7ffa7..309e3ba76 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,8 +9,9 @@ import inspect import logging +from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException @@ -154,9 +155,9 @@ class HookRegistry: def __init__(self) -> None: """Initialize an empty hook registry.""" - self._registered_callbacks: dict[Type, list[HookCallback]] = {} + self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: """Register a callback function for a specific event type. Args: diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 9eabe28a1..1e82bca73 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -1,14 +1,15 @@ """Configuration validation utilities for model providers.""" import warnings -from typing import Any, Mapping, Type +from collections.abc import Mapping +from typing import Any from typing_extensions import get_type_hints from ..types.tools import ToolChoice -def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: +def validate_config_keys(config_dict: Mapping[str, Any], config_class: type) -> None: """Validate that config keys match the TypedDict fields. Args: diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 68b234729..535c820ee 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import anthropic from pydantic import BaseModel @@ -59,9 +60,9 @@ class AnthropicConfig(TypedDict, total=False): max_tokens: Required[int] model_id: Required[str] - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]): + def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. Args: @@ -198,8 +199,8 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -369,8 +370,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -419,8 +420,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8e1558ca7..dfcd133c6 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -8,7 +8,8 @@ import logging import os import warnings -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast +from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from typing import Any, Literal, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -94,35 +95,35 @@ class BedrockConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature) """ - additional_args: Optional[dict[str, Any]] - additional_request_fields: Optional[dict[str, Any]] - additional_response_field_paths: Optional[list[str]] - cache_prompt: Optional[str] - cache_tools: Optional[str] - guardrail_id: Optional[str] - guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] - guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] - guardrail_version: Optional[str] - guardrail_redact_input: Optional[bool] - guardrail_redact_input_message: Optional[str] - guardrail_redact_output: Optional[bool] - guardrail_redact_output_message: Optional[str] - guardrail_latest_message: Optional[bool] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + additional_request_fields: dict[str, Any] | None + additional_response_field_paths: list[str] | None + cache_prompt: str | None + cache_tools: str | None + guardrail_id: str | None + guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None + guardrail_stream_processing_mode: Literal["sync", "async"] | None + guardrail_version: str | None + guardrail_redact_input: bool | None + guardrail_redact_input_message: str | None + guardrail_redact_output: bool | None + guardrail_redact_output_message: str | None + guardrail_latest_message: bool | None + max_tokens: int | None model_id: str - include_tool_result_status: Optional[Literal["auto"] | bool] - stop_sequences: Optional[list[str]] - streaming: Optional[bool] - temperature: Optional[float] - top_p: Optional[float] + include_tool_result_status: Literal["auto"] | bool | None + stop_sequences: list[str] | None + streaming: bool | None + temperature: float | None + top_p: float | None def __init__( self, *, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, - endpoint_url: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + endpoint_url: str | None = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -193,8 +194,8 @@ def get_config(self) -> BedrockConfig: def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -603,11 +604,11 @@ def _generate_redaction_events(self) -> list[StreamEvent]: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -631,13 +632,13 @@ async def stream( ModelThrottledException: If the model service is throttling requests. """ - def callback(event: Optional[StreamEvent] = None) -> None: + def callback(event: StreamEvent | None = None) -> None: loop.call_soon_threadsafe(queue.put_nowait, event) if event is None: return loop = asyncio.get_event_loop() - queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue() # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: @@ -659,8 +660,8 @@ def _stream( self, callback: Callable[..., None], messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -913,11 +914,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -962,7 +963,7 @@ async def structured_output( yield {"output": output_model(**output_response)} @staticmethod - def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str: + def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str: """Get the default Bedrock modelId based on region. If the region is not **known** to support inference then we show a helpful warning diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 45f7f4e18..52d45b649 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -6,7 +6,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import pydantic from google import genai @@ -54,8 +55,8 @@ class GeminiConfig(TypedDict, total=False): def __init__( self, *, - client: Optional[genai.Client] = None, - client_args: Optional[dict[str, Any]] = None, + client: genai.Client | None = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. @@ -219,7 +220,7 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten for message in messages ] - def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[genai.types.Tool | Any]: + def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Tool @@ -248,9 +249,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge def _format_request_config( self, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> genai.types.GenerateContentConfig: """Format Gemini request config. @@ -273,9 +274,9 @@ def _format_request_config( def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> dict[str, Any]: """Format a Gemini streaming request. @@ -394,8 +395,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: @@ -483,8 +484,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model using Gemini's native structured output. - Docs: https://ai.google.dev/gemini-api/docs/structured-output diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c120b0eda..ae71cc668 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import litellm from litellm.exceptions import ContextWindowExceededError @@ -42,9 +43,9 @@ class LiteLLMConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[LiteLLMConfig]) -> None: """Initialize provider instance. Args: @@ -137,9 +138,9 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for LiteLLM with cache point support. @@ -174,9 +175,9 @@ def _format_system_messages( def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format a LiteLLM compatible messages array with cache point support. @@ -243,11 +244,11 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -295,8 +296,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Some models do not support native structured output via response_format. @@ -322,7 +323,7 @@ async def structured_output( yield {"output": result} async def _structured_output_using_response_schema( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using native response_format support.""" response = await litellm.acompletion( @@ -350,7 +351,7 @@ async def _structured_output_using_response_schema( raise ValueError(f"Failed to parse or load content into model: {e}") from e async def _structured_output_using_tool( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using tool calling fallback.""" tool_spec = convert_pydantic_to_tool_spec(output_model) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..ce0367bf5 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -43,16 +44,16 @@ class LlamaConfig(TypedDict, total=False): """ model_id: str - repetition_penalty: Optional[float] - temperature: Optional[float] - top_p: Optional[float] - max_completion_tokens: Optional[int] - top_k: Optional[int] + repetition_penalty: float | None + temperature: float | None + top_p: float | None + max_completion_tokens: int | None + top_k: int | None def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[LlamaConfig], ) -> None: """Initialize provider instance. @@ -159,7 +160,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": [self._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -206,7 +207,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Llama API chat streaming request. @@ -328,8 +329,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -416,8 +417,8 @@ async def stream( @override def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 22a3a3873..ca838f3d7 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -14,15 +14,11 @@ import logging import mimetypes import time +from collections.abc import AsyncGenerator from typing import ( Any, - AsyncGenerator, - Dict, - Optional, - Type, TypedDict, TypeVar, - Union, cast, ) @@ -133,12 +129,12 @@ class LlamaCppConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( self, base_url: str = "http://localhost:8080", - timeout: Optional[Union[float, tuple[float, float]]] = None, + timeout: float | tuple[float, float] | None = None, **model_config: Unpack[LlamaCppConfig], ) -> None: """Initialize llama.cpp provider instance. @@ -196,7 +192,7 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + def _format_message_content(self, content: ContentBlock | dict[str, Any]) -> dict[str, Any]: """Format a content block for llama.cpp. Args: @@ -233,7 +229,7 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) # Handle audio content (not in standard ContentBlock but supported by llama.cpp) if "audio" in content: - audio_content = cast(Dict[str, Any], content) + audio_content = cast(dict[str, Any], content) audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") audio_format = audio_content["audio"].get("format", "wav") return { @@ -284,7 +280,7 @@ def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: "content": [self._format_message_content(content) for content in contents], } - def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format messages for llama.cpp. Args: @@ -343,8 +339,8 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, ) -> dict[str, Any]: """Format a request for the llama.cpp server. @@ -428,7 +424,7 @@ def _format_request( request[param] = value # Collect llama.cpp-specific parameters for extra_body - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} for param, value in params.items(): if param in llamacpp_specific_params: extra_body[param] = value @@ -511,8 +507,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -552,7 +548,7 @@ async def stream( yield self._format_chunk({"chunk_type": "message_start"}) yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls: Dict[int, list] = {} + tool_calls: dict[int, list] = {} usage_data = None finish_reason = None @@ -706,11 +702,11 @@ async def stream( @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output using llama.cpp's native JSON schema support. This implementation uses llama.cpp's json_schema parameter to constrain @@ -753,7 +749,7 @@ async def structured_output( if "text" in delta: response_text += delta["text"] # Forward events to caller - yield cast(Dict[str, Union[T, Any]], event) + yield cast(dict[str, T | Any], event) # Parse and validate the JSON response data = json.loads(response_text.strip()) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index b6459d63f..4ec77ccfe 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,8 @@ import base64 import json import logging -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable +from typing import Any, TypeVar import mistralai from pydantic import BaseModel @@ -47,16 +48,16 @@ class MistralConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - temperature: Optional[float] - top_p: Optional[float] - stream: Optional[bool] + max_tokens: int | None + temperature: float | None + top_p: float | None + stream: bool | None def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[MistralConfig], ) -> None: """Initialize provider instance. @@ -115,7 +116,7 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + def _format_request_message_content(self, content: ContentBlock) -> str | dict[str, Any]: """Format a Mistral content block. Args: @@ -187,7 +188,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -236,7 +237,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Mistral chat streaming request. @@ -395,8 +396,8 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -502,8 +503,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 6b7dd78d7..e6630f807 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -2,7 +2,8 @@ import abc import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any, TypeVar from pydantic import BaseModel @@ -45,8 +46,8 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -68,8 +69,8 @@ def structured_output( def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..8d72aa534 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import ollama from pydantic import BaseModel @@ -46,20 +47,20 @@ class OllamaConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature). """ - additional_args: Optional[dict[str, Any]] - keep_alive: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + keep_alive: str | None + max_tokens: int | None model_id: str - options: Optional[dict[str, Any]] - stop_sequences: Optional[list[str]] - temperature: Optional[float] - top_p: Optional[float] + options: dict[str, Any] | None + stop_sequences: list[str] | None + temperature: float | None + top_p: float | None def __init__( self, - host: Optional[str], + host: str | None, *, - ollama_client_args: Optional[dict[str, Any]] = None, + ollama_client_args: dict[str, Any] | None = None, **model_config: Unpack[OllamaConfig], ) -> None: """Initialize provider instance. @@ -147,7 +148,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format an Ollama compatible messages array. Args: @@ -167,7 +168,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s ] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format an Ollama chat streaming request. @@ -285,8 +286,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -339,8 +340,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c381201e4..d9266212b 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,8 +7,9 @@ import json import logging import mimetypes +from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from typing import Any, Protocol, TypedDict, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -54,12 +55,12 @@ class OpenAIConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( self, - client: Optional[Client] = None, - client_args: Optional[dict[str, Any]] = None, + client: Client | None = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIConfig], ) -> None: """Initialize provider instance. @@ -201,9 +202,7 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> } @classmethod - def _split_tool_message_images( - cls, tool_message: dict[str, Any] - ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]: + def _split_tool_message_images(cls, tool_message: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: """Split a tool message into text-only tool message and optional user message with images. OpenAI API restricts images to user role messages only. This method extracts any image @@ -291,9 +290,9 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for OpenAI-compatible providers. @@ -374,9 +373,9 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -549,8 +548,8 @@ async def _get_client(self) -> AsyncIterator[Any]: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -679,8 +678,8 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 1fe630fdc..775969290 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -3,8 +3,9 @@ import json import logging import os +from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, Literal, TypedDict, TypeVar import boto3 from botocore.config import Config as BotocoreConfig @@ -37,7 +38,7 @@ class UsageMetadata: total_tokens: int completion_tokens: int prompt_tokens: int - prompt_tokens_details: Optional[int] = 0 + prompt_tokens_details: int | None = 0 @dataclass @@ -49,8 +50,8 @@ class FunctionCall: arguments: Arguments to pass to the function """ - name: Union[str, dict[Any, Any]] - arguments: Union[str, dict[Any, Any]] + name: str | dict[Any, Any] + arguments: str | dict[Any, Any] def __init__(self, **kwargs: dict[str, str]): """Initialize function call. @@ -108,12 +109,12 @@ class SageMakerAIPayloadSchema(TypedDict, total=False): max_tokens: int stream: bool - temperature: Optional[float] - top_p: Optional[float] - top_k: Optional[int] - stop: Optional[list[str]] - tool_results_as_user_messages: Optional[bool] - additional_args: Optional[dict[str, Any]] + temperature: float | None + top_p: float | None + top_k: int | None + stop: list[str] | None + tool_results_as_user_messages: bool | None + additional_args: dict[str, Any] | None class SageMakerAIEndpointConfig(TypedDict, total=False): """Configuration options for SageMaker models. @@ -127,17 +128,17 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): endpoint_name: str region_name: str - inference_component_name: Union[str, None] - target_model: Union[Optional[str], None] - target_variant: Union[Optional[str], None] - additional_args: Optional[dict[str, Any]] + inference_component_name: str | None + target_model: str | None | None + target_variant: str | None | None + additional_args: dict[str, Any] | None def __init__( self, endpoint_config: SageMakerAIEndpointConfig, payload_config: SageMakerAIPayloadSchema, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, ): """Initialize provider instance. @@ -199,8 +200,8 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -300,8 +301,8 @@ def format_request( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -572,8 +573,8 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index a54fc44c3..f306d649b 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import writerai from pydantic import BaseModel @@ -41,13 +42,13 @@ class WriterConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - stop: Optional[Union[str, List[str]]] - stream_options: Dict[str, Any] - temperature: Optional[float] - top_p: Optional[float] + max_tokens: int | None + stop: str | list[str] | None + stream_options: dict[str, Any] + temperature: float | None + top_p: float | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[WriterConfig]): """Initialize provider instance. Args: @@ -201,7 +202,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": formatted_contents, } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Writer compatible messages array. Args: @@ -245,7 +246,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: """Format a streaming request to the underlying model. @@ -353,8 +354,8 @@ def format_chunk(self, event: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -431,8 +432,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 52b6d2ef1..f02b8c6cc 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -313,15 +313,13 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten elif uri_data: # For URI files, create a text representation since Strands ContentBlocks expect bytes content_blocks.append( - ContentBlock( - text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) - ) + ContentBlock(text=f"[File: {file_name} ({mime_type})] - Referenced file at: {uri_data}") ) elif isinstance(part_root, DataPart): # Handle DataPart - convert structured data to JSON text try: data_text = json.dumps(part_root.data, indent=2) - content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + content_blocks.append(ContentBlock(text=f"[Structured Data]\n{data_text}")) except Exception: logger.exception("Failed to serialize data part") except Exception: diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..dc3258f68 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -6,9 +6,10 @@ import logging import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, Union from .._async import run_async from ..agent import AgentResult @@ -95,7 +96,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": raise TypeError("NodeResult.from_dict: missing 'result'") raw = data["result"] - result: Union[AgentResult, "MultiAgentResult", Exception] + result: AgentResult | MultiAgentResult | Exception if isinstance(raw, dict) and raw.get("type") == "agent_result": result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..19504ad73 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,8 +18,9 @@ import copy import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, cast from opentelemetry import trace as trace_api @@ -90,14 +91,14 @@ class GraphState: # Graph structure info total_nodes: int = 0 - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) def should_continue( self, - max_node_executions: Optional[int], - execution_timeout: Optional[float], - ) -> Tuple[bool, str]: + max_node_executions: int | None, + execution_timeout: float | None, + ) -> tuple[bool, str]: """Check if the graph should continue execution. Returns: (should_continue, reason) @@ -123,7 +124,7 @@ class GraphResult(MultiAgentResult): completed_nodes: int = 0 failed_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -233,13 +234,13 @@ def __init__(self) -> None: self.entry_points: set[GraphNode] = set() # Configuration options - self._max_node_executions: Optional[int] = None - self._execution_timeout: Optional[float] = None - self._node_timeout: Optional[float] = None + self._max_node_executions: int | None = None + self._execution_timeout: float | None = None + self._node_timeout: float | None = None self._reset_on_revisit: bool = False self._id: str = _DEFAULT_GRAPH_ID - self._session_manager: Optional[SessionManager] = None - self._hooks: Optional[list[HookProvider]] = None + self._session_manager: SessionManager | None = None + self._hooks: list[HookProvider] | None = None def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" @@ -408,14 +409,14 @@ def __init__( nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode], - max_node_executions: Optional[int] = None, - execution_timeout: Optional[float] = None, - node_timeout: Optional[float] = None, + max_node_executions: int | None = None, + execution_timeout: float | None = None, + node_timeout: float | None = None, reset_on_revisit: bool = False, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_GRAPH_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..6c1149624 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -18,8 +18,9 @@ import json import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, Optional, cast from opentelemetry import trace as trace_api @@ -184,7 +185,7 @@ def should_continue( execution_timeout: float, repetitive_handoff_detection_window: int, repetitive_handoff_min_unique_agents: int, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the swarm should continue. Returns: (should_continue, reason) @@ -239,10 +240,10 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_SWARM_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Swarm with agents and configuration. diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fc80fc520..0b25d4b5d 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,7 @@ import os import shutil import tempfile -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from .. import _identifier from ..types.exceptions import SessionException @@ -44,7 +44,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): def __init__( self, session_id: str, - storage_dir: Optional[str] = None, + storage_dir: str | None = None, **kwargs: Any, ): """Initialize FileSession with filesystem storage. @@ -108,7 +108,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return cast(dict[str, Any], json.load(f)) except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e @@ -140,7 +140,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data.""" session_file = os.path.join(self._get_session_path(session_id), "session.json") if not os.path.exists(session_file): @@ -169,7 +169,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A session_data = session_agent.to_dict() self._write_file(agent_file, session_data) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data.""" agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") if not os.path.exists(agent_file): @@ -199,7 +199,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio session_dict = session_message.to_dict() self._write_file(message_file, session_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data.""" message_path = self._get_message_path(session_id, agent_id, message_id) if not os.path.exists(message_path): @@ -220,7 +220,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_file(message_file, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List messages for an agent with pagination.""" messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") @@ -269,7 +269,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_file(multi_agent_file, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from filesystem.""" multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") if not os.path.exists(multi_agent_file): diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a8ac099d9..d23c4a94f 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,7 +1,7 @@ """Repository session manager implementation.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..agent.state import AgentState from ..tools._tool_helpers import generate_missing_tool_result_content @@ -57,7 +57,7 @@ def __init__( self.session = session # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + self._latest_agent_message: dict[str, SessionMessage | None] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 7d081cf09..e5713e5b7 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -47,9 +47,9 @@ def __init__( session_id: str, bucket: str, prefix: str = "", - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -130,7 +130,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" - def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + def _read_s3_object(self, key: str) -> dict[str, Any] | None: """Read JSON object from S3.""" try: response = self.client.get_object(Bucket=self.bucket, Key=key) @@ -144,7 +144,7 @@ def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e - def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + def _write_s3_object(self, key: str, data: dict[str, Any]) -> None: """Write JSON object to S3.""" try: content = json.dumps(data, indent=2, ensure_ascii=False) @@ -171,7 +171,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: self._write_s3_object(session_key, session_dict) return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data from S3.""" session_key = f"{self._get_session_path(session_id)}session.json" session_data = self._read_s3_object(session_key) @@ -209,7 +209,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" self._write_s3_object(agent_key, agent_dict) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data from S3.""" agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" agent_data = self._read_s3_object(agent_key) @@ -236,7 +236,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio message_key = self._get_message_path(session_id, agent_id, message_id) self._write_s3_object(message_key, message_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data from S3.""" message_key = self._get_message_path(session_id, agent_id, message_id) message_data = self._read_s3_object(message_key) @@ -257,8 +257,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_s3_object(message_key, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any - ) -> List[SessionMessage]: + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any + ) -> list[SessionMessage]: """List messages for an agent with pagination from S3.""" messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: @@ -288,7 +288,7 @@ def list_messages( message_keys = message_keys[offset:] # Load only the required message objects - messages: List[SessionMessage] = [] + messages: list[SessionMessage] = [] for key in message_keys: message_data = self._read_s3_object(key) if message_data: @@ -312,7 +312,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_s3_object(multi_agent_key, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from S3.""" multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" return self._read_s3_object(multi_agent_key) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 3f5476bdf..0b6f2c705 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,7 +1,7 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..types.session import Session, SessionAgent, SessionMessage @@ -17,7 +17,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new Session.""" @abstractmethod - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read a Session.""" @abstractmethod @@ -25,7 +25,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A """Create a new Agent in a Session.""" @abstractmethod - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read an Agent.""" @abstractmethod @@ -37,7 +37,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio """Create a new Message for the Agent.""" @abstractmethod - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read a Message.""" @abstractmethod @@ -49,7 +49,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio @abstractmethod def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" @@ -57,7 +57,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k """Create a new MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 8f3ee1ea1..163df803a 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -3,8 +3,9 @@ import logging import time import uuid +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Optional import opentelemetry.metrics as metrics_api from opentelemetry.metrics import Counter, Histogram, Meter @@ -23,11 +24,11 @@ class Trace: def __init__( self, name: str, - parent_id: Optional[str] = None, - start_time: Optional[float] = None, - raw_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - message: Optional[Message] = None, + parent_id: str | None = None, + start_time: float | None = None, + raw_name: str | None = None, + metadata: dict[str, Any] | None = None, + message: Message | None = None, ) -> None: """Initialize a new trace. @@ -42,15 +43,15 @@ def __init__( """ self.id: str = str(uuid.uuid4()) self.name: str = name - self.raw_name: Optional[str] = raw_name - self.parent_id: Optional[str] = parent_id + self.raw_name: str | None = raw_name + self.parent_id: str | None = parent_id self.start_time: float = start_time if start_time is not None else time.time() - self.end_time: Optional[float] = None - self.children: List["Trace"] = [] - self.metadata: Dict[str, Any] = metadata or {} - self.message: Optional[Message] = message + self.end_time: float | None = None + self.children: list[Trace] = [] + self.metadata: dict[str, Any] = metadata or {} + self.message: Message | None = message - def end(self, end_time: Optional[float] = None) -> None: + def end(self, end_time: float | None = None) -> None: """Mark the trace as complete with the given or current timestamp. Args: @@ -67,7 +68,7 @@ def add_child(self, child: "Trace") -> None: """ self.children.append(child) - def duration(self) -> Optional[float]: + def duration(self) -> float | None: """Calculate the duration of this trace. Returns: @@ -83,7 +84,7 @@ def add_message(self, message: Message) -> None: """ self.message = message - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert the trace to a dictionary representation. Returns: @@ -127,7 +128,7 @@ def add_call( duration: float, success: bool, metrics_client: "MetricsClient", - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] | None = None, ) -> None: """Record a new tool call with its outcome. @@ -207,7 +208,7 @@ def _metrics_client(self) -> "MetricsClient": return MetricsClient() @property - def latest_agent_invocation(self) -> Optional[AgentInvocation]: + def latest_agent_invocation(self) -> AgentInvocation | None: """Get the most recent agent invocation. Returns: @@ -217,8 +218,8 @@ def latest_agent_invocation(self) -> Optional[AgentInvocation]: def start_cycle( self, - attributes: Dict[str, Any], - ) -> Tuple[float, Trace]: + attributes: dict[str, Any], + ) -> tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. Args: @@ -243,7 +244,7 @@ def start_cycle( return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: dict[str, Any] | None = None) -> None: """End the current event loop cycle and record its duration. Args: @@ -358,7 +359,7 @@ def update_metrics(self, metrics: Metrics) -> None: self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of all collected metrics. Returns: @@ -404,7 +405,7 @@ def get_summary(self) -> Dict[str, Any]: return summary -def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: Set[str]) -> Iterable[str]: +def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: set[str]) -> Iterable[str]: """Convert event loop metrics to a series of formatted text lines. Args: @@ -465,7 +466,7 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name yield from _trace_to_lines(trace.to_dict(), allowed_names=allowed_names, indent=1) -def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterable[str]: +def _trace_to_lines(trace: dict, allowed_names: set[str], indent: int) -> Iterable[str]: """Convert a trace to a series of formatted text lines. Args: @@ -497,7 +498,7 @@ def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterab yield from _trace_to_lines(child, allowed_names, indent + 1) -def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optional[Set[str]] = None) -> str: +def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: set[str] | None = None) -> str: """Convert event loop metrics to a human-readable string representation. Args: diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d16b37fc8..d73ea3c39 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,8 +7,9 @@ import json import logging import os +from collections.abc import Mapping from datetime import date, datetime, timezone -from typing import Any, Dict, Mapping, Optional, cast +from typing import Any, cast import opentelemetry.trace as trace_api from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -89,7 +90,7 @@ class Tracer: def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ - self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer_provider: trace_api.TracerProvider | None = None self.tracer_provider = trace_api.get_tracer_provider() self.tracer = self.tracer_provider.get_tracer(self.service_name) ThreadingInstrumentor().instrument() @@ -112,8 +113,8 @@ def _parse_semconv_opt_in(self) -> set[str]: def _start_span( self, span_name: str, - parent_span: Optional[Span] = None, - attributes: Optional[Dict[str, AttributeValue]] = None, + parent_span: Span | None = None, + attributes: dict[str, AttributeValue] | None = None, span_kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, ) -> Span: """Generic helper method to start a span with common attributes. @@ -145,7 +146,7 @@ def _start_span( return span - def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -159,7 +160,7 @@ def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> span.set_attribute(key, value) def _add_optional_usage_and_metrics_attributes( - self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + self, attributes: dict[str, AttributeValue], usage: Usage, metrics: Metrics ) -> None: """Add optional usage and metrics attributes if they have values. @@ -183,8 +184,8 @@ def _add_optional_usage_and_metrics_attributes( def _end_span( self, span: Span, - attributes: Optional[Dict[str, AttributeValue]] = None, - error: Optional[Exception] = None, + attributes: dict[str, AttributeValue] | None = None, + error: Exception | None = None, ) -> None: """Generic helper method to end a span. @@ -221,7 +222,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: """End a span with error status. Args: @@ -235,7 +236,7 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Optiona error = exception or Exception(error_message) self._end_span(span, error=error) - def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Attributes) -> None: + def _add_event(self, span: Span | None, event_name: str, event_attributes: Attributes) -> None: """Add an event with attributes to a span. Args: @@ -275,9 +276,9 @@ def _get_event_name_for_message(self, message: Message) -> str: def start_model_invoke_span( self, messages: Messages, - parent_span: Optional[Span] = None, - model_id: Optional[str] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + model_id: str | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a model invocation. @@ -292,7 +293,7 @@ def start_model_invoke_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") if custom_trace_attributes: attributes.update(custom_trace_attributes) @@ -315,7 +316,7 @@ def end_model_invoke_span( usage: Usage, metrics: Metrics, stop_reason: StopReason, - error: Optional[Exception] = None, + error: Exception | None = None, ) -> None: """End a model invocation span with results and metrics. @@ -327,7 +328,7 @@ def end_model_invoke_span( stop_reason (StopReason): The reason the model stopped generating. error: Optional exception if the model call failed. """ - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], @@ -366,8 +367,8 @@ def end_model_invoke_span( def start_tool_call_span( self, tool: ToolUse, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a tool call. @@ -381,7 +382,7 @@ def start_tool_call_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") attributes.update( { "gen_ai.tool.name": tool["name"], @@ -432,9 +433,7 @@ def start_tool_call_span( return span - def end_tool_call_span( - self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None - ) -> None: + def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: """End a tool call span with results. Args: @@ -442,7 +441,7 @@ def end_tool_call_span( tool_result: The result from the tool execution. error: Optional exception if the tool call failed. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if tool_result is not None: status = tool_result.get("status") status_str = str(status) if status is not None else "" @@ -490,10 +489,10 @@ def start_event_loop_cycle_span( self, invocation_state: Any, messages: Messages, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span | None: """Start a new span for an event loop cycle. Args: @@ -509,7 +508,7 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "event_loop.cycle_id": event_loop_cycle_id, } @@ -532,8 +531,8 @@ def end_event_loop_cycle_span( self, span: Span, message: Message, - tool_result_message: Optional[Message] = None, - error: Optional[Exception] = None, + tool_result_message: Message | None = None, + error: Exception | None = None, ) -> None: """End an event loop cycle span with results. @@ -543,8 +542,8 @@ def end_event_loop_cycle_span( tool_result_message: Optional tool result message if a tool was called. error: Optional exception if the cycle failed. """ - attributes: Dict[str, AttributeValue] = {} - event_attributes: Dict[str, AttributeValue] = {"message": serialize(message["content"])} + attributes: dict[str, AttributeValue] = {} + event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: event_attributes["tool.result"] = serialize(tool_result_message["content"]) @@ -572,10 +571,10 @@ def start_agent_span( self, messages: Messages, agent_name: str, - model_id: Optional[str] = None, - tools: Optional[list] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, - tools_config: Optional[dict] = None, + model_id: str | None = None, + tools: list | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + tools_config: dict | None = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -592,7 +591,7 @@ def start_agent_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") attributes.update( { "gen_ai.agent.name": agent_name, @@ -630,8 +629,8 @@ def start_agent_span( def end_agent_span( self, span: Span, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """End an agent span with results and metrics. @@ -640,7 +639,7 @@ def end_agent_span( response: The response from the agent. error: Any error that occurred. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if response: if self.use_latest_genai_conventions: @@ -702,11 +701,11 @@ def start_multiagent_span( self, task: MultiAgentInput, instance: str, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> Span: """Start a new span for swarm invocation.""" operation = f"invoke_{instance}" - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation) attributes.update( { "gen_ai.agent.name": instance, @@ -741,7 +740,7 @@ def start_multiagent_span( def end_swarm_span( self, span: Span, - result: Optional[str] = None, + result: str | None = None, ) -> None: """End a swarm span with results.""" if result: @@ -770,7 +769,7 @@ def end_swarm_span( def _get_common_attributes( self, operation_name: str, - ) -> Dict[str, AttributeValue]: + ) -> dict[str, AttributeValue]: """Returns a dictionary of common attributes based on the convention version used. Args: diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index bfec5886d..8ca6138fc 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,7 +9,8 @@ import json import random -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from .._async import run_async from ..tools.executors._executor import ToolExecutor diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 8dc933f51..f64c17ee9 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -44,16 +44,13 @@ def my_tool(param1: str, param2: int = 42) -> dict: import functools import inspect import logging +from collections.abc import Callable from typing import ( Annotated, Any, - Callable, Generic, - Optional, ParamSpec, - Type, TypeVar, - Union, cast, get_args, get_origin, @@ -183,7 +180,7 @@ def _validate_signature(self) -> None: # Found the parameter, no need to check further break - def _create_input_model(self) -> Type[BaseModel]: + def _create_input_model(self) -> type[BaseModel]: """Create a Pydantic model from function signature for input validation. This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can @@ -463,7 +460,7 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__(self, instance: Any, obj_type: type | None = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -666,20 +663,20 @@ def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... # Handle @decorator() @overload def tool( - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore - func: Optional[Callable[P, R]] = None, - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + func: Callable[P, R] | None = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5d01c5d48..6d58c5c75 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,7 +7,8 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, cast from opentelemetry import trace as trace_api diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 216eee379..7fa34eff0 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,8 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index f78e60872..dc5b9a5d9 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,7 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 6f745b728..2115cdee8 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -9,7 +9,7 @@ from pathlib import Path from posixpath import expanduser from types import ModuleType -from typing import List, cast +from typing import cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -20,7 +20,7 @@ _TOOL_MODULE_PREFIX = "_strands_tool_" -def load_tool_from_string(tool_string: str) -> List[AgentTool]: +def load_tool_from_string(tool_string: str) -> list[AgentTool]: """Load tools follows strands supported input string formats. This function can load a tool based on a string in the following ways: @@ -42,7 +42,7 @@ def load_tool_from_string(tool_string: str) -> List[AgentTool]: return load_tools_from_module_path(tool_string) -def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: +def load_tools_from_file_path(tool_path: str) -> list[AgentTool]: """Load module from specified path, and then load tools from that module. This function attempts to load the passed in path as a python module, and if it succeeds, @@ -116,7 +116,7 @@ def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTo # Try and see if any of the attributes in the module are function-based tools decorated with @tool # This means that there may be more than one tool available in this module, so we load them all - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] # Function tools will appear as attributes in the module for attr_name in dir(module): attr = getattr(module, attr_name) @@ -153,7 +153,7 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + def load_python_tools(tool_path: str, tool_name: str) -> list[AgentTool]: """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the @@ -206,7 +206,7 @@ def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: spec.loader.exec_module(module) # Collect function-based tools decorated with @tool - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c36811c17..1aff22a1e 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -14,10 +14,12 @@ import threading import uuid from asyncio import AbstractEventLoop +from collections.abc import Callable, Coroutine, Sequence from concurrent import futures from datetime import timedelta +from re import Pattern from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast +from typing import Any, TypeVar, cast import anyio from mcp import ClientSession, ListToolsResult @@ -71,7 +73,7 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] -MIME_TO_FORMAT: Dict[str, ImageFormat] = { +MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", "image/png": "png", @@ -117,7 +119,7 @@ def __init__( startup_timeout: int = 30, tool_filters: ToolFilters | None = None, prefix: str | None = None, - elicitation_callback: Optional[ElicitationFnT] = None, + elicitation_callback: ElicitationFnT | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -300,9 +302,7 @@ def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: # MCP-specific methods - def stop( - self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] - ) -> None: + def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. This method is defensive and can handle partial initialization states that may occur @@ -415,7 +415,7 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) - def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: """Synchronously retrieves the list of available prompts from the MCP server. This method calls the asynchronous list_prompts method on the MCP session @@ -463,7 +463,7 @@ async def _get_prompt_async() -> GetPromptResult: return get_prompt_result - def list_resources_sync(self, pagination_token: Optional[str] = None) -> ListResourcesResult: + def list_resources_sync(self, pagination_token: str | None = None) -> ListResourcesResult: """Synchronously retrieves the list of available resources from the MCP server. This method calls the asynchronous list_resources method on the MCP session @@ -510,7 +510,7 @@ async def _read_resource_async() -> ReadResourceResult: return read_resource_result - def list_resource_templates_sync(self, pagination_token: Optional[str] = None) -> ListResourceTemplatesResult: + def list_resource_templates_sync(self, pagination_token: str | None = None) -> ListResourceTemplatesResult: """Synchronously retrieves the list of available resource templates from the MCP server. Resource templates define URI patterns that can be used to access resources dynamically. @@ -739,7 +739,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, - ) -> Union[ToolResultContent, None]: + ) -> ToolResultContent | None: """Maps MCP content types to tool result content types. This method converts MCP-specific content types to the generic @@ -859,7 +859,7 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" return self._should_include_tool_with_filters(tool, self._tool_filters) - def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: """Check if a tool should be included based on provided filters.""" if not filters: return True diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index f8ab3bc80..d1750daa3 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -9,9 +9,10 @@ Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 """ +from collections.abc import AsyncGenerator, Callable from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Tuple +from typing import Any from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest @@ -129,7 +130,7 @@ def transport_wrapper() -> Callable[ @asynccontextmanager async def traced_method( wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[Tuple[Any, Any], None]: + ) -> AsyncGenerator[tuple[Any, Any], None]: async with wrapped(*args, **kwargs) as result: try: read_stream, write_stream = result @@ -139,7 +140,7 @@ async def traced_method( return traced_method - def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]: + def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: """Create a wrapper for MCP session initialization. Wraps session message streams to enable bidirectional context flow. @@ -151,7 +152,7 @@ def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any """ def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any] + wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: wrapped(*args, **kwargs) reader = getattr(instance, "_incoming_message_stream_reader", None) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2547aabcc..f9787a182 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -10,12 +10,13 @@ import sys import uuid import warnings +from collections.abc import Iterable, Sequence from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, cast -from typing_extensions import TypedDict, cast +from typing_extensions import TypedDict from .._async import run_async from ..experimental.tools import ToolProvider @@ -35,13 +36,13 @@ class ToolRegistry: def __init__(self) -> None: """Initialize the tool registry.""" - self.registry: Dict[str, AgentTool] = {} - self.dynamic_tools: Dict[str, AgentTool] = {} - self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List[ToolProvider] = [] + self.registry: dict[str, AgentTool] = {} + self.dynamic_tools: dict[str, AgentTool] = {} + self.tool_config: dict[str, Any] | None = None + self._tool_providers: list[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) - def process_tools(self, tools: List[Any]) -> List[str]: + def process_tools(self, tools: list[Any]) -> list[str]: """Process tools list. Process list of tools that can contain local file path string, module import path string, @@ -186,7 +187,7 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: logger.exception("tool_name=<%s> | failed to load tool", tool_name) raise ValueError(f"Failed to load tool {tool_name}: {exception_str}") from e - def get_all_tools_config(self) -> Dict[str, Any]: + def get_all_tools_config(self) -> dict[str, Any]: """Dynamically generate tool configuration by combining built-in and dynamic tools. Returns: @@ -305,7 +306,7 @@ def replace(self, new_tool: AgentTool) -> None: elif tool_name in self.dynamic_tools: del self.dynamic_tools[tool_name] - def get_tools_dirs(self) -> List[Path]: + def get_tools_dirs(self) -> list[Path]: """Get all tool directory paths. Returns: @@ -325,7 +326,7 @@ def get_tools_dirs(self) -> List[Path]: return tool_dirs - def discover_tool_modules(self) -> Dict[str, Path]: + def discover_tool_modules(self) -> dict[str, Path]: """Discover available tool modules in all tools directories. Returns: @@ -568,7 +569,7 @@ def get_all_tool_specs(self) -> list[ToolSpec]: A list of ToolSpecs. """ all_tools = self.get_all_tools_config() - tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] + tools: list[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools def register_dynamic_tool(self, tool: AgentTool) -> None: @@ -645,7 +646,7 @@ class NewToolDict(TypedDict): spec: ToolSpec - def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict) -> None: + def _update_tool_config(self, tool_config: dict[str, Any], new_tool: NewToolDict) -> None: """Update tool configuration with a new tool. Args: @@ -682,7 +683,7 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict tool_config["tools"].append(new_tool_entry) logger.debug("tool_name=<%s> | added new tool", new_tool_name) - def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: + def _scan_module_for_tools(self, module: Any) -> list[AgentTool]: """Scan a module for function-based tools. Args: @@ -691,7 +692,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: Returns: List of FunctionTool instances found in the module. """ - tools: List[AgentTool] = [] + tools: list[AgentTool] = [] for name, obj in inspect.getmembers(module): if isinstance(obj, DecoratedFunctionTool): diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index f33a06915..2f8dd8ca0 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -1,7 +1,7 @@ """Context management for structured output in the event loop.""" import logging -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -17,20 +17,20 @@ class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: Type[BaseModel] | None = None): + def __init__(self, structured_output_model: type[BaseModel] | None = None): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. """ self.results: dict[str, BaseModel] = {} - self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_model: type[BaseModel] | None = structured_output_model self.structured_output_tool: StructuredOutputTool | None = None self.forced_mode: bool = False self.force_attempted: bool = False self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False - self.expected_tool_name: Optional[str] = None + self.expected_tool_name: str | None = None if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) @@ -91,7 +91,7 @@ def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: return False return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) - def get_tool_spec(self) -> Optional[ToolSpec]: + def get_tool_spec(self) -> ToolSpec | None: """Get the tool specification for structured output. Returns: diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py index 25173d048..fa20f526c 100644 --- a/src/strands/tools/structured_output/structured_output_tool.py +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ValidationError from typing_extensions import override @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} +_TOOL_SPEC_CACHE: dict[type[BaseModel], ToolSpec] = {} if TYPE_CHECKING: from ._structured_output_context import StructuredOutputContext @@ -26,7 +26,7 @@ class StructuredOutputTool(AgentTool): """Tool implementation for structured output validation.""" - def __init__(self, structured_output_model: Type[BaseModel]) -> None: + def __init__(self, structured_output_model: type[BaseModel]) -> None: """Initialize a structured output tool. Args: @@ -43,7 +43,7 @@ def __init__(self, structured_output_model: Type[BaseModel]) -> None: self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") @classmethod - def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + def _get_tool_spec(cls, structured_output_model: type[BaseModel]) -> ToolSpec: """Get a cached tool spec for the given output type. Args: @@ -84,7 +84,7 @@ def tool_type(self) -> str: return "structured_output" @property - def structured_output_model(self) -> Type[BaseModel]: + def structured_output_model(self) -> type[BaseModel]: """Get the Pydantic model type for this tool. Returns: diff --git a/src/strands/tools/structured_output/structured_output_utils.py b/src/strands/tools/structured_output/structured_output_utils.py index 093d67f7c..a78ec6195 100644 --- a/src/strands/tools/structured_output/structured_output_utils.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -1,13 +1,13 @@ """Tools for converting Pydantic models to Bedrock tools.""" -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Union from pydantic import BaseModel from ...types.tools import ToolSpec -def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def _flatten_schema(schema: dict[str, Any]) -> dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. Handles required vs optional fields properly. @@ -80,11 +80,11 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: def _process_property( - prop: Dict[str, Any], - defs: Dict[str, Any], + prop: dict[str, Any], + defs: dict[str, Any], is_required: bool = False, fully_expand: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Process a property in a schema, resolving any references. Args: @@ -174,8 +174,8 @@ def _process_property( def _process_schema_object( - schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True -) -> Dict[str, Any]: + schema_obj: dict[str, Any], defs: dict[str, Any], fully_expand: bool = True +) -> dict[str, Any]: """Process a schema object, typically from $defs, to resolve all nested properties. Args: @@ -218,7 +218,7 @@ def _process_schema_object( return result -def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: +def _process_nested_dict(d: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]: """Recursively processes nested dictionaries and resolves $ref references. Args: @@ -228,7 +228,7 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A Returns: Processed dictionary """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} # Handle direct reference if "$ref" in d: @@ -258,8 +258,8 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A def convert_pydantic_to_tool_spec( - model: Type[BaseModel], - description: Optional[str] = None, + model: type[BaseModel], + description: str | None = None, ) -> ToolSpec: """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. @@ -302,7 +302,7 @@ def convert_pydantic_to_tool_spec( ) -def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _expand_nested_properties(schema: dict[str, Any], model: type[BaseModel]) -> None: """Expand the properties of nested models in the schema to include their full structure. This updates the schema in place. @@ -348,7 +348,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> schema["properties"][prop_name] = expanded_object -def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_referenced_models(schema: dict[str, Any], model: type[BaseModel]) -> None: """Process referenced models to ensure their docstrings are included. This updates the schema in place. @@ -388,7 +388,7 @@ def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) - _process_properties(ref_def, field_type) -def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_properties(schema_def: dict[str, Any], model: type[BaseModel]) -> None: """Process properties in a schema definition to add descriptions from field metadata. Args: diff --git a/src/strands/tools/watcher.py b/src/strands/tools/watcher.py index 44f2ed512..c7f50fccd 100644 --- a/src/strands/tools/watcher.py +++ b/src/strands/tools/watcher.py @@ -6,7 +6,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Set +from typing import Any from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -25,9 +25,9 @@ class ToolWatcher: # design pattern avoids conflicts when multiple tool registries are watching the same directories. _shared_observer = None - _watched_dirs: Set[str] = set() + _watched_dirs: set[str] = set() _observer_started = False - _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} + _registry_handlers: dict[str, dict[int, "ToolWatcher.ToolChangeHandler"]] = {} def __init__(self, tool_registry: ToolRegistry) -> None: """Initialize a tool watcher for the given tool registry. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index d64357cf8..0896d48e1 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,7 +5,8 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel from typing_extensions import override diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index 623f6ddc7..2b3714ce1 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Literal, Union +from typing import Literal from typing_extensions import TypedDict @@ -120,13 +120,13 @@ class WebLocation(TypedDict, total=False): WebLocationDict = dict[Literal["web"], WebLocation] # Union type for citation locations - tagged union format matching AWS Bedrock API -CitationLocation = Union[ - DocumentCharLocationDict, - DocumentPageLocationDict, - DocumentChunkLocationDict, - SearchResultLocationDict, - WebLocationDict, -] +CitationLocation = ( + DocumentCharLocationDict + | DocumentPageLocationDict + | DocumentChunkLocationDict + | SearchResultLocationDict + | WebLocationDict +) class CitationSourceContent(TypedDict, total=False): @@ -178,7 +178,7 @@ class Citation(TypedDict, total=False): """ location: CitationLocation - sourceContent: List[CitationSourceContent] + sourceContent: list[CitationSourceContent] title: str @@ -196,5 +196,5 @@ class CitationsContentBlock(TypedDict, total=False): citations. """ - citations: List[Citation] - content: List[CitationGeneratedContent] + citations: list[Citation] + content: list[CitationGeneratedContent] diff --git a/src/strands/types/collections.py b/src/strands/types/collections.py index df857ace0..28b4a1891 100644 --- a/src/strands/types/collections.py +++ b/src/strands/types/collections.py @@ -1,6 +1,6 @@ """Generic collection types for the Strands SDK.""" -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -12,7 +12,7 @@ class PaginatedList(list, Generic[T]): so existing code that expects List[T] will continue to work. """ - def __init__(self, data: List[T], token: Optional[str] = None): + def __init__(self, data: list[T], token: str | None = None): """Initialize a PaginatedList with data and an optional pagination token. Args: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..d75dbb87f 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,7 +6,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -23,7 +23,7 @@ class GuardContentText(TypedDict): text: The input text details to be evaluated by the guardrail. """ - qualifiers: List[Literal["grounding_source", "query", "guard_content"]] + qualifiers: list[Literal["grounding_source", "query", "guard_content"]] text: str @@ -45,7 +45,7 @@ class ReasoningTextBlock(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - signature: Optional[str] + signature: str | None text: str @@ -120,7 +120,7 @@ class DeltaContent(TypedDict, total=False): """ text: str - toolUse: Dict[Literal["input"], str] + toolUse: dict[Literal["input"], str] class ContentBlockStartToolUse(TypedDict): @@ -142,7 +142,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Information about a tool that the model is requesting to use. """ - toolUse: Optional[ContentBlockStartToolUse] + toolUse: ContentBlockStartToolUse | None class ContentBlockDelta(TypedDict): @@ -183,9 +183,9 @@ class Message(TypedDict): role: The role of the message sender. """ - content: List[ContentBlock] + content: list[ContentBlock] role: Role -Messages = List[Message] +Messages = list[Message] """A list of messages representing a conversation.""" diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index c15ba1bea..70a7aedd5 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -22,7 +22,7 @@ class GuardrailConfig(TypedDict, total=False): guardrailIdentifier: str guardrailVersion: str - streamProcessingMode: Optional[Literal["sync", "async"]] + streamProcessingMode: Literal["sync", "async"] | None trace: Literal["enabled", "disabled"] @@ -47,7 +47,7 @@ class TopicPolicy(TypedDict): topics: The topics in the assessment. """ - topics: List[Topic] + topics: list[Topic] class ContentFilter(TypedDict): @@ -71,7 +71,7 @@ class ContentPolicy(TypedDict): filters: List of content filters to apply. """ - filters: List[ContentFilter] + filters: list[ContentFilter] class CustomWord(TypedDict): @@ -108,8 +108,8 @@ class WordPolicy(TypedDict): managedWordLists: List of managed word lists to filter. """ - customWords: List[CustomWord] - managedWordLists: List[ManagedWord] + customWords: list[CustomWord] + managedWordLists: list[ManagedWord] class PIIEntity(TypedDict): @@ -182,8 +182,8 @@ class SensitiveInformationPolicy(TypedDict): regexes: The regex queries in the assessment. """ - piiEntities: List[PIIEntity] - regexes: List[Regex] + piiEntities: list[PIIEntity] + regexes: list[Regex] class ContextualGroundingFilter(TypedDict): @@ -209,7 +209,7 @@ class ContextualGroundingPolicy(TypedDict): filters: The filter details for the guardrails contextual grounding filter. """ - filters: List[ContextualGroundingFilter] + filters: list[ContextualGroundingFilter] class GuardrailAssessment(TypedDict): @@ -239,9 +239,9 @@ class GuardrailTrace(TypedDict): outputAssessments: Assessments of output content against guardrail policies, keyed by output identifier. """ - inputAssessment: Dict[str, GuardrailAssessment] - modelOutput: List[str] - outputAssessments: Dict[str, List[GuardrailAssessment]] + inputAssessment: dict[str, GuardrailAssessment] + modelOutput: list[str] + outputAssessments: dict[str, list[GuardrailAssessment]] class Trace(TypedDict): diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..462d8af34 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -37,8 +37,8 @@ class DocumentContent(TypedDict, total=False): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource - citations: Optional[CitationsConfig] - context: Optional[str] + citations: CitationsConfig | None + context: str | None ImageFormat = Literal["png", "jpeg", "gif", "webp"] diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 5da3dcde8..29453f4b7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..interrupt import _InterruptState from .content import Message @@ -69,7 +69,7 @@ class SessionMessage: message: Message message_id: int - redact_message: Optional[Message] = None + redact_message: Message | None = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index dcfd541a8..8ec2e8d7b 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -5,8 +5,6 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Optional, Union - from typing_extensions import TypedDict from .citations import CitationLocation @@ -34,7 +32,7 @@ class ContentBlockStartEvent(TypedDict, total=False): start: Information about the content block being started. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None start: ContentBlockStart @@ -102,9 +100,9 @@ class ReasoningContentBlockDelta(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - redactedContent: Optional[bytes] - signature: Optional[str] - text: Optional[str] + redactedContent: bytes | None + signature: str | None + text: str | None class ContentBlockDelta(TypedDict, total=False): @@ -131,7 +129,7 @@ class ContentBlockDeltaEvent(TypedDict, total=False): delta: The incremental content update for the content block. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None delta: ContentBlockDelta @@ -143,7 +141,7 @@ class ContentBlockStopEvent(TypedDict, total=False): This is optional to accommodate different model providers. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None class MessageStopEvent(TypedDict, total=False): @@ -154,7 +152,7 @@ class MessageStopEvent(TypedDict, total=False): stopReason: The reason why the model stopped generating content. """ - additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]] + additionalModelResponseFields: dict | list | int | float | str | bool | None | None stopReason: StopReason @@ -168,7 +166,7 @@ class MetadataEvent(TypedDict, total=False): """ metrics: Metrics - trace: Optional[Trace] + trace: Trace | None usage: Usage @@ -203,8 +201,8 @@ class RedactContentEvent(TypedDict, total=False): """ - redactUserContentMessage: Optional[str] - redactAssistantContentMessage: Optional[str] + redactUserContentMessage: str | None + redactAssistantContentMessage: str | None class StreamEvent(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8f4dba6b1..6fc0d703c 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -7,8 +7,9 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, Literal, Protocol from typing_extensions import NotRequired, TypedDict @@ -164,11 +165,7 @@ def _interrupt_id(self, name: str) -> str: ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] -ToolChoice = Union[ - ToolChoiceAutoDict, - ToolChoiceAnyDict, - ToolChoiceToolDict, -] +ToolChoice = ToolChoiceAutoDict | ToolChoiceAnyDict | ToolChoiceToolDict """ Configuration for how the model should choose tools. @@ -201,12 +198,7 @@ class ToolFunc(Protocol): __name__: str - def __call__( - self, *args: Any, **kwargs: Any - ) -> Union[ - ToolResult, - Awaitable[ToolResult], - ]: + def __call__(self, *args: Any, **kwargs: Any) -> ToolResult | Awaitable[ToolResult]: """Function signature for Python decorated and module based tools. Returns: diff --git a/src/strands/types/traces.py b/src/strands/types/traces.py index af6188adb..c5c3aaa64 100644 --- a/src/strands/types/traces.py +++ b/src/strands/types/traces.py @@ -1,20 +1,20 @@ """Tracing type definitions for the SDK.""" -from typing import List, Mapping, Optional, Sequence, Union +from collections.abc import Mapping, Sequence -AttributeValue = Union[ - str, - bool, - float, - int, - List[str], - List[bool], - List[float], - List[int], - Sequence[str], - Sequence[bool], - Sequence[int], - Sequence[float], -] +AttributeValue = ( + str + | bool + | float + | int + | list[str] + | list[bool] + | list[float] + | list[int] + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float] +) -Attributes = Optional[Mapping[str, AttributeValue]] +Attributes = Mapping[str, AttributeValue] | None diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 091f44d06..cf17bb470 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands import Agent from strands.hooks import ( @@ -17,7 +18,7 @@ class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ AgentInitializedEvent, @@ -37,7 +38,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 727d28a48..4d18297a2 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, @@ -14,7 +15,7 @@ class MockMultiAgentHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ MultiAgentInitializedEvent, @@ -30,7 +31,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 24de958bc..f1c5cae77 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,6 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable, Sequence +from typing import Any, TypedDict, TypeVar from pydantic import BaseModel @@ -25,7 +26,7 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + def __init__(self, agent_responses: Sequence[Message | RedactionMessage]): self.agent_responses = [*agent_responses] self.index = 0 @@ -33,7 +34,7 @@ def format_chunk(self, event: Any) -> StreamEvent: return event def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: return None @@ -45,9 +46,9 @@ def update_config(self, **model_config: Any) -> None: async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, ) -> AsyncGenerator[Any, None]: pass @@ -55,9 +56,9 @@ async def structured_output( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, - tool_choice: Optional[Any] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: Any | None = None, *, system_prompt_content=None, **kwargs: Any, @@ -68,7 +69,7 @@ async def stream( self.index += 1 - def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: + def map_agent_message_to_events(self, agent_message: Message | RedactionMessage) -> Iterable[dict[str, Any]]: stop_reason: StopReason = "end_turn" yield {"messageStart": {"role": "assistant"}} if agent_message.get("redactedAssistantContent"): diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index ad1415f22..12b5af42c 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -1,6 +1,5 @@ import unittest.mock from dataclasses import dataclass -from typing import List from unittest.mock import MagicMock, Mock import pytest @@ -139,7 +138,7 @@ async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, nor @pytest.mark.asyncio async def test_invoke_callbacks_async_after_event(hook_registry, after_event): """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" - call_order: List[str] = [] + call_order: list[str] = [] def callback1(_event): call_order.append("callback1") diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 5d1f02089..1ec0a8407 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Optional, cast +from typing import cast import pytest from pydantic import BaseModel @@ -150,7 +150,7 @@ class StructuredOutputModel(BaseModel): name: str value: int - optional_field: Optional[str] = None + optional_field: str | None = None def test__init__with_structured_output(mock_metrics, simple_message: Message): diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index b679faed0..7341c714e 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -1,6 +1,5 @@ """Tests for Agent structured output functionality.""" -from typing import Optional from unittest import mock from unittest.mock import Mock, patch @@ -28,7 +27,7 @@ class ProductModel(BaseModel): title: str price: float - description: Optional[str] = None + description: str | None = None @pytest.fixture diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 72ebf01c6..5d6d6869a 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -2,7 +2,7 @@ import json import unittest.mock -from typing import Any, Dict, List +from typing import Any import boto3 import pytest @@ -32,7 +32,7 @@ def sagemaker_client(boto_session): @pytest.fixture -def endpoint_config() -> Dict[str, Any]: +def endpoint_config() -> dict[str, Any]: """Default endpoint configuration for tests.""" return { "endpoint_name": "test-endpoint", @@ -42,7 +42,7 @@ def endpoint_config() -> Dict[str, Any]: @pytest.fixture -def payload_config() -> Dict[str, Any]: +def payload_config() -> dict[str, Any]: """Default payload configuration for tests.""" return { "max_tokens": 1024, @@ -64,7 +64,7 @@ def messages() -> Messages: @pytest.fixture -def tool_specs() -> List[ToolSpec]: +def tool_specs() -> list[ToolSpec]: """Sample tool specifications for testing.""" return [ { @@ -405,8 +405,8 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages, # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response @@ -444,8 +444,8 @@ async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, me # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 8cf64a39a..963904002 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Any, List +from typing import Any import pytest @@ -266,7 +266,7 @@ def test_format_request_with_unsupported_type(model, content, content_type): class AsyncStreamWrapper: - def __init__(self, items: List[Any]): + def __init__(self, items: list[Any]): self.items = items def __aiter__(self): @@ -277,7 +277,7 @@ async def _generator(self): yield item -async def mock_streaming_response(items: List[Any]): +async def mock_streaming_response(items: list[Any]): return AsyncStreamWrapper(items) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 7e28be998..8e14c9adc 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -82,7 +82,7 @@ def test_create_session(file_manager, sample_session): assert os.path.exists(session_file) # Verify content - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) assert data["session_id"] == sample_session.session_id assert data["session_type"] == sample_session.session_type @@ -144,7 +144,7 @@ def test_create_agent(file_manager, sample_session, sample_agent): assert os.path.exists(agent_file) # Verify content - with open(agent_file, "r") as f: + with open(agent_file) as f: data = json.load(f) assert data["agent_id"] == sample_agent.agent_id assert data["state"] == sample_agent.state @@ -210,7 +210,7 @@ def test_create_message(file_manager, sample_session, sample_agent, sample_messa assert os.path.exists(message_path) # Verify content - with open(message_path, "r") as f: + with open(message_path) as f: data = json.load(f) assert data["message_id"] == sample_message.message_id @@ -439,7 +439,7 @@ def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agen assert os.path.exists(multi_agent_file) # Verify content - with open(multi_agent_file, "r") as f: + with open(multi_agent_file) as f: data = json.load(f) assert data["id"] == mock_multi_agent.id assert data["state"] == mock_multi_agent.state diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index a7eb27ca5..0f1c7ffff 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -1,7 +1,5 @@ """Tests for StructuredOutputContext class.""" -from typing import Optional - from pydantic import BaseModel, Field from strands.tools.structured_output._structured_output_context import StructuredOutputContext @@ -13,7 +11,7 @@ class SampleModel(BaseModel): name: str = Field(..., description="Name field") age: int = Field(..., description="Age field", ge=0) - email: Optional[str] = Field(None, description="Optional email field") + email: str | None = Field(None, description="Optional email field") class AnotherSampleModel(BaseModel): diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py index 66f1d465d..784a508bd 100644 --- a/tests/strands/tools/structured_output/test_structured_output_tool.py +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -1,6 +1,5 @@ """Tests for StructuredOutputTool class.""" -from typing import List, Optional from unittest.mock import MagicMock import pytest @@ -23,8 +22,8 @@ class ComplexModel(BaseModel): title: str = Field(..., description="Title field") count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") - tags: List[str] = Field(default_factory=list, description="List of tags") - metadata: Optional[dict] = Field(None, description="Optional metadata") + tags: list[str] = Field(default_factory=list, description="List of tags") + metadata: dict | None = Field(None, description="Optional metadata") class ValidationTestModel(BaseModel): diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a2a4c6213..4757e5587 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,7 +3,8 @@ """ from asyncio import Queue -from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator +from typing import Annotated, Any from unittest.mock import MagicMock import pytest @@ -267,7 +268,7 @@ async def test_tool_with_optional_params(alist): """Test tool decorator with optional parameters.""" @strands.tool - def test_tool(required: str, optional: Optional[int] = None) -> str: + def test_tool(required: str, optional: int | None = None) -> str: """Test with optional param. Args: @@ -864,7 +865,7 @@ def int_return_tool(param: str) -> int: # Define tool with Union return type @strands.tool - def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: + def union_return_tool(param: str) -> dict[str, Any] | str | None: """Tool with Union return type. Args: @@ -936,7 +937,7 @@ async def test_complex_parameter_types(alist): """Test handling of complex parameter types like nested dictionaries.""" @strands.tool - def complex_type_tool(config: Dict[str, Any]) -> str: + def complex_type_tool(config: dict[str, Any]) -> str: """Tool with complex parameter type. Args: @@ -965,7 +966,7 @@ async def test_custom_tool_result_handling(alist): """Test that a function returning a properly formatted tool result dictionary is handled correctly.""" @strands.tool - def custom_result_tool(param: str) -> Dict[str, Any]: + def custom_result_tool(param: str) -> dict[str, Any]: """Tool that returns a custom tool result dictionary. Args: @@ -1079,11 +1080,11 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: @pytest.mark.asyncio async def test_tool_complex_validation_edge_cases(alist): """Test validation of complex schema edge cases.""" - from typing import Any, Dict, Union + from typing import Any # Define a tool with a complex anyOf type that could trigger edge case handling @strands.tool - def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: + def edge_case_tool(param: dict[str, Any] | None) -> str: """Tool with complex anyOf structure. Args: @@ -1236,10 +1237,10 @@ def failing_tool(param: str) -> str: @pytest.mark.asyncio async def test_tool_with_complex_anyof_schema(alist): """Test handling of complex anyOf structures in the schema.""" - from typing import Any, Dict, List, Union + from typing import Any @strands.tool - def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]) -> str: + def complex_schema_tool(union_param: list[int] | dict[str, Any] | str | None) -> str: """Tool with a complex Union type that creates anyOf in schema. Args: @@ -1680,7 +1681,7 @@ def test_tool_decorator_annotated_optional_type(): @strands.tool def optional_annotated_tool( - required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + required: Annotated[str, "Required parameter"], optional: Annotated[str | None, "Optional parameter"] = None ) -> str: """Tool with optional annotated parameter.""" return f"{required}, {optional}" @@ -1702,7 +1703,7 @@ def test_tool_decorator_annotated_complex_types(): @strands.tool def complex_annotated_tool( - tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + tags: Annotated[list[str], "List of tag strings"], config: Annotated[dict[str, Any], "Configuration dictionary"] ) -> str: """Tool with complex annotated types.""" return f"Tags: {len(tags)}, Config: {len(config)}" diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index fe9b55334..72a53bfe6 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import pytest from pydantic import BaseModel, Field @@ -27,7 +27,7 @@ class TwoUsersWithPlanet(BaseModel): """Two users model with planet.""" user1: UserWithPlanet = Field(description="The first user") - user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + user2: UserWithPlanet | None = Field(description="The second user", default=None) # Test model with list of same type fields @@ -250,8 +250,8 @@ class NodeWithCircularRef(BaseModel): def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) expected_output = { @@ -281,8 +281,8 @@ class Family(BaseModel): def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] @@ -312,14 +312,14 @@ def test_convert_pydantic_with_items_refs(): """Test that no $refs exist after lists of different components.""" class Address(BaseModel): - postal_code: Optional[str] = None + postal_code: str | None = None class Person(BaseModel): """Complete person information.""" list_of_items: list[Address] - list_of_items_nullable: Optional[list[Address]] - list_of_item_or_nullable: list[Optional[Address]] + list_of_items_nullable: list[Address] | None + list_of_item_or_nullable: list[Address | None] tool_spec = convert_pydantic_to_tool_spec(Person) @@ -378,7 +378,7 @@ class Address(BaseModel): street: str city: str country: str - postal_code: Optional[str] = None + postal_code: str | None = None class Contact(BaseModel): address: Address diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 151f913d6..8fa1fb2b2 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -84,9 +84,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): resource=BlobResourceContents( uri="https://weather.api/data/london.json", mimeType="application/json", - blob=base64.b64encode( - '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() - ).decode(), + blob=base64.b64encode(b'{"temperature": 18, "condition": "rainy", "humidity": 85}').decode(), ), ) ] diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 5c3baeba8..298272df5 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -3,7 +3,7 @@ import os import threading import time -from typing import List, Literal +from typing import Literal import pytest from mcp import StdioServerParameters, stdio_client @@ -47,7 +47,7 @@ def generate_custom_image() -> MCPImageContent: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: - print("Error while generating custom image: {}".format(e)) + print(f"Error while generating custom image: {e}") # Prompts @mcp.prompt(description="A greeting prompt template") @@ -366,7 +366,7 @@ def test_mcp_client_embedded_resources_with_agent(): assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) -def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: +def _messages_to_content_blocks(messages: list[Message]) -> list[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..57614b97f 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -3,7 +3,7 @@ """ import os -from typing import Callable, Optional +from collections.abc import Callable import requests from pytest import mark @@ -26,7 +26,7 @@ def __init__( self, id: str, factory: Callable[[], Model], - environment_variable: Optional[str] = None, + environment_variable: str | None = None, ) -> None: self.id = id self.model_factory = factory diff --git a/tests_integ/test_function_tools.py b/tests_integ/test_function_tools.py index 835dccf5d..6c72bdddb 100644 --- a/tests_integ/test_function_tools.py +++ b/tests_integ/test_function_tools.py @@ -4,7 +4,6 @@ """ import logging -from typing import Optional from strands import Agent, tool @@ -25,7 +24,7 @@ def word_counter(text: str) -> str: @tool(name="count_chars", description="Count characters in text") -def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: +def count_chars(text: str, include_spaces: bool | None = True) -> str: """ Count characters in text. diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..b80a0f82d 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from unittest.mock import patch from uuid import uuid4 diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 390bd3cff..01d3c80b2 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -2,8 +2,6 @@ Comprehensive integration tests for structured output passed into the agent functionality. """ -from typing import List, Optional - import pytest from pydantic import BaseModel, Field, field_validator @@ -42,7 +40,7 @@ class Contact(BaseModel): """Contact information.""" email: str - phone: Optional[str] = None + phone: str | None = None preferred_method: str = "email" @@ -54,7 +52,7 @@ class Employee(BaseModel): department: str address: Address contact: Contact - skills: List[str] + skills: list[str] hire_date: str salary_range: str @@ -65,7 +63,7 @@ class ProductReview(BaseModel): product_name: str rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") sentiment: str = Field(pattern="^(positive|negative|neutral)$") - key_points: List[str] + key_points: list[str] would_recommend: bool @@ -84,7 +82,7 @@ class TaskList(BaseModel): """Task management structure.""" project_name: str - tasks: List[str] + tasks: list[str] priority: str = Field(pattern="^(high|medium|low)$") due_date: str estimated_hours: int @@ -102,7 +100,7 @@ class Company(BaseModel): name: str = Field(description="Company name") address: Address = Field(description="Company address") - employees: List[Person] = Field(description="list of persons") + employees: list[Person] = Field(description="list of persons") class Task(BaseModel): From c23090f013d475bd93f49ae290899fcda5538f53 Mon Sep 17 00:00:00 2001 From: Masashi Tomooka Date: Fri, 16 Jan 2026 01:41:15 +0900 Subject: [PATCH 26/47] fix(agent): extract text from citationsContent in AgentResult.__str__ (#1489) AgentResult.__str__ now correctly extracts text from citationsContent blocks. Previously, only plain text blocks were processed, causing citation responses to return empty strings when converted to str(). Co-authored-by: Claude Opus 4.5 --- src/strands/agent/agent_result.py | 11 ++++- tests/strands/agent/test_agent_result.py | 58 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index 2ab95e5b5..8f9241a67 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -49,8 +49,15 @@ def __str__(self) -> str: result = "" for item in content_array: - if isinstance(item, dict) and "text" in item: - result += item.get("text", "") + "\n" + if isinstance(item, dict): + if "text" in item: + result += item.get("text", "") + "\n" + elif "citationsContent" in item: + citations_block = item["citationsContent"] + if "content" in citations_block: + for content in citations_block["content"]: + if isinstance(content, dict) and "text" in content: + result += content.get("text", "") + "\n" if not result and self.structured_output: result = self.structured_output.model_dump_json() diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 1ec0a8407..6e4c2c91a 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -225,3 +225,61 @@ def test__str__empty_message_with_structured_output(mock_metrics, empty_message: assert "example" in message_string assert "123" in message_string assert "optional" in message_string + + +@pytest.fixture +def citations_message(): + """Message with citationsContent block.""" + return { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "title": "Source Document", + "location": {"document": {"pageNumber": 1}}, + "sourceContent": [{"text": "source text"}], + } + ], + "content": [{"text": "This is cited text from the document."}], + } + } + ], + } + + +@pytest.fixture +def mixed_text_and_citations_message(): + """Message with both plain text and citationsContent blocks.""" + return { + "role": "assistant", + "content": [ + {"text": "Introduction paragraph"}, + { + "citationsContent": { + "citations": [{"title": "Doc", "location": {}, "sourceContent": []}], + "content": [{"text": "Cited content here."}], + } + }, + {"text": "Conclusion paragraph"}, + ], + } + + +def test__str__with_citations_content(mock_metrics, citations_message: Message): + """Test that str() extracts text from citationsContent blocks.""" + result = AgentResult(stop_reason="end_turn", message=citations_message, metrics=mock_metrics, state={}) + + message_string = str(result) + assert message_string == "This is cited text from the document.\n" + + +def test__str__mixed_text_and_citations_content(mock_metrics, mixed_text_and_citations_message: Message): + """Test that str() works with both plain text and citationsContent blocks.""" + result = AgentResult( + stop_reason="end_turn", message=mixed_text_and_citations_message, metrics=mock_metrics, state={} + ) + + message_string = str(result) + assert message_string == "Introduction paragraph\nCited content here.\nConclusion paragraph\n" From dfe3ec75d7e414b27a13798edcb51edff1e82f21 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 15 Jan 2026 14:52:35 -0500 Subject: [PATCH 27/47] Expose input messages to BeforeInvocationEvent hook (#1474) * feat(hooks): expose input messages to BeforeInvocationEvent Add messages attribute to BeforeInvocationEvent to enable input-side guardrails for PII detection, content moderation, and prompt attack prevention. Hooks can now inspect and modify messages before they are added to the agent's conversation history. - Add writable messages attribute to BeforeInvocationEvent (None default) - Pass messages parameter from _run_loop() to BeforeInvocationEvent - Add unit tests for new messages attribute and writability - Add integration tests for message modification use case - Update docs/HOOKS.md with input guardrails documentation Resolves #8 * refactor: address review feedback - Remove detailed Input Guardrails section from docs/HOOKS.md - Simplify BeforeInvocationEvent docstring per review - Remove backward compatibility note from messages attribute - Remove no-op test for messages initialization * refactor: simplify test assertions per review Use concise equality comparison for BeforeInvocationEvent assertions instead of verbose instance checks and property assertions. * Use overwritten messages array for the agent * Fix mypy issue --------- Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/agent.py | 5 +- src/strands/hooks/events.py | 11 ++- tests/strands/agent/hooks/test_events.py | 38 +++++++++- tests/strands/agent/test_agent_hooks.py | 96 +++++++++++++++++++++++- 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index b58b55f24..7b9e9c914 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -637,7 +637,10 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, messages=messages) + ) + messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages agent_result: AgentResult | None = None try: diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 340b6d3d2..8aa8a68d6 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult -from ..types.content import Message +from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse @@ -43,9 +43,16 @@ class BeforeInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + messages: The input messages for this invocation. Can be modified by hooks + to redact or transform content before processing. """ - pass + messages: Messages | None = None + + def _can_write(self, name: str) -> bool: + return name == "messages" @dataclass diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 9203478b2..83cb1af24 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -11,7 +11,7 @@ BeforeToolCallEvent, MessageAddedEvent, ) -from strands.types.content import Message +from strands.types.content import Message, Messages from strands.types.tools import ToolResult, ToolUse @@ -20,6 +20,11 @@ def agent(): return Mock() +@pytest.fixture +def sample_messages() -> Messages: + return [{"role": "user", "content": [{"text": "Hello, agent!"}]}] + + @pytest.fixture def tool(): tool = Mock() @@ -52,6 +57,11 @@ def start_request_event(agent): return BeforeInvocationEvent(agent=agent) +@pytest.fixture +def start_request_event_with_messages(agent, sample_messages): + return BeforeInvocationEvent(agent=agent, messages=sample_messages) + + @pytest.fixture def messaged_added_event(agent): return MessageAddedEvent(agent=agent, message=Mock()) @@ -159,3 +169,29 @@ def test_after_invocation_event_properties_not_writable(agent): with pytest.raises(AttributeError, match="Property agent is not writable"): event.agent = Mock() + + +def test_before_invocation_event_messages_default_none(agent): + """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" + event = BeforeInvocationEvent(agent=agent) + assert event.messages is None + + +def test_before_invocation_event_messages_writable(agent, sample_messages): + """Test that BeforeInvocationEvent.messages can be modified in-place for guardrail redaction.""" + event = BeforeInvocationEvent(agent=agent, messages=sample_messages) + + # Should be able to modify the messages list in-place + event.messages[0]["content"] = [{"text": "[REDACTED]"}] + assert event.messages[0]["content"] == [{"text": "[REDACTED]"}] + + # Should be able to reassign messages entirely + new_messages: Messages = [{"role": "user", "content": [{"text": "Different message"}]}] + event.messages = new_messages + assert event.messages == new_messages + + +def test_before_invocation_event_agent_not_writable(start_request_event_with_messages): + """Test that BeforeInvocationEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + start_request_event_with_messages.agent = Mock() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 00b9d368a..be71b5fcf 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -160,7 +160,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -214,7 +214,11 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m """Verify that the correct hook events are emitted as part of stream_async.""" iterator = agent.stream_async("test message") await anext(iterator) - assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] + + # Verify first event is BeforeInvocationEvent with messages + assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].messages is not None + assert hook_provider.events_received[0].messages[0]["role"] == "user" # iterate the rest result = None @@ -226,7 +230,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -596,3 +600,89 @@ async def handle_after_model_call(event: AfterModelCallEvent): # Should succeed after: custom retry + 2 throttle retries assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Success after mixed retries" + + +def test_before_invocation_event_message_modification(): + """Test that hooks can modify messages in BeforeInvocationEvent for input guardrails.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your redacted message"}], + }, + ] + ) + + modified_content = None + + async def input_guardrail_hook(event: BeforeInvocationEvent): + """Simulates a guardrail that redacts sensitive content.""" + nonlocal modified_content + if event.messages is not None: + for message in event.messages: + if message.get("role") == "user": + content = message.get("content", []) + for block in content: + if "text" in block and "SECRET" in block["text"]: + # Redact sensitive content in-place + block["text"] = block["text"].replace("SECRET", "[REDACTED]") + modified_content = event.messages[0]["content"][0]["text"] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, input_guardrail_hook) + + agent("My password is SECRET123") + + # Verify the message was modified before being processed + assert modified_content == "My password is [REDACTED]123" + # Verify the modified message was added to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "My password is [REDACTED]123" + + +def test_before_invocation_event_message_overwrite(): + """Test that hooks can overwrite messages in BeforeInvocationEvent.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your message message"}], + }, + ] + ) + + async def overwrite_input_hook(event: BeforeInvocationEvent): + event.messages = [{"role": "user", "content": [{"text": "GOODBYE"}]}] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, overwrite_input_hook) + + agent("HELLO") + + # Verify the message was overwritten to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "GOODBYE" + + +@pytest.mark.asyncio +async def test_before_invocation_event_messages_none_in_structured_output(agenerator): + """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" + + class Person(BaseModel): + name: str + age: int + + mock_provider = MockedModelProvider([]) + mock_provider.structured_output = Mock(return_value=agenerator([{"output": Person(name="Test", age=30)}])) + + received_messages = "not_set" + + async def capture_messages_hook(event: BeforeInvocationEvent): + nonlocal received_messages + received_messages = event.messages + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, capture_messages_hook) + + await agent.structured_output_async(Person, "Test prompt") + + # structured_output_async uses deprecated path that doesn't pass messages + assert received_messages is None From 058c03a487e53e4162136376232cbe5724e7c90c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 15 Jan 2026 14:59:47 -0500 Subject: [PATCH 28/47] interrupts - graph - hook based (#1478) --- src/strands/multiagent/graph.py | 135 +++++++++++-- tests/strands/multiagent/test_graph.py | 88 ++++++++- .../interrupts/multiagent/test_hook.py | 187 +++++++++++++++++- .../interrupts/multiagent/test_session.py | 98 ++++++++- 4 files changed, 475 insertions(+), 33 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 19504ad73..97435ad4a 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -35,11 +35,13 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -64,10 +66,15 @@ class GraphState: status: Current execution status of the graph. completed_nodes: Set of nodes that have completed execution. failed_nodes: Set of nodes that failed during execution. + interrupted_nodes: Set of nodes that user interrupted during execution. execution_order: List of nodes in the order they were executed. task: The original input prompt/query provided to the graph execution. This represents the actual work to be performed by the graph as a whole. Entry point nodes receive this task as their input if they have no dependencies. + start_time: Timestamp when the current invocation started. + Resets on each invocation, even when resuming from interrupt. + execution_time: Execution time of current invocation in milliseconds. + Excludes time spent waiting for interrupt responses. """ # Task (with default empty string) @@ -77,6 +84,7 @@ class GraphState: status: Status = Status.PENDING completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) + interrupted_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) start_time: float = field(default_factory=time.time) @@ -109,7 +117,7 @@ def should_continue( # Check timeout (only if set) if execution_timeout is not None: - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -123,6 +131,7 @@ class GraphResult(MultiAgentResult): total_nodes: int = 0 completed_nodes: int = 0 failed_nodes: int = 0 + interrupted_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -149,13 +158,7 @@ def should_traverse(self, state: GraphState) -> bool: @dataclass class GraphNode: - """Represents a node in the graph. - - The execution_status tracks the node's lifecycle within graph orchestration: - - PENDING: Node hasn't started executing yet - - EXECUTING: Node is currently running - - COMPLETED/FAILED: Node finished executing (regardless of result quality) - """ + """Represents a node in the graph.""" node_id: str executor: Agent | MultiAgentBase @@ -446,6 +449,7 @@ def __init__( self.node_timeout = node_timeout self.reset_on_revisit = reset_on_revisit self.state = GraphState() + self._interrupt_state = _InterruptState() self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) self.session_manager = session_manager @@ -520,6 +524,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final graph result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -529,7 +535,7 @@ async def stream_async( # Initialize state start_time = time.time() - if not self._resume_from_session: + if not self._resume_from_session and not self._interrupt_state.activated: # Initialize state self.state = GraphState( status=Status.EXECUTING, @@ -545,6 +551,8 @@ async def stream_async( span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: logger.debug( "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", @@ -554,6 +562,9 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts.extend(event.interrupts) + yield event.as_dict() # Set final status based on execution results @@ -565,7 +576,7 @@ async def stream_async( logger.debug("status=<%s> | graph execution completed", self.state.status) # Yield final result (consistent with Agent's AgentResultEvent format) - result = self._build_result() + result = self._build_result(interrupts) # Use the same event format as Agent for consistency yield MultiAgentResultEvent(result=result).as_dict() @@ -575,7 +586,7 @@ async def stream_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time += round((time.time() - start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -592,9 +603,41 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) + def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + + node.execution_status = Status.INTERRUPTED + + self.state.status = Status.INTERRUPTED + self.state.interrupted_nodes.add(node) + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) + if self._interrupt_state.activated: + ready_nodes = [self.nodes[node_id] for node_id in self._interrupt_state.context["completed_nodes"]] + ready_nodes.extend(self.state.interrupted_nodes) + + self.state.interrupted_nodes.clear() + + elif self._resume_from_session: + ready_nodes = self._resume_next_nodes + + else: + ready_nodes = list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -614,6 +657,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in self._execute_nodes_parallel(current_batch, invocation_state): yield event + if self.state.status == Status.INTERRUPTED: + self._interrupt_state.context["completed_nodes"] = [ + node.node_id for node in current_batch if node.execution_status == Status.COMPLETED + ] + return + + self._interrupt_state.deactivate() + # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here @@ -642,6 +693,9 @@ async def _execute_nodes_parallel( Uses a shared queue where each node's stream runs independently and pushes events as they occur, enabling true real-time event propagation without round-robin delays. """ + if self._interrupt_state.activated: + nodes = [node for node in nodes if node.execution_status == Status.INTERRUPTED] + event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks @@ -798,12 +852,16 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) yield start_event - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, node.node_id, invocation_state) ) start_time = time.time() try: + if interrupts: + yield self._activate_interrupt(node, interrupts) + return + if before_event.cancel_node: cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" @@ -831,6 +889,13 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") + if multi_agent_result.status == Status.INTERRUPTED: + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from a multi agent node" + ) + node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, @@ -855,12 +920,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Check for interrupt (from main branch) if agent_response.stop_reason == "interrupt": node.executor.messages.pop() # remove interrupted tool use message node.executor._interrupt_state.deactivate() - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from an agent node" + ) # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) @@ -1007,8 +1075,15 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_input - def _build_result(self) -> GraphResult: - """Build graph result from current state.""" + def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: + """Build graph result from current state. + + Args: + interrupts: List of interrupts collected during execution. + + Returns: + GraphResult with current state. + """ return GraphResult( status=self.state.status, results=self.state.results, @@ -1019,9 +1094,11 @@ def _build_result(self) -> GraphResult: total_nodes=self.state.total_nodes, completed_nodes=len(self.state.completed_nodes), failed_nodes=len(self.state.failed_nodes), + interrupted_nodes=len(self.state.interrupted_nodes), execution_order=self.state.execution_order, edges=self.state.edges, entry_points=self.state.entry_points, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: @@ -1034,10 +1111,14 @@ def serialize_state(self) -> dict[str, Any]: "status": self.state.status.value, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "failed_nodes": [n.node_id for n in self.state.failed_nodes], + "interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -1053,6 +1134,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1099,10 +1184,20 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.state.failed_nodes = set( self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes ) + for node in self.state.failed_nodes: + node.execution_status = Status.FAILED - # Restore completed nodes from persisted data - completed_node_ids = payload.get("completed_nodes") or [] - self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + self.state.interrupted_nodes = set( + self.nodes[node_id] for node_id in (payload.get("interrupted_nodes") or []) if node_id in self.nodes + ) + for node in self.state.interrupted_nodes: + node.execution_status = Status.INTERRUPTED + + self.state.completed_nodes = set( + self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes + ) + for node in self.state.completed_nodes: + node.execution_status = Status.COMPLETED # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..ab2d86e70 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest @@ -9,6 +9,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager @@ -2004,6 +2005,9 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): state = graph.serialize_state() assert state["type"] == "graph" assert state["id"] == "default_graph" + assert state["_internal_state"] == { + "interrupt_state": {"activated": False, "context": {}, "interrupts": {}}, + } assert "status" in state assert "completed_nodes" in state assert "node_results" in state @@ -2013,14 +2017,33 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "status": "executing", "completed_nodes": [], "failed_nodes": [], + "interrupted_nodes": [], "node_results": {}, "current_task": "persisted task", "execution_order": [], "next_nodes_to_execute": ["test_node"], + "_internal_state": { + "interrupt_state": { + "activated": False, + "context": {"a": 1}, + "interrupts": { + "i1": { + "id": "i1", + "name": "test_name", + "reason": "test_reason", + }, + }, + }, + }, } graph.deserialize_state(persisted_state) assert graph.state.task == "persisted task" + assert graph._interrupt_state == _InterruptState( + activated=False, + context={"a": 1}, + interrupts={"i1": Interrupt(id="i1", name="test_name", reason="test_reason")}, + ) # Execute graph to test persistence integration result = await graph.invoke_async("Test persistence") @@ -2068,3 +2091,66 @@ def cancel_callback(event): tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_graph_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_hook_providers([interrupt_hook]) + graph = builder.build() + + multiagent_result = graph("Test task") + + first_execution_time = multiagent_result.execution_time + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + assert multiagent_result.execution_time >= first_execution_time diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index be7682082..9350b3535 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -7,7 +7,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -18,16 +18,34 @@ def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) def interrupt(self, event): - if event.node_id == "info": + if event.node_id == "info" or event.node_id == "time": return - response = event.interrupt("test_interrupt", reason="need approval") + response = event.interrupt(f"{event.node_id}_interrupt", reason="need approval") if response != "APPROVE": event.cancel_node = "node rejected" return Hook() +@pytest.fixture +def day_tool(): + @tool(name="day_tool") + def func(): + return "monday" + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool") @@ -38,13 +56,49 @@ def func(): @pytest.fixture -def swarm(interrupt_hook, weather_tool): - info_agent = Agent(name="info") - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + + +@pytest.fixture +def swarm(interrupt_hook, info_agent, weather_agent): return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) +@pytest.fixture +def graph(interrupt_hook, info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + builder.set_hook_providers([interrupt_hook]) + + return builder.build() + + def test_swarm_interrupt(swarm): multiagent_result = swarm("What is the weather?") @@ -56,7 +110,7 @@ def test_swarm_interrupt(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -97,7 +151,7 @@ async def test_swarm_interrupt_reject(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -131,3 +185,120 @@ async def test_swarm_interrupt_reject(swarm): tru_node_id = multiagent_result.node_history[0].node_id exp_node_id = "info" assert tru_node_id == exp_node_id + + +def test_graph_interrupt(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + } + for interrupt in multiagent_result.interrupts + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_graph_interrupt_reject(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "REJECT", + }, + }, + ] + + try: + async for event in graph.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + except RuntimeError as e: + assert "node rejected" in str(e) + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_state_status = graph.state.status + exp_state_status = Status.FAILED + assert tru_state_status == exp_state_status diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index d6e8cdbf8..bab4b428f 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,13 +4,30 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.session import FileSessionManager from strands.types.tools import ToolContext +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "time": + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_node = "node rejected" + + return Hook() + + @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) @@ -22,9 +39,12 @@ def func(tool_context: ToolContext) -> str: @pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) - return Swarm([weather_agent]) +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func def test_swarm_interrupt_session(weather_tool, tmpdir): @@ -75,3 +95,73 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): summarizer_message = json.dumps(summarizer_result.result.message).lower() assert "sunny" in summarizer_message + + +def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + multiagent_result = graph("Can you check the time and then summarize the results?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 2 + summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() + assert "12:01" in summarizer_message From bb3052b20f4534cb523219fb7146810227e2ea21 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 15 Jan 2026 15:34:15 -0500 Subject: [PATCH 29/47] fix: Swap sleeps with explicit signaling (#1497) So that unit tests are determistic Co-authored-by: Mackenzie Zastrow --- tests/strands/agent/test_agent.py | 59 +++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 81ce65989..eb039185c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,14 +1,13 @@ -import asyncio import copy import importlib import json import os import textwrap import threading -import time import unittest.mock import warnings -from typing import Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from uuid import uuid4 import pytest @@ -193,11 +192,25 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") -class SlowMockedModel(MockedModelProvider): +class SyncEventMockedModel(MockedModelProvider): + """A mock model that uses events to synchronize concurrent threads. + + This model signals when it starts streaming and waits for a proceed signal, + allowing deterministic testing of concurrent behavior without relying on sleeps. + """ + + def __init__(self, agent_responses): + super().__init__(agent_responses) + self.started_event = threading.Event() + self.proceed_event = threading.Event() + async def stream( self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs ) -> AsyncGenerator[Any, None]: - await asyncio.sleep(0.15) # Add async delay to ensure concurrency + # Signal that streaming has started + self.started_event.set() + # Wait for signal to proceed + self.proceed_event.wait() async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): yield event @@ -2212,7 +2225,7 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): def test_agent_concurrent_call_raises_exception(): """Test that concurrent __call__() calls raise ConcurrencyException.""" - model = SlowMockedModel( + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "hello"}]}, {"role": "assistant", "content": [{"text": "world"}]}, @@ -2233,12 +2246,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join() @@ -2254,11 +2275,12 @@ def test_agent_concurrent_structured_output_raises_exception(): Note: This test validates that the sync invocation path is protected. The concurrent __call__() test already validates the core functionality. """ - model = SlowMockedModel( + # Events for synchronization + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "response1"}]}, {"role": "assistant", "content": [{"text": "response2"}]}, - ] + ], ) agent = Agent(model=model) @@ -2275,13 +2297,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() - time.sleep(0.05) # Small delay to ensure first thread acquires lock + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join() From 25c46a1011cf296342d9a5855f3bf631147601b5 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:41:03 -0500 Subject: [PATCH 30/47] Fix PEP 563 incompatibility with @tool decorated tools (#1494) Fixes the incompatibility between strands-agents 1.16.0+ and Pydantic 2.12+ when tools use modules with from __future__ import annotations (PEP 563) which causes type annotations to be strings --------- Co-authored-by: strands-coder Co-authored-by: Mackenzie Zastrow --- src/strands/tools/decorator.py | 16 ++- tests/strands/tools/test_decorator_pep563.py | 142 +++++++++++++++++++ 2 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 tests/strands/tools/test_decorator_pep563.py diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index f64c17ee9..f72a8ccf1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -98,7 +98,7 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - """ self.func = func self.signature = inspect.signature(func) - self.type_hints = get_type_hints(func) + self.type_hints = get_type_hints(func, include_extras=True) self._context_param = context_param self._validate_signature() @@ -198,9 +198,17 @@ def _create_input_model(self) -> type[BaseModel]: if self._is_special_parameter(name): continue - # Use param.annotation directly to get the raw type hint. Using get_type_hints() - # can cause inconsistent behavior across Python versions for complex Annotated types. - param_type = param.annotation + # Handle PEP 563 (from __future__ import annotations): + # - When PEP 563 is active, param.annotation is a string literal that needs resolution + # - When PEP 563 is not active, param.annotation is the actual type object (may include Annotated) + # We check if param.annotation is a string to determine if we need type hint resolution. + # This preserves Annotated metadata correctly in both cases and is consistent across Python versions. + if isinstance(param.annotation, str): + # PEP 563 active: resolve string annotation + param_type = self.type_hints.get(name, param.annotation) + else: + # PEP 563 not active: use the actual type object directly + param_type = param.annotation if param_type is inspect.Parameter.empty: param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py new file mode 100644 index 000000000..07ec8f2ba --- /dev/null +++ b/tests/strands/tools/test_decorator_pep563.py @@ -0,0 +1,142 @@ +"""Tests for PEP 563 (from __future__ import annotations) compatibility. + +This module tests that the @tool decorator works correctly when modules use +`from __future__ import annotations` (PEP 563), which causes all annotations +to be stored as string literals rather than evaluated types. + +This is a regression test for issue #1208: +https://github.com/strands-agents/sdk-python/issues/1208 +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from typing_extensions import Literal, TypedDict + +from strands import tool + +# Define types at module level (simulating nova-act pattern) +CLICK_TYPE = Literal["left", "right", "middle", "double"] +EXTRA_TYPE = Literal["extra"] + + +class ClickOptions(TypedDict): + """Options for click operation.""" + + blur_field: bool | None + + +@tool +def simple_literal_tool(click_type: CLICK_TYPE) -> dict[str, Any]: + return {"status": "success", "content": [{"text": f"Clicked: {click_type}"}]} + + +@tool +def complex_literal_tool( + box: str, + extra: EXTRA_TYPE, + click_type: CLICK_TYPE | None = None, + click_options: ClickOptions | None = None, +) -> Any: + return "Done" + + +@tool +def union_literal_tool(mode: Literal["fast", "slow"] | None = None) -> str: + return f"Mode: {mode}" + + +def test_simple_literal_type_tool_spec(): + """Test that simple Literal type parameters work with __future__ annotations.""" + spec = simple_literal_tool.tool_spec + assert spec["name"] == "simple_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "click_type" in schema["properties"] + # Verify Literal values are present in schema + click_type_schema = schema["properties"]["click_type"] + assert "enum" in click_type_schema or "anyOf" in click_type_schema + + +def test_complex_literal_type_tool_spec(): + """Test that complex type hints with Literal work with __future__ annotations.""" + spec = complex_literal_tool.tool_spec + assert spec["name"] == "complex_literal_tool" + + schema = spec["inputSchema"]["json"] + # Ensure schema is correct and contains the expected shape + assert schema == { + "$defs": { + "ClickOptions": { + "description": "Options for click operation.", + "properties": {"blur_field": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "title": "Blur Field"}}, + "required": ["blur_field"], + "title": "ClickOptions", + "type": "object", + } + }, + "properties": { + "box": {"description": "Parameter box", "type": "string"}, + "click_options": { + "$ref": "#/$defs/ClickOptions", + "default": None, + "description": "Parameter click_options", + }, + "click_type": { + "default": None, + "description": "Parameter click_type", + "enum": ["left", "right", "middle", "double"], + "type": "string", + }, + "extra": {"const": "extra", "description": "Parameter extra", "type": "string"}, + }, + "required": ["box", "extra"], + "type": "object", + } + + +def test_union_literal_tool_spec(): + """Test that inline Literal in Union works with __future__ annotations.""" + spec = union_literal_tool.tool_spec + assert spec["name"] == "union_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "mode" in schema["properties"] + + +def test_simple_literal_tool_invocation(): + """Test that tools with Literal types can be invoked.""" + result = simple_literal_tool(click_type="left") + assert result["status"] == "success" + assert "left" in result["content"][0]["text"] + + +def test_complex_literal_tool_invocation(): + """Test that tools with complex types can be invoked.""" + result = complex_literal_tool( + box="box1", + extra="extra", + click_type="double", + click_options={"blur_field": True}, + ) + assert result == "Done" + + +def test_tool_spec_no_pydantic_error(): + """Verify no PydanticUserError is raised when accessing tool_spec. + + This is the specific error from issue #1208: + PydanticUserError: `Agent_clickTool` is not fully defined; + you should define `EXTRA_TYPE`, then call `Agent_clickTool.model_rebuild()`. + """ + # This should not raise PydanticUserError + try: + _ = simple_literal_tool.tool_spec + _ = complex_literal_tool.tool_spec + _ = union_literal_tool.tool_spec + except Exception as e: + if "not fully defined" in str(e): + pytest.fail(f"PydanticUserError raised - PEP 563 compatibility broken: {e}") + raise From 5e733ef00b5162ed97b662ad6a9ed4f9c72ced21 Mon Sep 17 00:00:00 2001 From: okamototk Date: Fri, 16 Jan 2026 06:40:39 -0800 Subject: [PATCH 31/47] feat: override service name by OTEL_SERVICE_NAME env (#1400) --- src/strands/telemetry/config.py | 6 +++++- tests/strands/telemetry/test_config.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py index 0509c7440..93225335d 100644 --- a/src/strands/telemetry/config.py +++ b/src/strands/telemetry/config.py @@ -5,6 +5,7 @@ """ import logging +import os from importlib.metadata import version from typing import Any @@ -29,9 +30,11 @@ def get_otel_resource() -> Resource: Returns: Resource object with standard service information. """ + service_name = os.getenv("OTEL_SERVICE_NAME", "strands-agents").strip() + resource = Resource.create( { - "service.name": "strands-agents", + "service.name": service_name, "service.version": version("strands-agents"), "telemetry.sdk.name": "opentelemetry", "telemetry.sdk.language": "python", @@ -56,6 +59,7 @@ class StrandsTelemetry: Environment variables are handled by the underlying OpenTelemetry SDK: - OTEL_EXPORTER_OTLP_ENDPOINT: OTLP endpoint URL - OTEL_EXPORTER_OTLP_HEADERS: Headers for OTLP requests + - OTEL_SERVICE_NAME: Overrides resource service name Examples: Quick setup with method chaining: diff --git a/tests/strands/telemetry/test_config.py b/tests/strands/telemetry/test_config.py index 658d4d08a..cc08c295c 100644 --- a/tests/strands/telemetry/test_config.py +++ b/tests/strands/telemetry/test_config.py @@ -2,6 +2,7 @@ import pytest +import strands.telemetry.config as telemetry_config from strands.telemetry import StrandsTelemetry @@ -212,3 +213,21 @@ def test_setup_otlp_exporter_exception(mock_resource, mock_tracer_provider, mock telemetry.setup_otlp_exporter() mock_otlp_exporter.assert_called_once() + + +def test_get_otel_resource_uses_default_service_name(monkeypatch): + monkeypatch.delenv("OTEL_SERVICE_NAME", raising=False) + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "strands-agents" + + +def test_get_otel_resource_respects_otel_service_name(monkeypatch): + monkeypatch.setenv("OTEL_SERVICE_NAME", "my-service") + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "my-service" From bce2464b4aaf6699eaa5fc1d0f78ac7cfcbc6e73 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Fri, 16 Jan 2026 10:04:56 -0500 Subject: [PATCH 32/47] fix(bedrock): disable thinking mode when forcing tool_choice (#1495) --------- Co-authored-by: Dean Schmigelski --- src/strands/models/bedrock.py | 34 ++++++-- tests/strands/models/test_bedrock_thinking.py | 84 +++++++++++++++++++ tests_integ/models/test_model_bedrock.py | 37 ++++++++ 3 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 tests/strands/models/test_bedrock_thinking.py diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index dfcd133c6..567a2e147 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -255,11 +255,7 @@ def _format_request( if tool_specs else {} ), - **( - {"additionalModelRequestFields": self.config["additional_request_fields"]} - if self.config.get("additional_request_fields") - else {} - ), + **(self._get_additional_request_fields(tool_choice)), **( {"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]} if self.config.get("additional_response_field_paths") @@ -298,6 +294,34 @@ def _format_request( ), } + def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Get additional request fields, removing thinking if tool_choice forces tool use. + + Bedrock's API does not allow thinking mode when tool_choice forces tool use. + When forcing a tool (e.g., for structured_output retry), we temporarily disable thinking. + + Args: + tool_choice: The tool choice configuration. + + Returns: + A dict containing additionalModelRequestFields if configured, or empty dict. + """ + additional_fields = self.config.get("additional_request_fields") + if not additional_fields: + return {} + + # Check if tool_choice is forcing tool use ("any" or specific "tool") + is_forcing_tool = tool_choice is not None and ("any" in tool_choice or "tool" in tool_choice) + + if is_forcing_tool and "thinking" in additional_fields: + # Create a copy without the thinking key + fields_without_thinking = {k: v for k, v in additional_fields.items() if k != "thinking"} + if fields_without_thinking: + return {"additionalModelRequestFields": fields_without_thinking} + return {} + + return {"additionalModelRequestFields": additional_fields} + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. diff --git a/tests/strands/models/test_bedrock_thinking.py b/tests/strands/models/test_bedrock_thinking.py new file mode 100644 index 000000000..10b53cb03 --- /dev/null +++ b/tests/strands/models/test_bedrock_thinking.py @@ -0,0 +1,84 @@ +"""Tests for thinking mode behavior in BedrockModel.""" + +import pytest + +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def model_with_thinking(): + """Create a BedrockModel with thinking enabled.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"thinking": {"type": "enabled", "budget_tokens": 5000}}, + ) + + +@pytest.fixture +def model_without_thinking(): + """Create a BedrockModel without thinking.""" + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + +@pytest.fixture +def model_with_thinking_and_other_fields(): + """Create a BedrockModel with thinking and other additional fields.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": {"type": "enabled", "budget_tokens": 5000}, + "some_other_field": "value", + }, + ) + + +def test_thinking_removed_when_forcing_tool_any(model_with_thinking): + """Thinking should be removed when tool_choice forces tool use with 'any'.""" + tool_choice = {"any": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_removed_when_forcing_specific_tool(model_with_thinking): + """Thinking should be removed when tool_choice forces a specific tool.""" + tool_choice = {"tool": {"name": "structured_output_tool"}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_preserved_with_auto_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is 'auto'.""" + tool_choice = {"auto": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_thinking_preserved_with_none_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is None.""" + result = model_with_thinking._get_additional_request_fields(None) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_other_fields_preserved_when_thinking_removed(model_with_thinking_and_other_fields): + """Other additional fields should be preserved when thinking is removed.""" + tool_choice = {"any": {}} + result = model_with_thinking_and_other_fields._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_other_field": "value"}} + + +def test_no_fields_when_model_has_no_additional_fields(model_without_thinking): + """Should return empty dict when model has no additional_request_fields.""" + tool_choice = {"any": {}} + result = model_without_thinking._get_additional_request_fields(tool_choice) + assert result == {} + + +def test_fields_preserved_when_no_thinking_and_forcing_tool(): + """Additional fields without thinking should be preserved when forcing tool.""" + model = BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"some_field": "value"}, + ) + tool_choice = {"any": {}} + result = model._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_field": "value"}} diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index b31f23663..0b3aa7b47 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -275,6 +275,43 @@ def test_redacted_content_handling(): assert isinstance(result.message["content"][0]["reasoningContent"]["redactedContent"], bytes) +def test_reasoning_content_in_messages_with_thinking_disabled(): + """Test that messages with reasoningContent are accepted when thinking is explicitly disabled.""" + # First, get a real reasoning response with thinking enabled + thinking_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + } + }, + ) + agent_with_thinking = Agent(model=thinking_model) + result_with_thinking = agent_with_thinking("What is 2+2?") + + # Verify we got reasoning content + assert "reasoningContent" in result_with_thinking.message["content"][0] + + # Now create a model with thinking disabled and use the messages from the thinking session + disabled_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": { + "type": "disabled", + } + }, + ) + + # Use the conversation history that includes reasoning content + messages = agent_with_thinking.messages + + agent_disabled = Agent(model=disabled_model, messages=messages) + result = agent_disabled("What about 3+3?") + + assert result.stop_reason == "end_turn" + + def test_multi_prompt_system_content(): """Test multi-prompt system content blocks.""" system_prompt_content = [ From e4bd3bc9d77b9bf40b11c42f92775b17fe0c618e Mon Sep 17 00:00:00 2001 From: Bryce Cole Date: Fri, 16 Jan 2026 13:48:51 -0500 Subject: [PATCH 33/47] fix: a2a use artifact update event (#1401) fix: update tests fix: simplify code by storing in class fix: remove uneeded code change fix: hide a2a artifact streaming under feature flag fix: use walrus operator fix: use star to signify end of unnamed fix: add check for walrus legacy fix: clarify enable_a2a_compliant_streaming parameter in StrandsA2AExecutor initialization fix: update tests refactor: streamline artifact addition logic in StrandsA2AExecutor --- src/strands/multiagent/a2a/executor.py | 84 ++++++++++++++--- src/strands/multiagent/a2a/server.py | 8 +- tests/strands/multiagent/a2a/test_executor.py | 93 +++++++++++++++++++ 3 files changed, 172 insertions(+), 13 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index f02b8c6cc..58dfcc045 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -12,6 +12,8 @@ import json import logging import mimetypes +import uuid +import warnings from typing import Any, Literal from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -49,13 +51,21 @@ class StrandsA2AExecutor(AgentExecutor): # Handle special cases where format differs from extension FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"} - def __init__(self, agent: SAAgent): + # A2A-compliant streaming mode + _current_artifact_id: str | None + _is_first_chunk: bool + + def __init__(self, agent: SAAgent, *, enable_a2a_compliant_streaming: bool = False): """Initialize a StrandsA2AExecutor. Args: agent: The Strands Agent instance to adapt to the A2A protocol. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.agent = agent + self.enable_a2a_compliant_streaming = enable_a2a_compliant_streaming async def execute( self, @@ -104,12 +114,30 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater else: raise ValueError("No content blocks available") + if not self.enable_a2a_compliant_streaming: + warnings.warn( + "The default A2A response stream implemented in the strands sdk does not conform to " + "what is expected in the A2A spec. Please set the `enable_a2a_compliant_streaming` " + "boolean to `True` on your `A2AServer` class to properly conform to the spec. " + "In the next major version release, this will be the default behavior.", + UserWarning, + stacklevel=3, + ) + + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = str(uuid.uuid4()) + self._is_first_chunk = True + try: async for event in self.agent.stream_async(content_blocks): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") raise + finally: + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = None + self._is_first_chunk = True async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -125,28 +153,60 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda logger.debug("Streaming event: %s", event) if "data" in event: if text_content := event["data"]: - await updater.update_status( - TaskState.working, - new_agent_text_message( - text_content, - updater.context_id, - updater.task_id, - ), - ) + if self.enable_a2a_compliant_streaming: + await updater.add_artifact( + [Part(root=TextPart(text=text_content))], + artifact_id=self._current_artifact_id, + name="agent_response", + append=not self._is_first_chunk, + ) + self._is_first_chunk = False + else: + # Legacy use update_status with agent message + await updater.update_status( + TaskState.working, + new_agent_text_message( + text_content, + updater.context_id, + updater.task_id, + ), + ) elif "result" in event: await self._handle_agent_result(event["result"], updater) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. - Processes the agent's final result, extracts text content from the response, - and adds it as an artifact to the task before marking the task as complete. + For A2A-compliant streaming: sends the final artifact chunk marker and marks + the task as complete. If no data chunks were previously sent, includes the + result content. + + For legacy streaming: adds the final result as a simple artifact without + artifact_id tracking. Args: result: The agent result object containing the final response, or None if no result. updater: The task updater for managing task state and adding the final artifact. """ - if final_content := str(result): + if self.enable_a2a_compliant_streaming: + if self._is_first_chunk: + final_content = str(result) if result else "" + parts = [Part(root=TextPart(text=final_content))] if final_content else [] + await updater.add_artifact( + parts, + artifact_id=self._current_artifact_id, + name="agent_response", + last_chunk=True, + ) + else: + await updater.add_artifact( + [], + artifact_id=self._current_artifact_id, + name="agent_response", + append=True, + last_chunk=True, + ) + elif final_content := str(result): await updater.add_artifact( [Part(root=TextPart(text=final_content))], name="agent_response", diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index a9093742f..7b4c4c73a 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -42,6 +42,7 @@ def __init__( queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, + enable_a2a_compliant_streaming: bool = False, ): """Initialize an A2A-compatible server from a Strands agent. @@ -66,6 +67,9 @@ def __init__( no push notification configuration is used. push_sender: Custom push notification sender implementation. If None, no push notifications are sent. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.host = host self.port = port @@ -90,7 +94,9 @@ def __init__( self.description = self.strands_agent.description self.capabilities = AgentCapabilities(streaming=True) self.request_handler = DefaultRequestHandler( - agent_executor=StrandsA2AExecutor(self.strands_agent), + agent_executor=StrandsA2AExecutor( + self.strands_agent, enable_a2a_compliant_streaming=enable_a2a_compliant_streaming + ), task_store=task_store or InMemoryTaskStore(), queue_manager=queue_manager, push_config_store=push_config_store, diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 1463d3f48..73ade574e 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1020,3 +1020,96 @@ def test_default_formats_modularization(): assert executor._get_file_format_from_mime_type("", "document") == "txt" assert executor._get_file_format_from_mime_type("", "image") == "png" assert executor._get_file_format_from_mime_type("", "video") == "mp4" + + +# Tests for enable_a2a_compliant_streaming parameter + + +@pytest.mark.asyncio +async def test_legacy_mode_emits_deprecation_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that legacy streaming (default) emits deprecation warning.""" + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent) # Default is False + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with pytest.warns(UserWarning, match="does not conform to what is expected in the A2A spec"): + await executor.execute(mock_request_context, mock_event_queue) + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_no_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that A2A-compliant mode does not emit warning.""" + import warnings + + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with warnings.catch_warnings(): + warnings.simplefilter("error") + try: + await executor.execute(mock_request_context, mock_event_queue) + except UserWarning: + pytest.fail("Should not emit warning") + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_uses_add_artifact(mock_strands_agent): + """Test that A2A-compliant mode uses add_artifact with artifact_id.""" + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-123" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.update_status = AsyncMock() + + event = {"data": "content"} + await executor._handle_streaming_event(event, mock_updater) + + mock_updater.add_artifact.assert_called_once() + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-123" + assert mock_updater.add_artifact.call_args[1]["append"] is False + mock_updater.update_status.assert_not_called() From 51cbe7b6e9450f91cc8862120e69f6c1ac8bc96d Mon Sep 17 00:00:00 2001 From: Zezhen Xu <32421101+CrysisDeu@users.noreply.github.com> Date: Tue, 20 Jan 2026 07:03:37 -0800 Subject: [PATCH 34/47] Add parallel reading support to S3SessionManager.list_messages() (#1186) Co-authored-by: Jack Yuan Co-authored-by: Nicholas Clegg --- src/strands/session/s3_session_manager.py | 51 +++++++++++++++++-- tests/strands/models/test_bedrock.py | 9 ++-- .../session/test_s3_session_manager.py | 34 +++++++++++++ 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index e5713e5b7..8d557e81c 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,6 +2,7 @@ import json import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, Any, cast import boto3 @@ -259,7 +260,21 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio def list_messages( self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: - """List messages for an agent with pagination from S3.""" + """List messages for an agent with pagination from S3. + + Args: + session_id: ID of the session + agent_id: ID of the agent + limit: Optional limit on number of messages to return + offset: Optional offset for pagination + **kwargs: Additional keyword arguments + + Returns: + List of SessionMessage objects, sorted by message_id. + + Raises: + SessionException: If S3 error occurs during message retrieval. + """ messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: paginator = self.client.get_paginator("list_objects_v2") @@ -287,10 +302,38 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load only the required message objects + # Load message objects in parallel for better performance messages: list[SessionMessage] = [] - for key in message_keys: - message_data = self._read_s3_object(key) + if not message_keys: + return messages + + # Optimize for single worker case - avoid thread pool overhead + if len(message_keys) == 1: + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + return messages + + with ThreadPoolExecutor() as executor: + # Submit all read tasks + future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} + + # Create a mapping from key to index to maintain order + key_to_index = {key: idx for idx, key in enumerate(message_keys)} + + # Initialize results list with None placeholders to maintain order + results: list[dict[str, Any] | None] = [None] * len(message_keys) + + # Process results as they complete + for future in as_completed(future_to_key): + key = future_to_key[future] + message_data = future.result() + # Store result at the correct index to maintain order + results[key_to_index[key]] = message_data + + # Convert results to SessionMessage objects, filtering out None values + for message_data in results: if message_data: messages.append(SessionMessage.from_dict(message_data)) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..833b14729 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -201,10 +201,11 @@ def test__init__region_precedence(mock_client_method, session_cls): def test__init__with_endpoint_url(mock_client_method): """Test that BedrockModel uses the provided endpoint_url for VPC endpoints.""" custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com" - BedrockModel(endpoint_url=custom_endpoint) - mock_client_method.assert_called_with( - region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint - ) + with unittest.mock.patch.object(os, "environ", {}): + BedrockModel(endpoint_url=custom_endpoint) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint + ) def test__init__with_region_and_session_raises_value_error(): diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 719fbc2c9..c1c89da5b 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -282,6 +282,40 @@ def test_list_messages_all(s3_manager, sample_session, sample_agent): assert len(result) == 5 +def test_list_messages_single_message(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create single message + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text="Single Message")], + }, + 0, + ) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 1 + + +def test_list_no_messages(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent): """Test listing messages with pagination in S3.""" # Create session and agent From 8b7f6ccfd483c1120d22871b9fb4434d8783282c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 20 Jan 2026 17:17:57 +0200 Subject: [PATCH 35/47] feat(steering): allow steering on AfterModelCallEvents (#1429) --- .gitignore | 1 + src/strands/experimental/steering/__init__.py | 7 +- .../experimental/steering/core/__init__.py | 4 +- .../experimental/steering/core/action.py | 51 ++-- .../experimental/steering/core/handler.py | 141 ++++++++-- .../steering/handlers/llm/llm_handler.py | 8 +- .../steering/core/test_handler.py | 245 ++++++++++++++++-- .../steering/handlers/llm/test_llm_handler.py | 12 +- tests/strands/tools/test_decorator_pep563.py | 4 +- tests_integ/steering/test_model_steering.py | 204 +++++++++++++++ ...t_llm_handler.py => test_tool_steering.py} | 10 +- 11 files changed, 597 insertions(+), 90 deletions(-) create mode 100644 tests_integ/steering/test_model_steering.py rename tests_integ/steering/{test_llm_handler.py => test_tool_steering.py} (91%) diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index 4d0775873..be04a9ddb 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -9,7 +9,7 @@ - SteeringHandler: Base class for guidance logic with local context - SteeringContextCallback: Protocol for context update functions - SteeringContextProvider: Protocol for multi-event context providers -- SteeringAction: Proceed/Guide/Interrupt decisions +- ToolSteeringAction/ModelSteeringAction: Proceed/Guide/Interrupt decisions Usage: handler = LLMSteeringHandler(system_prompt="...") @@ -23,7 +23,7 @@ LedgerBeforeToolCall, LedgerProvider, ) -from .core.action import Guide, Interrupt, Proceed, SteeringAction +from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .core.context import SteeringContextCallback, SteeringContextProvider from .core.handler import SteeringHandler @@ -31,7 +31,8 @@ from .handlers.llm import LLMPromptMapper, LLMSteeringHandler __all__ = [ - "SteeringAction", + "ToolSteeringAction", + "ModelSteeringAction", "Proceed", "Guide", "Interrupt", diff --git a/src/strands/experimental/steering/core/__init__.py b/src/strands/experimental/steering/core/__init__.py index a3efe0dbc..cdd0d8269 100644 --- a/src/strands/experimental/steering/core/__init__.py +++ b/src/strands/experimental/steering/core/__init__.py @@ -1,6 +1,6 @@ """Core steering system interfaces and base classes.""" -from .action import Guide, Interrupt, Proceed, SteeringAction +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .handler import SteeringHandler -__all__ = ["SteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] +__all__ = ["ToolSteeringAction", "ModelSteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py index 8b4ec141d..b1f124b40 100644 --- a/src/strands/experimental/steering/core/action.py +++ b/src/strands/experimental/steering/core/action.py @@ -1,18 +1,18 @@ """SteeringAction types for steering evaluation results. -Defines structured outcomes from steering handlers that determine how tool calls +Defines structured outcomes from steering handlers that determine how agent actions should be handled. SteeringActions enable modular prompting by providing just-in-time feedback rather than front-loading all instructions in monolithic prompts. Flow: - SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling - ↓ ↓ ↓ - Evaluate context Action type Tool execution modified + SteeringHandler.steer_*() → SteeringAction → Event handling + ↓ ↓ ↓ + Evaluate context Action type Execution modified SteeringAction types: - Proceed: Tool executes immediately (no intervention needed) - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system + Proceed: Allow execution to continue without intervention + Guide: Provide contextual guidance to redirect the agent + Interrupt: Pause execution for human input Extensibility: New action types can be added to the union. Always handle the default @@ -25,9 +25,9 @@ class Proceed(BaseModel): - """Allow tool to execute immediately without intervention. + """Allow execution to continue without intervention. - The tool call proceeds as planned. The reason provides context + The action proceeds as planned. The reason provides context for logging and debugging purposes. """ @@ -36,11 +36,11 @@ class Proceed(BaseModel): class Guide(BaseModel): - """Cancel tool and provide contextual feedback for agent to explore alternatives. + """Provide contextual guidance to redirect the agent. - The tool call is cancelled and the agent receives the reason as contextual - feedback to help them consider alternative approaches while maintaining - adaptive reasoning capabilities. + The agent receives the reason as contextual feedback to help guide + its behavior. The specific handling depends on the steering context + (e.g., tool call vs. model response). """ type: Literal["guide"] = "guide" @@ -48,18 +48,29 @@ class Guide(BaseModel): class Interrupt(BaseModel): - """Pause tool execution for human input via interrupt system. + """Pause execution for human input via interrupt system. - The tool call is paused and human input is requested through Strands' + Execution is paused and human input is requested through Strands' interrupt system. The human can approve or deny the operation, and their - decision determines whether the tool executes or is cancelled. + decision determines whether execution continues or is cancelled. """ type: Literal["interrupt"] = "interrupt" reason: str -# SteeringAction union - extensible for future action types -# IMPORTANT: Always handle the default case when pattern matching -# to maintain backward compatibility as new action types are added -SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +# Context-specific steering action types +ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Steering actions valid for tool steering (steer_before_tool). + +- Proceed: Allow tool execution to continue +- Guide: Cancel tool and provide feedback for alternative approaches +- Interrupt: Pause for human input before tool execution +""" + +ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] +"""Steering actions valid for model steering (steer_after_model). + +- Proceed: Accept model response without modification +- Guide: Discard model response and retry with guidance +""" diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 4a0bcaa6a..fd00a27fc 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -2,38 +2,48 @@ Provides modular prompting through contextual guidance that appears when relevant, rather than front-loading all instructions. Handlers integrate with the Strands hook -system to intercept tool calls and provide just-in-time feedback based on local context. +system to intercept actions and provide just-in-time feedback based on local context. Architecture: - BeforeToolCallEvent → Context Callbacks → Update steering_context → steer() → SteeringAction - ↓ ↓ ↓ ↓ ↓ - Hook triggered Populate context Handler evaluates Handler decides Action taken + Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken Lifecycle: 1. Context callbacks update handler's steering_context on hook events - 2. BeforeToolCallEvent triggers steering evaluation via steer() method - 3. Handler accesses self.steering_context for guidance decisions - 4. SteeringAction determines tool execution: Proceed/Guide/Interrupt + 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering + 3. AfterModelCallEvent triggers steer_after_model() for model steering + 4. Handler accesses self.steering_context for guidance decisions + 5. SteeringAction determines execution flow Implementation: - Subclass SteeringHandler and implement steer() method. - Pass context_callbacks in constructor to register context update functions. + Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). + Both methods have default implementations that return Proceed, so you only need to + override the methods you want to customize. + Pass context_providers in constructor to register context update functions. Each handler maintains isolated steering_context that persists across calls. -SteeringAction handling: +SteeringAction handling for steer_before_tool: Proceed: Tool executes immediately Guide: Tool cancelled, agent receives contextual feedback to explore alternatives Interrupt: Tool execution paused for human input via interrupt system + +SteeringAction handling for steer_after_model: + Proceed: Model response accepted without modification + Guide: Discard model response and retry (message is dropped, model is called again) + Interrupt: Model response handling paused for human input via interrupt system """ import logging -from abc import ABC, abstractmethod +from abc import ABC from typing import TYPE_CHECKING, Any -from ....hooks.events import BeforeToolCallEvent +from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent from ....hooks.registry import HookProvider, HookRegistry +from ....types.content import Message +from ....types.streaming import StopReason from ....types.tools import ToolUse -from .action import Guide, Interrupt, Proceed, SteeringAction +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .context import SteeringContext, SteeringContextProvider if TYPE_CHECKING: @@ -73,24 +83,29 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) ) - # Register steering guidance - registry.add_callback(BeforeToolCallEvent, self._provide_steering_guidance) + # Register tool steering guidance + registry.add_callback(BeforeToolCallEvent, self._provide_tool_steering_guidance) + + # Register model steering guidance + registry.add_callback(AfterModelCallEvent, self._provide_model_steering_guidance) - async def _provide_steering_guidance(self, event: BeforeToolCallEvent) -> None: + async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] - logger.debug("tool_name=<%s> | providing steering guidance", tool_name) + logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) try: - action = await self.steer(event.agent, event.tool_use) + action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) except Exception as e: - logger.debug("tool_name=<%s>, error=<%s> | steering handler guidance failed", tool_name, e) + logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) return - self._handle_steering_action(action, event, tool_name) + self._handle_tool_steering_action(action, event, tool_name) - def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: - """Handle the steering action by modifying tool execution flow. + def _handle_tool_steering_action( + self, action: ToolSteeringAction, event: BeforeToolCallEvent, tool_name: str + ) -> None: + """Handle the steering action for tool calls by modifying tool execution flow. Proceed: Tool executes normally Guide: Tool cancelled with contextual feedback for agent to consider alternatives @@ -114,11 +129,52 @@ def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallE else: logger.debug("tool_name=<%s> | tool call approved manually", tool_name) else: - raise ValueError(f"Unknown steering action type: {action}") + raise ValueError(f"Unknown steering action type for tool call: {action}") + + async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + """Provide steering guidance for model response.""" + logger.debug("providing model steering guidance") + + # Only steer on successful model responses + if event.stop_response is None: + logger.debug("no stop response available | skipping model steering") + return + + try: + action = await self.steer_after_model( + agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason + ) + except Exception as e: + logger.debug("error=<%s> | model steering handler guidance failed", e) + return + + await self._handle_model_steering_action(action, event) + + async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: + """Handle the steering action for model responses by modifying response handling flow. - @abstractmethod - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: - """Provide contextual guidance to help agent navigate complex workflows. + Proceed: Model response accepted without modification + Guide: Discard model response and retry with guidance message added to conversation + """ + if isinstance(action, Proceed): + logger.debug("model response proceeding") + elif isinstance(action, Guide): + logger.debug("model response guided (retrying): %s", action.reason) + # Set retry flag to discard current response + event.retry = True + # Add guidance message to agent's conversation so model sees it on retry + await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) + logger.debug("added guidance message to conversation for model retry") + else: + raise ValueError(f"Unknown steering action type for model response: {action}") + + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance before tool execution. + + This method is called before a tool is executed, allowing the handler to: + - Proceed: Allow tool execution to continue + - Guide: Cancel tool and provide feedback for alternative approaches + - Interrupt: Pause for human input before tool execution Args: agent: The agent instance @@ -126,9 +182,38 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer **kwargs: Additional keyword arguments for guidance evaluation Returns: - SteeringAction indicating how to guide the agent's next action + ToolSteeringAction indicating how to guide the tool execution + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (allow tool execution) + Override this method to implement custom tool steering logic + """ + return Proceed(reason="Default implementation: allowing tool execution") + + async def steer_after_model( + self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any + ) -> ModelSteeringAction: + """Provide contextual guidance after model response. + + This method is called after the model generates a response, allowing the handler to: + - Proceed: Accept the model response without modification + - Guide: Discard the response and retry (message is dropped, model is called again) + + Note: Interrupt is not supported for model steering as the model has already responded. + + Args: + agent: The agent instance + message: The model's generated message + stop_reason: The reason the model stopped generating + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ModelSteeringAction indicating how to handle the model response Note: Access steering context via self.steering_context + Default implementation returns Proceed (accept response as-is) + Override this method to implement custom model steering logic """ - ... + return Proceed(reason="Default implementation: accepting model response") diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 4d90f46c9..379dc684a 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -10,7 +10,7 @@ from .....models import Model from .....types.tools import ToolUse from ...context_providers.ledger_provider import LedgerProvider -from ...core.action import Guide, Interrupt, Proceed, SteeringAction +from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction from ...core.context import SteeringContextProvider from ...core.handler import SteeringHandler from .mappers import DefaultPromptMapper, LLMPromptMapper @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer_before_tool(self, *, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: """Provide contextual guidance for tool usage. Args: @@ -67,7 +67,7 @@ async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> Steerin **kwargs: Additional keyword arguments for steering evaluation Returns: - SteeringAction indicating how to guide the agent's next action + SteeringAction indicating how to guide the tool execution """ # Generate steering prompt prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) @@ -91,5 +91,5 @@ async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> Steerin case "interrupt": return Interrupt(reason=llm_result.reason) case _: - logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 8d5ef6884..a16208e5b 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,20 +1,20 @@ """Unit tests for steering handler base class.""" -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from strands.experimental.steering.core.action import Guide, Interrupt, Proceed from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler -from strands.hooks.events import BeforeToolCallEvent +from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent from strands.hooks.registry import HookRegistry class TestSteeringHandler(SteeringHandler): """Test implementation of SteeringHandler.""" - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -31,9 +31,9 @@ def test_register_hooks(): handler.register_hooks(registry) - # Verify hooks were registered - assert registry.add_callback.call_count >= 1 - registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_steering_guidance) + # Verify hooks were registered (tool and model steering hooks) + assert registry.add_callback.call_count >= 2 + registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_tool_steering_guidance) def test_steering_context_initialization(): @@ -65,7 +65,7 @@ async def test_proceed_action_flow(): """Test complete flow with Proceed action.""" class ProceedHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") handler = ProceedHandler() @@ -73,7 +73,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -84,7 +84,7 @@ async def test_guide_action_flow(): """Test complete flow with Guide action.""" class GuideHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Guide(reason="Test guidance") handler = GuideHandler() @@ -92,7 +92,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" @@ -104,7 +104,7 @@ async def test_interrupt_action_approved_flow(): """Test complete flow with Interrupt action when approved.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -113,7 +113,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -123,7 +123,7 @@ async def test_interrupt_action_denied_flow(): """Test complete flow with Interrupt action when denied.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -132,7 +132,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -143,7 +143,7 @@ async def test_unknown_action_flow(): """Test complete flow with unknown action type raises error.""" class UnknownActionHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Mock() # Not a valid SteeringAction handler = UnknownActionHandler() @@ -152,14 +152,14 @@ async def steer(self, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) def test_register_steering_hooks_override(): """Test that _register_steering_hooks can be overridden.""" class CustomHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Custom") def register_hooks(self, registry, **kwargs): @@ -200,7 +200,7 @@ def __init__(self, context_callbacks=None): providers = [MockContextProvider(context_callbacks)] if context_callbacks else None super().__init__(context_providers=providers) - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -260,8 +260,8 @@ def test_multiple_context_callbacks_registered(): handler.register_hooks(registry) - # Should register one callback for each context provider plus steering guidance - expected_calls = 2 + 1 # 2 callbacks + 1 for steering guidance + # Should register one callback for each context provider plus tool and model steering guidance + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert registry.add_callback.call_count >= expected_calls @@ -276,3 +276,208 @@ def test_handler_initialization_with_callbacks(): assert len(handler._context_callbacks) == 2 assert callback1 in handler._context_callbacks assert callback2 in handler._context_callbacks + + +# Model steering tests +@pytest.mark.asyncio +async def test_model_steering_proceed_action_flow(): + """Test model steering with Proceed action.""" + + class ModelProceedHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Proceed(reason="Model response accepted") + + handler = ModelProceedHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should not set retry for Proceed + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_steering_guide_action_flow(): + """Test model steering with Guide action sets retry and adds message.""" + + class ModelGuideHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Guide(reason="Please improve your response") + + handler = ModelGuideHandler() + agent = AsyncMock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should set retry flag + assert event.retry is True + # Should add guidance message to conversation + agent._append_messages.assert_called_once() + call_args = agent._append_messages.call_args[0][0] + assert call_args["role"] == "user" + assert "Please improve your response" in call_args["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_model_steering_skips_when_no_stop_response(): + """Test model steering skips when stop_response is None.""" + + class ModelProceedHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.steer_called = False + + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + self.steer_called = True + return Proceed(reason="Should not be called") + + handler = ModelProceedHandler() + event = Mock(spec=AfterModelCallEvent) + event.stop_response = None + + await handler._provide_model_steering_guidance(event) + + # steer_after_model should not have been called + assert handler.steer_called is False + + +@pytest.mark.asyncio +async def test_model_steering_unknown_action_raises_error(): + """Test model steering with unknown action type raises error.""" + + class UnknownModelActionHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Mock() # Not a valid ModelSteeringAction + + handler = UnknownModelActionHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler._provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_interrupt_raises_error(): + """Test model steering with Interrupt action raises error (not supported for model steering).""" + + class InterruptModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Interrupt(reason="Should not be allowed") + + handler = InterruptModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler._provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_exception_handling(): + """Test model steering handles exceptions gracefully.""" + + class ExceptionModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + # Should not raise, just return early + await handler._provide_model_steering_guidance(event) + + # retry should not be set since exception occurred + assert event.retry is False + + +@pytest.mark.asyncio +async def test_tool_steering_exception_handling(): + """Test tool steering handles exceptions gracefully.""" + + class ExceptionToolHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionToolHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + # Should not raise, just return early + await handler._provide_tool_steering_guidance(event) + + # cancel_tool should not be set since exception occurred + assert not event.cancel_tool + + +# Default implementation tests +@pytest.mark.asyncio +async def test_default_steer_before_tool_returns_proceed(): + """Test default steer_before_tool returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + + # Call the parent's default implementation + result = await SteeringHandler.steer_before_tool(handler, agent=agent, tool_use=tool_use) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +@pytest.mark.asyncio +async def test_default_steer_after_model_returns_proceed(): + """Test default steer_after_model returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_reason = "end_turn" + + # Call the parent's default implementation + result = await SteeringHandler.steer_after_model(handler, agent=agent, message=message, stop_reason=stop_reason) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +def test_register_hooks_registers_model_steering(): + """Test that register_hooks registers model steering callback.""" + handler = TestSteeringHandler() + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Verify model steering hook was registered + registry.add_callback.assert_any_call(AfterModelCallEvent, handler._provide_model_steering_guidance) diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py index f780088b5..f10254e50 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py +++ b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py @@ -59,7 +59,7 @@ async def test_steer_proceed_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert result.reason == "Tool call is safe" @@ -82,7 +82,7 @@ async def test_steer_guide_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Guide) assert result.reason == "Consider security implications" @@ -105,7 +105,7 @@ async def test_steer_interrupt_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Interrupt) assert result.reason == "Human approval required" @@ -133,7 +133,7 @@ async def test_steer_unknown_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert "Unknown LLM decision, defaulting to proceed" in result.reason @@ -158,7 +158,7 @@ async def test_steer_uses_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=custom_model, callback_handler=None) @@ -181,7 +181,7 @@ async def test_steer_uses_agent_model_when_no_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=agent.model, callback_handler=None) diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py index 07ec8f2ba..44d9a626a 100644 --- a/tests/strands/tools/test_decorator_pep563.py +++ b/tests/strands/tools/test_decorator_pep563.py @@ -10,10 +10,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import pytest -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from strands import tool diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py new file mode 100644 index 000000000..e867ea033 --- /dev/null +++ b/tests_integ/steering/test_model_steering.py @@ -0,0 +1,204 @@ +"""Integration tests for model steering (steer_after_model).""" + +from strands import Agent, tool +from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed +from strands.experimental.steering.core.handler import SteeringHandler +from strands.types.content import Message +from strands.types.streaming import StopReason + + +class SimpleModelSteeringHandler(SteeringHandler): + """Simple handler that steers only on model responses.""" + + def __init__(self, should_guide: bool = False, guidance_message: str = ""): + """Initialize handler. + + Args: + should_guide: If True, guide (retry) on first model response + guidance_message: The guidance message to provide on retry + """ + super().__init__() + self.should_guide = should_guide + self.guidance_message = guidance_message + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + """Steer after model response.""" + self.call_count += 1 + + # On first call, guide to retry if configured + if self.should_guide and self.call_count == 1: + return Guide(reason=self.guidance_message) + + return Proceed(reason="Model response accepted") + + +def test_model_steering_proceeds_without_intervention(): + """Test that model steering can accept responses without modification.""" + handler = SimpleModelSteeringHandler(should_guide=False) + agent = Agent(hooks=[handler]) + + response = agent("What is 2+2?") + + # Handler should have been called once + assert handler.call_count >= 1 + # Response should be generated successfully + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_triggers_retry(): + """Test that Guide action triggers model retry.""" + handler = SimpleModelSteeringHandler(should_guide=True, guidance_message="Please provide a more detailed response.") + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Handler should have been called at least twice (first response + retry) + assert handler.call_count >= 2, "Handler should be called on initial response and retry" + + # Response should be generated successfully after retry + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_influences_retry_response(): + """Test that guidance message influences the retry response.""" + + class SpecificGuidanceHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.retry_done = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + if not self.retry_done: + self.retry_done = True + # Provide very specific guidance that should appear in retry + return Guide(reason="Please mention that Paris is also known as the 'City of Light'.") + return Proceed(reason="Response is good now") + + handler = SpecificGuidanceHandler() + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Verify retry happened + assert handler.retry_done, "Retry should have occurred" + + # Check that the response likely incorporated the guidance + output = str(response).lower() + assert "paris" in output, "Response should mention Paris" + + # The guidance should have influenced the retry (check for "light" or that retry happened) + # We can't guarantee the model will include it, but we verify the mechanism worked + assert handler.retry_done, "Guidance mechanism should have executed" + + +def test_model_steering_multiple_retries(): + """Test that model steering can guide multiple times before proceeding.""" + + class MultiRetryHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + self.call_count += 1 + + # Retry twice + if self.call_count == 1: + return Guide(reason="Please provide more context.") + if self.call_count == 2: + return Guide(reason="Please add specific examples.") + return Proceed(reason="Response is good now") + + handler = MultiRetryHandler() + agent = Agent(hooks=[handler]) + + response = agent("Explain machine learning.") + + # Should have been called 3 times (2 guides + 1 proceed) + assert handler.call_count >= 3, "Handler should be called multiple times for multiple retries" + + # Response should still complete successfully + assert str(response) is not None + assert len(str(response)) > 0 + + +@tool +def log_activity(activity: str) -> str: + """Log an activity for audit purposes.""" + return f"Activity logged: {activity}" + + +def test_model_steering_forces_tool_usage_on_unrelated_prompt(): + """Test that steering forces tool usage even when prompt doesn't need the tool. + + This test verifies the flow: + 1. Agent has a logging tool available + 2. User asks an unrelated question (math problem) + 3. Model tries to answer directly without using the tool + 4. Steering intercepts and forces tool usage before termination + 5. Model uses the tool and then completes + """ + + class ForceToolUsageHandler(SteeringHandler): + """Handler that forces a specific tool to be used before allowing termination.""" + + def __init__(self, required_tool: str): + super().__init__() + self.required_tool = required_tool + self.tool_was_used = False + self.guidance_given = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + # Only check when model is trying to end the turn + if stop_reason != "end_turn": + return Proceed(reason="Model still processing") + + # Check if the required tool was used in this message + content_blocks = message.get("content", []) + for block in content_blocks: + if "toolUse" in block and block["toolUse"].get("name") == self.required_tool: + self.tool_was_used = True + return Proceed(reason="Required tool was used") + + # If tool wasn't used and we haven't guided yet, force its usage + if not self.tool_was_used and not self.guidance_given: + self.guidance_given = True + return Guide( + reason=f"Before completing your response, you MUST use the {self.required_tool} tool " + "to log this interaction. Call the tool with a brief description of what you did." + ) + + # Allow completion after guidance was given (model may have used tool in retry) + return Proceed(reason="Guidance was provided") + + handler = ForceToolUsageHandler(required_tool="log_activity") + agent = Agent(tools=[log_activity], hooks=[handler]) + + # Ask a question that clearly doesn't need the logging tool + response = agent("What is 2 + 2?") + + # Verify the steering mechanism worked + assert handler.guidance_given, "Handler should have provided guidance to use the tool" + + # Verify tool was actually called by checking metrics + tool_metrics = response.metrics.tool_metrics + assert "log_activity" in tool_metrics, "log_activity tool should have been called" + assert tool_metrics["log_activity"].call_count >= 1, "log_activity should have been called at least once" + assert tool_metrics["log_activity"].success_count >= 1, "log_activity should have succeeded" + + # Verify the response still answers the original question + output = str(response).lower() + assert "4" in output, "Response should contain the answer to 2+2" diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_tool_steering.py similarity index 91% rename from tests_integ/steering/test_llm_handler.py rename to tests_integ/steering/test_tool_steering.py index 8a8cebea2..eced94ba0 100644 --- a/tests_integ/steering/test_llm_handler.py +++ b/tests_integ/steering/test_tool_steering.py @@ -1,4 +1,4 @@ -"""Integration tests for LLM steering handler.""" +"""Integration tests for tool steering (steer_before_tool).""" import pytest @@ -30,7 +30,7 @@ async def test_llm_steering_handler_proceed(): agent = Agent(tools=[send_notification]) tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Proceed) @@ -48,7 +48,7 @@ async def test_llm_steering_handler_guide(): agent = Agent(tools=[send_email, send_notification]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Guide) @@ -64,12 +64,12 @@ async def test_llm_steering_handler_interrupt(): agent = Agent(tools=[send_email]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Interrupt) -def test_agent_with_steering_e2e(): +def test_agent_with_tool_steering_e2e(): """End-to-end test of agent with steering handler guiding tool choice.""" handler = LLMSteeringHandler( system_prompt=( From 63e58aa83dbb63fab06c405cbeb2acaa500c9e32 Mon Sep 17 00:00:00 2001 From: Qian Zhang Date: Tue, 20 Jan 2026 16:18:52 +0100 Subject: [PATCH 36/47] fix: provide unique toolUseId for gemini models (#1201) Co-authored-by: spicadust Co-authored-by: Patrick Gray --- src/strands/models/gemini.py | 28 ++++++++---- tests/strands/models/test_gemini.py | 67 ++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 52d45b649..5417f20b3 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -6,6 +6,7 @@ import json import logging import mimetypes +import secrets from collections.abc import AsyncGenerator from typing import Any, TypedDict, TypeVar, cast @@ -86,6 +87,7 @@ def __init__( self._custom_client = client self.client_args = client_args or {} + self._tool_use_id_to_name: dict[str, str] = {} # Validate gemini_tools if provided if "gemini_tools" in self.config: @@ -173,10 +175,13 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par return genai.types.Part(text=content["text"]) if "toolResult" in content: + tool_use_id = content["toolResult"]["toolUseId"] + function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id) + return genai.types.Part( function_response=genai.types.FunctionResponse( - id=content["toolResult"]["toolUseId"], - name=content["toolResult"]["toolUseId"], + id=tool_use_id, + name=function_name, response={ "output": [ tool_result_content @@ -191,6 +196,12 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: + # Store the mapping from toolUseId to name for later use in toolResult formatting. + # This mapping is built as we format the request, ensuring that when we encounter + # toolResult blocks (which come after toolUse blocks in the message history), + # we can look up the function name. + self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], @@ -317,16 +328,16 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_start": match event["data_type"]: case "tool": - # Note: toolUseId is the only identifier available in a tool result. However, Gemini requires - # that name be set in the equivalent FunctionResponse type. Consequently, we assign - # function name to toolUseId in our tool use block. And another reason, function_call is - # not guaranteed to have id populated. + function_call = event["data"].function_call + # Use Gemini's provided ID or generate one if missing + tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}" + return { "contentBlockStart": { "start": { "toolUse": { - "name": event["data"].function_call.name, - "toolUseId": event["data"].function_call.name, + "name": function_call.name, + "toolUseId": tool_use_id, }, }, }, @@ -417,6 +428,7 @@ async def stream( ModelThrottledException: If the request is throttled by Gemini. """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) + self._tool_use_id_to_name.clear() client = self._get_client().aio diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 08be9188d..70f5032d8 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -360,6 +360,71 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id): gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) +@pytest.mark.asyncio +async def test_stream_request_with_tool_results_preserving_name(gemini_client, model, model_id): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_1", + "input": {}, + }, + }, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "done"}], + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [{"function_declarations": []}], + }, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {}, + "id": "t1", + "name": "tool_1", + }, + }, + ], + "role": "model", + }, + { + "parts": [ + { + "function_response": { + "id": "t1", + "name": "tool_1", + "response": {"output": [{"text": "done"}]}, + }, + }, + ], + "role": "user", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_request_with_empty_content(gemini_client, model, model_id): messages = [ @@ -459,7 +524,7 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, {"contentBlockStop": {}}, {"contentBlockStop": {}}, From 456b70a0c14b255eafb49442e9663905b8ba5eba Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 20 Jan 2026 13:02:31 -0500 Subject: [PATCH 37/47] gemini - tool_use_id_to_name - local (#1521) --- src/strands/models/gemini.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 5417f20b3..855e1ef5c 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -87,7 +87,6 @@ def __init__( self._custom_client = client self.client_args = client_args or {} - self._tool_use_id_to_name: dict[str, str] = {} # Validate gemini_tools if provided if "gemini_tools" in self.config: @@ -135,13 +134,19 @@ def _get_client(self) -> genai.Client: # Create a new client from client_args return genai.Client(**self.client_args) - def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part: + def _format_request_content_part( + self, content: ContentBlock, tool_use_id_to_name: dict[str, str] + ) -> genai.types.Part: """Format content block into a Gemini part instance. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Part Args: content: Message content to format. + tool_use_id_to_name: Mapping of tool use id to tool name. + Store the mapping from toolUseId to name for later use in toolResult formatting. This mapping is built + as we format the request, ensuring that when we encounter toolResult blocks (which come after toolUse + blocks in the message history), we can look up the function name. Returns: Gemini part. @@ -176,7 +181,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par if "toolResult" in content: tool_use_id = content["toolResult"]["toolUseId"] - function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id) + function_name = tool_use_id_to_name.get(tool_use_id, tool_use_id) return genai.types.Part( function_response=genai.types.FunctionResponse( @@ -187,7 +192,8 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par tool_result_content if "json" in tool_result_content else self._format_request_content_part( - cast(ContentBlock, tool_result_content) + cast(ContentBlock, tool_result_content), + tool_use_id_to_name, ).to_json_dict() for tool_result_content in content["toolResult"]["content"] ], @@ -196,11 +202,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: - # Store the mapping from toolUseId to name for later use in toolResult formatting. - # This mapping is built as we format the request, ensuring that when we encounter - # toolResult blocks (which come after toolUse blocks in the message history), - # we can look up the function name. - self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] return genai.types.Part( function_call=genai.types.FunctionCall( @@ -223,9 +225,15 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten Returns: Gemini content list. """ + # Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not + # available in tool result blocks, hence the mapping. + tool_use_id_to_name: dict[str, str] = {} + return [ genai.types.Content( - parts=[self._format_request_content_part(content) for content in message["content"]], + parts=[ + self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"] + ], role="user" if message["role"] == "user" else "model", ) for message in messages @@ -428,7 +436,6 @@ async def stream( ModelThrottledException: If the request is throttled by Gemini. """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) - self._tool_use_id_to_name.clear() client = self._get_client().aio From 6dcd24739d7a153eed8eb778d795bb9df6cd3fc3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 20 Jan 2026 21:12:36 +0200 Subject: [PATCH 38/47] fix(litellm): handle missing usage attribute on ModelResponseStream (#1520) --- src/strands/models/litellm.py | 4 +- tests/strands/models/test_litellm.py | 101 +++++++++++++++++++++++ tests_integ/models/test_model_litellm.py | 21 +++++ 3 files changed, 124 insertions(+), 2 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index ae71cc668..ec6579c58 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -547,8 +547,8 @@ async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> A # Skip remaining events as we don't have use for anything except the final usage payload async for event in response: _ = event - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if usage := getattr(event, "usage", None): + yield self.format_chunk({"chunk_type": "metadata", "data": usage}) logger.debug("finished streaming response from model") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 99df22a3f..f5e1837bf 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -711,3 +711,104 @@ def test_stream_switch_content_different_type_no_prev(): assert len(chunks) == 1 assert chunks[0]["contentBlockStart"] == {"start": {}} assert data_type == "text" + + +@pytest.mark.asyncio +async def test_stream_with_events_missing_usage_attribute( + litellm_acompletion, api_key, model_id, model, agenerator, alist +): + """Test streaming handles events that don't have a usage attribute. + + This test verifies the fix for a bug where ModelResponseStream objects + (which don't have a 'usage' attribute) would cause an AttributeError + when the code tried to access event.usage directly instead of using getattr. + + The bug occurred because: + 1. ModelResponse (non-streaming) has a 'usage' attribute + 2. ModelResponseStream (streaming chunks) does NOT have a 'usage' attribute + 3. The code assumed all events would have the 'usage' attribute + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + + # Use spec to ensure mock objects only have specified attributes + # This mimics the real ModelResponseStream which doesn't have 'usage' + class MockStreamChunk: + """Mock that mimics ModelResponseStream - no usage attribute.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + # After finish_reason is received, remaining events in the stream also don't have 'usage' + mock_event_3 = MockStreamChunk(choices=[]) + mock_event_4 = MockStreamChunk(choices=[]) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}] + response = model.stream(messages) + + # This should NOT raise AttributeError: 'MockStreamChunk' object has no attribute 'usage' + tru_events = await alist(response) + + # Verify we got the expected events (no metadata since no usage was available) + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert {"messageStop": {"stopReason": "end_turn"}} in tru_events + # No metadata event since mock events don't have usage + assert not any("metadata" in event for event in tru_events) + + +@pytest.mark.asyncio +async def test_stream_with_usage_in_final_event(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test streaming correctly extracts usage when it IS present in final events. + + This test ensures that when usage data IS available (e.g., with stream_options.include_usage=True), + it is correctly extracted and included in the metadata event. + """ + + class MockStreamChunkWithoutUsage: + """Mock streaming chunk without usage.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + class MockStreamChunkWithUsage: + """Mock streaming chunk with usage (final event).""" + + def __init__(self, usage): + self.choices = [] + self.usage = usage + + mock_delta = unittest.mock.Mock(content="Hi", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage data + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + mock_usage.total_tokens = 15 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_event_3 = MockStreamChunkWithUsage(usage=mock_usage) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + # Verify metadata event is present with correct usage + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 + assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 5 + assert metadata_events[0]["metadata"]["usage"]["totalTokens"] == 15 diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 80e21bdfd..eb0737e0f 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -236,6 +236,27 @@ def test_structured_output_unsupported_model(model, nested_weather): mock_schema.assert_not_called() +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_streaming_returns_usage_metrics(model_fixture, request): + """Test that streaming returns usage metrics. + + This test verifies that the streaming flow correctly extracts and returns + usage data from the model response. This is a regression test for the bug + where accessing 'usage' attribute on ModelResponseStream raised AttributeError. + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) + result = agent("Say hello") + + # Verify usage metrics are returned - this would fail if streaming breaks + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + + @pytest.mark.asyncio async def test_cache_read_tokens_multi_turn(model): """Integration test for cache read tokens in multi-turn conversation.""" From 64e1bb25e2e462d2bd42b66d048a8782d674223a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:41:08 -0500 Subject: [PATCH 39/47] feat(agent): add configurable retry_strategy for model calls (#1424) The current retry logic for handling ModelThrottledException is hardcoded in event_loop.py with fixed values (6 attempts, exponential backoff starting at 4s). This makes it impossible for users to customize retry behavior for their specific use cases, such as: This refactors the hardcoded retry logic into a `ModelRetryStrategy` class so that folks can customize the parameters. **Not Included**: The does not introduce a `RetryStrategy` base class. I started to do so, but am deferring it because: 1. It requires some additional design work to accommodate the tool-retries, which I anticipate should be accounted for in the design 2. It simplifies this review which refactors how the default retries work internally 3. `ModelRetryStrategy` provides enough benefit to allow folks to customize the agent loop without blocking on a more extensible design ---- Co-authored-by: Strands Agent Co-authored-by: Mackenzie Zastrow --- src/strands/__init__.py | 2 + src/strands/agent/__init__.py | 3 + src/strands/agent/agent.py | 21 +- src/strands/event_loop/_retry.py | 157 +++++++++ src/strands/event_loop/event_loop.py | 48 +-- tests/strands/agent/conftest.py | 22 ++ .../strands/agent/hooks/test_agent_events.py | 10 +- tests/strands/agent/test_agent_hooks.py | 2 +- tests/strands/agent/test_agent_retry.py | 161 +++++++++ tests/strands/agent/test_retry.py | 328 ++++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 27 +- 11 files changed, 736 insertions(+), 45 deletions(-) create mode 100644 src/strands/event_loop/_retry.py create mode 100644 tests/strands/agent/conftest.py create mode 100644 tests/strands/agent/test_agent_retry.py create mode 100644 tests/strands/agent/test_retry.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index bc17497a0..6026d4240 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,6 +3,7 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .agent.base import AgentBase +from .event_loop._retry import ModelRetryStrategy from .tools.decorator import tool from .types.tools import ToolContext @@ -11,6 +12,7 @@ "AgentBase", "agent", "models", + "ModelRetryStrategy", "tool", "ToolContext", "types", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c00623dc2..2e40866a9 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,8 +4,10 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Retry Strategies: Configurable retry behavior for model calls """ +from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -24,4 +26,5 @@ "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "ModelRetryStrategy", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7b9e9c914..cacc69ece 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -26,7 +26,8 @@ from .. import _identifier from .._async import run_async -from ..event_loop.event_loop import event_loop_cycle +from ..event_loop._retry import ModelRetryStrategy +from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: @@ -118,6 +119,7 @@ def __init__( hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, tool_executor: ToolExecutor | None = None, + retry_strategy: ModelRetryStrategy | None = None, ): """Initialize the Agent with the specified configuration. @@ -167,6 +169,9 @@ def __init__( session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + retry_strategy: Strategy for retrying model calls on throttling or other transient errors. + Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. + Implement a custom HookProvider for custom retry logic, or pass None to disable retries. Raises: ValueError: If agent id contains path separators. @@ -244,6 +249,17 @@ def __init__( # separate event loops in different threads, so asyncio.Lock wouldn't work self._invocation_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until + # that API is determined we only allow ModelRetryStrategy + if retry_strategy and type(retry_strategy) is not ModelRetryStrategy: + raise ValueError("retry_strategy must be an instance of ModelRetryStrategy") + + self._retry_strategy = ( + retry_strategy + if retry_strategy is not None + else ModelRetryStrategy(max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY) + ) + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -252,6 +268,9 @@ def __init__( # Allow conversation_managers to subscribe to hooks self.hooks.add_hook(self.conversation_manager) + # Register retry strategy as a hook + self.hooks.add_hook(self._retry_strategy) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: diff --git a/src/strands/event_loop/_retry.py b/src/strands/event_loop/_retry.py new file mode 100644 index 000000000..04a6101b8 --- /dev/null +++ b/src/strands/event_loop/_retry.py @@ -0,0 +1,157 @@ +"""Retry strategy implementations for handling model throttling and other retry scenarios. + +This module provides hook-based retry strategies that can be configured on the Agent +to control retry behavior for model invocations. Retry strategies implement the +HookProvider protocol and register callbacks for AfterModelCallEvent to determine +when and how to retry failed model calls. +""" + +import asyncio +import logging +from typing import Any + +from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types._events import EventLoopThrottleEvent, TypedEvent +from ..types.exceptions import ModelThrottledException + +logger = logging.getLogger(__name__) + + +class ModelRetryStrategy(HookProvider): + """Default retry strategy for model throttling with exponential backoff. + + Retries model calls on ModelThrottledException using exponential backoff. + Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4, + etc., capped at max_delay. State resets after successful calls. + + With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are: + 4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt). + + Args: + max_attempts: Total model attempts before re-raising the exception. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + max_delay: Upper bound in seconds for the exponential backoff. + """ + + def __init__( + self, + *, + max_attempts: int = 6, + initial_delay: int = 4, + max_delay: int = 240, + ): + """Initialize the retry strategy. + + Args: + max_attempts: Total model attempts before re-raising the exception. Defaults to 6. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + Defaults to 4. + max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240. + """ + self._max_attempts = max_attempts + self._initial_delay = initial_delay + self._max_delay = max_delay + self._current_attempt = 0 + self._backwards_compatible_event_to_yield: TypedEvent | None = None + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) + registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) + + def _calculate_delay(self, attempt: int) -> int: + """Calculate retry delay using exponential backoff. + + Args: + attempt: The attempt number (0-indexed) to calculate delay for. + + Returns: + Delay in seconds for the given attempt. + """ + delay: int = self._initial_delay * (2**attempt) + return min(delay, self._max_delay) + + def _reset_retry_state(self) -> None: + """Reset retry state to initial values.""" + self._current_attempt = 0 + + async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: + """Reset retry state after invocation completes. + + Args: + event: The AfterInvocationEvent signaling invocation completion. + """ + self._reset_retry_state() + + async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: + """Handle model call completion and determine if retry is needed. + + This callback is invoked after each model call. If the call failed with + a ModelThrottledException and we haven't exceeded max_attempts, it sets + event.retry to True and sleeps for the current delay before returning. + + On successful calls, it resets the retry state to prepare for future calls. + + Args: + event: The AfterModelCallEvent containing call results or exception. + """ + delay = self._calculate_delay(self._current_attempt) + + self._backwards_compatible_event_to_yield = None + + # If already retrying, skip processing (another hook may have triggered retry) + if event.retry: + return + + # If model call succeeded, reset retry state + if event.stop_response is not None: + logger.debug( + "stop_reason=<%s> | model call succeeded, resetting retry state", + event.stop_response.stop_reason, + ) + self._reset_retry_state() + return + + # Check if we have an exception and reset state if no exception + if event.exception is None: + self._reset_retry_state() + return + + # Only retry on ModelThrottledException + if not isinstance(event.exception, ModelThrottledException): + return + + # Increment attempt counter first + self._current_attempt += 1 + + # Check if we've exceeded max attempts + if self._current_attempt >= self._max_attempts: + logger.debug( + "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", + self._current_attempt, + self._max_attempts, + ) + return + + self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay) + + # Retry the model call + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered | delaying before next retry", + delay, + self._max_attempts, + self._current_attempt, + ) + + # Sleep for current delay + await asyncio.sleep(delay) + + # Set retry flag and track that this strategy triggered it + event.retry = True diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 99c8f5179..f5d00a201 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -8,7 +8,6 @@ 4. Manage recursive execution cycles """ -import asyncio import logging import uuid from collections.abc import AsyncGenerator @@ -23,7 +22,6 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, - EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, ModelStopReason, @@ -39,12 +37,12 @@ ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, - ModelThrottledException, StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached +from ._retry import ModelRetryStrategy from .streaming import stream_messages if TYPE_CHECKING: @@ -316,9 +314,9 @@ async def _handle_model_execution( stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): + # Retry loop - actual retry logic is handled by retry_strategy hook + # Hooks control when to stop retrying via the event.retry flag + while True: model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, @@ -366,9 +364,8 @@ async def _handle_model_execution( # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "stop_reason=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, - attempt + 1, ) continue # Retry the model call @@ -389,34 +386,27 @@ async def _handle_model_execution( ) await agent.hooks.invoke_callbacks_async(after_model_call_event) + # Emit backwards-compatible events if retry strategy supports it + # (prior to making the retry strategy configurable, this is what we emitted) + + if ( + isinstance(agent._retry_strategy, ModelRetryStrategy) + and agent._retry_strategy._backwards_compatible_event_to_yield + ): + yield agent._retry_strategy._backwards_compatible_event_to_yield + # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "exception=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "exception=<%s>, retry_requested= | hook requested model retry", type(e).__name__, - attempt + 1, ) - continue # Retry the model call - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e - - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, - ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) + continue # Retry the model call - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + # No retry requested, raise the exception + yield ForceStopEvent(reason=e) + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/tests/strands/agent/conftest.py b/tests/strands/agent/conftest.py new file mode 100644 index 000000000..d3af90dc8 --- /dev/null +++ b/tests/strands/agent/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for agent tests.""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_sleep(monkeypatch): + """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" + sleep_calls = [] + + async def _mock_sleep(delay): + sleep_calls.append(delay) + + mock = AsyncMock(side_effect=_mock_sleep) + monkeypatch.setattr(asyncio, "sleep", mock) + + # Return both the mock and the sleep_calls list for verification + mock.sleep_calls = sleep_calls + return mock diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 7b189a5c6..f511c7019 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,6 +1,6 @@ import asyncio import unittest.mock -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pydantic import BaseModel @@ -34,9 +34,7 @@ async def streaming_tool(): @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -359,8 +357,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): {"arg1": 1013, "init_event_loop": True}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **throttle_props}, {"event_loop_throttled_delay": 8, **throttle_props}, - {"event_loop_throttled_delay": 16, **throttle_props}, {"event": {"messageStart": {"role": "assistant"}}}, {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, {"event": {"contentBlockStart": {"start": {}}}}, @@ -508,11 +506,11 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"init_event_loop": True, "arg1": 1013}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **common_props}, {"event_loop_throttled_delay": 8, **common_props}, {"event_loop_throttled_delay": 16, **common_props}, {"event_loop_throttled_delay": 32, **common_props}, {"event_loop_throttled_delay": 64, **common_props}, - {"event_loop_throttled_delay": 128, **common_props}, {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index be71b5fcf..e8b7e5077 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -104,7 +104,7 @@ class User(BaseModel): @pytest.fixture def mock_sleep(): - with patch.object(strands.event_loop.event_loop.asyncio, "sleep", new_callable=AsyncMock) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py new file mode 100644 index 000000000..1b3bc5e9c --- /dev/null +++ b/tests/strands/agent/test_agent_retry.py @@ -0,0 +1,161 @@ +"""Integration tests for Agent retry_strategy parameter.""" + +from unittest.mock import Mock + +import pytest + +from strands import Agent, ModelRetryStrategy +from strands.event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY +from strands.hooks import AfterModelCallEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Agent Retry Strategy Initialization Tests + + +def test_agent_with_default_retry_strategy(): + """Test that Agent uses ModelRetryStrategy by default when retry_strategy=None.""" + agent = Agent() + + # Should have a retry_strategy + assert agent._retry_strategy is not None + + # Should be ModelRetryStrategy with default parameters + assert isinstance(agent._retry_strategy, ModelRetryStrategy) + assert agent._retry_strategy._max_attempts == 6 + assert agent._retry_strategy._initial_delay == 4 + assert agent._retry_strategy._max_delay == 240 + + +def test_agent_with_custom_model_retry_strategy(): + """Test Agent initialization with custom ModelRetryStrategy parameters.""" + custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + agent = Agent(retry_strategy=custom_strategy) + + assert agent._retry_strategy is custom_strategy + assert agent._retry_strategy._max_attempts == 3 + assert agent._retry_strategy._initial_delay == 2 + assert agent._retry_strategy._max_delay == 60 + + +def test_agent_rejects_invalid_retry_strategy_type(): + """Test that Agent raises ValueError for non-ModelRetryStrategy retry_strategy.""" + + class FakeRetryStrategy: + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=FakeRetryStrategy()) + + +def test_agent_rejects_subclass_of_model_retry_strategy(): + """Test that Agent rejects subclasses of ModelRetryStrategy (strict type check).""" + + class CustomRetryStrategy(ModelRetryStrategy): + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=CustomRetryStrategy()) + + +def test_agent_default_retry_strategy_uses_event_loop_constants(): + """Test that default retry strategy uses constants from event_loop module.""" + agent = Agent() + + assert agent._retry_strategy._max_attempts == MAX_ATTEMPTS + assert agent._retry_strategy._initial_delay == INITIAL_DELAY + assert agent._retry_strategy._max_delay == MAX_DELAY + + +def test_retry_strategy_registered_as_hook(): + """Test that retry_strategy is registered with the hook system.""" + custom_strategy = ModelRetryStrategy(max_attempts=3) + agent = Agent(retry_strategy=custom_strategy) + + # Verify retry strategy callback is registered + callbacks = list(agent.hooks.get_callbacks_for(AfterModelCallEvent(agent=agent, exception=None))) + + # Should have at least one callback (from retry strategy) + assert len(callbacks) > 0 + + # Verify one of the callbacks is from the retry strategy + assert any( + callback.__self__ is custom_strategy if hasattr(callback, "__self__") else False for callback in callbacks + ) + + +# Agent Retry Behavior Tests + + +@pytest.mark.asyncio +async def test_agent_retries_with_default_strategy(mock_sleep): + """Test that Agent retries on throttling with default ModelRetryStrategy.""" + # Create a model that fails twice with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success after retries"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have succeeded after retries - just check we got events + assert len(events) > 0 + + # Should have slept twice (for two retries) + assert len(mock_sleep.sleep_calls) == 2 + # First retry: 4 seconds + assert mock_sleep.sleep_calls[0] == 4 + # Second retry: 8 seconds (exponential backoff) + assert mock_sleep.sleep_calls[1] == 8 + + +@pytest.mark.asyncio +async def test_agent_respects_max_attempts(mock_sleep): + """Test that Agent respects max_attempts in retry strategy.""" + # Create a model that always fails + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Use custom strategy with max 2 attempts + custom_strategy = ModelRetryStrategy(max_attempts=2, initial_delay=1, max_delay=60) + agent = Agent(model=model, retry_strategy=custom_strategy) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should have attempted max_attempts times, which means (max_attempts - 1) sleeps + # Attempt 0: fail, sleep + # Attempt 1: fail, no more attempts + assert len(mock_sleep.sleep_calls) == 1 + + +# Backwards Compatibility Tests + + +@pytest.mark.asyncio +async def test_event_loop_throttle_event_emitted(mock_sleep): + """Test that EventLoopThrottleEvent is still emitted for backwards compatibility.""" + # Create a model that fails once with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have EventLoopThrottleEvent in the stream + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) > 0 + + # Should have the correct delay value + assert throttle_events[0]["event_loop_throttled_delay"] > 0 diff --git a/tests/strands/agent/test_retry.py b/tests/strands/agent/test_retry.py new file mode 100644 index 000000000..830c1b5b8 --- /dev/null +++ b/tests/strands/agent/test_retry.py @@ -0,0 +1,328 @@ +"""Unit tests for retry strategy implementations.""" + +from unittest.mock import Mock + +import pytest + +from strands import ModelRetryStrategy +from strands.hooks import AfterInvocationEvent, AfterModelCallEvent, HookRegistry +from strands.types._events import EventLoopThrottleEvent +from strands.types.exceptions import ModelThrottledException + +# ModelRetryStrategy Tests + + +def test_model_retry_strategy_init_with_defaults(): + """Test ModelRetryStrategy initialization with default parameters.""" + strategy = ModelRetryStrategy() + assert strategy._max_attempts == 6 + assert strategy._initial_delay == 4 + assert strategy._max_delay == 240 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_init_with_custom_parameters(): + """Test ModelRetryStrategy initialization with custom parameters.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + assert strategy._max_attempts == 3 + assert strategy._initial_delay == 2 + assert strategy._max_delay == 60 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_calculate_delay_with_different_attempts(): + """Test _calculate_delay returns correct exponential backoff for different attempt numbers.""" + strategy = ModelRetryStrategy(initial_delay=2, max_delay=32) + + # Test exponential backoff: 2 * (2^attempt) + assert strategy._calculate_delay(0) == 2 # 2 * 2^0 = 2 + assert strategy._calculate_delay(1) == 4 # 2 * 2^1 = 4 + assert strategy._calculate_delay(2) == 8 # 2 * 2^2 = 8 + assert strategy._calculate_delay(3) == 16 # 2 * 2^3 = 16 + assert strategy._calculate_delay(4) == 32 # 2 * 2^4 = 32 (at max) + assert strategy._calculate_delay(5) == 32 # 2 * 2^5 = 64, capped at 32 + assert strategy._calculate_delay(10) == 32 # Large attempt, still capped + + +def test_model_retry_strategy_calculate_delay_respects_max_delay(): + """Test _calculate_delay respects max_delay cap.""" + strategy = ModelRetryStrategy(initial_delay=10, max_delay=50) + + assert strategy._calculate_delay(0) == 10 # 10 * 2^0 = 10 + assert strategy._calculate_delay(1) == 20 # 10 * 2^1 = 20 + assert strategy._calculate_delay(2) == 40 # 10 * 2^2 = 40 + assert strategy._calculate_delay(3) == 50 # 10 * 2^3 = 80, capped at 50 + assert strategy._calculate_delay(4) == 50 # 10 * 2^4 = 160, capped at 50 + + +def test_model_retry_strategy_register_hooks(): + """Test that ModelRetryStrategy registers AfterModelCallEvent and AfterInvocationEvent callbacks.""" + strategy = ModelRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify AfterModelCallEvent callback was registered + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + # Verify AfterInvocationEvent callback was registered + assert AfterInvocationEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterInvocationEvent]) == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_retry_on_throttle_exception_first_attempt(mock_sleep): + """Test retry behavior on first ModelThrottledException.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should set retry to True + assert event.retry is True + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + # Should increment attempt + assert strategy._current_attempt == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_exponential_backoff(mock_sleep): + """Test exponential backoff calculation.""" + strategy = ModelRetryStrategy(max_attempts=5, initial_delay=2, max_delay=16) + mock_agent = Mock() + + # Simulate multiple retries + for _ in range(4): + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event) + assert event.retry is True + + # Verify exponential backoff with max_delay cap + # attempt 0: 2*2^0=2, attempt 1: 2*2^1=4, attempt 2: 2*2^2=8, attempt 3: 2*2^3=16 (capped) + assert mock_sleep.sleep_calls == [2, 4, 8, 16] + for i, sleep_delay in enumerate(mock_sleep.sleep_calls): + assert sleep_delay == strategy._calculate_delay(i) + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_after_max_attempts(mock_sleep): + """Test that retry is not set after reaching max_attempts.""" + strategy = ModelRetryStrategy(max_attempts=2, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First attempt + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Second attempt (at max_attempts) + event2 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event2) + # Should NOT retry after reaching max_attempts + assert event2.retry is False + assert strategy._current_attempt == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_non_throttle_exception(): + """Test that retry is not set for non-throttling exceptions.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ValueError("Some other error"), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on non-throttling exceptions + assert event.retry is False + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_success(): + """Test that retry is not set when model call succeeds.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on success + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_success(mock_sleep): + """Test that strategy resets attempt counter on successful call.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First failure + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + + # Success - should reset + event2 = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + await strategy._handle_after_model_call(event2) + assert event2.retry is False + # Should reset to initial state + assert strategy._current_attempt == 0 + assert strategy._calculate_delay(0) == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_skips_if_already_retrying(): + """Test that strategy skips processing if event.retry is already True.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + # Simulate another hook already set retry to True + event.retry = True + + await strategy._handle_after_model_call(event) + + # Should not modify state since another hook already triggered retry + assert strategy._current_attempt == 0 + assert event.retry is True + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_after_invocation(): + """Test that strategy resets state on AfterInvocationEvent.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Simulate some retry attempts + strategy._current_attempt = 3 + + event = AfterInvocationEvent(agent=mock_agent, result=Mock()) + await strategy._handle_after_invocation(event) + + # Should reset to initial state + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_set_on_retry(mock_sleep): + """Test that _backwards_compatible_event_to_yield is set when retrying.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should have set the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is not None + assert isinstance(strategy._backwards_compatible_event_to_yield, EventLoopThrottleEvent) + assert strategy._backwards_compatible_event_to_yield["event_loop_throttled_delay"] == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_cleared_on_success(): + """Test that _backwards_compatible_event_to_yield is cleared on success.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Set a previous backwards compatible event + strategy._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=2) + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should have cleared the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is None + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_not_set_on_max_attempts(mock_sleep): + """Test that _backwards_compatible_event_to_yield is not set when max attempts reached.""" + strategy = ModelRetryStrategy(max_attempts=1, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should not have set the backwards compatible event since max attempts reached + assert strategy._backwards_compatible_event_to_yield is None + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_when_no_exception_and_no_stop_response(): + """Test that retry is not set when there's no exception and no stop_response.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + # Event with neither exception nor stop_response + event = AfterModelCallEvent( + agent=mock_agent, + exception=None, + stop_response=None, + ) + + await strategy._handle_after_model_call(event) + + # Should not retry and should reset state + assert event.retry is False + assert strategy._current_attempt == 0 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 639e60ea0..d4afd579b 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import unittest.mock from unittest.mock import ANY, AsyncMock, MagicMock, call, patch @@ -7,6 +8,7 @@ import strands import strands.telemetry from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -31,9 +33,7 @@ @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -116,7 +116,11 @@ def tool_stream(tool): @pytest.fixture def hook_registry(): - return HookRegistry() + registry = HookRegistry() + # Register default retry strategy + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry @pytest.fixture @@ -147,6 +151,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() return mock @@ -693,7 +698,7 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): + with patch.object(asyncio, "sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -856,15 +861,21 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, # 1st call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 2nd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 3rd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 4th call - successful assert next(events) == BeforeModelCallEvent(agent=agent) From 7604e98bece0fe3fb0e0fcb5baa8055d69dcc422 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 21 Jan 2026 09:48:05 -0500 Subject: [PATCH 40/47] fix(swarm): accumulate execution_time across interrupt/resume cycles (#1502) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/multiagent/swarm.py | 4 ++-- tests/strands/multiagent/test_swarm.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 6c1149624..8368f5936 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -199,7 +199,7 @@ def should_continue( return False, f"Max iterations reached: {max_iterations}" # Check timeout - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -406,7 +406,7 @@ async def stream_async( self.state.completion_status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self.state.execution_time += round((time.time() - self.state.start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..aae11b709 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1243,6 +1243,8 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): multiagent_result = swarm("Test task") + first_execution_time = multiagent_result.execution_time + tru_status = multiagent_result.status exp_status = Status.INTERRUPTED assert tru_status == exp_status @@ -1279,6 +1281,8 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + assert multiagent_result.execution_time >= first_execution_time + def test_swarm_interrupt_on_agent(agenerator): exp_interrupts = [ From 2e23d755ecd438c92082103ec941200e011cadc8 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 21 Jan 2026 10:45:48 -0500 Subject: [PATCH 41/47] Feat: graduate multiagent hook events from experimental (#1498) --- .../experimental/hooks/multiagent/__init__.py | 4 +- .../experimental/hooks/multiagent/events.py | 138 +++--------------- src/strands/hooks/__init__.py | 11 ++ src/strands/hooks/events.py | 106 +++++++++++++- src/strands/multiagent/graph.py | 4 +- src/strands/multiagent/swarm.py | 4 +- src/strands/session/session_manager.py | 6 +- .../fixtures/mock_multiagent_hook_provider.py | 6 +- .../experimental/hooks/multiagent/__init__.py | 0 .../hooks/multiagent => hooks}/test_events.py | 4 +- .../test_multi_agent_hooks.py | 2 +- tests/strands/multiagent/conftest.py | 3 +- tests/strands/multiagent/test_graph.py | 3 +- tests/strands/multiagent/test_swarm.py | 2 +- tests_integ/hooks/multiagent/test_cancel.py | 3 +- tests_integ/hooks/multiagent/test_events.py | 4 +- .../interrupts/multiagent/test_hook.py | 3 +- .../interrupts/multiagent/test_session.py | 3 +- tests_integ/test_multiagent_swarm.py | 2 +- 19 files changed, 164 insertions(+), 144 deletions(-) delete mode 100644 tests/strands/experimental/hooks/multiagent/__init__.py rename tests/strands/{experimental/hooks/multiagent => hooks}/test_events.py (97%) rename tests/strands/{experimental/hooks/multiagent => hooks}/test_multi_agent_hooks.py (98%) diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py index d059d0da5..6755db7e4 100644 --- a/src/strands/experimental/hooks/multiagent/__init__.py +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -1,6 +1,6 @@ -"""Multi-agent hook events and utilities. +"""Multi-agent hook events. -Provides event classes for hooking into multi-agent orchestrator lifecycle. +Deprecated: Use strands.hooks.multiagent instead. """ from .events import ( diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index fa881bf32..2c65c53e3 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -1,118 +1,28 @@ """Multi-agent execution lifecycle events for hook system integration. -These events are fired by orchestrators (Graph/Swarm) at key points so -hooks can persist, monitor, or debug execution. No intermediate state model -is used—hooks read from the orchestrator directly. +Deprecated: Use strands.hooks.multiagent instead. """ -import uuid -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -from typing_extensions import override - -from ....hooks import BaseHookEvent -from ....types.interrupt import _Interruptible - -if TYPE_CHECKING: - from ....multiagent.base import MultiAgentBase - - -@dataclass -class MultiAgentInitializedEvent(BaseHookEvent): - """Event triggered when multi-agent orchestrator initialized. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): - """Event triggered before individual node execution starts. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node about to execute - invocation_state: Configuration that user passes in - cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. - The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the - node using a default cancel message. - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - cancel_node: bool | str = False - - def _can_write(self, name: str) -> bool: - return name in ["cancel_node"] - - @override - def _interrupt_id(self, name: str) -> str: - """Unique id for the interrupt. - - Args: - name: User defined name for the interrupt. - - Returns: - Interrupt id. - """ - node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) - call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) - return f"v1:before_node_call:{node_id}:{call_id}" - - -@dataclass -class AfterNodeCallEvent(BaseHookEvent): - """Event triggered after individual node execution completes. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node that just completed execution - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BeforeMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered before orchestrator execution starts. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class AfterMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered after orchestrator execution completes. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True +import warnings + +from ....hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +warnings.warn( + "strands.experimental.hooks.multiagent is deprecated. Use strands.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..96c7f577b 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -32,12 +32,18 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .events import ( AfterInvocationEvent, AfterModelCallEvent, + # Multiagent hook events + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, + MultiAgentInitializedEvent, ) from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry @@ -56,4 +62,9 @@ def log_end(self, event: AfterInvocationEvent) -> None: "HookRegistry", "HookEvent", "BaseHookEvent", + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", ] diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8aa8a68d6..1faa8a917 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -16,7 +16,10 @@ from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse -from .registry import HookEvent +from .registry import BaseHookEvent, HookEvent + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase @dataclass @@ -250,3 +253,104 @@ def _can_write(self, name: str) -> bool: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +# Multiagent hook events start here +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. + The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the + node using a default cancel message. + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + cancel_node: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_node"] + + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) + call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) + return f"v1:before_node_call:{node_id}:{call_id}" + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 97435ad4a..32eca00ff 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -27,14 +27,14 @@ from .._async import run_async from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 8368f5936..9a4ce5494 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -27,14 +27,14 @@ from .._async import run_async from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index ba4356089..cc954e17d 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -9,12 +9,14 @@ BidiAgentInitializedEvent, BidiMessageAddedEvent, ) -from ..experimental.hooks.multiagent.events import ( +from ..hooks.events import ( + AfterInvocationEvent, AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + AgentInitializedEvent, + MessageAddedEvent, MultiAgentInitializedEvent, ) -from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 4d18297a2..a89d5aca8 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,16 +1,14 @@ from collections.abc import Iterator from typing import Literal -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, - MultiAgentInitializedEvent, -) -from strands.hooks import ( HookEvent, HookProvider, HookRegistry, + MultiAgentInitializedEvent, ) diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/hooks/test_events.py similarity index 97% rename from tests/strands/experimental/hooks/multiagent/test_events.py rename to tests/strands/hooks/test_events.py index 6c4d7c4e7..90ab205a9 100644 --- a/tests/strands/experimental/hooks/multiagent/test_events.py +++ b/tests/strands/hooks/test_events.py @@ -4,14 +4,14 @@ import pytest -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + BaseHookEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from strands.hooks import BaseHookEvent @pytest.fixture diff --git a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py b/tests/strands/hooks/test_multi_agent_hooks.py similarity index 98% rename from tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py rename to tests/strands/hooks/test_multi_agent_hooks.py index 4e97a9217..3f6e0c940 100644 --- a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py +++ b/tests/strands/hooks/test_multi_agent_hooks.py @@ -1,7 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py index 85e0ef7fc..e5dd1b4f9 100644 --- a/tests/strands/multiagent/conftest.py +++ b/tests/strands/multiagent/conftest.py @@ -1,7 +1,6 @@ import pytest -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider @pytest.fixture diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index ab2d86e70..cd750865e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -6,8 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import AgentInitializedEvent +from strands.hooks import AgentInitializedEvent, BeforeNodeCallEvent from strands.hooks.registry import HookProvider, HookRegistry from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index aae11b709..75ef97a25 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -6,7 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import BeforeNodeCallEvent from strands.hooks.registry import HookRegistry from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import Status diff --git a/tests_integ/hooks/multiagent/test_cancel.py b/tests_integ/hooks/multiagent/test_cancel.py index 9267330b7..ae3008861 100644 --- a/tests_integ/hooks/multiagent/test_cancel.py +++ b/tests_integ/hooks/multiagent/test_cancel.py @@ -1,8 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types._events import MultiAgentNodeCancelEvent diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py index e8039444f..3a10b74c1 100644 --- a/tests_integ/hooks/multiagent/test_events.py +++ b/tests_integ/hooks/multiagent/test_events.py @@ -1,14 +1,14 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + HookProvider, MultiAgentInitializedEvent, ) -from strands.hooks import HookProvider from strands.multiagent import GraphBuilder, Swarm diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index 9350b3535..53305b4e8 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -4,8 +4,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index bab4b428f..2ccff2c12 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,8 +4,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e8e969af1..e9738d3d9 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -3,13 +3,13 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import ( AfterInvocationEvent, AfterModelCallEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) From b41a99bedaca93b55cb57262aeb5c109f6b2a688 Mon Sep 17 00:00:00 2001 From: Lana Zhang Date: Wed, 21 Jan 2026 11:07:18 -0500 Subject: [PATCH 42/47] Nova Sonic 2 support for BidiAgent (#1476) --- README.md | 28 ++- .../experimental/bidi/models/nova_sonic.py | 142 +++++++++++++-- .../bidi/models/test_nova_sonic.py | 165 ++++++++++++++++++ tests_integ/bidi/test_bidirectional_agent.py | 9 +- 4 files changed, 328 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e7d1b2a7e..8e4d9d0e8 100644 --- a/README.md +++ b/README.md @@ -204,9 +204,9 @@ It's also available on GitHub via [strands-agents/tools](https://github.com/stra Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. **Supported Model Providers:** -- Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) -- Google Gemini Live (`gemini-2.5-flash-native-audio-preview-09-2025`) -- OpenAI Realtime API (`gpt-realtime`) +- Amazon Nova Sonic (v1, v2) +- Google Gemini Live +- OpenAI Realtime API **Quick Example:** @@ -219,7 +219,7 @@ from strands.experimental.bidi.tools import stop_conversation from strands_tools import calculator async def main(): - # Create bidirectional agent with audio model + # Create bidirectional agent with Nova Sonic v2 model = BidiNovaSonicModel() agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) @@ -241,7 +241,9 @@ if __name__ == "__main__": **Configuration Options:** ```python -# Configure audio settings +from strands.experimental.bidi.models import BidiNovaSonicModel + +# Configure audio settings and turn detection (v2 only) model = BidiNovaSonicModel( provider_config={ "audio": { @@ -249,6 +251,9 @@ model = BidiNovaSonicModel( "output_rate": 16000, "voice": "matthew" }, + "turn_detection": { + "endpointingSensitivity": "MEDIUM" # HIGH, MEDIUM, or LOW + }, "inference": { "max_tokens": 2048, "temperature": 0.7 @@ -263,6 +268,19 @@ audio_io = BidiAudioIO( input_buffer_size=10, output_buffer_size=10 ) + +# Text input mode (type messages instead of speaking) +text_io = BidiTextIO() +await agent.run( + inputs=[text_io.input()], # Use text input + outputs=[audio_io.output(), text_io.output()] +) + +# Multi-modal: Both audio and text input +await agent.run( + inputs=[audio_io.input(), text_io.input()], # Speak OR type + outputs=[audio_io.output(), text_io.output()] +) ``` ## Documentation diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 1c946220d..d836bde49 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -64,6 +64,10 @@ logger = logging.getLogger(__name__) +# Nova Sonic model identifiers +NOVA_SONIC_V1_MODEL_ID = "amazon.nova-sonic-v1:0" +NOVA_SONIC_V2_MODEL_ID = "amazon.nova-2-sonic-v1:0" + _NOVA_INFERENCE_CONFIG_KEYS = { "max_tokens": "maxTokens", "temperature": "temperature", @@ -110,7 +114,7 @@ class BidiNovaSonicModel(BidiModel): def __init__( self, - model_id: str = "amazon.nova-sonic-v1:0", + model_id: str = NOVA_SONIC_V2_MODEL_ID, provider_config: dict[str, Any] | None = None, client_config: dict[str, Any] | None = None, **kwargs: Any, @@ -118,19 +122,41 @@ def __init__( """Initialize Nova Sonic bidirectional model. Args: - model_id: Model identifier (default: amazon.nova-sonic-v1:0) - provider_config: Model behavior (audio, inference settings) + model_id: Model identifier (default: amazon.nova-2-sonic-v1:0) + provider_config: Model behavior configuration including: + - audio: Audio input/output settings (sample rate, voice, etc.) + - inference: Model inference settings (max_tokens, temperature, top_p) + - turn_detection: Turn detection configuration (v2 only feature) + - endpointingSensitivity: "HIGH" | "MEDIUM" | "LOW" (optional) client_config: AWS authentication (boto_session OR region, not both) **kwargs: Reserved for future parameters. + + Raises: + ValueError: If turn_detection is used with v1 model. + ValueError: If endpointingSensitivity is not HIGH, MEDIUM, or LOW. """ # Store model ID self.model_id = model_id + # Validate turn_detection configuration + provider_config = provider_config or {} + if "turn_detection" in provider_config and provider_config["turn_detection"]: + if model_id == NOVA_SONIC_V1_MODEL_ID: + raise ValueError( + f"turn_detection is only supported in Nova Sonic v2. " + f"Current model_id: {model_id}. Use {NOVA_SONIC_V2_MODEL_ID} instead." + ) + + # Validate endpointingSensitivity value if provided + sensitivity = provider_config["turn_detection"].get("endpointingSensitivity") + if sensitivity and sensitivity not in ["HIGH", "MEDIUM", "LOW"]: + raise ValueError(f"Invalid endpointingSensitivity: {sensitivity}. Must be HIGH, MEDIUM, or LOW") + # Resolve client config with defaults self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) + self.config = self._resolve_provider_config(provider_config) # Store session and region for later use self._session = self._client_config["boto_session"] @@ -182,6 +208,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: **config.get("audio", {}), }, "inference": config.get("inference", {}), + "turn_detection": config.get("turn_detection", {}), } return resolved @@ -269,21 +296,57 @@ def _build_initialization_events( def _log_event_type(self, nova_event: dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" + # Log the full event structure for detailed debugging + event_keys = list(nova_event.keys()) + logger.debug("event_keys=<%s> | nova sonic event received", event_keys) + if "usageEvent" in nova_event: - logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + usage = nova_event["usageEvent"] + logger.debug( + "input_tokens=<%s>, output_tokens=<%s>, usage_details=<%s> | nova usage event", + usage.get("totalInputTokens", 0), + usage.get("totalOutputTokens", 0), + json.dumps(usage, indent=2), + ) elif "textOutput" in nova_event: - logger.debug("nova text output received") + text_content = nova_event["textOutput"].get("content", "") + logger.debug( + "text_length=<%d>, text_preview=<%s>, text_output_details=<%s> | nova text output", + len(text_content), + text_content[:100], + json.dumps(nova_event["textOutput"], indent=2)[:500], + ) elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + "tool_name=<%s>, tool_use_id=<%s>, tool_use_details=<%s> | nova tool use received", tool_use["toolName"], tool_use["toolUseId"], + json.dumps(tool_use, indent=2)[:500], ) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + elif "completionStart" in nova_event: + completion_id = nova_event["completionStart"].get("completionId", "unknown") + logger.debug("completion_id=<%s> | nova completion started", completion_id) + elif "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + logger.debug( + "completion_id=<%s>, stop_reason=<%s> | nova completion ended", + completion_data.get("completionId", "unknown"), + completion_data.get("stopReason", "unknown"), + ) + elif "stopReason" in nova_event: + logger.debug("stop_reason=<%s> | nova stop reason event", nova_event["stopReason"]) + else: + # Log any other event types + audio_metadata = self._get_audio_metadata_for_logging({"event": nova_event}) + if audio_metadata: + logger.debug("audio_byte_count=<%d> | nova sonic event with audio", audio_metadata["audio_byte_count"]) + else: + logger.debug("event_payload=<%s> | nova sonic event details", json.dumps(nova_event, indent=2)[:500]) async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Nova Sonic events and convert to provider-agnostic format. @@ -312,14 +375,25 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: raise BidiModelTimeoutError(error.message) from error if not event_data: + logger.debug("received empty event data, continuing") continue - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + # Decode and parse the event + raw_bytes = event_data.value.bytes_.decode("utf-8") + logger.debug("raw_event_size=<%d> | received nova sonic event", len(raw_bytes)) + + nova_event = json.loads(raw_bytes)["event"] self._log_event_type(nova_event) model_event = self._convert_nova_event(nova_event) if model_event: + event_type = ( + model_event.get("type", "unknown") if isinstance(model_event, dict) else type(model_event).__name__ + ) + logger.debug("converted_event_type=<%s> | yielding converted event", event_type) yield model_event + else: + logger.debug("event_not_converted | nova event did not produce output event") async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. @@ -336,14 +410,24 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: raise RuntimeError("model not started | call start before sending") if isinstance(content, BidiTextInputEvent): + text_preview = content.text[:100] if len(content.text) > 100 else content.text + logger.debug("text_length=<%d>, text_preview=<%s> | sending text content", len(content.text), text_preview) await self._send_text_content(content.text) elif isinstance(content, BidiAudioInputEvent): + audio_size = len(base64.b64decode(content.audio)) if content.audio else 0 + logger.debug("audio_bytes=<%d>, format=<%s> | sending audio content", audio_size, content.format) await self._send_audio_content(content) elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") if tool_result: + logger.debug( + "tool_use_id=<%s>, content_blocks=<%d> | sending tool result", + tool_result.get("toolUseId", "unknown"), + len(tool_result.get("content", [])), + ) await self._send_tool_result(tool_result) else: + logger.error("content_type=<%s> | unsupported content type", type(content)) raise ValueError(f"content_type={type(content)} | content not supported") async def _start_audio_connection(self) -> None: @@ -583,7 +667,15 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) + + session_start_event: dict[str, Any] = {"event": {"sessionStart": {"inferenceConfiguration": inference_config}}} + + # Add turn detection configuration if provided (v2 feature) + turn_detection_config = self.config.get("turn_detection", {}) + if turn_detection_config: + session_start_event["event"]["sessionStart"]["turnDetectionConfiguration"] = turn_detection_config + + return json.dumps(session_start_event) def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" @@ -749,6 +841,37 @@ def _get_connection_end_event(self) -> str: """Generate connection end event.""" return json.dumps({"event": {"connectionEnd": {}}}) + def _get_audio_metadata_for_logging(self, event_dict: dict[str, Any]) -> dict[str, Any]: + """Extract audio metadata from event dict for logging. + + Instead of logging large base64-encoded audio data, this extracts metadata + like byte count to verify audio presence without bloating logs. + + Args: + event_dict: The event dictionary to process. + + Returns: + A dict with audio metadata (byte_count) if audio is present, empty dict otherwise. + """ + metadata: dict[str, Any] = {} + + if "event" in event_dict: + event_data = event_dict["event"] + + # Handle contentStart events with audio + if "contentStart" in event_data and "content" in event_data["contentStart"]: + content = event_data["contentStart"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + # Handle content events with audio + if "content" in event_data and "content" in event_data["content"]: + content = event_data["content"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + return metadata + async def _send_nova_events(self, events: list[str]) -> None: """Send event JSON string to Nova Sonic stream. @@ -764,4 +887,3 @@ async def _send_nova_events(self, events: list[str]) -> None: value=BidirectionalInputPayloadPart(bytes_=bytes_data) ) await self._stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 7435d4ad2..14630875b 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -23,6 +23,8 @@ from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, + NOVA_SONIC_V1_MODEL_ID, + NOVA_SONIC_V2_MODEL_ID, ) from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -579,6 +581,169 @@ async def test_default_audio_rates_in_events(model_id, boto_session): assert result.format == "pcm" +# Nova Sonic v2 Support Tests + + +def test_nova_sonic_model_constants(): + """Test that Nova Sonic model ID constants are correctly defined.""" + assert NOVA_SONIC_V1_MODEL_ID == "amazon.nova-sonic-v1:0" + assert NOVA_SONIC_V2_MODEL_ID == "amazon.nova-2-sonic-v1:0" + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v1 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V1_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "joanna", "output_rate": 24000}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_custom.config["audio"]["voice"] == "joanna" + assert model_custom.config["audio"]["output_rate"] == 24000 + + +@pytest.mark.asyncio +async def test_nova_sonic_v2_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v2 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V2_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "ruth", "input_rate": 48000}, "inference": {"temperature": 0.8}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V2_MODEL_ID + assert model_custom.config["audio"]["voice"] == "ruth" + assert model_custom.config["audio"]["input_rate"] == 48000 + assert model_custom.config["inference"]["temperature"] == 0.8 + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_v2_compatibility(boto_session, mock_client): + """Test that v1 and v2 models have the same config structure and behavior.""" + _ = mock_client # Ensure mock is active + + # Create both models with same config + provider_config = {"audio": {"voice": "matthew"}} + client_config = {"boto_session": boto_session} + + model_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + # Both should have the same config structure + assert model_v1.config["audio"]["voice"] == model_v2.config["audio"]["voice"] + assert model_v1.region == model_v2.region + + # Only model_id should differ + assert model_v1.model_id != model_v2.model_id + assert model_v1.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_backward_compatibility(boto_session, mock_client): + """Test that existing code continues to work (backward compatibility).""" + _ = mock_client # Ensure mock is active + + # Test that default behavior now uses v2 (updated default) + model_default = BidiNovaSonicModel(client_config={"boto_session": boto_session}) + assert model_default.model_id == NOVA_SONIC_V2_MODEL_ID + + # Test that existing explicit v1 usage still works + model_explicit_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v1.model_id == NOVA_SONIC_V1_MODEL_ID + + # Test that explicit v2 usage works + model_explicit_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_v1_validation(boto_session, mock_client): + """Test that turn_detection raises error when used with v1 model.""" + _ = mock_client # Ensure mock is active + + # Test that turn_detection with v1 raises ValueError + with pytest.raises(ValueError, match="turn_detection is only supported in Nova Sonic v2"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + + # Test that turn_detection with v2 works fine + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + assert model_v2.config["turn_detection"]["endpointingSensitivity"] == "MEDIUM" + + # Test that empty turn_detection dict doesn't raise error for v1 + model_v1_empty = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert model_v1_empty.model_id == NOVA_SONIC_V1_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_sensitivity_validation(boto_session, mock_client): + """Test that endpointingSensitivity is validated at initialization.""" + _ = mock_client # Ensure mock is active + + # Test invalid sensitivity value raises ValueError at init + with pytest.raises(ValueError, match="Invalid endpointingSensitivity.*Must be HIGH, MEDIUM, or LOW"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "INVALID"}}, + client_config={"boto_session": boto_session}, + ) + + # Test valid sensitivity values work + for sensitivity in ["HIGH", "MEDIUM", "LOW"]: + model = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": sensitivity}}, + client_config={"boto_session": boto_session}, + ) + assert model.config["turn_detection"]["endpointingSensitivity"] == sensitivity + + # Test that turn_detection without sensitivity works (sensitivity is optional) + model_no_sensitivity = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert "endpointingSensitivity" not in model_no_sensitivity.config["turn_detection"] + + # Error Handling Tests @pytest.mark.asyncio async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index 61cf78723..243db46ac 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -55,11 +55,18 @@ def calculator(operation: str, x: float, y: float) -> float: PROVIDER_CONFIGS = { "nova_sonic": { "model_class": BidiNovaSonicModel, - "model_kwargs": {"region": "us-east-1"}, + "model_kwargs": {"region": "us-east-1"}, # Uses v2 by default "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], "skip_reason": "AWS credentials not available", }, + "nova_sonic_v1": { + "model_class": BidiNovaSonicModel, + "model_kwargs": {"model_id": "amazon.nova-sonic-v1:0", "region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic v1 needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, "openai": { "model_class": BidiOpenAIRealtimeModel, "model_kwargs": { From f87925b9383a8ead59f0138e55e8412478c70928 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 21 Jan 2026 18:09:29 +0200 Subject: [PATCH 43/47] fix(tests): reduce flakiness in guardrail redact output test (#1505) --- pyproject.toml | 1 + tests_integ/conftest.py | 116 +++++++++++++++++++++++++ tests_integ/test_bedrock_guardrails.py | 13 ++- 3 files changed, 126 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b49c74d1b..a16132881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dev = [ "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.15.0", + "tenacity>=9.0.0,<10.0.0", ] [project.urls] diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 26453e1f7..9de00089b 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -1,13 +1,129 @@ +import functools import json import logging import os +from collections.abc import Callable, Sequence import boto3 import pytest +from tenacity import RetryCallState, RetryError, Retrying, stop_after_attempt, wait_exponential logger = logging.getLogger(__name__) +# Type alias for retry conditions +RetryCondition = type[BaseException] | Callable[[BaseException], bool] | str + + +def _should_retry_exception(exc: BaseException, conditions: Sequence[RetryCondition]) -> bool: + """Check if exception matches any of the given retry conditions. + + Args: + exc: The exception to check + conditions: Sequence of conditions, each can be: + - Exception type: retry if isinstance(exc, condition) + - Callable: retry if condition(exc) returns True + - str: retry if string is in str(exc) + """ + for condition in conditions: + if isinstance(condition, type) and issubclass(condition, BaseException): + if isinstance(exc, condition): + return True + elif callable(condition): + if condition(exc): + return True + elif isinstance(condition, str): + if condition in str(exc): + return True + return False + + +_RETRY_ON_ANY: Sequence[RetryCondition] = (lambda _: True,) + + +def retry_on_flaky( + reason: str, + *, + max_attempts: int = 3, + wait_multiplier: float = 1, + wait_max: float = 10, + retry_on: Sequence[RetryCondition] = _RETRY_ON_ANY, +) -> Callable: + """Decorator to retry flaky integration tests that fail due to external factors. + + WHEN TO USE: + - External service instability (API rate limits, transient network errors) + - Non-deterministic LLM responses that occasionally fail assertions + - Resource contention in shared test environments + - Known intermittent issues with third-party dependencies + + WHEN NOT TO USE: + - Actual bugs in the code under test (fix the bug instead) + - Deterministic failures (these indicate real problems) + - Unit tests (flakiness in unit tests usually indicates a design issue) + - To mask consistently failing tests (investigate root cause first) + + Prefer using specific retry_on conditions over retrying on any exception + to avoid masking real bugs. + + Args: + reason: Required explanation of why this test is flaky and needs retries. + This should describe the source of non-determinism (e.g., "LLM responses + may vary" or "External API has intermittent rate limits"). + max_attempts: Maximum number of retry attempts (default: 3) + wait_multiplier: Multiplier for exponential backoff in seconds (default: 1) + wait_max: Maximum wait time between retries in seconds (default: 10) + retry_on: Conditions for when to retry. Defaults to retrying on any exception. + Each condition can be: + - Exception type: e.g., ValueError, TimeoutError + - Callable: e.g., lambda e: "timeout" in str(e).lower() + - str: substring to match in exception message + + Usage: + # Retry on any failure + @retry_on_flaky("LLM responses are non-deterministic") + def test_something(): + ... + + # Retry only on specific exception types + @retry_on_flaky("Network calls may fail transiently", retry_on=[TimeoutError, ConnectionError]) + def test_network_call(): + ... + + # Retry on string patterns in exception message + @retry_on_flaky("Service has intermittent availability", retry_on=["Service unavailable", "Status 503"]) + def test_service_call(): + ... + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + def should_retry(retry_state: RetryCallState) -> bool: + if retry_state.outcome is None or not retry_state.outcome.failed: + return False + exc = retry_state.outcome.exception() + if exc is None: + return False + return _should_retry_exception(exc, retry_on) + + try: + for attempt in Retrying( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=wait_multiplier, max=wait_max), + retry=should_retry, + reraise=True, + ): + with attempt: + return func(*args, **kwargs) + except RetryError: + raise + + return wrapper + + return decorator + + def pytest_sessionstart(session): _load_api_keys_from_secrets_manager() diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 058597026..56edc3fc4 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -8,6 +8,7 @@ from strands import Agent, tool from strands.models.bedrock import BedrockModel from strands.session.file_session_manager import FileSessionManager +from tests_integ.conftest import retry_on_flaky BLOCKED_INPUT = "BLOCKED_INPUT" BLOCKED_OUTPUT = "BLOCKED_OUTPUT" @@ -170,9 +171,11 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi ) +@retry_on_flaky("LLM may mention CACTUS unprompted, triggering guardrail on response2") @pytest.mark.parametrize("guardrail_trace", ["enabled", "enabled_full"]) @pytest.mark.parametrize("processing_mode", ["sync", "async"]) def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode, guardrail_trace): + """Test guardrail output intervention with redaction.""" REDACT_MESSAGE = "Redacted." bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, @@ -182,23 +185,25 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, region_name="us-east-1", + temperature=0, # Use deterministic responses to reduce flakiness ) agent = Agent( model=bedrock_model, - system_prompt="When asked to say the word, say CACTUS.", + system_prompt="When asked to say the word, say CACTUS. Otherwise, respond normally.", callback_handler=None, load_tools_from_directory=False, ) response1 = agent("Say the word.") - response2 = agent("Hello!") + # Use a completely unrelated prompt to reduce likelihood of model volunteering CACTUS + response2 = agent("What is 2+2? Reply with only the number.") assert response1.stop_reason == "guardrail_intervened" """ - In async streaming: The buffering is non-blocking. - Tokens are streamed while Guardrails processes the buffered content in the background. + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. This means the response may be returned before Guardrails has finished processing. As a result, we cannot guarantee that the REDACT_MESSAGE is in the response. """ From d851d0604bbf9af9e3157a6a117526904e5f6b7e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 9 Dec 2025 10:04:45 -0500 Subject: [PATCH 44/47] ci: add workflow for lambda layer publish and yank --- .github/workflows/LAMDBA_LAYERS_SOP.md | 43 +++++ .github/workflows/publish-lambda-layer.yml | 202 +++++++++++++++++++++ .github/workflows/yank-lambda-layer.yml | 81 +++++++++ 3 files changed, 326 insertions(+) create mode 100644 .github/workflows/LAMDBA_LAYERS_SOP.md create mode 100644 .github/workflows/publish-lambda-layer.yml create mode 100644 .github/workflows/yank-lambda-layer.yml diff --git a/.github/workflows/LAMDBA_LAYERS_SOP.md b/.github/workflows/LAMDBA_LAYERS_SOP.md new file mode 100644 index 000000000..4ac96a77d --- /dev/null +++ b/.github/workflows/LAMDBA_LAYERS_SOP.md @@ -0,0 +1,43 @@ +# Lambda Layers Standard Operating Procedures (SOP) + +## Overview + +This document defines the standard operating procedures for managing Strands Agents Lambda layers across all AWS regions, Python versions, and architectures. + +**Total: 136 individual Lambda layers** (17 regions × 2 architectures × 4 Python versions). All variants must maintain the same layer version number for each PyPI package version, with only one row per PyPI version appearing in documentation. + +## Deployment Process + +### 1. Initial Deployment +1. Run workflow with ALL options selected (default) +2. Specify PyPI package version +3. Type "Create Lambda Layer {package_version}" to confirm +4. All 136 individual layers deploy in parallel (4 Python × 2 arch × 17 regions) +5. Each layer gets its own unique name: `strands-agents-py{PYTHON_VERSION}-{ARCH}` + +### 2. Version Buffering for New Variants +When adding new variants (new Python version, architecture, or region): + +1. **Determine target layer version**: Check existing variants to find the highest layer version +2. **Buffer deployment**: Deploy new variants multiple times until layer version matches existing variants +3. **Example**: If existing variants are at layer version 5, deploy new variant 5 times to reach version 5 + +### 3. Handling Transient Failures +When some regions fail during deployment: + +1. **Identify failed regions**: Check which combinations didn't complete successfully +2. **Targeted redeployment**: Use specific region/arch/Python inputs to redeploy failed combinations +3. **Version alignment**: Continue deploying until all variants reach the same layer version +4. **Verification**: Confirm all combinations have identical layer versions before updating docs + +## Yank Process + +### Yank Procedure +1. Use the `yank_lambda_layer` GitHub action workflow +2. Specify the layer version to yank +3. Type "Yank Lambda Layer {layer_version}" to confirm +4. **Full yank**: Run with ALL options selected (default) to yank all 136 variants OR **Partial yank**: Specify Python versions, architectures, and regions for targeted yanking +6. Update documentation +7. **Communication**: Notify users through appropriate channels + +**Note**: Yanking deletes layer versions completely. Existing Lambda functions using the layer continue to work, but new functions cannot use the yanked version. \ No newline at end of file diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml new file mode 100644 index 000000000..9e2702819 --- /dev/null +++ b/.github/workflows/publish-lambda-layer.yml @@ -0,0 +1,202 @@ +name: Publish PyPI Package to Lambda Layer + +on: + workflow_dispatch: + inputs: + package_version: + description: 'Package version to download' + required: true + type: string + layer_version: + description: 'Layer version' + required: true + type: string + python_version: + description: 'Python version' + required: true + default: 'ALL' + type: choice + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: true + default: 'ALL' + type: choice + options: ['ALL', 'x86_64', 'aarch64'] + region: + description: 'AWS region' + required: true + default: 'ALL' + type: choice + # Only non opt-in regions included for now + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Create Lambda Layer {PyPI version}-layer{layer version}" to confirm publishing the layer' + required: true + type: string + +env: + BUCKET_NAME: strands-agents-lambda-layer + +jobs: + validate: + runs-on: ubuntu-latest + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Create Lambda Layer ${{ inputs.package_version }}-layer${{ inputs.layer_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + create-buckets: + needs: validate + runs-on: ubuntu-latest + strategy: + matrix: + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + permissions: + id-token: write + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create S3 bucket + run: | + REGION="${{ matrix.region }}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGIONAL_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + + if ! aws s3api head-bucket --bucket "$REGIONAL_BUCKET" 2>/dev/null; then + if [ "$REGION" = "us-east-1" ]; then + aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" + else + aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" --create-bucket-configuration LocationConstraint="$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" + fi + echo "S3 bucket ready: $REGIONAL_BUCKET" + else + echo "S3 bucket already exists: $REGIONAL_BUCKET" + fi + + package-and-upload: + needs: create-buckets + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create layer directory structure + run: | + mkdir -p layer/python + + - name: Download and install package + run: | + pip install strands-agents==${{ inputs.package_version }} \ + --python-version ${{ matrix.python-version }} \ + --platform manylinux2014_${{ matrix.architecture }} \ + -t layer/python/ \ + --only-binary=:all: + + - name: Create layer zip + run: | + cd layer + zip -r ../lambda-layer.zip . + + - name: Upload to S3 + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + BUCKET_NAME="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" + echo "Uploaded layer to s3://$BUCKET_NAME/$LAYER_KEY" + + publish-layer: + needs: package-and-upload + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Publish layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGION_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" + + # Set compatible architecture based on matrix architecture + if [ "$ARCH" = "x86_64" ]; then + COMPATIBLE_ARCH="x86_64" + else + COMPATIBLE_ARCH="arm64" + fi + + LAYER_OUTPUT=$(aws lambda publish-layer-version \ + --layer-name $LAYER_NAME \ + --description "$DESCRIPTION" \ + --content S3Bucket=$REGION_BUCKET,S3Key=$LAYER_KEY \ + --compatible-runtimes python${{ matrix.python-version }} \ + --compatible-architectures $COMPATIBLE_ARCH \ + --region "$REGION" \ + --license-info Apache-2.0 \ + --output json) + + LAYER_ARN=$(echo "$LAYER_OUTPUT" | jq -r '.LayerArn') + LAYER_VERSION=$(echo "$LAYER_OUTPUT" | jq -r '.Version') + + echo "Published layer version $LAYER_VERSION with ARN: $LAYER_ARN in region $REGION" + + aws lambda add-layer-version-permission \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --statement-id public \ + --action lambda:GetLayerVersion \ + --principal '*' \ + --region "$REGION" + + echo "Successfully published layer version $LAYER_VERSION in region $REGION" \ No newline at end of file diff --git a/.github/workflows/yank-lambda-layer.yml b/.github/workflows/yank-lambda-layer.yml new file mode 100644 index 000000000..27927a862 --- /dev/null +++ b/.github/workflows/yank-lambda-layer.yml @@ -0,0 +1,81 @@ +name: Yank Lambda Layer + +on: + workflow_dispatch: + inputs: + layer_version: + description: 'Layer version to yank' + required: true + type: string + python_version: + description: 'Python version' + required: true + default: 'ALL' + type: choice + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: true + default: 'ALL' + type: choice + options: ['ALL', 'x86_64', 'aarch64'] + region: + description: 'AWS region' + required: true + default: 'ALL' + type: choice + # Only non opt-in regions included for now + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Yank Lambda Layer {layer version}" to confirm yanking the layer' + required: true + type: string + +jobs: + yank-layer: + runs-on: ubuntu-latest + continue-on-error: true + strategy: + fail-fast: false + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Yank Lambda Layer ${{ inputs.layer_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Yank layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + LAYER_VERSION="${{ inputs.layer_version }}" + + echo "Attempting to yank layer $LAYER_NAME version $LAYER_VERSION in region $REGION" + + # Delete the layer version completely + aws lambda delete-layer-version \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --region "$REGION" + + echo "Completed yank attempt for layer $LAYER_NAME version $LAYER_VERSION in region $REGION" \ No newline at end of file From 8f6e82fd6597c70c4c552b37ae65531da3b4a7c7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 16 Jan 2026 15:21:15 -0500 Subject: [PATCH 45/47] fix: update bucket name to match deployed infra --- .github/workflows/publish-lambda-layer.yml | 41 +---------- .github/workflows/yank-lambda-layer.yml | 81 ---------------------- 2 files changed, 3 insertions(+), 119 deletions(-) delete mode 100644 .github/workflows/yank-lambda-layer.yml diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index 9e2702819..f21e88fd7 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -35,9 +35,6 @@ on: required: true type: string -env: - BUCKET_NAME: strands-agents-lambda-layer - jobs: validate: runs-on: ubuntu-latest @@ -52,40 +49,8 @@ jobs: fi echo "Confirmation validated" - create-buckets: - needs: validate - runs-on: ubuntu-latest - strategy: - matrix: - region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} - permissions: - id-token: write - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} - aws-region: ${{ matrix.region }} - - - name: Create S3 bucket - run: | - REGION="${{ matrix.region }}" - ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) - REGIONAL_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" - - if ! aws s3api head-bucket --bucket "$REGIONAL_BUCKET" 2>/dev/null; then - if [ "$REGION" = "us-east-1" ]; then - aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" - else - aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" --create-bucket-configuration LocationConstraint="$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" - fi - echo "S3 bucket ready: $REGIONAL_BUCKET" - else - echo "S3 bucket already exists: $REGIONAL_BUCKET" - fi - package-and-upload: - needs: create-buckets + needs: validate runs-on: ubuntu-latest strategy: matrix: @@ -132,7 +97,7 @@ jobs: REGION="${{ matrix.region }}" LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) - BUCKET_NAME="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + BUCKET_NAME="strands-layer-${ACCOUNT_ID}-${{ vars.BUCKET_SALT }}-${REGION}" LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" @@ -164,7 +129,7 @@ jobs: REGION="${{ matrix.region }}" LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) - REGION_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + REGION_BUCKET="strands-layer-${ACCOUNT_ID}-${{ vars.BUCKET_SALT }}-${REGION}" LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" diff --git a/.github/workflows/yank-lambda-layer.yml b/.github/workflows/yank-lambda-layer.yml deleted file mode 100644 index 27927a862..000000000 --- a/.github/workflows/yank-lambda-layer.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Yank Lambda Layer - -on: - workflow_dispatch: - inputs: - layer_version: - description: 'Layer version to yank' - required: true - type: string - python_version: - description: 'Python version' - required: true - default: 'ALL' - type: choice - options: ['ALL', '3.10', '3.11', '3.12', '3.13'] - architecture: - description: 'Architecture' - required: true - default: 'ALL' - type: choice - options: ['ALL', 'x86_64', 'aarch64'] - region: - description: 'AWS region' - required: true - default: 'ALL' - type: choice - # Only non opt-in regions included for now - options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] - confirm: - description: 'Type "Yank Lambda Layer {layer version}" to confirm yanking the layer' - required: true - type: string - -jobs: - yank-layer: - runs-on: ubuntu-latest - continue-on-error: true - strategy: - fail-fast: false - matrix: - python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} - architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} - region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} - - permissions: - id-token: write - - steps: - - name: Validate confirmation - run: | - CONFIRM="${{ inputs.confirm }}" - EXPECTED="Yank Lambda Layer ${{ inputs.layer_version }}" - if [ "$CONFIRM" != "$EXPECTED" ]; then - echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." - exit 1 - fi - echo "Confirmation validated" - - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} - aws-region: ${{ matrix.region }} - - - name: Yank layer - run: | - PYTHON_VERSION="${{ matrix.python-version }}" - ARCH="${{ matrix.architecture }}" - REGION="${{ matrix.region }}" - LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" - LAYER_VERSION="${{ inputs.layer_version }}" - - echo "Attempting to yank layer $LAYER_NAME version $LAYER_VERSION in region $REGION" - - # Delete the layer version completely - aws lambda delete-layer-version \ - --layer-name $LAYER_NAME \ - --version-number $LAYER_VERSION \ - --region "$REGION" - - echo "Completed yank attempt for layer $LAYER_NAME version $LAYER_VERSION in region $REGION" \ No newline at end of file From 112453491ae6ee37b6ed2642b26069380f066f15 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 16 Jan 2026 15:27:13 -0500 Subject: [PATCH 46/47] remove yank from SOP --- .github/workflows/LAMDBA_LAYERS_SOP.md | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/.github/workflows/LAMDBA_LAYERS_SOP.md b/.github/workflows/LAMDBA_LAYERS_SOP.md index 4ac96a77d..1cf58a614 100644 --- a/.github/workflows/LAMDBA_LAYERS_SOP.md +++ b/.github/workflows/LAMDBA_LAYERS_SOP.md @@ -28,16 +28,4 @@ When some regions fail during deployment: 1. **Identify failed regions**: Check which combinations didn't complete successfully 2. **Targeted redeployment**: Use specific region/arch/Python inputs to redeploy failed combinations 3. **Version alignment**: Continue deploying until all variants reach the same layer version -4. **Verification**: Confirm all combinations have identical layer versions before updating docs - -## Yank Process - -### Yank Procedure -1. Use the `yank_lambda_layer` GitHub action workflow -2. Specify the layer version to yank -3. Type "Yank Lambda Layer {layer_version}" to confirm -4. **Full yank**: Run with ALL options selected (default) to yank all 136 variants OR **Partial yank**: Specify Python versions, architectures, and regions for targeted yanking -6. Update documentation -7. **Communication**: Notify users through appropriate channels - -**Note**: Yanking deletes layer versions completely. Existing Lambda functions using the layer continue to work, but new functions cannot use the yanked version. \ No newline at end of file +4. **Verification**: Confirm all combinations have identical layer versions before updating docs \ No newline at end of file From d7aa0fbf8f606d6229b0da0cc252a186d9e3666e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 16 Jan 2026 15:54:35 -0500 Subject: [PATCH 47/47] rename vars.BUCKET_SALT to secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT --- .github/workflows/publish-lambda-layer.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index f21e88fd7..4211d715f 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -97,7 +97,7 @@ jobs: REGION="${{ matrix.region }}" LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) - BUCKET_NAME="strands-layer-${ACCOUNT_ID}-${{ vars.BUCKET_SALT }}-${REGION}" + BUCKET_NAME="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" @@ -129,7 +129,7 @@ jobs: REGION="${{ matrix.region }}" LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) - REGION_BUCKET="strands-layer-${ACCOUNT_ID}-${{ vars.BUCKET_SALT }}-${REGION}" + REGION_BUCKET="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)"