From 7f76f2cd7e05fd3a221134f7784acc65e6206685 Mon Sep 17 00:00:00 2001 From: Akarsha Sehwag Date: Mon, 24 Nov 2025 10:28:51 -0500 Subject: [PATCH 1/3] feat(memory): update session manager. --- .../memory/integrations/strands/config.py | 50 +- .../integrations/strands/session_manager.py | 874 ++++++++++-------- 2 files changed, 552 insertions(+), 372 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index d2d5cef..7265e8e 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -1,8 +1,27 @@ """Configuration for AgentCore Memory Session Manager.""" -from typing import Dict, Optional +from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator + +from bedrock_agentcore.memory.constants import MessageRole +from bedrock_agentcore.memory.models import StringValue + + +class BranchConfig(BaseModel): + """Configuration for AgentCore Memory branching. + + Attributes: + name: Descriptive name for the branch + root_event_id: ID of the event from which this branch originates + """ + + name: str = Field(min_length=1) + root_event_id: Optional[str] = "" + + def to_agentcore_format(self) -> dict: + """Convert to AgentCore Memory API format.""" + return {"name": self.name, "rootEventId": self.root_event_id} class RetrievalConfig(BaseModel): @@ -21,6 +40,13 @@ class RetrievalConfig(BaseModel): initialization_query: Optional[str] = None +class ShortTermRetrievalConfig(BaseModel): + """Configuration for Short term memory retrieval operations""" + + branch_filter: Optional[bool] = True + metadata: Optional[Dict[str, StringValue]] = None + + class AgentCoreMemoryConfig(BaseModel): """Configuration for AgentCore Memory Session Manager. @@ -29,9 +55,29 @@ class AgentCoreMemoryConfig(BaseModel): session_id: Required unique ID for the session actor_id: Required unique ID for the agent instance/user retrieval_config: Optional dictionary mapping namespaces to retrieval configurations + default_branch: Optional default branch configuration for the session + message_types: Optional list of message types to filter + metadata: Optional dictionary of metadata to include with events """ memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1) retrieval_config: Optional[Dict[str, RetrievalConfig]] = None + default_branch: Optional[BranchConfig] = Field( + default_factory=lambda: BranchConfig(name="main", root_event_id="") + ) + short_term_retrieval_config: Optional[ShortTermRetrievalConfig] = ( + ShortTermRetrievalConfig() + ) + message_types: Optional[List[str]] = Field(default=["user", "assistant"]) + metadata: Optional[Dict[str, StringValue]] = ( + None # Currently only supports agent_id. Will be extended further. + ) + + @field_validator("memory_id", "session_id", "actor_id") + @classmethod + def validate_non_empty_strings(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("must be a non-empty string") + return v diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index d77db53..39fa716 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -2,6 +2,7 @@ import json import logging +import re import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone @@ -9,7 +10,8 @@ import boto3 from botocore.config import Config as BotocoreConfig -from strands.hooks import MessageAddedEvent +from strands.agent.state import AgentState +from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -18,7 +20,15 @@ from strands.types.session import Session, SessionAgent, SessionMessage from typing_extensions import override -from bedrock_agentcore.memory.client import MemoryClient +from bedrock_agentcore.memory.constants import ConversationalMessage, MessageRole +from bedrock_agentcore.memory.models import ( + EventMetadataFilter, + LeftExpression, + OperatorType, + RightExpression, + StringValue, +) +from bedrock_agentcore.memory.session import MemorySessionManager from .bedrock_converter import AgentCoreMemoryConverter from .config import AgentCoreMemoryConfig @@ -28,10 +38,6 @@ logger = logging.getLogger(__name__) -SESSION_PREFIX = "session_" -AGENT_PREFIX = "agent_" -MESSAGE_PREFIX = "message_" - class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository): """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration. @@ -52,8 +58,62 @@ class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository) _timestamp_lock = threading.Lock() _last_timestamp: Optional[datetime] = None + def _validate_session_id(self, session_id: str) -> None: + """Validate session ID matches configuration.""" + if session_id != self.config.session_id: + raise SessionException( + f"Session ID mismatch: expected {self.config.session_id}, got {session_id}" + ) + + def _prepare_event_params( + self, + payload: Any, + metadata: Optional[dict] = None, + branch=True, + ) -> dict: + """Prepare common event parameters.""" + event_params = { + "memoryId": self.config.memory_id, + "actorId": self.config.actor_id, + "sessionId": self.config.session_id, + "payload": payload, + "eventTimestamp": self._get_monotonic_timestamp(), + } + + if metadata: + event_params["metadata"] = metadata + + if ( + branch + and self.config.default_branch + and self.config.default_branch.name != "main" + ): + event_params["branch"] = self.config.default_branch.to_agentcore_format() + + return event_params + + def _prepare_list_params(self, branch=True, **kwargs) -> dict: + """Prepare common list_events parameters.""" + params = { + "actor_id": self.config.actor_id, + "session_id": self.config.session_id, + **kwargs, + } + + if ( + self.config.default_branch + and branch + and self.config.short_term_retrieval_config.branch_filter + and self.config.default_branch.name != "main" + ): + params["branch_name"] = self.config.default_branch.name + + return params + @classmethod - def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None) -> datetime: + def _get_monotonic_timestamp( + cls, desired_timestamp: Optional[datetime] = None + ) -> datetime: """Get a monotonically increasing timestamp. Args: @@ -98,368 +158,415 @@ def __init__( Defaults to None. **kwargs (Any): Additional keyword arguments. """ - self.config = agentcore_memory_config - self.memory_client = MemoryClient(region_name=region_name) - session = boto_session or boto3.Session(region_name=region_name) - self.has_existing_agent = False - - # Override the clients if custom boto session or config is provided - # Add strands-agents to the request user agent - if boto_client_config: - existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) - if existing_user_agent: - new_user_agent = f"{existing_user_agent} strands-agents" - else: - new_user_agent = "strands-agents" - client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) - else: - client_config = BotocoreConfig(user_agent_extra="strands-agents") - - # Override the memory client's boto3 clients - self.memory_client.gmcp_client = session.client( - "bedrock-agentcore-control", region_name=region_name or session.region_name, config=client_config - ) - self.memory_client.gmdp_client = session.client( - "bedrock-agentcore", region_name=region_name or session.region_name, config=client_config - ) - super().__init__(session_id=self.config.session_id, session_repository=self) - - def _get_full_session_id(self, session_id: str) -> str: - """Get the full session ID with the configured prefix. - - Args: - session_id (str): The session ID. - - Returns: - str: The full session ID with the prefix. - """ - full_session_id = f"{SESSION_PREFIX}{session_id}" - if full_session_id == self.config.actor_id: - raise SessionException( - f"Cannot have session [ {full_session_id} ] with the same ID as the actor ID: {self.config.actor_id}" - ) - return full_session_id - - def _get_full_agent_id(self, agent_id: str) -> str: - """Get the full agent ID with the configured prefix. - Args: - agent_id (str): The agent ID. - - Returns: - str: The full agent ID with the prefix. - """ - full_agent_id = f"{AGENT_PREFIX}{agent_id}" - if full_agent_id == self.config.actor_id: + self.config = agentcore_memory_config + self._session_cache: Optional[SessionAgent] = None + self._branch_root_events: dict[str, str] = {} + # Validate session_id length + if len(self.config.session_id) < 33: raise SessionException( - f"Cannot create agent [ {full_agent_id} ] with the same ID as the actor ID: {self.config.actor_id}" + f"Session ID must be at least 33 characters long to ensure uniqueness: {self.config.session_id}" ) - return full_agent_id - # region SessionRepository interface implementation - def create_session(self, session: Session, **kwargs: Any) -> Session: - """Create a new session in AgentCore Memory. + # Initialize the new MemorySessionManager + self.memory_session_manager = MemorySessionManager( + memory_id=self.config.memory_id, + region_name=region_name, + boto3_session=boto_session, + boto_client_config=boto_client_config, + ) + self.agent_id = None + + def _get_or_create_branch_root( + self, branch_name: str, session_agent: SessionAgent + ) -> str: + """Get branch root event ID and cache it.""" + if self._branch_root_events.get(branch_name): + return self._branch_root_events[branch_name] + if branch_name == "main": + return None + # Check if branch exists + list_params = self._prepare_list_params(max_results=1) + branch_events = self.memory_session_manager.list_events(**list_params) + + if branch_events: + root_event_id = branch_events[0].get("eventId") + self._branch_root_events[branch_name] = root_event_id + return root_event_id + + # Branch doesn't exist - get main branch root + if branch_name != "main": + main_params = self._prepare_list_params(max_results=1, branch=False) + + main_events = self.memory_session_manager.list_events(**main_params) + + if not main_events: + # Create event in main first + metadata = { + "agent_id": StringValue.build(session_agent.agent_id), + "event_type": StringValue.build("session_state"), + } + if self.config.metadata: + metadata.update(self.config.metadata) - Note: AgentCore Memory doesn't have explicit session creation, - so we just validate the session and return it. + event_params = self._prepare_event_params( + payload=[ + {"blob": json.dumps({"session_state": session_agent.to_dict()})} + ], + metadata=metadata, + branch=False, + ) - Args: - session (Session): The session to create. - **kwargs (Any): Additional keyword arguments. + main_event = self.memory_session_manager.create_event(**event_params) + root_event_id = main_event.get("eventId") + else: + root_event_id = main_events[0].get("eventId") - Returns: - Session: The created session. + self._branch_root_events[branch_name] = root_event_id + return root_event_id - Raises: - SessionException: If session ID doesn't match configuration. - """ - if session.session_id != self.config.session_id: - raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session.session_id}") - - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self._get_full_session_id(session.session_id), - sessionId=self.session_id, - payload=[ - {"blob": json.dumps(session.to_dict())}, - ], - eventTimestamp=self._get_monotonic_timestamp(), - ) - logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId")) - return session + return None - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: - """Read session data. + def create_or_fetch_agent_branch( + self, session_id: str, session_agent: SessionAgent + ) -> Any: + """Create event and update session cache.""" + self._validate_session_id(session_id) - AgentCore Memory does not have a `get_session` method. - Which is fine as AgentCore Memory is a managed service we therefore do not need to read/update - the session data. We just return the session object. + branch_name = self.config.default_branch.name - Args: - session_id (str): The session ID to read. - **kwargs (Any): Additional keyword arguments. + try: + # Get or create branch root + root_event_id = self._get_or_create_branch_root(branch_name, session_agent) + + # Create new session_state event + metadata = { + "agent_id": StringValue.build(session_agent.agent_id), + "event_type": StringValue.build("session_state"), + } + if self.config.metadata: + metadata.update(self.config.metadata) + + if branch_name == "main": + event_params = self._prepare_event_params( + payload=[ + {"blob": json.dumps({"session_state": session_agent.to_dict()})} + ], + metadata=metadata, + ) + created_event = self.memory_session_manager.create_event(**event_params) + else: + created_event = self.memory_session_manager.fork_conversation( + actor_id=self.config.actor_id, + session_id=self.config.session_id, + root_event_id=root_event_id, + branch_name=branch_name, + messages=[ + {"blob": json.dumps({"session_state": session_agent.to_dict()})} + ], + metadata=metadata, + event_timestamp=self._get_monotonic_timestamp(), + ) - Returns: - Optional[Session]: The session if found, None otherwise. - """ - if session_id != self.config.session_id: - return None + self._session_cache = session_agent + return created_event - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self._get_full_session_id(session_id), - session_id=session_id, - max_results=1, - ) - if not events: - return None + except Exception as e: + logger.error("Failed to create or fetch agent branch: %s", e) + raise SessionException( + f"Failed to create or fetch agent branch: {e}" + ) from e - session_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return Session.from_dict(session_data) + def create_session(self, session: Session, **kwargs: Any) -> Session: + """Create a new session.""" + self._validate_session_id(session.session_id) + return session - def delete_session(self, session_id: str, **kwargs: Any) -> None: - """Delete session and all associated data. + def update_agent( + self, session_id: str, session_agent: SessionAgent, **kwargs: Any + ) -> None: + """Update an existing agent.""" + self._validate_session_id(session_id) + self._session_cache = session_agent - Note: AgentCore Memory doesn't support deletion of events, - so this is a no-op operation. + def create_message( + self, + session_id: str, + agent_id: str, + session_message: SessionMessage, + **kwargs: Any, + ) -> None: + """Create a new message.""" + self._validate_session_id(session_id) - Args: - session_id (str): The session ID to delete. - **kwargs (Any): Additional keyword arguments. - """ - logger.warning("Session deletion not supported in AgentCore Memory: %s", session_id) + def update_message( + self, + session_id: str, + agent_id: str, + session_message: SessionMessage, + **kwargs: Any, + ) -> None: + """Update an existing message.""" + self._validate_session_id(session_id) - def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: - """Create a new agent in the session. + # region SessionRepository interface implementation + def create_agent( + self, session_id: str, session_agent: SessionAgent, **kwargs: Any + ) -> SessionAgent: + """Create a new session or get the existing session in AgentCore Memory.""" + logger.debug( + f"Creating agent: {session_agent.agent_id} in session: {session_id}" + ) + self._validate_session_id(session_id) + self.agent_id = session_agent.agent_id + try: + if self.config.default_branch.name == "main": + metadata = { + "agent_id": StringValue.build(session_agent.agent_id), + "event_type": StringValue.build("session_state"), + } + if self.config.metadata: + metadata.update(self.config.metadata) - For AgentCore Memory, we don't need to explicitly create agents; we have Implicit Agent Existence - The agent's existence is inferred from the presence of events/messages in the memory system, - but we validate the session_id matches our config. + # Prepare event parameters + event_params = self._prepare_event_params( + payload=[ + {"blob": json.dumps({"session": session_agent.to_dict()})} + ], + metadata=metadata, + ) + event = self.memory_session_manager.create_event(**event_params) + self._session_cache = session_agent + else: + event = self.create_or_fetch_agent_branch(session_id, session_agent) - Args: - session_id (str): The session ID to create the agent in. - session_agent (SessionAgent): The agent to create. - **kwargs (Any): Additional keyword arguments. + self._session_cache = session_agent + logger.info( + "Created session: %s with event: %s", + session_id, + event.get("eventId"), + ) + logger.debug( + f"Successfully created session with event: {event.get('eventId')}" + ) + return session_agent + except Exception as e: + logger.error("Failed to create session: %s", e) + raise SessionException(f"Failed to create session: {e}") from e - Raises: - SessionException: If session ID doesn't match configuration. - """ - if session_id != self.config.session_id: - raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), - sessionId=self.session_id, - payload=[ - {"blob": json.dumps(session_agent.to_dict())}, - ], - eventTimestamp=self._get_monotonic_timestamp(), - ) - logger.info( - "Created agent: %s in session: %s with event %s", - session_agent.agent_id, - session_id, - event.get("event", {}).get("eventId"), + def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + """Read session data - always from the main branch.""" + logger.debug(f"Reading session: {session_id}") + + # Return cached session if available + if ( + self._session_cache + and hasattr(self._session_cache, "session_id") + and self._session_cache.session_id == session_id + ): + logger.debug(f"Returning cached session: {session_id}") + return self._session_cache + + session = self.read_agent( + session_id, agent_id=self.agent_id if self.agent_id else "default" ) + return session - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: - """Read agent data from AgentCore Memory events. - - We reconstruct the agent state from the conversation history. - - Args: - session_id (str): The session ID to read from. - agent_id (str): The agent ID to read. - **kwargs (Any): Additional keyword arguments. + def read_agent( + self, session_id: str, agent_id: str, **kwargs: Any + ) -> Optional[SessionAgent]: + """Read agent data from AgentCore Memory events (uses branch filtering).""" + logger.debug(f"Reading agent: {agent_id} from session: {session_id}") + self._validate_session_id(session_id) + if self._session_cache: + logger.debug(f"Returning cached session: {session_id}") + return self._session_cache - Returns: - Optional[SessionAgent]: The agent if found, None otherwise. - """ - if session_id != self.config.session_id: - return None try: - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self._get_full_agent_id(agent_id), - session_id=session_id, - max_results=1, + logger.debug(f"Building metadata filters for agent {agent_id}") + # Agent operations use branch filtering + filters = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="event_type"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build("session_state"), + ) + ] + if ( + agent_id + and hasattr(self.config, "short_term_retrieval_config") + and hasattr(self.config.short_term_retrieval_config, "metadata") + and self.config.short_term_retrieval_config.metadata + and "agent_id" + in self.config.short_term_retrieval_config.metadata.keys() + ): + logger.debug( + f"Using metadata filter: {self.config.short_term_retrieval_config.metadata['agent_id']}" + ) + filters.append( + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="agent_id"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build( + self.config.short_term_retrieval_config.metadata["agent_id"] + ), + ) + ) + list_params = self._prepare_list_params( + eventMetadata=filters, max_results=1 ) + events = self.memory_session_manager.list_events(**list_params) + logger.debug(f"Found {len(events)} agent events for {agent_id}") + if not events: + logger.debug(f"No events found for agent {agent_id}") return None - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) + payload = events[-1].get("payload", []) + if payload and "blob" in payload[0]: + blob_data = json.loads(payload[0]["blob"]) + # Extract SessionAgent from blob. + session_data = blob_data.get("session_state") or blob_data.get( + "session" + ) + agent_data = SessionAgent.from_dict(session_data) + logger.debug(f"Successfully read agent {agent_id}") + self._session_cache = agent_data + return agent_data + + logger.debug(f"Agent {agent_id} not found - no valid payload") + return None except Exception as e: - logger.error("Failed to read agent %s", e) + logger.error("Failed to read agent %s: %s", agent_id, e) return None - def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: - """Update agent data. - - Args: - session_id (str): The session ID containing the agent. - session_agent (SessionAgent): The agent to update. - **kwargs (Any): Additional keyword arguments. - - Raises: - SessionException: If session ID doesn't match configuration. - """ - agent_id = session_agent.agent_id - previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) - if previous_agent is None: - raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") - - session_agent.created_at = previous_agent.created_at - # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` - self.create_agent(session_id, session_agent) - - def create_message( - self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any - ) -> Optional[dict[str, Any]]: - """Create a new message in AgentCore Memory. - - Args: - session_id (str): The session ID to create the message in. - agent_id (str): The agent ID associated with the message (only here for the interface. - We use the actorId for AgentCore). - session_message (SessionMessage): The message to create. - **kwargs (Any): Additional keyword arguments. - - Returns: - Optional[dict[str, Any]]: The created event data from AgentCore Memory. - - Raises: - SessionException: If session ID doesn't match configuration or message creation fails. - - Note: - The returned created message `event` looks like: - ```python - { - "memoryId": "my-mem-id", - "actorId": "user_1", - "sessionId": "test_session_id", - "eventId": "0000001752235548000#97f30a6b", - "eventTimestamp": datetime.datetime(2025, 8, 18, 12, 45, 48, tzinfo=tzlocal()), - "branch": {"name": "main"}, - } - ``` - """ - if session_id != self.config.session_id: - raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - + def read_message( + self, session_id: str, agent_id: str, message_id: int, **kwargs: Any + ) -> Optional[SessionMessage]: + """Read a specific message by ID from AgentCore Memory.""" try: - messages = AgentCoreMemoryConverter.message_to_payload(session_message) - if not messages: - return - - # Parse the original timestamp and use it as desired timestamp - original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) - monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) - - if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]): - event = self.memory_client.create_event( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - messages=messages, - event_timestamp=monotonic_timestamp, - ) - else: - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=session_id, - payload=[ - {"blob": json.dumps(messages[0])}, - ], - eventTimestamp=monotonic_timestamp, - ) - logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id) - return event + result = self.memory_session_manager.get_event( + actor_id=self.config.actor_id, + session_id=session_id, + event_id=str(message_id), + ) + return SessionMessage.from_dict(result) if result else None except Exception as e: - logger.error("Failed to create message in AgentCore Memory: %s", e) - raise SessionException(f"Failed to create message: {e}") from e - - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: - """Read a specific message by ID from AgentCore Memory. - - Args: - session_id (str): The session ID to read from. - agent_id (str): The agent ID associated with the message. - message_id (int): The message ID to read. - **kwargs (Any): Additional keyword arguments. - - Returns: - Optional[SessionMessage]: The message if found, None otherwise. + logger.error("Failed to read message: %s", e) + return None - Note: - This should not be called as (as of now) only the `update_message` method calls this method and - updating messages is not supported in AgentCore Memory. - """ - result = self.memory_client.gmdp_client.get_event( - memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, eventId=message_id - ) - return SessionMessage.from_dict(result) if result else None + def save_turn_messages( + self, event: AfterInvocationEvent, **kwargs: Any + ) -> Optional[dict[str, Any]]: + """Save turn messages to MemorySessionManager with both SessionAgent and Session in blob.""" + try: + logger.debug(f"Saving turn messages for agent: {event.agent.agent_id}") + # Filter messages based on configured message types + filtered_messages = [] + role_map = { + "user": MessageRole.USER, + "assistant": MessageRole.ASSISTANT, + "tool": MessageRole.TOOL, + "other": MessageRole.OTHER, + } + + for message in reversed(event.agent.messages): + role = message.get("role") + if role in self.config.message_types: + content = message.get("content", [{}])[0].get("text", "") + mapped_role = role_map.get(role, MessageRole.ASSISTANT) + if role == "user": + content = self.remove_user_context(content) + filtered_messages.append( + ConversationalMessage(content, mapped_role) + ) + if role == "user": + break + + logger.debug(f"Filtered {len(filtered_messages)} messages to save..") + if not filtered_messages: + return None - def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: - """Update message data. + # Prepare enhanced metadata using agent properties + event_metadata = { + "agent_id": StringValue.build(event.agent.agent_id), + } - Note: AgentCore Memory doesn't support updating events, - so this is primarily for validation and logging. + # Merge with config metadata + if self.config.metadata: + event_metadata.update(self.config.metadata) - Args: - session_id (str): The session ID containing the message. - agent_id (str): The agent ID associated with the message. - session_message (SessionMessage): The message to update. - **kwargs (Any): Additional keyword arguments. + # Prepare event parameters + event_params = self._prepare_event_params( + payload=filtered_messages, metadata=event_metadata + ) + # Add branch configuration if not "main" + event_params = { + "actor_id": self.config.actor_id, + "session_id": self.config.session_id, + "messages": filtered_messages, + "metadata": event_metadata, + "event_timestamp": self._get_monotonic_timestamp(), + } + if self.config.default_branch and self.config.default_branch.name != "main": + event_params["branch"] = ( + self.config.default_branch.to_agentcore_format() + ) - Raises: - SessionException: If session ID doesn't match configuration. - """ - if session_id != self.config.session_id: - raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") + return self.memory_session_manager.add_turns(**event_params) - logger.debug( - "Message update requested for message: %s (AgentCore Memory doesn't support updates)", - {session_message.message_id}, - ) + except Exception as e: + logger.error("Failed to save turn messages: %s", e) + raise SessionException(f"Failed to save turn messages: {e}") from e def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, + session_id: str, + agent_id: str, + limit: Optional[int] = None, + offset: int = 0, + **kwargs: Any, ) -> list[SessionMessage]: - """List messages for an agent from AgentCore Memory with pagination. + """List messages for an agent from AgentCore Memory with pagination.""" + self._validate_session_id(session_id) + logger.debug(f"Listing messages for agent: {agent_id}, limit: {limit}") - Args: - session_id (str): The session ID to list messages from. - agent_id (str): The agent ID to list messages for. - limit (Optional[int], optional): Maximum number of messages to return. Defaults to None. - offset (int, optional): Number of messages to skip. Defaults to 0. - **kwargs (Any): Additional keyword arguments. + try: + max_results = (limit + offset) if limit else 100 - Returns: - list[SessionMessage]: list of messages for the agent. + # Filter for non-session-state events (conversation messages) + message_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="event_type"), + operator=OperatorType.NOT_EXISTS, + ) - Raises: - SessionException: If session ID doesn't match configuration. - """ - if session_id != self.config.session_id: - raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") + filters = [message_filter] + + # Create metadata filter for agent_id if configured + if ( + hasattr(self.config, "retrieval_config") + and hasattr(self.config.short_term_retrieval_config, "metadata") + and self.config.short_term_retrieval_config.metadata + and "agent_id" + in self.config.short_term_retrieval_config.metadata.keys() + ): + agent_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="agent_id"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ) + filters.append(agent_filter) - try: - max_results = (limit + offset) if limit else 100 - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, + list_params = self._prepare_list_params( max_results=max_results, + eventMetadata=filters, ) + + events = self.memory_session_manager.list_events(**list_params) + logger.debug(f"Found {len(events)} events") messages = AgentCoreMemoryConverter.events_to_messages(events) - if limit is not None: - return messages[offset : offset + limit] - else: - return messages[offset:] + return messages except Exception as e: logger.error("Failed to list messages from AgentCore Memory: %s", e) @@ -468,51 +575,76 @@ def list_messages( # endregion SessionRepository interface implementation # region RepositorySessionManager overrides + @staticmethod + def remove_user_context(text: str) -> str: + """Remove user context from text.""" + return re.sub( + r".*?", "", text, flags=re.DOTALL + ).strip() + @override - def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: - """Append a message to the agent's session using AgentCore's eventId as message_id. + def initialize(self, agent: "Agent", **kwargs: Any) -> None: + """Agent hydration with branch support for multiple actors.""" + try: + self.agent_id = agent.agent_id + # Use create_or_fetch_agent_branch for both main and custom branches + if not self.read_agent(self.config.session_id, agent.agent_id): + self.create_agent( + self.config.session_id, SessionAgent.from_agent(agent) + ) + # Set agent state from cached session + agent.state = AgentState( + self._session_cache.state if self._session_cache else {} + ) - Args: - message: Message to add to the agent in the session - agent: Agent to append the message to - **kwargs: Additional keyword arguments for future extensibility. - """ - created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0)) - session_message = SessionMessage.from_message(message, created_message.get("eventId")) - self._latest_agent_message[agent.agent_id] = session_message + # Load previous messages + prev_messages = self.list_messages( + self.config.session_id, agent.agent_id, limit=10 + ) + agent.messages = prev_messages if prev_messages else [] - def retrieve_customer_context(self, event: MessageAddedEvent) -> None: - """Retrieve customer LTM context before processing support query. + except Exception as e: + logger.error("Failed to initialize agent %s: %s", agent.agent_id, e) + raise - Args: - event (MessageAddedEvent): The message added event containing the agent and message data. - """ + def retrieve_customer_context(self, event: MessageAddedEvent) -> None: + """Retrieve customer context from both short-term and long-term memory.""" messages = event.agent.messages - if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]: - return None - if not self.config.retrieval_config: - # Only retrieve LTM - return None + if not messages or messages[-1].get("role") != "user": + return + + logger.debug(f"Retrieving customer context for agent:") + # Skip if message contains tool results + last_content = messages[-1].get("content", []) + if last_content and any( + "toolResult" in str(content) for content in last_content + ): + return user_query = messages[-1]["content"][0]["text"] - def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConfig): + def retrieve_for_namespace( + namespace: str, retrieval_config: AgentCoreMemoryConfig + ): """Helper function to retrieve memories for a single namespace.""" resolved_namespace = namespace.format( actorId=self.config.actor_id, sessionId=self.config.session_id, memoryStrategyId=retrieval_config.strategy_id or "", ) - - memories = self.memory_client.retrieve_memories( - memory_id=self.config.memory_id, - namespace=resolved_namespace, + memories = self.memory_session_manager.search_long_term_memories( query=user_query, + namespace_prefix=resolved_namespace, top_k=retrieval_config.top_k, ) + context_items = [] for memory in memories: - if isinstance(memory, dict): + if hasattr(memory, "content") and hasattr(memory.content, "text"): + text = memory.content.text.strip() + if text: + all_context.append(text) + elif isinstance(memory, dict): content = memory.get("content", {}) if isinstance(content, dict): text = content.get("text", "").strip() @@ -521,12 +653,13 @@ def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConf return context_items try: - # Retrieve customer context from all namespaces in parallel all_context = [] - + # Retrieve from long-term memory in parallel with ThreadPoolExecutor() as executor: future_to_namespace = { - executor.submit(retrieve_for_namespace, namespace, retrieval_config): namespace + executor.submit( + retrieve_for_namespace, namespace, retrieval_config + ): namespace for namespace, retrieval_config in self.config.retrieval_config.items() } for future in as_completed(future_to_namespace): @@ -536,40 +669,41 @@ def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConf except Exception as e: # Continue processing other futures event if one fails rather than failing the entire operation namespace = future_to_namespace[future] - logger.error("Failed to retrieve memories for namespace %s: %s", namespace, e) + logger.error( + "Failed to retrieve memories for namespace %s: %s", + namespace, + e, + ) - # Inject customer context into the query if all_context: + original_text = messages[-1]["content"][0]["text"] context_text = "\n".join(all_context) - ltm_msg: Message = { - "role": "assistant", - "content": [{"text": f"{context_text}"}], - } - event.agent.messages.append(ltm_msg) - logger.info("Retrieved %s customer context items", len(all_context)) + messages[-1]["content"][0][ + "text" + ] = f"\n{context_text}\n\n\n{original_text}\n" + event.agent.messages[-1]["content"][0]["text"] = context_text + logger.info("Retrieved %d customer context items", len(all_context)) except Exception as e: logger.error("Failed to retrieve customer context: %s", e) @override def register_hooks(self, registry: HookRegistry, **kwargs) -> None: - """Register additional hooks. + """Register hooks. Args: registry (HookRegistry): The hook registry to register callbacks with. **kwargs: Additional keyword arguments. """ - RepositorySessionManager.register_hooks(self, registry, **kwargs) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) - - @override - def initialize(self, agent: "Agent", **kwargs: Any) -> None: - if self.has_existing_agent: - logger.warning( - "An Agent already exists in session %s. We currently support one agent per session.", self.session_id - ) - else: - self.has_existing_agent = True - RepositorySessionManager.initialize(self, agent, **kwargs) - - # endregion RepositorySessionManager overrides + # After the normal Agent initialization behavior, call the session initialize function to restore the agent + registry.add_callback( + AgentInitializedEvent, lambda event: self.initialize(event.agent) + ) + # For each message appended to the Agents messages, store that message in the session + # After an agent was invoked, sync it with the session to capture any conversation manager state updates + registry.add_callback( + AfterInvocationEvent, lambda event: self.save_turn_messages(event) + ) + registry.add_callback( + MessageAddedEvent, lambda event: self.retrieve_customer_context(event) + ) From 58ed5f791b664055744dd21ff299be8871bfaaf0 Mon Sep 17 00:00:00 2001 From: Akarsha Sehwag Date: Wed, 26 Nov 2025 14:38:39 -0500 Subject: [PATCH 2/3] fix(session_manager): update session_manager --- .../integrations/strands/session_manager.py | 79 +++++++------------ 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 39fa716..354b0b7 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -204,14 +204,14 @@ def _get_or_create_branch_root( # Create event in main first metadata = { "agent_id": StringValue.build(session_agent.agent_id), - "event_type": StringValue.build("session_state"), + "event_type": StringValue.build("session_agent"), } if self.config.metadata: metadata.update(self.config.metadata) event_params = self._prepare_event_params( payload=[ - {"blob": json.dumps({"session_state": session_agent.to_dict()})} + {"blob": json.dumps({"session_agent": session_agent.to_dict()})} ], metadata=metadata, branch=False, @@ -239,10 +239,10 @@ def create_or_fetch_agent_branch( # Get or create branch root root_event_id = self._get_or_create_branch_root(branch_name, session_agent) - # Create new session_state event + # Create new session_agent event metadata = { "agent_id": StringValue.build(session_agent.agent_id), - "event_type": StringValue.build("session_state"), + "event_type": StringValue.build("session_agent"), } if self.config.metadata: metadata.update(self.config.metadata) @@ -250,7 +250,7 @@ def create_or_fetch_agent_branch( if branch_name == "main": event_params = self._prepare_event_params( payload=[ - {"blob": json.dumps({"session_state": session_agent.to_dict()})} + {"blob": json.dumps({"session_agent": session_agent.to_dict()})} ], metadata=metadata, ) @@ -262,7 +262,7 @@ def create_or_fetch_agent_branch( root_event_id=root_event_id, branch_name=branch_name, messages=[ - {"blob": json.dumps({"session_state": session_agent.to_dict()})} + {"blob": json.dumps({"session_agent": session_agent.to_dict()})} ], metadata=metadata, event_timestamp=self._get_monotonic_timestamp(), @@ -323,7 +323,7 @@ def create_agent( if self.config.default_branch.name == "main": metadata = { "agent_id": StringValue.build(session_agent.agent_id), - "event_type": StringValue.build("session_state"), + "event_type": StringValue.build("session_agent"), } if self.config.metadata: metadata.update(self.config.metadata) @@ -331,7 +331,7 @@ def create_agent( # Prepare event parameters event_params = self._prepare_event_params( payload=[ - {"blob": json.dumps({"session": session_agent.to_dict()})} + {"blob": json.dumps({"session_agent": session_agent.to_dict()})} ], metadata=metadata, ) @@ -389,29 +389,14 @@ def read_agent( EventMetadataFilter.build_expression( left_operand=LeftExpression.build(key="event_type"), operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build("session_state"), - ) + right_operand=RightExpression.build("session_agent"), + ), + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="agent_id"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ), ] - if ( - agent_id - and hasattr(self.config, "short_term_retrieval_config") - and hasattr(self.config.short_term_retrieval_config, "metadata") - and self.config.short_term_retrieval_config.metadata - and "agent_id" - in self.config.short_term_retrieval_config.metadata.keys() - ): - logger.debug( - f"Using metadata filter: {self.config.short_term_retrieval_config.metadata['agent_id']}" - ) - filters.append( - EventMetadataFilter.build_expression( - left_operand=LeftExpression.build(key="agent_id"), - operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build( - self.config.short_term_retrieval_config.metadata["agent_id"] - ), - ) - ) list_params = self._prepare_list_params( eventMetadata=filters, max_results=1 ) @@ -427,9 +412,7 @@ def read_agent( if payload and "blob" in payload[0]: blob_data = json.loads(payload[0]["blob"]) # Extract SessionAgent from blob. - session_data = blob_data.get("session_state") or blob_data.get( - "session" - ) + session_data = blob_data.get("session_agent") agent_data = SessionAgent.from_dict(session_data) logger.debug(f"Successfully read agent {agent_id}") self._session_cache = agent_data @@ -541,22 +524,13 @@ def list_messages( operator=OperatorType.NOT_EXISTS, ) - filters = [message_filter] - - # Create metadata filter for agent_id if configured - if ( - hasattr(self.config, "retrieval_config") - and hasattr(self.config.short_term_retrieval_config, "metadata") - and self.config.short_term_retrieval_config.metadata - and "agent_id" - in self.config.short_term_retrieval_config.metadata.keys() - ): - agent_filter = EventMetadataFilter.build_expression( - left_operand=LeftExpression.build(key="agent_id"), - operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build(agent_id), - ) - filters.append(agent_filter) + # Always filter by agent_id to prevent cross-pollution + agent_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(key="agent_id"), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ) + filters = [message_filter, agent_filter] list_params = self._prepare_list_params( max_results=max_results, @@ -565,8 +539,7 @@ def list_messages( events = self.memory_session_manager.list_events(**list_params) logger.debug(f"Found {len(events)} events") - messages = AgentCoreMemoryConverter.events_to_messages(events) - return messages + return AgentCoreMemoryConverter.events_to_messages(events) if events else [] except Exception as e: logger.error("Failed to list messages from AgentCore Memory: %s", e) @@ -654,6 +627,10 @@ def retrieve_for_namespace( try: all_context = [] + # Check if retrieval_config exists and is not empty + if not self.config.retrieval_config: + return + # Retrieve from long-term memory in parallel with ThreadPoolExecutor() as executor: future_to_namespace = { From ec7eea3b71a49f0fed657bc34739633b43df3320 Mon Sep 17 00:00:00 2001 From: Akarsha Sehwag Date: Fri, 28 Nov 2025 15:12:58 -0500 Subject: [PATCH 3/3] feat(session_manager): update create_session functionality --- .../integrations/strands/session_manager.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 354b0b7..c8eb042 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -280,6 +280,24 @@ def create_or_fetch_agent_branch( def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" self._validate_session_id(session.session_id) + + try: + metadata = { + "event_type": StringValue.build("session"), + } + if self.config.metadata: + metadata.update(self.config.metadata) + + event_params = self._prepare_event_params( + payload=[{"blob": json.dumps({"session": session.to_dict()})}], + metadata=metadata, + ) + + self.memory_session_manager.create_event(**event_params) + logger.info(f"Created session: {session.session_id}") + except Exception as e: + logger.error(f"Failed to create session: {e}") + raise SessionException(f"Failed to create session: {e}") from e return session def update_agent(