From f9bea2c561bd9615c2fd245c341488aa3dec7bad Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Thu, 8 Jan 2026 14:55:01 -0700 Subject: [PATCH 1/3] feat(bedrock): add automatic prompt caching support Add CacheConfig with strategy="auto" for BedrockModel to automatically inject cache points at the end of assistant messages in multi-turn conversations. - Add CacheConfig dataclass in model.py with strategy field - Add supports_caching property to check Claude model compatibility - Implement _inject_cache_point() for automatic cache point management - Export CacheConfig from models/__init__.py Closes #1432 --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 92 +++++++++++++++++++++++++++- src/strands/models/model.py | 29 ++++++++- tests/strands/models/test_bedrock.py | 19 ++++++ 4 files changed, 138 insertions(+), 5 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index d5f88d09a..be6a96549 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,12 +7,13 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import Model +from .model import CacheConfig, Model __all__ = [ "bedrock", "model", "BedrockModel", + "CacheConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8e1558ca7..923e12cd4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -28,7 +28,7 @@ from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys -from .model import Model +from .model import CacheConfig, Model logger = logging.getLogger(__name__) @@ -72,8 +72,9 @@ class BedrockConfig(TypedDict, total=False): additional_args: Any additional arguments to include in the request additional_request_fields: Additional fields to include in the Bedrock request additional_response_field_paths: Additional response field paths to extract - cache_prompt: Cache point type for the system prompt - cache_tools: Cache point type for tools + cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) + cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. + cache_tools: Cache point type for tools (deprecated, use cache_config) guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -98,6 +99,7 @@ class BedrockConfig(TypedDict, total=False): additional_request_fields: Optional[dict[str, Any]] additional_response_field_paths: Optional[list[str]] cache_prompt: Optional[str] + cache_config: Optional[CacheConfig] cache_tools: Optional[str] guardrail_id: Optional[str] guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] @@ -146,6 +148,9 @@ def __init__( ) self.update_config(**model_config) + # Set cache_config on base Model class for Agent to detect + self.cache_config = self.config.get("cache_config") + logger.debug("config=<%s> | initializing", self.config) # Add strands-agents to the request user agent @@ -171,6 +176,16 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + @property + @override + def supports_caching(self) -> bool: + """Whether this model supports prompt caching. + + Returns True for Claude models on Bedrock. + """ + model_id = self.config.get("model_id", "").lower() + return "claude" in model_id or "anthropic" in model_id + @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore """Update the Bedrock Model configuration with the provided arguments. @@ -297,6 +312,71 @@ def _format_request( ), } + def _inject_cache_point(self, messages: Messages) -> None: + """Inject a cache point at the end of the last assistant message. + + This enables prompt caching for multi-turn conversations by placing a single + cache point that covers system prompt, tools, and conversation history. + + The cache point is automatically moved to the latest assistant message on each + model call, ensuring optimal cache utilization with minimal write overhead. + + Args: + messages: List of messages to inject cache point into (modified in place). + """ + if not messages: + return + + # Step 1: Find all existing cache points and the last assistant message + cache_point_positions: list[tuple[int, int]] = [] # [(msg_idx, block_idx), ...] + last_assistant_idx: int | None = None + + for msg_idx, msg in enumerate(messages): + # Track last assistant message + if msg.get("role") == "assistant": + last_assistant_idx = msg_idx + + content = msg.get("content", []) + if not isinstance(content, list): + continue + + for block_idx, block in enumerate(content): + if isinstance(block, dict) and "cachePoint" in block: + cache_point_positions.append((msg_idx, block_idx)) + + # Step 2: If no assistant message yet, nothing to cache + if last_assistant_idx is None: + logger.debug("No assistant message in conversation - skipping cache point") + return + + last_assistant_content = messages[last_assistant_idx].get("content", []) + if not isinstance(last_assistant_content, list) or len(last_assistant_content) == 0: + logger.debug("Last assistant message has no content - skipping cache point") + return + + # Step 3: Check if cache point already exists at the end of last assistant message + last_block = last_assistant_content[-1] + if isinstance(last_block, dict) and "cachePoint" in last_block: + logger.debug("Cache point already exists at end of last assistant message") + return + + # Step 4: Remove ALL existing cache points (we only want 1 at the end) + # Process in reverse order to avoid index shifting issues + for msg_idx, block_idx in reversed(cache_point_positions): + msg_content = messages[msg_idx].get("content", []) + if isinstance(msg_content, list) and block_idx < len(msg_content): + del msg_content[block_idx] + logger.debug(f"Removed old cache point at msg {msg_idx} block {block_idx}") + + # Step 5: Add single cache point at the end of the last assistant message + cache_block: ContentBlock = {"cachePoint": {"type": "default"}} + + # Re-fetch content in case it was modified by deletion + last_assistant_content = messages[last_assistant_idx].get("content", []) + if isinstance(last_assistant_content, list): + last_assistant_content.append(cache_block) + logger.debug(f"Added cache point at end of assistant message {last_assistant_idx}") + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -305,6 +385,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API - Optionally wrapping the last user message in guardrailConverseContent blocks + - Injecting cache points when cache_config is set with strategy="auto" Args: messages: List of messages to format @@ -319,6 +400,11 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: content blocks to remove any additional fields before sending to Bedrock. https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html """ + # Inject cache point if cache_config is set with strategy="auto" + cache_config = self.config.get("cache_config") + if cache_config and cache_config.strategy == "auto" and self.supports_caching: + self._inject_cache_point(messages) + cleaned_messages: list[dict[str, Any]] = [] filtered_unknown_members = False diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..bbc6ccacc 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -2,7 +2,8 @@ import abc import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +from dataclasses import dataclass +from typing import Any, AsyncGenerator, AsyncIterable, Literal, Optional, Type, TypeVar, Union from pydantic import BaseModel @@ -15,13 +16,39 @@ T = TypeVar("T", bound=BaseModel) +@dataclass +class CacheConfig: + """Configuration for prompt caching. + + Attributes: + strategy: Caching strategy to use. + - "auto": Automatically inject cachePoint at optimal positions + """ + + strategy: Literal["auto"] = "auto" + + class Model(abc.ABC): """Abstract base class for Agent model providers. This class defines the interface for all model implementations in the Strands Agents SDK. It provides a standardized way to configure and process requests for different AI model providers. + + Attributes: + cache_config: Optional configuration for prompt caching. """ + cache_config: Optional[CacheConfig] = None + + @property + def supports_caching(self) -> bool: + """Whether this model supports prompt caching. + + Override in subclasses to indicate caching support. + Returns False by default. + """ + return False + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..10c6355b0 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2240,3 +2240,22 @@ async def test_format_request_with_guardrail_latest_message(model): # Latest user message image should also be wrapped assert "guardContent" in formatted_messages[2]["content"][1] assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" + + +def test_cache_config_auto_sets_model_attribute(bedrock_client): + """Test that cache_config with strategy='auto' sets the cache_config attribute on the model.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="test-model", cache_config=CacheConfig(strategy="auto")) + + assert model.cache_config is not None + assert model.cache_config.strategy == "auto" + assert model.get_config().get("cache_config") is not None + + +def test_cache_config_none_by_default(bedrock_client): + """Test that cache_config is None by default.""" + model = BedrockModel(model_id="test-model") + + assert model.cache_config is None + assert model.get_config().get("cache_config") is None From 92e2a5945dd1218177660fb263eee3ead5f2bd7f Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Fri, 9 Jan 2026 11:48:13 -0700 Subject: [PATCH 2/3] refactor: simplify cache point injection logic and add tests --- src/strands/models/bedrock.py | 76 +++++++------- src/strands/models/model.py | 14 --- tests/strands/agent/test_agent.py | 43 ++++++++ tests/strands/models/test_bedrock.py | 146 ++++++++++++++++++++++++++- 4 files changed, 219 insertions(+), 60 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 923e12cd4..11d7b4828 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -74,7 +74,7 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - cache_tools: Cache point type for tools (deprecated, use cache_config) + cache_tools: Cache point type for tools guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -148,9 +148,6 @@ def __init__( ) self.update_config(**model_config) - # Set cache_config on base Model class for Agent to detect - self.cache_config = self.config.get("cache_config") - logger.debug("config=<%s> | initializing", self.config) # Add strands-agents to the request user agent @@ -177,7 +174,6 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) @property - @override def supports_caching(self) -> bool: """Whether this model supports prompt caching. @@ -327,53 +323,51 @@ def _inject_cache_point(self, messages: Messages) -> None: if not messages: return - # Step 1: Find all existing cache points and the last assistant message - cache_point_positions: list[tuple[int, int]] = [] # [(msg_idx, block_idx), ...] + # Loop backwards through messages: + # 1. Find first assistant message and add cache point there + # 2. Remove any other cache points along the way last_assistant_idx: int | None = None - for msg_idx, msg in enumerate(messages): - # Track last assistant message - if msg.get("role") == "assistant": - last_assistant_idx = msg_idx - + for msg_idx in range(len(messages) - 1, -1, -1): + msg = messages[msg_idx] content = msg.get("content", []) - if not isinstance(content, list): - continue - for block_idx, block in enumerate(content): - if isinstance(block, dict) and "cachePoint" in block: - cache_point_positions.append((msg_idx, block_idx)) + # Remove any cache points in this message's content (iterate backwards to avoid index issues) + for block_idx in range(len(content) - 1, -1, -1): + if "cachePoint" in content[block_idx]: + # If this is the last assistant message and cache point is at the end, keep it + if ( + last_assistant_idx is None + and msg.get("role") == "assistant" + and block_idx == len(content) - 1 + ): + # This is where we want the cache point - mark and continue + last_assistant_idx = msg_idx + logger.debug(f"Cache point already at end of last assistant message {msg_idx}") + continue + + # Remove cache points that aren't at the target position + del content[block_idx] + logger.warning(f"Removed existing cache point at msg {msg_idx} block {block_idx}") + + # If we haven't found an assistant message yet, check if this is one + if last_assistant_idx is None and msg.get("role") == "assistant": + last_assistant_idx = msg_idx - # Step 2: If no assistant message yet, nothing to cache + # If no assistant message found, nothing to cache if last_assistant_idx is None: logger.debug("No assistant message in conversation - skipping cache point") return - last_assistant_content = messages[last_assistant_idx].get("content", []) - if not isinstance(last_assistant_content, list) or len(last_assistant_content) == 0: - logger.debug("Last assistant message has no content - skipping cache point") + # Check if cache point was already found at the right position + last_assistant_content = messages[last_assistant_idx]["content"] + if last_assistant_content and "cachePoint" in last_assistant_content[-1]: + # Already has cache point at the end return - # Step 3: Check if cache point already exists at the end of last assistant message - last_block = last_assistant_content[-1] - if isinstance(last_block, dict) and "cachePoint" in last_block: - logger.debug("Cache point already exists at end of last assistant message") - return - - # Step 4: Remove ALL existing cache points (we only want 1 at the end) - # Process in reverse order to avoid index shifting issues - for msg_idx, block_idx in reversed(cache_point_positions): - msg_content = messages[msg_idx].get("content", []) - if isinstance(msg_content, list) and block_idx < len(msg_content): - del msg_content[block_idx] - logger.debug(f"Removed old cache point at msg {msg_idx} block {block_idx}") - - # Step 5: Add single cache point at the end of the last assistant message - cache_block: ContentBlock = {"cachePoint": {"type": "default"}} - - # Re-fetch content in case it was modified by deletion - last_assistant_content = messages[last_assistant_idx].get("content", []) - if isinstance(last_assistant_content, list): + # Add cache point at the end of the last assistant message + if last_assistant_content: + cache_block: ContentBlock = {"cachePoint": {"type": "default"}} last_assistant_content.append(cache_block) logger.debug(f"Added cache point at end of assistant message {last_assistant_idx}") diff --git a/src/strands/models/model.py b/src/strands/models/model.py index bbc6ccacc..ac5807a49 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -33,22 +33,8 @@ class Model(abc.ABC): This class defines the interface for all model implementations in the Strands Agents SDK. It provides a standardized way to configure and process requests for different AI model providers. - - Attributes: - cache_config: Optional configuration for prompt caching. """ - cache_config: Optional[CacheConfig] = None - - @property - def supports_caching(self) -> bool: - """Whether this model supports prompt caching. - - Override in subclasses to indicate caching support. - Returns False by default. - """ - return False - @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..f5bbc3ce4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2182,3 +2182,46 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +def test_cache_config_does_not_mutate_original_messages(mock_model, agenerator): + """Test that cache_config injection does not mutate the original agent.messages.""" + from strands.models import CacheConfig + + mock_model.mock_stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Simulate a mock BedrockModel with cache_config + mock_model.get_config = unittest.mock.MagicMock( + return_value={"cache_config": CacheConfig(strategy="auto"), "model_id": "us.anthropic.claude-sonnet-4-v1:0"} + ) + + # Initial messages with assistant response (no cache point) + initial_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + ] + + agent = Agent(model=mock_model, messages=copy.deepcopy(initial_messages)) + + # Store deep copy of messages before invocation + messages_before = copy.deepcopy(agent.messages) + + # Invoke agent + agent("Follow up question") + + # Check that original assistant message content was not mutated with cachePoint + # The assistant message at index 1 should still only have the text block + original_assistant_content = messages_before[1]["content"] + current_assistant_content = agent.messages[1]["content"] + + # Both should have the same structure (no cache point added to agent.messages) + assert len(original_assistant_content) == len(current_assistant_content) + assert "cachePoint" not in current_assistant_content[-1] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 10c6355b0..0d1bed31e 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2242,20 +2242,156 @@ async def test_format_request_with_guardrail_latest_message(model): assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" -def test_cache_config_auto_sets_model_attribute(bedrock_client): - """Test that cache_config with strategy='auto' sets the cache_config attribute on the model.""" +def test_cache_config_auto_sets_config(bedrock_client): + """Test that cache_config with strategy='auto' is stored in config.""" from strands.models import CacheConfig model = BedrockModel(model_id="test-model", cache_config=CacheConfig(strategy="auto")) - assert model.cache_config is not None - assert model.cache_config.strategy == "auto" assert model.get_config().get("cache_config") is not None + assert model.get_config().get("cache_config").strategy == "auto" def test_cache_config_none_by_default(bedrock_client): """Test that cache_config is None by default.""" model = BedrockModel(model_id="test-model") - assert model.cache_config is None assert model.get_config().get("cache_config") is None + + +def test_supports_caching_true_for_claude(bedrock_client): + """Test that supports_caching returns True for Claude models.""" + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + assert model.supports_caching is True + + model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") + assert model2.supports_caching is True + + +def test_supports_caching_false_for_non_claude(bedrock_client): + """Test that supports_caching returns False for non-Claude models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0") + assert model.supports_caching is False + + +def test_inject_cache_point_adds_to_last_assistant(bedrock_client): + """Test that _inject_cache_point adds cache point to last assistant message.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + model._inject_cache_point(messages) + + # Cache point should be added to assistant message (index 1) + assert len(messages[1]["content"]) == 2 + assert "cachePoint" in messages[1]["content"][-1] + assert messages[1]["content"][-1]["cachePoint"]["type"] == "default" + + +def test_inject_cache_point_moves_existing_cache_point(bedrock_client): + """Test that _inject_cache_point moves cache point from old to new position.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "First response"}, {"cachePoint": {"type": "default"}}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + + model._inject_cache_point(messages) + + # Old cache point should be removed from first assistant (index 1) + assert len(messages[1]["content"]) == 1 + assert "cachePoint" not in messages[1]["content"][0] + + # New cache point should be at end of last assistant (index 3) + assert len(messages[3]["content"]) == 2 + assert "cachePoint" in messages[3]["content"][-1] + + +def test_inject_cache_point_removes_multiple_cache_points(bedrock_client): + """Test that _inject_cache_point removes all existing cache points except the target.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"cachePoint": {"type": "default"}}]}, + {"role": "assistant", "content": [{"cachePoint": {"type": "default"}}, {"text": "Response"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + + model._inject_cache_point(messages) + + # All old cache points should be removed + assert len(messages[0]["content"]) == 1 # user message: only text + assert len(messages[1]["content"]) == 1 # first assistant: only text (cache point removed) + + # New cache point at end of last assistant + assert len(messages[3]["content"]) == 2 + assert "cachePoint" in messages[3]["content"][-1] + + +def test_inject_cache_point_no_assistant_message(bedrock_client): + """Test that _inject_cache_point does nothing when no assistant message exists.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(messages) + + # No changes should be made + assert len(messages) == 1 + assert len(messages[0]["content"]) == 1 + + +def test_inject_cache_point_already_at_correct_position(bedrock_client): + """Test that _inject_cache_point keeps cache point if already at correct position.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}, {"cachePoint": {"type": "default"}}]}, + ] + + model._inject_cache_point(messages) + + # Cache point should remain unchanged + assert len(messages[1]["content"]) == 2 + assert "cachePoint" in messages[1]["content"][-1] + + +def test_inject_cache_point_skipped_for_non_claude(bedrock_client): + """Test that cache point injection is skipped for non-Claude models.""" + from strands.models import CacheConfig + + model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + # _format_bedrock_messages checks supports_caching before injecting + # Since Nova doesn't support caching, no cache point should be added + formatted = model._format_bedrock_messages(messages) + + # No cache point should be added + assert len(formatted[1]["content"]) == 1 + assert "cachePoint" not in formatted[1]["content"][0] From cc661afb22e11ff6c8d26bddb82dcf90d60e9f2a Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Mon, 19 Jan 2026 14:53:44 -0700 Subject: [PATCH 3/3] refactor: address PR review feedback for prompt caching - Add warning when cache_config enabled but model doesn't support caching - Make supports_caching private (_supports_caching) - Fix log formatting to follow style guide - Clean up tests and imports --- src/strands/models/bedrock.py | 34 ++++++++++++++-------------- tests/strands/agent/test_agent.py | 3 +-- tests/strands/models/test_bedrock.py | 31 ++++--------------------- 3 files changed, 22 insertions(+), 46 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 11d7b4828..45987387f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -174,7 +174,7 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) @property - def supports_caching(self) -> bool: + def _supports_caching(self) -> bool: """Whether this model supports prompt caching. Returns True for Claude models on Bedrock. @@ -311,11 +311,9 @@ def _format_request( def _inject_cache_point(self, messages: Messages) -> None: """Inject a cache point at the end of the last assistant message. - This enables prompt caching for multi-turn conversations by placing a single - cache point that covers system prompt, tools, and conversation history. - - The cache point is automatically moved to the latest assistant message on each - model call, ensuring optimal cache utilization with minimal write overhead. + This automatically manages cache point placement in the messages array during + agent loop execution. The cache point is moved to the latest assistant message + on each model call to maximize cache hits. Args: messages: List of messages to inject cache point into (modified in place). @@ -343,12 +341,12 @@ def _inject_cache_point(self, messages: Messages) -> None: ): # This is where we want the cache point - mark and continue last_assistant_idx = msg_idx - logger.debug(f"Cache point already at end of last assistant message {msg_idx}") + logger.debug("msg_idx=<%s> | cache point already at end of last assistant message", msg_idx) continue # Remove cache points that aren't at the target position del content[block_idx] - logger.warning(f"Removed existing cache point at msg {msg_idx} block {block_idx}") + logger.warning("msg_idx=<%s>, block_idx=<%s> | removed existing cache point", msg_idx, block_idx) # If we haven't found an assistant message yet, check if this is one if last_assistant_idx is None and msg.get("role") == "assistant": @@ -359,17 +357,13 @@ def _inject_cache_point(self, messages: Messages) -> None: logger.debug("No assistant message in conversation - skipping cache point") return - # Check if cache point was already found at the right position last_assistant_content = messages[last_assistant_idx]["content"] - if last_assistant_content and "cachePoint" in last_assistant_content[-1]: - # Already has cache point at the end - return - - # Add cache point at the end of the last assistant message if last_assistant_content: + if "cachePoint" in last_assistant_content[-1]: + return cache_block: ContentBlock = {"cachePoint": {"type": "default"}} last_assistant_content.append(cache_block) - logger.debug(f"Added cache point at end of assistant message {last_assistant_idx}") + logger.debug("msg_idx=<%s> | added cache point at end of assistant message", last_assistant_idx) def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -396,8 +390,14 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """ # Inject cache point if cache_config is set with strategy="auto" cache_config = self.config.get("cache_config") - if cache_config and cache_config.strategy == "auto" and self.supports_caching: - self._inject_cache_point(messages) + if cache_config and cache_config.strategy == "auto": + if self._supports_caching: + self._inject_cache_point(messages) + else: + logger.warning( + "model_id=<%s> | cache_config is enabled but this model does not support caching", + self.config.get("model_id"), + ) cleaned_messages: list[dict[str, Any]] = [] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f5bbc3ce4..0ac6b907c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,6 +19,7 @@ from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.hooks import BeforeToolCallEvent from strands.interrupt import Interrupt +from strands.models import CacheConfig from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize @@ -2186,8 +2187,6 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): def test_cache_config_does_not_mutate_original_messages(mock_model, agenerator): """Test that cache_config injection does not mutate the original agent.messages.""" - from strands.models import CacheConfig - mock_model.mock_stream.return_value = agenerator( [ {"messageStart": {"role": "assistant"}}, diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0d1bed31e..7b923c292 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -12,7 +12,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, @@ -2242,41 +2242,23 @@ async def test_format_request_with_guardrail_latest_message(model): assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" -def test_cache_config_auto_sets_config(bedrock_client): - """Test that cache_config with strategy='auto' is stored in config.""" - from strands.models import CacheConfig - - model = BedrockModel(model_id="test-model", cache_config=CacheConfig(strategy="auto")) - - assert model.get_config().get("cache_config") is not None - assert model.get_config().get("cache_config").strategy == "auto" - - -def test_cache_config_none_by_default(bedrock_client): - """Test that cache_config is None by default.""" - model = BedrockModel(model_id="test-model") - - assert model.get_config().get("cache_config") is None - - def test_supports_caching_true_for_claude(bedrock_client): """Test that supports_caching returns True for Claude models.""" model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") - assert model.supports_caching is True + assert model._supports_caching is True model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") - assert model2.supports_caching is True + assert model2._supports_caching is True def test_supports_caching_false_for_non_claude(bedrock_client): """Test that supports_caching returns False for non-Claude models.""" model = BedrockModel(model_id="amazon.nova-pro-v1:0") - assert model.supports_caching is False + assert model._supports_caching is False def test_inject_cache_point_adds_to_last_assistant(bedrock_client): """Test that _inject_cache_point adds cache point to last assistant message.""" - from strands.models import CacheConfig model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) @@ -2296,7 +2278,6 @@ def test_inject_cache_point_adds_to_last_assistant(bedrock_client): def test_inject_cache_point_moves_existing_cache_point(bedrock_client): """Test that _inject_cache_point moves cache point from old to new position.""" - from strands.models import CacheConfig model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) @@ -2320,7 +2301,6 @@ def test_inject_cache_point_moves_existing_cache_point(bedrock_client): def test_inject_cache_point_removes_multiple_cache_points(bedrock_client): """Test that _inject_cache_point removes all existing cache points except the target.""" - from strands.models import CacheConfig model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) @@ -2344,7 +2324,6 @@ def test_inject_cache_point_removes_multiple_cache_points(bedrock_client): def test_inject_cache_point_no_assistant_message(bedrock_client): """Test that _inject_cache_point does nothing when no assistant message exists.""" - from strands.models import CacheConfig model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) @@ -2361,7 +2340,6 @@ def test_inject_cache_point_no_assistant_message(bedrock_client): def test_inject_cache_point_already_at_correct_position(bedrock_client): """Test that _inject_cache_point keeps cache point if already at correct position.""" - from strands.models import CacheConfig model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")) @@ -2379,7 +2357,6 @@ def test_inject_cache_point_already_at_correct_position(bedrock_client): def test_inject_cache_point_skipped_for_non_claude(bedrock_client): """Test that cache point injection is skipped for non-Claude models.""" - from strands.models import CacheConfig model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto"))