diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py
index d2d5cef..7265e8e 100644
--- a/src/bedrock_agentcore/memory/integrations/strands/config.py
+++ b/src/bedrock_agentcore/memory/integrations/strands/config.py
@@ -1,8 +1,27 @@
"""Configuration for AgentCore Memory Session Manager."""
-from typing import Dict, Optional
+from typing import Dict, List, Optional
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, field_validator
+
+from bedrock_agentcore.memory.constants import MessageRole
+from bedrock_agentcore.memory.models import StringValue
+
+
+class BranchConfig(BaseModel):
+ """Configuration for AgentCore Memory branching.
+
+ Attributes:
+ name: Descriptive name for the branch
+ root_event_id: ID of the event from which this branch originates
+ """
+
+ name: str = Field(min_length=1)
+ root_event_id: Optional[str] = ""
+
+ def to_agentcore_format(self) -> dict:
+ """Convert to AgentCore Memory API format."""
+ return {"name": self.name, "rootEventId": self.root_event_id}
class RetrievalConfig(BaseModel):
@@ -21,6 +40,13 @@ class RetrievalConfig(BaseModel):
initialization_query: Optional[str] = None
+class ShortTermRetrievalConfig(BaseModel):
+ """Configuration for Short term memory retrieval operations"""
+
+ branch_filter: Optional[bool] = True
+ metadata: Optional[Dict[str, StringValue]] = None
+
+
class AgentCoreMemoryConfig(BaseModel):
"""Configuration for AgentCore Memory Session Manager.
@@ -29,9 +55,29 @@ class AgentCoreMemoryConfig(BaseModel):
session_id: Required unique ID for the session
actor_id: Required unique ID for the agent instance/user
retrieval_config: Optional dictionary mapping namespaces to retrieval configurations
+ default_branch: Optional default branch configuration for the session
+ message_types: Optional list of message types to filter
+ metadata: Optional dictionary of metadata to include with events
"""
memory_id: str = Field(min_length=1)
session_id: str = Field(min_length=1)
actor_id: str = Field(min_length=1)
retrieval_config: Optional[Dict[str, RetrievalConfig]] = None
+ default_branch: Optional[BranchConfig] = Field(
+ default_factory=lambda: BranchConfig(name="main", root_event_id="")
+ )
+ short_term_retrieval_config: Optional[ShortTermRetrievalConfig] = (
+ ShortTermRetrievalConfig()
+ )
+ message_types: Optional[List[str]] = Field(default=["user", "assistant"])
+ metadata: Optional[Dict[str, StringValue]] = (
+ None # Currently only supports agent_id. Will be extended further.
+ )
+
+ @field_validator("memory_id", "session_id", "actor_id")
+ @classmethod
+ def validate_non_empty_strings(cls, v: str) -> str:
+ if not v or not v.strip():
+ raise ValueError("must be a non-empty string")
+ return v
diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py
index d77db53..c8eb042 100644
--- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py
+++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py
@@ -2,6 +2,7 @@
import json
import logging
+import re
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
@@ -9,7 +10,8 @@
import boto3
from botocore.config import Config as BotocoreConfig
-from strands.hooks import MessageAddedEvent
+from strands.agent.state import AgentState
+from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
from strands.hooks.registry import HookRegistry
from strands.session.repository_session_manager import RepositorySessionManager
from strands.session.session_repository import SessionRepository
@@ -18,7 +20,15 @@
from strands.types.session import Session, SessionAgent, SessionMessage
from typing_extensions import override
-from bedrock_agentcore.memory.client import MemoryClient
+from bedrock_agentcore.memory.constants import ConversationalMessage, MessageRole
+from bedrock_agentcore.memory.models import (
+ EventMetadataFilter,
+ LeftExpression,
+ OperatorType,
+ RightExpression,
+ StringValue,
+)
+from bedrock_agentcore.memory.session import MemorySessionManager
from .bedrock_converter import AgentCoreMemoryConverter
from .config import AgentCoreMemoryConfig
@@ -28,10 +38,6 @@
logger = logging.getLogger(__name__)
-SESSION_PREFIX = "session_"
-AGENT_PREFIX = "agent_"
-MESSAGE_PREFIX = "message_"
-
class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository):
"""AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.
@@ -52,8 +58,62 @@ class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository)
_timestamp_lock = threading.Lock()
_last_timestamp: Optional[datetime] = None
+ def _validate_session_id(self, session_id: str) -> None:
+ """Validate session ID matches configuration."""
+ if session_id != self.config.session_id:
+ raise SessionException(
+ f"Session ID mismatch: expected {self.config.session_id}, got {session_id}"
+ )
+
+ def _prepare_event_params(
+ self,
+ payload: Any,
+ metadata: Optional[dict] = None,
+ branch=True,
+ ) -> dict:
+ """Prepare common event parameters."""
+ event_params = {
+ "memoryId": self.config.memory_id,
+ "actorId": self.config.actor_id,
+ "sessionId": self.config.session_id,
+ "payload": payload,
+ "eventTimestamp": self._get_monotonic_timestamp(),
+ }
+
+ if metadata:
+ event_params["metadata"] = metadata
+
+ if (
+ branch
+ and self.config.default_branch
+ and self.config.default_branch.name != "main"
+ ):
+ event_params["branch"] = self.config.default_branch.to_agentcore_format()
+
+ return event_params
+
+ def _prepare_list_params(self, branch=True, **kwargs) -> dict:
+ """Prepare common list_events parameters."""
+ params = {
+ "actor_id": self.config.actor_id,
+ "session_id": self.config.session_id,
+ **kwargs,
+ }
+
+ if (
+ self.config.default_branch
+ and branch
+ and self.config.short_term_retrieval_config.branch_filter
+ and self.config.default_branch.name != "main"
+ ):
+ params["branch_name"] = self.config.default_branch.name
+
+ return params
+
@classmethod
- def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None) -> datetime:
+ def _get_monotonic_timestamp(
+ cls, desired_timestamp: Optional[datetime] = None
+ ) -> datetime:
"""Get a monotonically increasing timestamp.
Args:
@@ -98,368 +158,406 @@ def __init__(
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
- self.config = agentcore_memory_config
- self.memory_client = MemoryClient(region_name=region_name)
- session = boto_session or boto3.Session(region_name=region_name)
- self.has_existing_agent = False
-
- # Override the clients if custom boto session or config is provided
- # Add strands-agents to the request user agent
- if boto_client_config:
- existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
- if existing_user_agent:
- new_user_agent = f"{existing_user_agent} strands-agents"
- else:
- new_user_agent = "strands-agents"
- client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
- else:
- client_config = BotocoreConfig(user_agent_extra="strands-agents")
-
- # Override the memory client's boto3 clients
- self.memory_client.gmcp_client = session.client(
- "bedrock-agentcore-control", region_name=region_name or session.region_name, config=client_config
- )
- self.memory_client.gmdp_client = session.client(
- "bedrock-agentcore", region_name=region_name or session.region_name, config=client_config
- )
- super().__init__(session_id=self.config.session_id, session_repository=self)
-
- def _get_full_session_id(self, session_id: str) -> str:
- """Get the full session ID with the configured prefix.
- Args:
- session_id (str): The session ID.
-
- Returns:
- str: The full session ID with the prefix.
- """
- full_session_id = f"{SESSION_PREFIX}{session_id}"
- if full_session_id == self.config.actor_id:
+ self.config = agentcore_memory_config
+ self._session_cache: Optional[SessionAgent] = None
+ self._branch_root_events: dict[str, str] = {}
+ # Validate session_id length
+ if len(self.config.session_id) < 33:
raise SessionException(
- f"Cannot have session [ {full_session_id} ] with the same ID as the actor ID: {self.config.actor_id}"
+ f"Session ID must be at least 33 characters long to ensure uniqueness: {self.config.session_id}"
)
- return full_session_id
-
- def _get_full_agent_id(self, agent_id: str) -> str:
- """Get the full agent ID with the configured prefix.
- Args:
- agent_id (str): The agent ID.
+ # Initialize the new MemorySessionManager
+ self.memory_session_manager = MemorySessionManager(
+ memory_id=self.config.memory_id,
+ region_name=region_name,
+ boto3_session=boto_session,
+ boto_client_config=boto_client_config,
+ )
+ self.agent_id = None
+
+ def _get_or_create_branch_root(
+ self, branch_name: str, session_agent: SessionAgent
+ ) -> str:
+ """Get branch root event ID and cache it."""
+ if self._branch_root_events.get(branch_name):
+ return self._branch_root_events[branch_name]
+ if branch_name == "main":
+ return None
+ # Check if branch exists
+ list_params = self._prepare_list_params(max_results=1)
+ branch_events = self.memory_session_manager.list_events(**list_params)
+
+ if branch_events:
+ root_event_id = branch_events[0].get("eventId")
+ self._branch_root_events[branch_name] = root_event_id
+ return root_event_id
+
+ # Branch doesn't exist - get main branch root
+ if branch_name != "main":
+ main_params = self._prepare_list_params(max_results=1, branch=False)
+
+ main_events = self.memory_session_manager.list_events(**main_params)
+
+ if not main_events:
+ # Create event in main first
+ metadata = {
+ "agent_id": StringValue.build(session_agent.agent_id),
+ "event_type": StringValue.build("session_agent"),
+ }
+ if self.config.metadata:
+ metadata.update(self.config.metadata)
- Returns:
- str: The full agent ID with the prefix.
- """
- full_agent_id = f"{AGENT_PREFIX}{agent_id}"
- if full_agent_id == self.config.actor_id:
- raise SessionException(
- f"Cannot create agent [ {full_agent_id} ] with the same ID as the actor ID: {self.config.actor_id}"
- )
- return full_agent_id
+ event_params = self._prepare_event_params(
+ payload=[
+ {"blob": json.dumps({"session_agent": session_agent.to_dict()})}
+ ],
+ metadata=metadata,
+ branch=False,
+ )
- # region SessionRepository interface implementation
- def create_session(self, session: Session, **kwargs: Any) -> Session:
- """Create a new session in AgentCore Memory.
+ main_event = self.memory_session_manager.create_event(**event_params)
+ root_event_id = main_event.get("eventId")
+ else:
+ root_event_id = main_events[0].get("eventId")
- Note: AgentCore Memory doesn't have explicit session creation,
- so we just validate the session and return it.
+ self._branch_root_events[branch_name] = root_event_id
+ return root_event_id
- Args:
- session (Session): The session to create.
- **kwargs (Any): Additional keyword arguments.
+ return None
- Returns:
- Session: The created session.
+ def create_or_fetch_agent_branch(
+ self, session_id: str, session_agent: SessionAgent
+ ) -> Any:
+ """Create event and update session cache."""
+ self._validate_session_id(session_id)
- Raises:
- SessionException: If session ID doesn't match configuration.
- """
- if session.session_id != self.config.session_id:
- raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session.session_id}")
-
- event = self.memory_client.gmdp_client.create_event(
- memoryId=self.config.memory_id,
- actorId=self._get_full_session_id(session.session_id),
- sessionId=self.session_id,
- payload=[
- {"blob": json.dumps(session.to_dict())},
- ],
- eventTimestamp=self._get_monotonic_timestamp(),
- )
- logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId"))
- return session
+ branch_name = self.config.default_branch.name
- def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
- """Read session data.
+ try:
+ # Get or create branch root
+ root_event_id = self._get_or_create_branch_root(branch_name, session_agent)
+
+ # Create new session_agent event
+ metadata = {
+ "agent_id": StringValue.build(session_agent.agent_id),
+ "event_type": StringValue.build("session_agent"),
+ }
+ if self.config.metadata:
+ metadata.update(self.config.metadata)
+
+ if branch_name == "main":
+ event_params = self._prepare_event_params(
+ payload=[
+ {"blob": json.dumps({"session_agent": session_agent.to_dict()})}
+ ],
+ metadata=metadata,
+ )
+ created_event = self.memory_session_manager.create_event(**event_params)
+ else:
+ created_event = self.memory_session_manager.fork_conversation(
+ actor_id=self.config.actor_id,
+ session_id=self.config.session_id,
+ root_event_id=root_event_id,
+ branch_name=branch_name,
+ messages=[
+ {"blob": json.dumps({"session_agent": session_agent.to_dict()})}
+ ],
+ metadata=metadata,
+ event_timestamp=self._get_monotonic_timestamp(),
+ )
- AgentCore Memory does not have a `get_session` method.
- Which is fine as AgentCore Memory is a managed service we therefore do not need to read/update
- the session data. We just return the session object.
+ self._session_cache = session_agent
+ return created_event
- Args:
- session_id (str): The session ID to read.
- **kwargs (Any): Additional keyword arguments.
+ except Exception as e:
+ logger.error("Failed to create or fetch agent branch: %s", e)
+ raise SessionException(
+ f"Failed to create or fetch agent branch: {e}"
+ ) from e
- Returns:
- Optional[Session]: The session if found, None otherwise.
- """
- if session_id != self.config.session_id:
- return None
+ def create_session(self, session: Session, **kwargs: Any) -> Session:
+ """Create a new session."""
+ self._validate_session_id(session.session_id)
- events = self.memory_client.list_events(
- memory_id=self.config.memory_id,
- actor_id=self._get_full_session_id(session_id),
- session_id=session_id,
- max_results=1,
- )
- if not events:
- return None
+ try:
+ metadata = {
+ "event_type": StringValue.build("session"),
+ }
+ if self.config.metadata:
+ metadata.update(self.config.metadata)
+
+ event_params = self._prepare_event_params(
+ payload=[{"blob": json.dumps({"session": session.to_dict()})}],
+ metadata=metadata,
+ )
- session_data = json.loads(events[0].get("payload", {})[0].get("blob"))
- return Session.from_dict(session_data)
+ self.memory_session_manager.create_event(**event_params)
+ logger.info(f"Created session: {session.session_id}")
+ except Exception as e:
+ logger.error(f"Failed to create session: {e}")
+ raise SessionException(f"Failed to create session: {e}") from e
+ return session
- def delete_session(self, session_id: str, **kwargs: Any) -> None:
- """Delete session and all associated data.
+ def update_agent(
+ self, session_id: str, session_agent: SessionAgent, **kwargs: Any
+ ) -> None:
+ """Update an existing agent."""
+ self._validate_session_id(session_id)
+ self._session_cache = session_agent
- Note: AgentCore Memory doesn't support deletion of events,
- so this is a no-op operation.
+ def create_message(
+ self,
+ session_id: str,
+ agent_id: str,
+ session_message: SessionMessage,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new message."""
+ self._validate_session_id(session_id)
- Args:
- session_id (str): The session ID to delete.
- **kwargs (Any): Additional keyword arguments.
- """
- logger.warning("Session deletion not supported in AgentCore Memory: %s", session_id)
+ def update_message(
+ self,
+ session_id: str,
+ agent_id: str,
+ session_message: SessionMessage,
+ **kwargs: Any,
+ ) -> None:
+ """Update an existing message."""
+ self._validate_session_id(session_id)
- def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
- """Create a new agent in the session.
+ # region SessionRepository interface implementation
+ def create_agent(
+ self, session_id: str, session_agent: SessionAgent, **kwargs: Any
+ ) -> SessionAgent:
+ """Create a new session or get the existing session in AgentCore Memory."""
+ logger.debug(
+ f"Creating agent: {session_agent.agent_id} in session: {session_id}"
+ )
+ self._validate_session_id(session_id)
+ self.agent_id = session_agent.agent_id
+ try:
+ if self.config.default_branch.name == "main":
+ metadata = {
+ "agent_id": StringValue.build(session_agent.agent_id),
+ "event_type": StringValue.build("session_agent"),
+ }
+ if self.config.metadata:
+ metadata.update(self.config.metadata)
- For AgentCore Memory, we don't need to explicitly create agents; we have Implicit Agent Existence
- The agent's existence is inferred from the presence of events/messages in the memory system,
- but we validate the session_id matches our config.
+ # Prepare event parameters
+ event_params = self._prepare_event_params(
+ payload=[
+ {"blob": json.dumps({"session_agent": session_agent.to_dict()})}
+ ],
+ metadata=metadata,
+ )
+ event = self.memory_session_manager.create_event(**event_params)
+ self._session_cache = session_agent
+ else:
+ event = self.create_or_fetch_agent_branch(session_id, session_agent)
- Args:
- session_id (str): The session ID to create the agent in.
- session_agent (SessionAgent): The agent to create.
- **kwargs (Any): Additional keyword arguments.
+ self._session_cache = session_agent
+ logger.info(
+ "Created session: %s with event: %s",
+ session_id,
+ event.get("eventId"),
+ )
+ logger.debug(
+ f"Successfully created session with event: {event.get('eventId')}"
+ )
+ return session_agent
+ except Exception as e:
+ logger.error("Failed to create session: %s", e)
+ raise SessionException(f"Failed to create session: {e}") from e
- Raises:
- SessionException: If session ID doesn't match configuration.
- """
- if session_id != self.config.session_id:
- raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
-
- event = self.memory_client.gmdp_client.create_event(
- memoryId=self.config.memory_id,
- actorId=self._get_full_agent_id(session_agent.agent_id),
- sessionId=self.session_id,
- payload=[
- {"blob": json.dumps(session_agent.to_dict())},
- ],
- eventTimestamp=self._get_monotonic_timestamp(),
- )
- logger.info(
- "Created agent: %s in session: %s with event %s",
- session_agent.agent_id,
- session_id,
- event.get("event", {}).get("eventId"),
+ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
+ """Read session data - always from the main branch."""
+ logger.debug(f"Reading session: {session_id}")
+
+ # Return cached session if available
+ if (
+ self._session_cache
+ and hasattr(self._session_cache, "session_id")
+ and self._session_cache.session_id == session_id
+ ):
+ logger.debug(f"Returning cached session: {session_id}")
+ return self._session_cache
+
+ session = self.read_agent(
+ session_id, agent_id=self.agent_id if self.agent_id else "default"
)
+ return session
- def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
- """Read agent data from AgentCore Memory events.
-
- We reconstruct the agent state from the conversation history.
-
- Args:
- session_id (str): The session ID to read from.
- agent_id (str): The agent ID to read.
- **kwargs (Any): Additional keyword arguments.
+ def read_agent(
+ self, session_id: str, agent_id: str, **kwargs: Any
+ ) -> Optional[SessionAgent]:
+ """Read agent data from AgentCore Memory events (uses branch filtering)."""
+ logger.debug(f"Reading agent: {agent_id} from session: {session_id}")
+ self._validate_session_id(session_id)
+ if self._session_cache:
+ logger.debug(f"Returning cached session: {session_id}")
+ return self._session_cache
- Returns:
- Optional[SessionAgent]: The agent if found, None otherwise.
- """
- if session_id != self.config.session_id:
- return None
try:
- events = self.memory_client.list_events(
- memory_id=self.config.memory_id,
- actor_id=self._get_full_agent_id(agent_id),
- session_id=session_id,
- max_results=1,
+ logger.debug(f"Building metadata filters for agent {agent_id}")
+ # Agent operations use branch filtering
+ filters = [
+ EventMetadataFilter.build_expression(
+ left_operand=LeftExpression.build(key="event_type"),
+ operator=OperatorType.EQUALS_TO,
+ right_operand=RightExpression.build("session_agent"),
+ ),
+ EventMetadataFilter.build_expression(
+ left_operand=LeftExpression.build(key="agent_id"),
+ operator=OperatorType.EQUALS_TO,
+ right_operand=RightExpression.build(agent_id),
+ ),
+ ]
+ list_params = self._prepare_list_params(
+ eventMetadata=filters, max_results=1
)
+ events = self.memory_session_manager.list_events(**list_params)
+ logger.debug(f"Found {len(events)} agent events for {agent_id}")
+
if not events:
+ logger.debug(f"No events found for agent {agent_id}")
return None
- agent_data = json.loads(events[0].get("payload", {})[0].get("blob"))
- return SessionAgent.from_dict(agent_data)
+ payload = events[-1].get("payload", [])
+ if payload and "blob" in payload[0]:
+ blob_data = json.loads(payload[0]["blob"])
+ # Extract SessionAgent from blob.
+ session_data = blob_data.get("session_agent")
+ agent_data = SessionAgent.from_dict(session_data)
+ logger.debug(f"Successfully read agent {agent_id}")
+ self._session_cache = agent_data
+ return agent_data
+
+ logger.debug(f"Agent {agent_id} not found - no valid payload")
+ return None
except Exception as e:
- logger.error("Failed to read agent %s", e)
+ logger.error("Failed to read agent %s: %s", agent_id, e)
return None
- def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
- """Update agent data.
-
- Args:
- session_id (str): The session ID containing the agent.
- session_agent (SessionAgent): The agent to update.
- **kwargs (Any): Additional keyword arguments.
-
- Raises:
- SessionException: If session ID doesn't match configuration.
- """
- agent_id = session_agent.agent_id
- previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
- if previous_agent is None:
- raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
-
- session_agent.created_at = previous_agent.created_at
- # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
- self.create_agent(session_id, session_agent)
-
- def create_message(
- self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any
- ) -> Optional[dict[str, Any]]:
- """Create a new message in AgentCore Memory.
-
- Args:
- session_id (str): The session ID to create the message in.
- agent_id (str): The agent ID associated with the message (only here for the interface.
- We use the actorId for AgentCore).
- session_message (SessionMessage): The message to create.
- **kwargs (Any): Additional keyword arguments.
-
- Returns:
- Optional[dict[str, Any]]: The created event data from AgentCore Memory.
-
- Raises:
- SessionException: If session ID doesn't match configuration or message creation fails.
-
- Note:
- The returned created message `event` looks like:
- ```python
- {
- "memoryId": "my-mem-id",
- "actorId": "user_1",
- "sessionId": "test_session_id",
- "eventId": "0000001752235548000#97f30a6b",
- "eventTimestamp": datetime.datetime(2025, 8, 18, 12, 45, 48, tzinfo=tzlocal()),
- "branch": {"name": "main"},
- }
- ```
- """
- if session_id != self.config.session_id:
- raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
-
+ def read_message(
+ self, session_id: str, agent_id: str, message_id: int, **kwargs: Any
+ ) -> Optional[SessionMessage]:
+ """Read a specific message by ID from AgentCore Memory."""
try:
- messages = AgentCoreMemoryConverter.message_to_payload(session_message)
- if not messages:
- return
-
- # Parse the original timestamp and use it as desired timestamp
- original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
- monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)
-
- if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]):
- event = self.memory_client.create_event(
- memory_id=self.config.memory_id,
- actor_id=self.config.actor_id,
- session_id=session_id,
- messages=messages,
- event_timestamp=monotonic_timestamp,
- )
- else:
- event = self.memory_client.gmdp_client.create_event(
- memoryId=self.config.memory_id,
- actorId=self.config.actor_id,
- sessionId=session_id,
- payload=[
- {"blob": json.dumps(messages[0])},
- ],
- eventTimestamp=monotonic_timestamp,
- )
- logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id)
- return event
+ result = self.memory_session_manager.get_event(
+ actor_id=self.config.actor_id,
+ session_id=session_id,
+ event_id=str(message_id),
+ )
+ return SessionMessage.from_dict(result) if result else None
except Exception as e:
- logger.error("Failed to create message in AgentCore Memory: %s", e)
- raise SessionException(f"Failed to create message: {e}") from e
-
- def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]:
- """Read a specific message by ID from AgentCore Memory.
-
- Args:
- session_id (str): The session ID to read from.
- agent_id (str): The agent ID associated with the message.
- message_id (int): The message ID to read.
- **kwargs (Any): Additional keyword arguments.
+ logger.error("Failed to read message: %s", e)
+ return None
- Returns:
- Optional[SessionMessage]: The message if found, None otherwise.
+ def save_turn_messages(
+ self, event: AfterInvocationEvent, **kwargs: Any
+ ) -> Optional[dict[str, Any]]:
+ """Save turn messages to MemorySessionManager with both SessionAgent and Session in blob."""
+ try:
+ logger.debug(f"Saving turn messages for agent: {event.agent.agent_id}")
+ # Filter messages based on configured message types
+ filtered_messages = []
+ role_map = {
+ "user": MessageRole.USER,
+ "assistant": MessageRole.ASSISTANT,
+ "tool": MessageRole.TOOL,
+ "other": MessageRole.OTHER,
+ }
+
+ for message in reversed(event.agent.messages):
+ role = message.get("role")
+ if role in self.config.message_types:
+ content = message.get("content", [{}])[0].get("text", "")
+ mapped_role = role_map.get(role, MessageRole.ASSISTANT)
+ if role == "user":
+ content = self.remove_user_context(content)
+ filtered_messages.append(
+ ConversationalMessage(content, mapped_role)
+ )
+ if role == "user":
+ break
+
+ logger.debug(f"Filtered {len(filtered_messages)} messages to save..")
+ if not filtered_messages:
+ return None
- Note:
- This should not be called as (as of now) only the `update_message` method calls this method and
- updating messages is not supported in AgentCore Memory.
- """
- result = self.memory_client.gmdp_client.get_event(
- memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, eventId=message_id
- )
- return SessionMessage.from_dict(result) if result else None
+ # Prepare enhanced metadata using agent properties
+ event_metadata = {
+ "agent_id": StringValue.build(event.agent.agent_id),
+ }
- def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
- """Update message data.
+ # Merge with config metadata
+ if self.config.metadata:
+ event_metadata.update(self.config.metadata)
- Note: AgentCore Memory doesn't support updating events,
- so this is primarily for validation and logging.
+ # Prepare event parameters
+ event_params = self._prepare_event_params(
+ payload=filtered_messages, metadata=event_metadata
+ )
+ # Add branch configuration if not "main"
+ event_params = {
+ "actor_id": self.config.actor_id,
+ "session_id": self.config.session_id,
+ "messages": filtered_messages,
+ "metadata": event_metadata,
+ "event_timestamp": self._get_monotonic_timestamp(),
+ }
+ if self.config.default_branch and self.config.default_branch.name != "main":
+ event_params["branch"] = (
+ self.config.default_branch.to_agentcore_format()
+ )
- Args:
- session_id (str): The session ID containing the message.
- agent_id (str): The agent ID associated with the message.
- session_message (SessionMessage): The message to update.
- **kwargs (Any): Additional keyword arguments.
+ return self.memory_session_manager.add_turns(**event_params)
- Raises:
- SessionException: If session ID doesn't match configuration.
- """
- if session_id != self.config.session_id:
- raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
-
- logger.debug(
- "Message update requested for message: %s (AgentCore Memory doesn't support updates)",
- {session_message.message_id},
- )
+ except Exception as e:
+ logger.error("Failed to save turn messages: %s", e)
+ raise SessionException(f"Failed to save turn messages: {e}") from e
def list_messages(
- self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
+ self,
+ session_id: str,
+ agent_id: str,
+ limit: Optional[int] = None,
+ offset: int = 0,
+ **kwargs: Any,
) -> list[SessionMessage]:
- """List messages for an agent from AgentCore Memory with pagination.
+ """List messages for an agent from AgentCore Memory with pagination."""
+ self._validate_session_id(session_id)
+ logger.debug(f"Listing messages for agent: {agent_id}, limit: {limit}")
- Args:
- session_id (str): The session ID to list messages from.
- agent_id (str): The agent ID to list messages for.
- limit (Optional[int], optional): Maximum number of messages to return. Defaults to None.
- offset (int, optional): Number of messages to skip. Defaults to 0.
- **kwargs (Any): Additional keyword arguments.
+ try:
+ max_results = (limit + offset) if limit else 100
- Returns:
- list[SessionMessage]: list of messages for the agent.
+ # Filter for non-session-state events (conversation messages)
+ message_filter = EventMetadataFilter.build_expression(
+ left_operand=LeftExpression.build(key="event_type"),
+ operator=OperatorType.NOT_EXISTS,
+ )
- Raises:
- SessionException: If session ID doesn't match configuration.
- """
- if session_id != self.config.session_id:
- raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")
+ # Always filter by agent_id to prevent cross-pollution
+ agent_filter = EventMetadataFilter.build_expression(
+ left_operand=LeftExpression.build(key="agent_id"),
+ operator=OperatorType.EQUALS_TO,
+ right_operand=RightExpression.build(agent_id),
+ )
+ filters = [message_filter, agent_filter]
- try:
- max_results = (limit + offset) if limit else 100
- events = self.memory_client.list_events(
- memory_id=self.config.memory_id,
- actor_id=self.config.actor_id,
- session_id=session_id,
+ list_params = self._prepare_list_params(
max_results=max_results,
+ eventMetadata=filters,
)
- messages = AgentCoreMemoryConverter.events_to_messages(events)
- if limit is not None:
- return messages[offset : offset + limit]
- else:
- return messages[offset:]
+
+ events = self.memory_session_manager.list_events(**list_params)
+ logger.debug(f"Found {len(events)} events")
+ return AgentCoreMemoryConverter.events_to_messages(events) if events else []
except Exception as e:
logger.error("Failed to list messages from AgentCore Memory: %s", e)
@@ -468,51 +566,76 @@ def list_messages(
# endregion SessionRepository interface implementation
# region RepositorySessionManager overrides
+ @staticmethod
+ def remove_user_context(text: str) -> str:
+ """Remove user context from text."""
+ return re.sub(
+ r".*?", "", text, flags=re.DOTALL
+ ).strip()
+
@override
- def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
- """Append a message to the agent's session using AgentCore's eventId as message_id.
+ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
+ """Agent hydration with branch support for multiple actors."""
+ try:
+ self.agent_id = agent.agent_id
+ # Use create_or_fetch_agent_branch for both main and custom branches
+ if not self.read_agent(self.config.session_id, agent.agent_id):
+ self.create_agent(
+ self.config.session_id, SessionAgent.from_agent(agent)
+ )
+ # Set agent state from cached session
+ agent.state = AgentState(
+ self._session_cache.state if self._session_cache else {}
+ )
- Args:
- message: Message to add to the agent in the session
- agent: Agent to append the message to
- **kwargs: Additional keyword arguments for future extensibility.
- """
- created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0))
- session_message = SessionMessage.from_message(message, created_message.get("eventId"))
- self._latest_agent_message[agent.agent_id] = session_message
+ # Load previous messages
+ prev_messages = self.list_messages(
+ self.config.session_id, agent.agent_id, limit=10
+ )
+ agent.messages = prev_messages if prev_messages else []
- def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
- """Retrieve customer LTM context before processing support query.
+ except Exception as e:
+ logger.error("Failed to initialize agent %s: %s", agent.agent_id, e)
+ raise
- Args:
- event (MessageAddedEvent): The message added event containing the agent and message data.
- """
+ def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
+ """Retrieve customer context from both short-term and long-term memory."""
messages = event.agent.messages
- if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]:
- return None
- if not self.config.retrieval_config:
- # Only retrieve LTM
- return None
+ if not messages or messages[-1].get("role") != "user":
+ return
+
+ logger.debug(f"Retrieving customer context for agent:")
+ # Skip if message contains tool results
+ last_content = messages[-1].get("content", [])
+ if last_content and any(
+ "toolResult" in str(content) for content in last_content
+ ):
+ return
user_query = messages[-1]["content"][0]["text"]
- def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConfig):
+ def retrieve_for_namespace(
+ namespace: str, retrieval_config: AgentCoreMemoryConfig
+ ):
"""Helper function to retrieve memories for a single namespace."""
resolved_namespace = namespace.format(
actorId=self.config.actor_id,
sessionId=self.config.session_id,
memoryStrategyId=retrieval_config.strategy_id or "",
)
-
- memories = self.memory_client.retrieve_memories(
- memory_id=self.config.memory_id,
- namespace=resolved_namespace,
+ memories = self.memory_session_manager.search_long_term_memories(
query=user_query,
+ namespace_prefix=resolved_namespace,
top_k=retrieval_config.top_k,
)
+
context_items = []
for memory in memories:
- if isinstance(memory, dict):
+ if hasattr(memory, "content") and hasattr(memory.content, "text"):
+ text = memory.content.text.strip()
+ if text:
+ all_context.append(text)
+ elif isinstance(memory, dict):
content = memory.get("content", {})
if isinstance(content, dict):
text = content.get("text", "").strip()
@@ -521,12 +644,17 @@ def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConf
return context_items
try:
- # Retrieve customer context from all namespaces in parallel
all_context = []
+ # Check if retrieval_config exists and is not empty
+ if not self.config.retrieval_config:
+ return
+ # Retrieve from long-term memory in parallel
with ThreadPoolExecutor() as executor:
future_to_namespace = {
- executor.submit(retrieve_for_namespace, namespace, retrieval_config): namespace
+ executor.submit(
+ retrieve_for_namespace, namespace, retrieval_config
+ ): namespace
for namespace, retrieval_config in self.config.retrieval_config.items()
}
for future in as_completed(future_to_namespace):
@@ -536,40 +664,41 @@ def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConf
except Exception as e:
# Continue processing other futures event if one fails rather than failing the entire operation
namespace = future_to_namespace[future]
- logger.error("Failed to retrieve memories for namespace %s: %s", namespace, e)
+ logger.error(
+ "Failed to retrieve memories for namespace %s: %s",
+ namespace,
+ e,
+ )
- # Inject customer context into the query
if all_context:
+ original_text = messages[-1]["content"][0]["text"]
context_text = "\n".join(all_context)
- ltm_msg: Message = {
- "role": "assistant",
- "content": [{"text": f"{context_text}"}],
- }
- event.agent.messages.append(ltm_msg)
- logger.info("Retrieved %s customer context items", len(all_context))
+ messages[-1]["content"][0][
+ "text"
+ ] = f"\n{context_text}\n\n\n{original_text}\n"
+ event.agent.messages[-1]["content"][0]["text"] = context_text
+ logger.info("Retrieved %d customer context items", len(all_context))
except Exception as e:
logger.error("Failed to retrieve customer context: %s", e)
@override
def register_hooks(self, registry: HookRegistry, **kwargs) -> None:
- """Register additional hooks.
+ """Register hooks.
Args:
registry (HookRegistry): The hook registry to register callbacks with.
**kwargs: Additional keyword arguments.
"""
- RepositorySessionManager.register_hooks(self, registry, **kwargs)
- registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event))
-
- @override
- def initialize(self, agent: "Agent", **kwargs: Any) -> None:
- if self.has_existing_agent:
- logger.warning(
- "An Agent already exists in session %s. We currently support one agent per session.", self.session_id
- )
- else:
- self.has_existing_agent = True
- RepositorySessionManager.initialize(self, agent, **kwargs)
-
- # endregion RepositorySessionManager overrides
+ # After the normal Agent initialization behavior, call the session initialize function to restore the agent
+ registry.add_callback(
+ AgentInitializedEvent, lambda event: self.initialize(event.agent)
+ )
+ # For each message appended to the Agents messages, store that message in the session
+ # After an agent was invoked, sync it with the session to capture any conversation manager state updates
+ registry.add_callback(
+ AfterInvocationEvent, lambda event: self.save_turn_messages(event)
+ )
+ registry.add_callback(
+ MessageAddedEvent, lambda event: self.retrieve_customer_context(event)
+ )