Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from strands.types.media import S3Location, SourceLocation

from .._exception_notes import add_exception_note
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
Expand Down Expand Up @@ -407,6 +409,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:

# Format content blocks for Bedrock API compatibility
formatted_content = self._format_request_message_content(content_block)
if formatted_content is None:
continue

# Wrap text or image content in guardrailContent if this is the last user message
if (
Expand Down Expand Up @@ -459,7 +463,19 @@ def _should_include_tool_result_status(self) -> bool:
else: # "auto"
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None:
"""Convert location content block to Bedrock format if its an S3Location."""
if location["type"] == "s3":
s3_location = cast(S3Location, location)
formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]}
if "bucketOwner" in s3_location:
formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"]
return {"s3Location": formatted_document_s3}
else:
logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block")
return None

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None:
"""Format a Bedrock content block.

Bedrock strictly validates content blocks and throws exceptions for unknown fields.
Expand Down Expand Up @@ -489,9 +505,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "format" in document:
result["format"] = document["format"]

# Handle source
# Handle source - supports bytes or location
if "source" in document:
result["source"] = {"bytes": document["source"]["bytes"]}
source = document["source"]
formatted_document_source: dict[str, Any] | None
if "location" in source:
formatted_document_source = self._handle_location(source["location"])
if formatted_document_source is None:
return None
elif "bytes" in source:
formatted_document_source = {"bytes": source["bytes"]}
result["source"] = formatted_document_source

# Handle optional fields
if "citations" in document and document["citations"] is not None:
Expand All @@ -512,10 +536,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "image" in content:
image = content["image"]
source = image["source"]
formatted_source = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": image["format"], "source": formatted_source}
formatted_image_source: dict[str, Any] | None
if "location" in source:
formatted_image_source = self._handle_location(source["location"])
if formatted_image_source is None:
return None
elif "bytes" in source:
formatted_image_source = {"bytes": source["bytes"]}
result = {"format": image["format"], "source": formatted_image_source}
return {"image": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
Expand Down Expand Up @@ -550,9 +578,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
# Handle json field since not in ContentBlock but valid in ToolResultContent
formatted_content.append({"json": tool_result_content["json"]})
else:
formatted_content.append(
self._format_request_message_content(cast(ContentBlock, tool_result_content))
formatted_message_content = self._format_request_message_content(
cast(ContentBlock, tool_result_content)
)
if formatted_message_content is None:
continue
formatted_content.append(formatted_message_content)

result = {
"content": formatted_content,
Expand All @@ -577,10 +608,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "video" in content:
video = content["video"]
source = video["source"]
formatted_source = {}
if "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": video["format"], "source": formatted_source}
formatted_video_source: dict[str, Any] | None
if "location" in source:
formatted_video_source = self._handle_location(source["location"])
if formatted_video_source is None:
return None
elif "bytes" in source:
formatted_video_source = {"bytes": source["bytes"]}
result = {"format": video["format"], "source": formatted_video_source}
return {"video": result}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
Expand Down
54 changes: 49 additions & 5 deletions src/strands/types/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,60 @@
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
"""

from typing import Literal
from typing import Literal, TypeAlias

from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict

from .citations import CitationsConfig

DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
"""Supported document formats."""


class DocumentSource(TypedDict):
class Location(TypedDict, total=False):
"""A location for a document.

This type is a generic location for a document. Its usage is determined by the underlying model provider.
"""

type: Required[str]


class S3Location(Location, total=False):
"""A storage location in an Amazon S3 bucket.

Used by Bedrock to reference media files stored in S3 instead of passing raw bytes.

- Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html

Attributes:
type: s3
uri: An object URI starting with `s3://`. Required.
bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional.
"""

# mypy doesn't like overriding this field since its a subclass, but since its just a literal string, this is fine.

type: Literal["s3"] # type: ignore[misc]
uri: Required[str]
bucketOwner: str


SourceLocation: TypeAlias = Location | S3Location


class DocumentSource(TypedDict, total=False):
"""Contains the content of a document.

Only one of `bytes` or `s3Location` should be specified.

Attributes:
bytes: The binary content of the document.
location: Location of the document.
"""

bytes: bytes
location: SourceLocation


class DocumentContent(TypedDict, total=False):
Expand All @@ -45,14 +81,18 @@ class DocumentContent(TypedDict, total=False):
"""Supported image formats."""


class ImageSource(TypedDict):
class ImageSource(TypedDict, total=False):
"""Contains the content of an image.

Only one of `bytes` or `s3Location` should be specified.

Attributes:
bytes: The binary content of the image.
location: Location of the image.
"""

bytes: bytes
location: SourceLocation


class ImageContent(TypedDict):
Expand All @@ -71,14 +111,18 @@ class ImageContent(TypedDict):
"""Supported video formats."""


class VideoSource(TypedDict):
class VideoSource(TypedDict, total=False):
"""Contains the content of a video.

Only one of `bytes` or `s3Location` should be specified.

Attributes:
bytes: The binary content of the video.
location: Location of the video.
"""

bytes: bytes
location: SourceLocation


class VideoContent(TypedDict):
Expand Down
2 changes: 0 additions & 2 deletions tests/strands/agent/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,6 @@ def test_invocation_state_is_available_in_model_call_events(agent):
assert after_event.invocation_state["request_id"] == "req-456"




def test_before_invocation_event_messages_default_none(agent):
"""Test that BeforeInvocationEvent.messages defaults to None for backward compatibility."""
event = BeforeInvocationEvent(agent=agent)
Expand Down
Loading
Loading