Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from . import bedrock, model
from .bedrock import BedrockModel
from .model import Model
from .model import CacheConfig, Model

__all__ = [
"bedrock",
"model",
"BedrockModel",
"CacheConfig",
"Model",
]

Expand Down
84 changes: 82 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ._validation import validate_config_keys
from .model import Model
from .model import CacheConfig, Model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +72,8 @@ class BedrockConfig(TypedDict, total=False):
additional_args: Any additional arguments to include in the request
additional_request_fields: Additional fields to include in the Bedrock request
additional_response_field_paths: Additional response field paths to extract
cache_prompt: Cache point type for the system prompt
cache_prompt: Cache point type for the system prompt (deprecated, use cache_config)
cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching.
cache_tools: Cache point type for tools
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
Expand All @@ -98,6 +99,7 @@ class BedrockConfig(TypedDict, total=False):
additional_request_fields: Optional[dict[str, Any]]
additional_response_field_paths: Optional[list[str]]
cache_prompt: Optional[str]
cache_config: Optional[CacheConfig]
cache_tools: Optional[str]
guardrail_id: Optional[str]
guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]]
Expand Down Expand Up @@ -171,6 +173,15 @@ def __init__(

logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)

@property
def _supports_caching(self) -> bool:
"""Whether this model supports prompt caching.

Returns True for Claude models on Bedrock.
"""
model_id = self.config.get("model_id", "").lower()
return "claude" in model_id or "anthropic" in model_id

@override
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
"""Update the Bedrock Model configuration with the provided arguments.
Expand Down Expand Up @@ -297,6 +308,63 @@ def _format_request(
),
}

def _inject_cache_point(self, messages: Messages) -> None:
"""Inject a cache point at the end of the last assistant message.

This automatically manages cache point placement in the messages array during
agent loop execution. The cache point is moved to the latest assistant message
on each model call to maximize cache hits.

Args:
messages: List of messages to inject cache point into (modified in place).
"""
if not messages:
return

# Loop backwards through messages:
# 1. Find first assistant message and add cache point there
# 2. Remove any other cache points along the way
last_assistant_idx: int | None = None

for msg_idx in range(len(messages) - 1, -1, -1):
msg = messages[msg_idx]
content = msg.get("content", [])

# Remove any cache points in this message's content (iterate backwards to avoid index issues)
for block_idx in range(len(content) - 1, -1, -1):
if "cachePoint" in content[block_idx]:
# If this is the last assistant message and cache point is at the end, keep it
if (
last_assistant_idx is None
and msg.get("role") == "assistant"
and block_idx == len(content) - 1
):
# This is where we want the cache point - mark and continue
last_assistant_idx = msg_idx
logger.debug("msg_idx=<%s> | cache point already at end of last assistant message", msg_idx)
continue

# Remove cache points that aren't at the target position
del content[block_idx]
logger.warning("msg_idx=<%s>, block_idx=<%s> | removed existing cache point", msg_idx, block_idx)

# If we haven't found an assistant message yet, check if this is one
if last_assistant_idx is None and msg.get("role") == "assistant":
last_assistant_idx = msg_idx

# If no assistant message found, nothing to cache
if last_assistant_idx is None:
logger.debug("No assistant message in conversation - skipping cache point")
return

last_assistant_content = messages[last_assistant_idx]["content"]
if last_assistant_content:
if "cachePoint" in last_assistant_content[-1]:
return
cache_block: ContentBlock = {"cachePoint": {"type": "default"}}
last_assistant_content.append(cache_block)
logger.debug("msg_idx=<%s> | added cache point at end of assistant message", last_assistant_idx)

def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Format messages for Bedrock API compatibility.

Expand All @@ -305,6 +373,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
- Eagerly filtering content blocks to only include Bedrock-supported fields
- Ensuring all message content blocks are properly formatted for the Bedrock API
- Optionally wrapping the last user message in guardrailConverseContent blocks
- Injecting cache points when cache_config is set with strategy="auto"

Args:
messages: List of messages to format
Expand All @@ -319,6 +388,17 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
content blocks to remove any additional fields before sending to Bedrock.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html
"""
# Inject cache point if cache_config is set with strategy="auto"
cache_config = self.config.get("cache_config")
if cache_config and cache_config.strategy == "auto":
if self._supports_caching:
self._inject_cache_point(messages)
else:
logger.warning(
"model_id=<%s> | cache_config is enabled but this model does not support caching",
self.config.get("model_id"),
)

cleaned_messages: list[dict[str, Any]] = []

filtered_unknown_members = False
Expand Down
15 changes: 14 additions & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import abc
import logging
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
from dataclasses import dataclass
from typing import Any, AsyncGenerator, AsyncIterable, Literal, Optional, Type, TypeVar, Union

from pydantic import BaseModel

Expand All @@ -15,6 +16,18 @@
T = TypeVar("T", bound=BaseModel)


@dataclass
class CacheConfig:
"""Configuration for prompt caching.

Attributes:
strategy: Caching strategy to use.
- "auto": Automatically inject cachePoint at optimal positions
"""

strategy: Literal["auto"] = "auto"


