diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index d5f88d09a..be6a96549 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,12 +7,13 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import Model +from .model import CacheConfig, Model __all__ = [ "bedrock", "model", "BedrockModel", + "CacheConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8e1558ca7..45987387f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -28,7 +28,7 @@ from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys -from .model import Model +from .model import CacheConfig, Model logger = logging.getLogger(__name__) @@ -72,7 +72,8 @@ class BedrockConfig(TypedDict, total=False): additional_args: Any additional arguments to include in the request additional_request_fields: Additional fields to include in the Bedrock request additional_response_field_paths: Additional response field paths to extract - cache_prompt: Cache point type for the system prompt + cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) + cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. cache_tools: Cache point type for tools guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. @@ -98,6 +99,7 @@ class BedrockConfig(TypedDict, total=False): additional_request_fields: Optional[dict[str, Any]] additional_response_field_paths: Optional[list[str]] cache_prompt: Optional[str] + cache_config: Optional[CacheConfig] cache_tools: Optional[str] guardrail_id: Optional[str] guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] @@ -171,6 +173,15 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + @property + def _supports_caching(self) -> bool: + """Whether this model supports prompt caching. + + Returns True for Claude models on Bedrock. + """ + model_id = self.config.get("model_id", "").lower() + return "claude" in model_id or "anthropic" in model_id + @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore """Update the Bedrock Model configuration with the provided arguments. @@ -297,6 +308,63 @@ def _format_request( ), } + def _inject_cache_point(self, messages: Messages) -> None: + """Inject a cache point at the end of the last assistant message. + + This automatically manages cache point placement in the messages array during + agent loop execution. The cache point is moved to the latest assistant message + on each model call to maximize cache hits. + + Args: + messages: List of messages to inject cache point into (modified in place). + """ + if not messages: + return + + # Loop backwards through messages: + # 1. Find first assistant message and add cache point there + # 2. Remove any other cache points along the way + last_assistant_idx: int | None = None + + for msg_idx in range(len(messages) - 1, -1, -1): + msg = messages[msg_idx] + content = msg.get("content", []) + + # Remove any cache points in this message's content (iterate backwards to avoid index issues) + for block_idx in range(len(content) - 1, -1, -1): + if "cachePoint" in content[block_idx]: + # If this is the last assistant message and cache point is at the end, keep it + if ( + last_assistant_idx is None + and msg.get("role") == "assistant" + and block_idx == len(content) - 1 + ): + # This is where we want the cache point - mark and continue + last_assistant_idx = msg_idx + logger.debug("msg_idx=<%s> | cache point already at end of last assistant message", msg_idx) + continue + + # Remove cache points that aren't at the target position + del content[block_idx] + logger.warning("msg_idx=<%s>, block_idx=<%s> | removed existing cache point", msg_idx, block_idx) + + # If we haven't found an assistant message yet, check if this is one + if last_assistant_idx is None and msg.get("role") == "assistant": + last_assistant_idx = msg_idx + + # If no assistant message found, nothing to cache + if last_assistant_idx is None: + logger.debug("No assistant message in conversation - skipping cache point") + return + + last_assistant_content = messages[last_assistant_idx]["content"] + if last_assistant_content: + if "cachePoint" in last_assistant_content[-1]: + return + cache_block: ContentBlock = {"cachePoint": {"type": "default"}} + last_assistant_content.append(cache_block) + logger.debug("msg_idx=<%s> | added cache point at end of assistant message", last_assistant_idx) + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -305,6 +373,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API - Optionally wrapping the last user message in guardrailConverseContent blocks + - Injecting cache points when cache_config is set with strategy="auto" Args: messages: List of messages to format @@ -319,6 +388,17 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: content blocks to remove any additional fields before sending to Bedrock. https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html """ + # Inject cache point if cache_config is set with strategy="auto" + cache_config = self.config.get("cache_config") + if cache_config and cache_config.strategy == "auto": + if self._supports_caching: + self._inject_cache_point(messages) + else: + logger.warning( + "model_id=<%s> | cache_config is enabled but this model does not support caching", + self.config.get("model_id"), + ) + cleaned_messages: list[dict[str, Any]] = [] filtered_unknown_members = False diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..ac5807a49 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -2,7 +2,8 @@ import abc import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +from dataclasses import dataclass +from typing import Any, AsyncGenerator, AsyncIterable, Literal, Optional, Type, TypeVar, Union from pydantic import BaseModel @@ -15,6 +16,18 @@ T = TypeVar("T", bound=BaseModel) +@dataclass +class CacheConfig: + """Configuration for prompt caching. + + Attributes: + strategy: Caching strategy to use. + - "auto": Automatically inject cachePoint at optimal positions + """ + + strategy: Literal["auto"] = "auto" + + class Model(abc.ABC): """Abstract base class for Agent model providers. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..0ac6b907c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,6 +19,7 @@ from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.hooks import BeforeToolCallEvent from strands.interrupt import Interrupt +from strands.models import CacheConfig from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize @@ -2182,3 +2183,44 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +def test_cache_config_does_not_mutate_original_messages(mock_model, agenerator): + """Test that cache_config injection does not mutate the original agent.messages.""" + mock_model.mock_stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Simulate a mock BedrockModel with cache_config + mock_model.get_config = unittest.mock.MagicMock( + return_value={"cache_config": CacheConfig(strategy="auto"), "model_id": "us.anthropic.claude-sonnet-4-v1:0"} + ) + + # Initial messages with assistant response (no cache point) + initial_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + ] + + agent = Agent(model=mock_model, messages=copy.deepcopy(initial_messages)) + + # Store deep copy of messages before invocation + messages_before = copy.deepcopy(agent.messages) + + # Invoke agent + agent("Follow up question") + + # Check that original assistant message content was not mutated with cachePoint + # The assistant message at index 1 should still only have the text block + original_assistant_content = messages_before[1]["content"] + current_assistant_content = agent.messages[1]["content"] + + # Both should have the same structure (no cache point added to agent.messages) + assert len(original_assistant_content) == len(current_assistant_content) + assert "cachePoint" not in current_assistant_content[-1] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..7b923c292 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -12,7 +12,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, @@ -2240,3 +2240,135 @@ async def test_format_request_with_guardrail_latest_message(model): # Latest user message image should also be wrapped assert "guardContent" in formatted_messages[2]["content"][1] assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" + + +def test_supports_caching_true_for_claude(bedrock_client): + """Test that supports_caching returns True for Claude models.""" + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + assert model._supports_caching is True + + model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") + assert model2._supports_caching is True + + +def test_supports_caching_false_for_non_claude(bedrock_client): + """Test that supports_caching returns False for non-Claude models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0") + assert model._supports_caching is False + + +def test_inject_cache_point_adds_to_last_assistant(bedrock_client): + """Test that _inject_cache_point adds cache point to last assistant message.""" + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + model._inject_cache_point(messages) + + # Cache point should be added to assistant message (index 1) + assert len(messages[1]["content"]) == 2 + assert "cachePoint" in messages[1]["content"][-1] + assert messages[1]["content"][-1]["cachePoint"]["type"] == "default" + + +def test_inject_cache_point_moves_existing_cache_point(bedrock_client): + """Test that _inject_cache_point moves cache point from old to new position.""" + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "First response"}, {"cachePoint": {"type": "default"}}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + + model._inject_cache_point(messages) + + # Old cache point should be removed from first assistant (index 1) + assert len(messages[1]["content"]) == 1 + assert "cachePoint" not in messages[1]["content"][0] + + # New cache point should be at end of last assistant (index 3) + assert len(messages[3]["content"]) == 2 + assert "cachePoint" in messages[3]["content"][-1] + + +def test_inject_cache_point_removes_multiple_cache_points(bedrock_client): + """Test that _inject_cache_point removes all existing cache points except the target.""" + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"cachePoint": {"type": "default"}}]}, + {"role": "assistant", "content": [{"cachePoint": {"type": "default"}}, {"text": "Response"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + + model._inject_cache_point(messages) + + # All old cache points should be removed + assert len(messages[0]["content"]) == 1 # user message: only text + assert len(messages[1]["content"]) == 1 # first assistant: only text (cache point removed) + + # New cache point at end of last assistant + assert len(messages[3]["content"]) == 2 + assert "cachePoint" in messages[3]["content"][-1] + + +def test_inject_cache_point_no_assistant_message(bedrock_client): + """Test that _inject_cache_point does nothing when no assistant message exists.""" + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(messages) + + # No changes should be made + assert len(messages) == 1 + assert len(messages[0]["content"]) == 1 + + +def test_inject_cache_point_already_at_correct_position(bedrock_client): + """Test that _inject_cache_point keeps cache point if already at correct position.""" + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}, {"cachePoint": {"type": "default"}}]}, + ] + + model._inject_cache_point(messages) + + # Cache point should remain unchanged + assert len(messages[1]["content"]) == 2 + assert "cachePoint" in messages[1]["content"][-1] + + +def test_inject_cache_point_skipped_for_non_claude(bedrock_client): + """Test that cache point injection is skipped for non-Claude models.""" + + model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + # _format_bedrock_messages checks supports_caching before injecting + # Since Nova doesn't support caching, no cache point should be added + formatted = model._format_bedrock_messages(messages) + + # No cache point should be added + assert len(formatted[1]["content"]) == 1 + assert "cachePoint" not in formatted[1]["content"][0]