class Model(abc.ABC):
"""Abstract base class for Agent model providers.

Expand Down
42 changes: 42 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from strands.hooks import BeforeToolCallEvent
from strands.interrupt import Interrupt
from strands.models import CacheConfig
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
from strands.session.repository_session_manager import RepositorySessionManager
from strands.telemetry.tracer import serialize
Expand Down Expand Up @@ -2182,3 +2183,44 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator):
# Should not have added any toolResult messages
# Only the new user message and assistant response should be added
assert len(agent.messages) == original_length + 2


def test_cache_config_does_not_mutate_original_messages(mock_model, agenerator):
"""Test that cache_config injection does not mutate the original agent.messages."""
mock_model.mock_stream.return_value = agenerator(
[
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {"text": ""}}},
{"contentBlockDelta": {"delta": {"text": "Response"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
]
)

# Simulate a mock BedrockModel with cache_config
mock_model.get_config = unittest.mock.MagicMock(
return_value={"cache_config": CacheConfig(strategy="auto"), "model_id": "us.anthropic.claude-sonnet-4-v1:0"}
)

# Initial messages with assistant response (no cache point)
initial_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Hi there!"}]},
]

agent = Agent(model=mock_model, messages=copy.deepcopy(initial_messages))

# Store deep copy of messages before invocation
messages_before = copy.deepcopy(agent.messages)

# Invoke agent
agent("Follow up question")

# Check that original assistant message content was not mutated with cachePoint
# The assistant message at index 1 should still only have the text block
original_assistant_content = messages_before[1]["content"]
current_assistant_content = agent.messages[1]["content"]

# Both should have the same structure (no cache point added to agent.messages)
assert len(original_assistant_content) == len(current_assistant_content)
assert "cachePoint" not in current_assistant_content[-1]
134 changes: 133 additions & 1 deletion tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import strands
from strands import _exception_notes
from strands.models import BedrockModel
from strands.models import BedrockModel, CacheConfig
from strands.models.bedrock import (
_DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_MODEL_ID,
Expand Down Expand Up @@ -2240,3 +2240,135 @@ async def test_format_request_with_guardrail_latest_message(model):
# Latest user message image should also be wrapped
assert "guardContent" in formatted_messages[2]["content"][1]
assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png"


def test_supports_caching_true_for_claude(bedrock_client):
"""Test that supports_caching returns True for Claude models."""
model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0")
assert model._supports_caching is True

model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")
assert model2._supports_caching is True


def test_supports_caching_false_for_non_claude(bedrock_client):
"""Test that supports_caching returns False for non-Claude models."""
model = BedrockModel(model_id="amazon.nova-pro-v1:0")
assert model._supports_caching is False


def test_inject_cache_point_adds_to_last_assistant(bedrock_client):
"""Test that _inject_cache_point adds cache point to last assistant message."""

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Hi there!"}]},
{"role": "user", "content": [{"text": "How are you?"}]},
]

model._inject_cache_point(messages)

# Cache point should be added to assistant message (index 1)
assert len(messages[1]["content"]) == 2
assert "cachePoint" in messages[1]["content"][-1]
assert messages[1]["content"][-1]["cachePoint"]["type"] == "default"


def test_inject_cache_point_moves_existing_cache_point(bedrock_client):
"""Test that _inject_cache_point moves cache point from old to new position."""

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "First response"}, {"cachePoint": {"type": "default"}}]},
{"role": "user", "content": [{"text": "Follow up"}]},
{"role": "assistant", "content": [{"text": "Second response"}]},
]

model._inject_cache_point(messages)

# Old cache point should be removed from first assistant (index 1)
assert len(messages[1]["content"]) == 1
assert "cachePoint" not in messages[1]["content"][0]

# New cache point should be at end of last assistant (index 3)
assert len(messages[3]["content"]) == 2
assert "cachePoint" in messages[3]["content"][-1]


def test_inject_cache_point_removes_multiple_cache_points(bedrock_client):
"""Test that _inject_cache_point removes all existing cache points except the target."""

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}, {"cachePoint": {"type": "default"}}]},
{"role": "assistant", "content": [{"cachePoint": {"type": "default"}}, {"text": "Response"}]},
{"role": "user", "content": [{"text": "Follow up"}]},
{"role": "assistant", "content": [{"text": "Final response"}]},
]

model._inject_cache_point(messages)

# All old cache points should be removed
assert len(messages[0]["content"]) == 1 # user message: only text
assert len(messages[1]["content"]) == 1 # first assistant: only text (cache point removed)

# New cache point at end of last assistant
assert len(messages[3]["content"]) == 2
assert "cachePoint" in messages[3]["content"][-1]


def test_inject_cache_point_no_assistant_message(bedrock_client):
"""Test that _inject_cache_point does nothing when no assistant message exists."""

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(messages)

# No changes should be made
assert len(messages) == 1
assert len(messages[0]["content"]) == 1


def test_inject_cache_point_already_at_correct_position(bedrock_client):
"""Test that _inject_cache_point keeps cache point if already at correct position."""

model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Response"}, {"cachePoint": {"type": "default"}}]},
]

model._inject_cache_point(messages)

# Cache point should remain unchanged
assert len(messages[1]["content"]) == 2
assert "cachePoint" in messages[1]["content"][-1]


def test_inject_cache_point_skipped_for_non_claude(bedrock_client):
"""Test that cache point injection is skipped for non-Claude models."""

model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto"))

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Response"}]},
]

# _format_bedrock_messages checks supports_caching before injecting
# Since Nova doesn't support caching, no cache point should be added
formatted = model._format_bedrock_messages(messages)

# No cache point should be added
assert len(formatted[1]["content"]) == 1
assert "cachePoint" not in formatted[1]["content"][0]
Loading