From b6d9c4416875a684896b37123752d3a65da1c704 Mon Sep 17 00:00:00 2001 From: Max Rabin Date: Mon, 24 Nov 2025 13:47:29 +0200 Subject: [PATCH] Upgrade Python syntax by using Pyupgrade, adding it to Ruff (UP). --- pyproject.toml | 3 +- src/bedrock_agentcore/identity/auth.py | 19 +- src/bedrock_agentcore/memory/client.py | 231 +++++++++--------- src/bedrock_agentcore/memory/constants.py | 14 +- src/bedrock_agentcore/memory/controlplane.py | 56 ++--- .../integrations/strands/bedrock_converter.py | 6 +- .../memory/integrations/strands/config.py | 7 +- .../integrations/strands/session_manager.py | 22 +- .../memory/models/DictWrapper.py | 4 +- .../memory/models/__init__.py | 14 +- .../memory/models/filters.py | 10 +- src/bedrock_agentcore/memory/session.py | 155 ++++++------ src/bedrock_agentcore/runtime/app.py | 22 +- src/bedrock_agentcore/runtime/context.py | 29 ++- src/bedrock_agentcore/runtime/utils.py | 2 +- src/bedrock_agentcore/services/identity.py | 29 +-- src/bedrock_agentcore/tools/browser_client.py | 60 ++--- .../tools/code_interpreter_client.py | 52 ++-- src/bedrock_agentcore/tools/config.py | 55 ++--- .../bedrock_agentcore/memory/test_session.py | 28 +-- tests/bedrock_agentcore/runtime/test_app.py | 2 +- tests/bedrock_agentcore/runtime/test_utils.py | 7 +- .../async/interactive_async_strands.py | 8 +- .../async/test_async_status_example.py | 6 +- tests_integ/memory/test_devex.py | 16 +- tests_integ/memory/test_memory_client.py | 12 +- tests_integ/runtime/base_test.py | 3 +- 27 files changed, 437 insertions(+), 435 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e25421..333e02d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ ignore_missing_imports = false [tool.ruff] line-length = 120 -include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] +include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"] exclude = ["**/*.md"] [tool.ruff.lint] @@ -79,6 +79,7 @@ select = [ "G", # logging format "I", # isort "LOG", # logging + "UP", # pyupgrade ] [tool.ruff.lint.per-file-ignores] diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index c8b08ba..08a3988 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -4,8 +4,9 @@ import contextvars import logging import os +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Any, Literal import boto3 @@ -22,14 +23,14 @@ def requires_access_token( *, provider_name: str, into: str = "access_token", - scopes: List[str], - on_auth_url: Optional[Callable[[str], Any]] = None, + scopes: list[str], + on_auth_url: Callable[[str], Any] | None = None, auth_flow: Literal["M2M", "USER_FEDERATION"], - callback_url: Optional[str] = None, + callback_url: str | None = None, force_authentication: bool = False, - token_poller: Optional[TokenPoller] = None, - custom_state: Optional[str] = None, - custom_parameters: Optional[Dict[str, str]] = None, + token_poller: TokenPoller | None = None, + custom_state: str | None = None, + custom_parameters: dict[str, str] | None = None, ) -> Callable: """Decorator that fetches an OAuth2 access token before calling the decorated function. @@ -151,7 +152,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]): +def _get_oauth2_callback_url(user_provided_oauth2_callback_url: str | None): if user_provided_oauth2_callback_url: return user_provided_oauth2_callback_url @@ -184,7 +185,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str: config = {} if config_path.exists(): try: - with open(config_path, "r", encoding="utf-8") as file: + with open(config_path, encoding="utf-8") as file: config = json.load(file) or {} except Exception: print("Could not find existing workload identity and user id") diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index 29280bf..52cfd64 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -13,8 +13,9 @@ import time import uuid import warnings +from collections.abc import Callable from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -60,7 +61,7 @@ class MemoryClient: "list_memory_strategies", } - def __init__(self, region_name: Optional[str] = None): + def __init__(self, region_name: str | None = None): """Initialize the Memory client.""" self.region_name = region_name or boto3.Session().region_name or "us-west-2" @@ -119,11 +120,11 @@ def __getattr__(self, name: str): def create_memory( self, name: str, - strategies: Optional[List[Dict[str, Any]]] = None, - description: Optional[str] = None, + strategies: list[dict[str, Any]] | None = None, + description: str | None = None, event_expiry_days: int = 90, - memory_execution_role_arn: Optional[str] = None, - ) -> Dict[str, Any]: + memory_execution_role_arn: str | None = None, + ) -> dict[str, Any]: """Create a memory with simplified configuration.""" if strategies is None: strategies = [] @@ -160,11 +161,11 @@ def create_memory( def create_or_get_memory( self, name: str, - strategies: Optional[List[Dict[str, Any]]] = None, - description: Optional[str] = None, + strategies: list[dict[str, Any]] | None = None, + description: str | None = None, event_expiry_days: int = 90, - memory_execution_role_arn: Optional[str] = None, - ) -> Dict[str, Any]: + memory_execution_role_arn: str | None = None, + ) -> dict[str, Any]: """Create a memory resource or fetch the existing memory details if it already exists. Returns: @@ -194,13 +195,13 @@ def create_or_get_memory( def create_memory_and_wait( self, name: str, - strategies: List[Dict[str, Any]], - description: Optional[str] = None, + strategies: list[dict[str, Any]], + description: str | None = None, event_expiry_days: int = 90, - memory_execution_role_arn: Optional[str] = None, + memory_execution_role_arn: str | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create a memory and wait for it to become ACTIVE. This method creates a memory and polls until it reaches ACTIVE status, @@ -253,7 +254,7 @@ def create_memory_and_wait( # Get failure reason if available response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name failure_reason = response["memory"].get("failureReason", "Unknown") - raise RuntimeError("Memory creation failed: %s" % failure_reason) + raise RuntimeError(f"Memory creation failed: {failure_reason}") else: logger.debug("Memory status: %s (%d seconds elapsed)", status, elapsed) @@ -263,11 +264,11 @@ def create_memory_and_wait( time.sleep(poll_interval) - raise TimeoutError("Memory %s did not become ACTIVE within %d seconds" % (memory_id, max_wait)) + raise TimeoutError(f"Memory {memory_id} did not become ACTIVE within {max_wait} seconds") def retrieve_memories( - self, memory_id: str, namespace: str, query: str, actor_id: Optional[str] = None, top_k: int = 3 - ) -> List[Dict[str, Any]]: + self, memory_id: str, namespace: str, query: str, actor_id: str | None = None, top_k: int = 3 + ) -> list[dict[str, Any]]: """Retrieve relevant memories from a namespace. Note: Wildcards (*) are NOT supported in namespaces. You must provide the @@ -332,10 +333,10 @@ def create_event( memory_id: str, actor_id: str, session_id: str, - messages: List[Tuple[str, str]], - event_timestamp: Optional[datetime] = None, - branch: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + messages: list[tuple[str, str]], + event_timestamp: datetime | None = None, + branch: dict[str, str] | None = None, + ) -> dict[str, Any]: """Save an event of an agent interaction or conversation with a user. This is the basis of short-term memory. If you configured your Memory resource @@ -410,7 +411,7 @@ def create_event( role_enum = MessageRole(role.upper()) except ValueError as err: raise ValueError( - "Invalid role '%s'. Must be one of: %s" % (role, ", ".join([r.value for r in MessageRole])) + "Invalid role '{}'. Must be one of: {}".format(role, ", ".join([r.value for r in MessageRole])) ) from err payload.append({"conversational": {"content": {"text": text}, "role": role_enum.value}}) @@ -447,9 +448,9 @@ def create_blob_event( actor_id: str, session_id: str, blob_data: Any, - event_timestamp: Optional[datetime] = None, - branch: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + event_timestamp: datetime | None = None, + branch: dict[str, str] | None = None, + ) -> dict[str, Any]: """Save a blob event to AgentCore Memory. Args: @@ -505,10 +506,10 @@ def save_conversation( memory_id: str, actor_id: str, session_id: str, - messages: List[Tuple[str, str]], - event_timestamp: Optional[datetime] = None, - branch: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + messages: list[tuple[str, str]], + event_timestamp: datetime | None = None, + branch: dict[str, str] | None = None, + ) -> dict[str, Any]: """DEPRECATED: Use create_event() instead. Args: @@ -564,7 +565,7 @@ def save_conversation( role_enum = MessageRole(role.upper()) except ValueError as err: raise ValueError( - "Invalid role '%s'. Must be one of: %s" % (role, ", ".join([r.value for r in MessageRole])) + "Invalid role '{}'. Must be one of: {}".format(role, ", ".join([r.value for r in MessageRole])) ) from err payload.append({"conversational": {"content": {"text": text}, "role": role_enum.value}}) @@ -603,8 +604,8 @@ def save_turn( session_id: str, user_input: str, agent_response: str, - event_timestamp: Optional[datetime] = None, - ) -> Dict[str, Any]: + event_timestamp: datetime | None = None, + ) -> dict[str, Any]: """DEPRECATED: Use save_conversation() for more flexibility. This method will be removed in v1.0.0. @@ -633,11 +634,11 @@ def process_turn( session_id: str, user_input: str, agent_response: str, - event_timestamp: Optional[datetime] = None, - retrieval_namespace: Optional[str] = None, - retrieval_query: Optional[str] = None, + event_timestamp: datetime | None = None, + retrieval_namespace: str | None = None, + retrieval_query: str | None = None, top_k: int = 3, - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + ) -> tuple[list[dict[str, Any]], dict[str, Any]]: """DEPRECATED: Use retrieve_memories() and save_conversation() separately. This method will be removed in v1.0.0. @@ -674,12 +675,12 @@ def process_turn_with_llm( actor_id: str, session_id: str, user_input: str, - llm_callback: Callable[[str, List[Dict[str, Any]]], str], - retrieval_namespace: Optional[str] = None, - retrieval_query: Optional[str] = None, + llm_callback: Callable[[str, list[dict[str, Any]]], str], + retrieval_namespace: str | None = None, + retrieval_query: str | None = None, top_k: int = 3, - event_timestamp: Optional[datetime] = None, - ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: + event_timestamp: datetime | None = None, + ) -> tuple[list[dict[str, Any]], str, dict[str, Any]]: r"""Complete conversation turn with LLM callback integration. This method combines memory retrieval, LLM invocation, and response storage @@ -760,11 +761,11 @@ def list_events( memory_id: str, actor_id: str, session_id: str, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_parent_branches: bool = False, max_results: int = 100, include_payload: bool = True, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """List all events in a session with pagination support. This method provides direct access to the raw events API, allowing developers @@ -830,7 +831,7 @@ def list_events( logger.error("Failed to list events: %s", e) raise - def list_branches(self, memory_id: str, actor_id: str, session_id: str) -> List[Dict[str, Any]]: + def list_branches(self, memory_id: str, actor_id: str, session_id: str) -> list[dict[str, Any]]: """List all branches in a session. This method handles pagination automatically and provides a structured view @@ -908,10 +909,10 @@ def list_branch_events( memory_id: str, actor_id: str, session_id: str, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_parent_branches: bool = False, max_results: int = 100, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """List events in a specific branch. This method provides complex filtering and pagination that would require @@ -966,7 +967,7 @@ def list_branch_events( logger.error("Failed to list branch events: %s", e) raise - def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str) -> Dict[str, Any]: + def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str) -> dict[str, Any]: """Get a tree structure of the conversation with all branches. This method transforms a flat list of events into a hierarchical tree structure, @@ -1035,7 +1036,7 @@ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str) def merge_branch_context( self, memory_id: str, actor_id: str, session_id: str, branch_name: str, include_parent: bool = True - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Get all messages from a branch for context building. Args: @@ -1085,10 +1086,10 @@ def get_last_k_turns( actor_id: str, session_id: str, k: int = 5, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_branches: bool = False, max_results: int = 100, - ) -> List[List[Dict[str, Any]]]: + ) -> list[list[dict[str, Any]]]: """Get the last K conversation turns. A "turn" typically consists of a user message followed by assistant response(s). @@ -1147,9 +1148,9 @@ def fork_conversation( session_id: str, root_event_id: str, branch_name: str, - new_messages: List[Tuple[str, str]], - event_timestamp: Optional[datetime] = None, - ) -> Dict[str, Any]: + new_messages: list[tuple[str, str]], + event_timestamp: datetime | None = None, + ) -> dict[str, Any]: """Fork a conversation from a specific event to create a new branch.""" try: branch = {"rootEventId": root_event_id, "name": branch_name} @@ -1170,7 +1171,7 @@ def fork_conversation( logger.error("Failed to fork conversation: %s", e) raise - def get_memory_strategies(self, memory_id: str) -> List[Dict[str, Any]]: + def get_memory_strategies(self, memory_id: str) -> list[dict[str, Any]]: """Get all strategies for a memory.""" try: response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name @@ -1212,7 +1213,7 @@ def get_memory_status(self, memory_id: str) -> str: logger.error("Failed to get memory status: %s", e) raise - def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: + def list_memories(self, max_results: int = 100) -> list[dict[str, Any]]: """List all memories for the account.""" try: # Ensure max_results doesn't exceed API limit per request @@ -1247,7 +1248,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: logger.error("Failed to list memories: %s", e) raise - def delete_memory(self, memory_id: str) -> Dict[str, Any]: + def delete_memory(self, memory_id: str) -> dict[str, Any]: """Delete a memory resource.""" try: response = self.gmcp_client.delete_memory( @@ -1259,7 +1260,7 @@ def delete_memory(self, memory_id: str) -> Dict[str, Any]: logger.error("Failed to delete memory: %s", e) raise - def delete_memory_and_wait(self, memory_id: str, max_wait: int = 300, poll_interval: int = 10) -> Dict[str, Any]: + def delete_memory_and_wait(self, memory_id: str, max_wait: int = 300, poll_interval: int = 10) -> dict[str, Any]: """Delete a memory and wait for deletion to complete. This method deletes a memory and polls until it's fully deleted, @@ -1299,20 +1300,20 @@ def delete_memory_and_wait(self, memory_id: str, max_wait: int = 300, poll_inter time.sleep(poll_interval) - raise TimeoutError("Memory %s was not deleted within %d seconds" % (memory_id, max_wait)) + raise TimeoutError(f"Memory {memory_id} was not deleted within {max_wait} seconds") def add_semantic_strategy( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - ) -> Dict[str, Any]: + description: str | None = None, + namespaces: list[str] | None = None, + ) -> dict[str, Any]: """Add a semantic memory strategy. Note: Configuration is no longer provided for built-in strategies as per API changes. """ - strategy: Dict = { + strategy: dict = { StrategyType.SEMANTIC.value: { "name": name, } @@ -1329,11 +1330,11 @@ def add_semantic_strategy_and_wait( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, + description: str | None = None, + namespaces: list[str] | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Add a semantic strategy and wait for memory to return to ACTIVE state. This addresses the issue where adding a strategy puts the memory into @@ -1349,14 +1350,14 @@ def add_summary_strategy( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - ) -> Dict[str, Any]: + description: str | None = None, + namespaces: list[str] | None = None, + ) -> dict[str, Any]: """Add a summary memory strategy. Note: Configuration is no longer provided for built-in strategies as per API changes. """ - strategy: Dict = { + strategy: dict = { StrategyType.SUMMARY.value: { "name": name, } @@ -1373,11 +1374,11 @@ def add_summary_strategy_and_wait( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, + description: str | None = None, + namespaces: list[str] | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Add a summary strategy and wait for memory to return to ACTIVE state.""" self.add_summary_strategy(memory_id, name, description, namespaces) return self._wait_for_memory_active(memory_id, max_wait, poll_interval) @@ -1386,14 +1387,14 @@ def add_user_preference_strategy( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - ) -> Dict[str, Any]: + description: str | None = None, + namespaces: list[str] | None = None, + ) -> dict[str, Any]: """Add a user preference memory strategy. Note: Configuration is no longer provided for built-in strategies as per API changes. """ - strategy: Dict = { + strategy: dict = { StrategyType.USER_PREFERENCE.value: { "name": name, } @@ -1410,11 +1411,11 @@ def add_user_preference_strategy_and_wait( self, memory_id: str, name: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, + description: str | None = None, + namespaces: list[str] | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Add a user preference strategy and wait for memory to return to ACTIVE state.""" self.add_user_preference_strategy(memory_id, name, description, namespaces) return self._wait_for_memory_active(memory_id, max_wait, poll_interval) @@ -1423,11 +1424,11 @@ def add_custom_semantic_strategy( self, memory_id: str, name: str, - extraction_config: Dict[str, Any], - consolidation_config: Dict[str, Any], - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - ) -> Dict[str, Any]: + extraction_config: dict[str, Any], + consolidation_config: dict[str, Any], + description: str | None = None, + namespaces: list[str] | None = None, + ) -> dict[str, Any]: """Add a custom semantic strategy with prompts. Args: @@ -1469,13 +1470,13 @@ def add_custom_semantic_strategy_and_wait( self, memory_id: str, name: str, - extraction_config: Dict[str, Any], - consolidation_config: Dict[str, Any], - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, + extraction_config: dict[str, Any], + consolidation_config: dict[str, Any], + description: str | None = None, + namespaces: list[str] | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Add a custom semantic strategy and wait for memory to return to ACTIVE state.""" self.add_custom_semantic_strategy( memory_id, name, extraction_config, consolidation_config, description, namespaces @@ -1486,12 +1487,12 @@ def modify_strategy( self, memory_id: str, strategy_id: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - configuration: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + description: str | None = None, + namespaces: list[str] | None = None, + configuration: dict[str, Any] | None = None, + ) -> dict[str, Any]: """Modify a strategy with full control over configuration.""" - modify_config: Dict = {"memoryStrategyId": strategy_id} # Using old field name for input + modify_config: dict = {"memoryStrategyId": strategy_id} # Using old field name for input if description is not None: modify_config["description"] = description @@ -1502,17 +1503,17 @@ def modify_strategy( return self.update_memory_strategies(memory_id=memory_id, modify_strategies=[modify_config]) - def delete_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]: + def delete_strategy(self, memory_id: str, strategy_id: str) -> dict[str, Any]: """Delete a strategy from a memory.""" return self.update_memory_strategies(memory_id=memory_id, delete_strategy_ids=[strategy_id]) def update_memory_strategies( self, memory_id: str, - add_strategies: Optional[List[Dict[str, Any]]] = None, - modify_strategies: Optional[List[Dict[str, Any]]] = None, - delete_strategy_ids: Optional[List[str]] = None, - ) -> Dict[str, Any]: + add_strategies: list[dict[str, Any]] | None = None, + modify_strategies: list[dict[str, Any]] | None = None, + delete_strategy_ids: list[str] | None = None, + ) -> dict[str, Any]: """Update memory strategies - add, modify, or delete.""" try: memory_strategies = {} @@ -1534,7 +1535,7 @@ def update_memory_strategies( strategy_info = strategy_map.get(strategy_id) if not strategy_info: - raise ValueError("Strategy %s not found in memory %s" % (strategy_id, memory_id)) + raise ValueError(f"Strategy {strategy_id} not found in memory {memory_id}") strategy_type = strategy_info["memoryStrategyType"] # Using normalized field override_type = strategy_info.get("configuration", {}).get("type") @@ -1575,12 +1576,12 @@ def update_memory_strategies( def update_memory_strategies_and_wait( self, memory_id: str, - add_strategies: Optional[List[Dict[str, Any]]] = None, - modify_strategies: Optional[List[Dict[str, Any]]] = None, - delete_strategy_ids: Optional[List[str]] = None, + add_strategies: list[dict[str, Any]] | None = None, + modify_strategies: list[dict[str, Any]] | None = None, + delete_strategy_ids: list[str] | None = None, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Update memory strategies and wait for memory to return to ACTIVE state. This method handles the temporary CREATING state that occurs when @@ -1663,7 +1664,7 @@ def wait_for_memories( logger.info("Note: Encountered %d service errors during polling", service_errors) return False - def add_strategy(self, memory_id: str, strategy: Dict[str, Any]) -> Dict[str, Any]: + def add_strategy(self, memory_id: str, strategy: dict[str, Any]) -> dict[str, Any]: """Add a strategy to a memory (without waiting). WARNING: After adding a strategy, the memory enters CREATING state temporarily. @@ -1686,7 +1687,7 @@ def add_strategy(self, memory_id: str, strategy: Dict[str, Any]) -> Dict[str, An # Private methods - def _normalize_memory_response(self, memory: Dict[str, Any]) -> Dict[str, Any]: + def _normalize_memory_response(self, memory: dict[str, Any]) -> dict[str, Any]: """Normalize memory response to include both old and new field names. The API returns new field names but SDK users might expect old ones. @@ -1728,11 +1729,11 @@ def _normalize_memory_response(self, memory: Dict[str, Any]) -> Dict[str, Any]: return memory - def _add_strategy(self, memory_id: str, strategy: Dict[str, Any]) -> Dict[str, Any]: + def _add_strategy(self, memory_id: str, strategy: dict[str, Any]) -> dict[str, Any]: """Internal method to add a single strategy.""" return self.update_memory_strategies(memory_id=memory_id, add_strategies=[strategy]) - def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: int) -> Dict[str, Any]: + def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: int) -> dict[str, Any]: """Wait for memory to return to ACTIVE state after strategy update.""" logger.info("Waiting for memory %s to return to ACTIVE state...", memory_id) @@ -1751,7 +1752,7 @@ def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: elif status == MemoryStatus.FAILED.value: response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name failure_reason = response["memory"].get("failureReason", "Unknown") - raise RuntimeError("Memory update failed: %s" % failure_reason) + raise RuntimeError(f"Memory update failed: {failure_reason}") else: logger.debug("Memory status: %s (%d seconds elapsed)", status, elapsed) @@ -1761,9 +1762,9 @@ def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: time.sleep(poll_interval) - raise TimeoutError("Memory %s did not return to ACTIVE state within %d seconds" % (memory_id, max_wait)) + raise TimeoutError(f"Memory {memory_id} did not return to ACTIVE state within {max_wait} seconds") - def _add_default_namespaces(self, strategies: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _add_default_namespaces(self, strategies: list[dict[str, Any]]) -> list[dict[str, Any]]: """Add default namespaces to strategies that don't have them.""" processed = [] @@ -1794,7 +1795,7 @@ def _validate_namespace(self, namespace: str) -> bool: return True - def _validate_strategy_config(self, strategy: Dict[str, Any], strategy_type: str) -> None: + def _validate_strategy_config(self, strategy: dict[str, Any], strategy_type: str) -> None: """Validate strategy configuration parameters.""" strategy_config = strategy[strategy_type] @@ -1803,8 +1804,8 @@ def _validate_strategy_config(self, strategy: Dict[str, Any], strategy_type: str self._validate_namespace(namespace) def _wrap_configuration( - self, config: Dict[str, Any], strategy_type: str, override_type: Optional[str] = None - ) -> Dict[str, Any]: + self, config: dict[str, Any], strategy_type: str, override_type: str | None = None + ) -> dict[str, Any]: """Wrap configuration based on strategy type.""" wrapped_config = {} diff --git a/src/bedrock_agentcore/memory/constants.py b/src/bedrock_agentcore/memory/constants.py index 52ca16b..00e99f0 100644 --- a/src/bedrock_agentcore/memory/constants.py +++ b/src/bedrock_agentcore/memory/constants.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -69,7 +69,7 @@ class MessageRole(Enum): # Default namespaces for each strategy type -DEFAULT_NAMESPACES: Dict[StrategyType, List[str]] = { +DEFAULT_NAMESPACES: dict[StrategyType, list[str]] = { StrategyType.SEMANTIC: ["/actor/{actorId}/strategy/{strategyId}/{sessionId}"], StrategyType.SUMMARY: ["/actor/{actorId}/strategy/{strategyId}/{sessionId}"], StrategyType.USER_PREFERENCE: ["/actor/{actorId}/strategy/{strategyId}"], @@ -78,17 +78,17 @@ class MessageRole(Enum): # Configuration wrapper keys for update operations # These are still needed for wrapping configurations during updates -EXTRACTION_WRAPPER_KEYS: Dict[MemoryStrategyTypeEnum, str] = { +EXTRACTION_WRAPPER_KEYS: dict[MemoryStrategyTypeEnum, str] = { MemoryStrategyTypeEnum.SEMANTIC: "semanticExtractionConfiguration", MemoryStrategyTypeEnum.USER_PREFERENCE: "userPreferenceExtractionConfiguration", } -CUSTOM_EXTRACTION_WRAPPER_KEYS: Dict[OverrideType, str] = { +CUSTOM_EXTRACTION_WRAPPER_KEYS: dict[OverrideType, str] = { OverrideType.SEMANTIC_OVERRIDE: "semanticExtractionOverride", OverrideType.USER_PREFERENCE_OVERRIDE: "userPreferenceExtractionOverride", } -CUSTOM_CONSOLIDATION_WRAPPER_KEYS: Dict[OverrideType, str] = { +CUSTOM_CONSOLIDATION_WRAPPER_KEYS: dict[OverrideType, str] = { OverrideType.SEMANTIC_OVERRIDE: "semanticConsolidationOverride", OverrideType.SUMMARY_OVERRIDE: "summaryConsolidationOverride", OverrideType.USER_PREFERENCE_OVERRIDE: "userPreferenceConsolidationOverride", @@ -149,5 +149,5 @@ class RetrievalConfig(BaseModel): top_k: int = Field(default=10, gt=1, le=100) relevance_score: float = Field(default=0.0, ge=0.0, le=1.0) - strategy_id: Optional[str] = None - retrieval_query: Optional[str] = None + strategy_id: str | None = None + retrieval_query: str | None = None diff --git a/src/bedrock_agentcore/memory/controlplane.py b/src/bedrock_agentcore/memory/controlplane.py index 25e75d0..a37a072 100644 --- a/src/bedrock_agentcore/memory/controlplane.py +++ b/src/bedrock_agentcore/memory/controlplane.py @@ -8,7 +8,7 @@ import os import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -48,13 +48,13 @@ def create_memory( self, name: str, event_expiry_days: int = 90, - description: Optional[str] = None, - memory_execution_role_arn: Optional[str] = None, - strategies: Optional[List[Dict[str, Any]]] = None, + description: str | None = None, + memory_execution_role_arn: str | None = None, + strategies: list[dict[str, Any]] | None = None, wait_for_active: bool = False, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create a memory resource with optional strategies. Args: @@ -101,7 +101,7 @@ def create_memory( logger.error("Failed to create memory: %s", e) raise - def get_memory(self, memory_id: str, include_strategies: bool = True) -> Dict[str, Any]: + def get_memory(self, memory_id: str, include_strategies: bool = True) -> dict[str, Any]: """Get a memory resource by ID. Args: @@ -129,7 +129,7 @@ def get_memory(self, memory_id: str, include_strategies: bool = True) -> Dict[st logger.error("Failed to get memory: %s", e) raise - def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: + def list_memories(self, max_results: int = 100) -> list[dict[str, Any]]: """List all memories for the account with pagination support. Args: @@ -168,16 +168,16 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: def update_memory( self, memory_id: str, - description: Optional[str] = None, - event_expiry_days: Optional[int] = None, - memory_execution_role_arn: Optional[str] = None, - add_strategies: Optional[List[Dict[str, Any]]] = None, - modify_strategies: Optional[List[Dict[str, Any]]] = None, - delete_strategy_ids: Optional[List[str]] = None, + description: str | None = None, + event_expiry_days: int | None = None, + memory_execution_role_arn: str | None = None, + add_strategies: list[dict[str, Any]] | None = None, + modify_strategies: list[dict[str, Any]] | None = None, + delete_strategy_ids: list[str] | None = None, wait_for_active: bool = False, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Update a memory resource properties and/or strategies. Args: @@ -195,7 +195,7 @@ def update_memory( Returns: Updated memory object """ - params: Dict = { + params: dict = { "memoryId": memory_id, "clientToken": str(uuid.uuid4()), } @@ -248,7 +248,7 @@ def delete_memory( wait_for_strategies: bool = False, # Changed default to False max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Delete a memory resource. Args: @@ -320,11 +320,11 @@ def delete_memory( def add_strategy( self, memory_id: str, - strategy: Dict[str, Any], + strategy: dict[str, Any], wait_for_active: bool = False, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Add a strategy to a memory resource. Args: @@ -373,7 +373,7 @@ def add_strategy( return memory - def get_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]: + def get_strategy(self, memory_id: str, strategy_id: str) -> dict[str, Any]: """Get a specific strategy from a memory resource. Args: @@ -401,13 +401,13 @@ def update_strategy( self, memory_id: str, strategy_id: str, - description: Optional[str] = None, - namespaces: Optional[List[str]] = None, - configuration: Optional[Dict[str, Any]] = None, + description: str | None = None, + namespaces: list[str] | None = None, + configuration: dict[str, Any] | None = None, wait_for_active: bool = False, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Update a strategy in a memory resource. Args: @@ -424,7 +424,7 @@ def update_strategy( Updated memory object """ # Note: API expects memoryStrategyId for input but returns strategyId in response - modify_config: Dict = {"memoryStrategyId": strategy_id} + modify_config: dict = {"memoryStrategyId": strategy_id} if description is not None: modify_config["description"] = description @@ -455,7 +455,7 @@ def remove_strategy( wait_for_active: bool = False, max_wait: int = 300, poll_interval: int = 10, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Remove a strategy from a memory resource. Args: @@ -480,7 +480,7 @@ def remove_strategy( # ==================== HELPER METHODS ==================== - def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: int) -> Dict[str, Any]: + def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: int) -> dict[str, Any]: """Wait for memory to return to ACTIVE state.""" logger.info("Waiting for memory %s to become ACTIVE...", memory_id) return self._wait_for_status( @@ -489,7 +489,7 @@ def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: def _wait_for_strategy_active( self, memory_id: str, strategy_id: str, max_wait: int, poll_interval: int - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Wait for specific memory strategy to become ACTIVE.""" logger.info("Waiting for strategy %s to become ACTIVE (max wait: %d seconds)...", strategy_id, max_wait) @@ -536,7 +536,7 @@ def _wait_for_strategy_active( def _wait_for_status( self, memory_id: str, target_status: str, max_wait: int, poll_interval: int, check_strategies: bool = True - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Generic method to wait for a memory to reach a specific status. Args: diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 2098e84..f582b8b 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Tuple +from typing import Any from strands.types.session import SessionMessage @@ -15,7 +15,7 @@ class AgentCoreMemoryConverter: """Handles conversion between Strands and Bedrock AgentCore Memory formats.""" @staticmethod - def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]: + def message_to_payload(session_message: SessionMessage) -> list[tuple[str, str]]: """Convert a SessionMessage to Bedrock AgentCore Memory message format. Args: @@ -65,7 +65,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: elif "blob" in payload_item: try: blob_data = json.loads(payload_item["blob"]) - if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: + if isinstance(blob_data, tuple | list) and len(blob_data) == 2: try: messages.append(SessionMessage.from_dict(json.loads(blob_data[0]))) except (json.JSONDecodeError, ValueError): diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index d2d5cef..5730533 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -1,6 +1,5 @@ """Configuration for AgentCore Memory Session Manager.""" -from typing import Dict, Optional from pydantic import BaseModel, Field @@ -17,8 +16,8 @@ class RetrievalConfig(BaseModel): top_k: int = Field(default=10, gt=0, le=1000) relevance_score: float = Field(default=0.2, ge=0.0, le=1.0) - strategy_id: Optional[str] = None - initialization_query: Optional[str] = None + strategy_id: str | None = None + initialization_query: str | None = None class AgentCoreMemoryConfig(BaseModel): @@ -34,4 +33,4 @@ class AgentCoreMemoryConfig(BaseModel): memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1) - retrieval_config: Optional[Dict[str, RetrievalConfig]] = None + retrieval_config: dict[str, RetrievalConfig] | None = None diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index d77db53..da1bb82 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -5,7 +5,7 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import boto3 from botocore.config import Config as BotocoreConfig @@ -50,10 +50,10 @@ class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository) # Class-level timestamp tracking for monotonic ordering _timestamp_lock = threading.Lock() - _last_timestamp: Optional[datetime] = None + _last_timestamp: datetime | None = None @classmethod - def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None) -> datetime: + def _get_monotonic_timestamp(cls, desired_timestamp: datetime | None = None) -> datetime: """Get a monotonically increasing timestamp. Args: @@ -83,9 +83,9 @@ def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None) def __init__( self, agentcore_memory_config: AgentCoreMemoryConfig, - region_name: Optional[str] = None, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + region_name: str | None = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, **kwargs: Any, ): """Initialize AgentCoreMemorySessionManager with Bedrock AgentCore Memory. @@ -188,7 +188,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId")) return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data. AgentCore Memory does not have a `get_session` method. @@ -263,7 +263,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A event.get("event", {}).get("eventId"), ) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data from AgentCore Memory events. We reconstruct the agent state from the conversation history. @@ -317,7 +317,7 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A def create_message( self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: """Create a new message in AgentCore Memory. Args: @@ -382,7 +382,7 @@ def create_message( logger.error("Failed to create message in AgentCore Memory: %s", e) raise SessionException(f"Failed to create message: {e}") from e - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read a specific message by ID from AgentCore Memory. Args: @@ -427,7 +427,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio ) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List messages for an agent from AgentCore Memory with pagination. diff --git a/src/bedrock_agentcore/memory/models/DictWrapper.py b/src/bedrock_agentcore/memory/models/DictWrapper.py index 75dadba..1013425 100644 --- a/src/bedrock_agentcore/memory/models/DictWrapper.py +++ b/src/bedrock_agentcore/memory/models/DictWrapper.py @@ -1,12 +1,12 @@ """Dictionary wrapper module for bedrock-agentcore memory models.""" -from typing import Any, Dict +from typing import Any class DictWrapper: """A wrapper class that provides dictionary-like access to data.""" - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): """Initialize the DictWrapper with data. Args: diff --git a/src/bedrock_agentcore/memory/models/__init__.py b/src/bedrock_agentcore/memory/models/__init__.py index 0213c5f..885e7ab 100644 --- a/src/bedrock_agentcore/memory/models/__init__.py +++ b/src/bedrock_agentcore/memory/models/__init__.py @@ -1,6 +1,6 @@ """Module containing all the model classes.""" -from typing import Any, Dict +from typing import Any from .DictWrapper import DictWrapper from .filters import ( @@ -17,7 +17,7 @@ class ActorSummary(DictWrapper): """A class representing an actor summary.""" - def __init__(self, actor_summary: Dict[str, Any]): + def __init__(self, actor_summary: dict[str, Any]): """Initialize an ActorSummary instance. Args: @@ -29,7 +29,7 @@ def __init__(self, actor_summary: Dict[str, Any]): class Branch(DictWrapper): """A class representing a branch.""" - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): """Initialize a Branch instance. Args: @@ -41,7 +41,7 @@ def __init__(self, data: Dict[str, Any]): class Event(DictWrapper): """A class representing an event.""" - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): """Initialize an Event instance. Args: @@ -53,7 +53,7 @@ def __init__(self, data: Dict[str, Any]): class EventMessage(DictWrapper): """A class representing an event message.""" - def __init__(self, event_message: Dict[str, Any]): + def __init__(self, event_message: dict[str, Any]): """Initialize an EventMessage instance. Args: @@ -65,7 +65,7 @@ def __init__(self, event_message: Dict[str, Any]): class MemoryRecord(DictWrapper): """A class representing a memory record.""" - def __init__(self, memory_record: Dict[str, Any]): + def __init__(self, memory_record: dict[str, Any]): """Initialize a MemoryRecord instance. Args: @@ -77,7 +77,7 @@ def __init__(self, memory_record: Dict[str, Any]): class SessionSummary(DictWrapper): """A class representing a session summary.""" - def __init__(self, session_summary: Dict[str, Any]): + def __init__(self, session_summary: dict[str, Any]): """Initialize a SessionSummary instance. Args: diff --git a/src/bedrock_agentcore/memory/models/filters.py b/src/bedrock_agentcore/memory/models/filters.py index 9ab25f7..b6d2389 100644 --- a/src/bedrock_agentcore/memory/models/filters.py +++ b/src/bedrock_agentcore/memory/models/filters.py @@ -1,7 +1,7 @@ """Event metadata filter models for querying events based on metadata.""" from enum import Enum -from typing import Optional, TypedDict, Union +from typing import TypedDict class StringValue(TypedDict): @@ -15,7 +15,7 @@ def build(value: str) -> "StringValue": return {"stringValue": value} -MetadataValue = Union[StringValue] +MetadataValue = StringValue """ Union type representing metadata values. @@ -23,7 +23,7 @@ def build(value: str) -> "StringValue": - StringValue: {"stringValue": str} - String metadata value """ -MetadataKey = Union[str] +MetadataKey = str """ Union type representing metadata key. """ @@ -80,12 +80,12 @@ class EventMetadataFilter(TypedDict): left: LeftExpression operator: OperatorType - right: Optional[RightExpression] + right: RightExpression | None def build_expression( left_operand: LeftExpression, operator: OperatorType, - right_operand: Optional[RightExpression] = None, + right_operand: RightExpression | None = None, ) -> "EventMetadataFilter": """Build the required event metadata filter expression. diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index 14d688d..5139deb 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -3,8 +3,9 @@ import logging import os import uuid +from collections.abc import Awaitable, Callable from datetime import datetime, timezone -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import boto3 from botocore.config import Config as BotocoreConfig @@ -98,9 +99,9 @@ def my_llm(user_input: str, memories: List[Dict]) -> str: def __init__( self, memory_id: str, - region_name: Optional[str] = None, - boto3_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + region_name: str | None = None, + boto3_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, ): """Initialize a MemorySessionManager instance. @@ -147,7 +148,7 @@ def __init__( "batch_update_memory_records", } - def _validate_and_resolve_region(self, region_name: Optional[str], session: Optional[boto3.Session]) -> str: + def _validate_and_resolve_region(self, region_name: str | None, session: boto3.Session | None) -> str: """Validate region consistency and resolve the final region to use. Args: @@ -175,7 +176,7 @@ def _validate_and_resolve_region(self, region_name: Optional[str], session: Opti region_name or session_region or os.environ.get("AWS_REGION") or boto3.Session().region_name or "us-west-2" ) - def _build_client_config(self, boto_client_config: Optional[BotocoreConfig]) -> BotocoreConfig: + def _build_client_config(self, boto_client_config: BotocoreConfig | None) -> BotocoreConfig: """Build the final boto3 client configuration with SDK user agent. Args: @@ -255,11 +256,11 @@ def process_turn_with_llm( actor_id: str, session_id: str, user_input: str, - llm_callback: Callable[[str, List[Dict[str, Any]]], str], - retrieval_config: Optional[Dict[str, RetrievalConfig]], - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, - ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: + llm_callback: Callable[[str, list[dict[str, Any]]], str], + retrieval_config: dict[str, RetrievalConfig] | None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, + ) -> tuple[list[dict[str, Any]], str, dict[str, Any]]: r"""Complete conversation turn with LLM callback integration. This method combines memory retrieval, LLM invocation, and response storage @@ -334,11 +335,11 @@ async def process_turn_with_llm_async( actor_id: str, session_id: str, user_input: str, - llm_callback: Callable[[str, List[Dict[str, Any]]], Awaitable[str]], - retrieval_config: Optional[Dict[str, RetrievalConfig]], - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, - ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: + llm_callback: Callable[[str, list[dict[str, Any]]], Awaitable[str]], + retrieval_config: dict[str, RetrievalConfig] | None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, + ) -> tuple[list[dict[str, Any]], str, dict[str, Any]]: r"""Complete conversation turn with async LLM callback integration. This method combines memory retrieval, LLM invocation, and response storage @@ -384,8 +385,8 @@ def _retrieve_memories_for_llm( actor_id: str, session_id: str, user_input: str, - retrieval_config: Optional[Dict[str, RetrievalConfig]], - ) -> List[Dict[str, Any]]: + retrieval_config: dict[str, RetrievalConfig] | None, + ) -> list[dict[str, Any]]: """Helper method to retrieve memories for LLM context.""" retrieved_memories = [] if retrieval_config: @@ -417,9 +418,9 @@ def _save_conversation_turn( session_id: str, user_input: str, agent_response: str, - metadata: Optional[Dict[str, MetadataValue]], - event_timestamp: Optional[datetime], - ) -> Dict[str, Any]: + metadata: dict[str, MetadataValue] | None, + event_timestamp: datetime | None, + ) -> dict[str, Any]: """Helper method to save conversation turn.""" event = self.add_turns( actor_id=actor_id, @@ -438,10 +439,10 @@ def add_turns( self, actor_id: str, session_id: str, - messages: List[Union[ConversationalMessage, BlobMessage]], - branch: Optional[Dict[str, str]] = None, - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, + messages: list[ConversationalMessage | BlobMessage], + branch: dict[str, str] | None = None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, ) -> Event: """Adds conversational turns or blob objects to short-term memory. @@ -529,10 +530,10 @@ def fork_conversation( session_id: str, root_event_id: str, branch_name: str, - messages: List[Union[ConversationalMessage, BlobMessage]], - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, - ) -> Dict[str, Any]: + messages: list[ConversationalMessage | BlobMessage], + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, + ) -> dict[str, Any]: """Fork a conversation from a specific event to create a new branch.""" try: branch = {"rootEventId": root_event_id, "name": branch_name} @@ -557,12 +558,12 @@ def list_events( self, actor_id: str, session_id: str, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_parent_branches: bool = False, - eventMetadata: Optional[List[EventMetadataFilter]] = None, + eventMetadata: list[EventMetadataFilter] | None = None, max_results: int = 100, include_payload: bool = True, - ) -> List[Event]: + ) -> list[Event]: """List all events in a session with pagination support. This method provides direct access to the raw events API, allowing developers @@ -634,7 +635,7 @@ def list_events( ``` """ try: - all_events: List[Event] = [] + all_events: list[Event] = [] next_token = None max_iterations = 1000 # Safety limit to prevent infinite loops @@ -692,7 +693,7 @@ def list_events( logger.error("Failed to list events: %s", e) raise - def list_branches(self, actor_id: str, session_id: str) -> List[Branch]: + def list_branches(self, actor_id: str, session_id: str) -> list[Branch]: """List all branches in a session. This method handles pagination automatically and provides a structured view @@ -755,7 +756,7 @@ def list_branches(self, actor_id: str, session_id: str) -> List[Branch]: main_branch_events.append(event) # Build result list - result: List[Branch] = [] + result: list[Branch] = [] # Only add main branch if there are actual events if main_branch_events: @@ -784,10 +785,10 @@ def get_last_k_turns( actor_id: str, session_id: str, k: int = 5, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_parent_branches: bool = False, max_results: int = 100, - ) -> List[List[EventMessage]]: + ) -> list[list[EventMessage]]: """Get the last K conversation turns. A "turn" typically consists of a user message followed by assistant response(s). @@ -875,7 +876,7 @@ def search_long_term_memories( top_k: int = 3, strategy_id: str = None, max_results: int = 20, - ) -> List[MemoryRecord]: + ) -> list[MemoryRecord]: """Performs a semantic search against the long-term memory for this actor. Maps to: bedrock-agentcore.retrieve_memory_records. @@ -903,8 +904,8 @@ def search_long_term_memories( raise def list_long_term_memory_records( - self, namespace_prefix: str, strategy_id: Optional[str] = None, max_results: int = 10 - ) -> List[MemoryRecord]: + self, namespace_prefix: str, strategy_id: str | None = None, max_results: int = 10 + ) -> list[MemoryRecord]: """Lists all long-term memory records for this actor without a semantic query. Maps to: bedrock-agentcore.list_memory_records. @@ -923,7 +924,7 @@ def list_long_term_memory_records( params["memoryStrategyId"] = strategy_id pages = paginator.paginate(**params) - all_records: List[MemoryRecord] = [] + all_records: list[MemoryRecord] = [] for page in pages: memory_records = page.get("memoryRecords", []) @@ -944,7 +945,7 @@ def list_long_term_memory_records( logger.error(" ❌ Error listing long-term records: %s", e) raise - def list_actors(self) -> List[ActorSummary]: + def list_actors(self) -> list[ActorSummary]: """Lists all actors who have events in a specific memory. Maps to: bedrock-agentcore.list_actors. @@ -991,7 +992,7 @@ def delete_memory_record(self, record_id: str): logger.error(" ❌ Error deleting record: %s", e) raise - def list_actor_sessions(self, actor_id: str) -> List[SessionSummary]: + def list_actor_sessions(self, actor_id: str) -> list[SessionSummary]: """Lists all sessions for a specific actor in a specific memory. Maps to: bedrock-agentcore.list_sessions. @@ -1000,7 +1001,7 @@ def list_actor_sessions(self, actor_id: str) -> List[SessionSummary]: try: paginator = self._data_plane_client.get_paginator("list_sessions") pages = paginator.paginate(memoryId=self._memory_id, actorId=actor_id) - all_sessions: List[SessionSummary] = [] + all_sessions: list[SessionSummary] = [] for page in pages: response = page.get("sessionSummaries", []) all_sessions.extend([SessionSummary(session) for session in response]) @@ -1010,7 +1011,7 @@ def list_actor_sessions(self, actor_id: str) -> List[SessionSummary]: logger.error(" ❌ Error listing sessions: %s", e) raise - def delete_all_long_term_memories_in_namespace(self, namespace: str) -> Dict[str, Any]: + def delete_all_long_term_memories_in_namespace(self, namespace: str) -> dict[str, Any]: """Delete all long-term memory records within a specific namespace. This method retrieves all memory records in the specified namespace and performs @@ -1083,27 +1084,27 @@ def __init__(self, memory_id: str, actor_id: str, session_id: str, manager: Memo self._manager = manager super().__init__(self._construct_session_dict()) - def _construct_session_dict(self) -> Dict[str, Any]: + def _construct_session_dict(self) -> dict[str, Any]: """Constructs a dictionary representing the session.""" return {"memoryId": self._memory_id, "actorId": self._actor_id, "sessionId": self._session_id} def add_turns( self, - messages: List[Union[ConversationalMessage, BlobMessage]], - branch: Optional[Dict[str, str]] = None, - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, + messages: list[ConversationalMessage | BlobMessage], + branch: dict[str, str] | None = None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, ) -> Event: """Delegates to manager.add_turns.""" return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, metadata, event_timestamp) def fork_conversation( self, - messages: List[Union[ConversationalMessage, BlobMessage]], + messages: list[ConversationalMessage | BlobMessage], root_event_id: str, branch_name: str, - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, ) -> Event: """Delegates to manager.fork_conversation.""" return self._manager.fork_conversation( @@ -1113,11 +1114,11 @@ def fork_conversation( def process_turn_with_llm( self, user_input: str, - llm_callback: Callable[[str, List[Dict[str, Any]]], str], - retrieval_config: Optional[Dict[str, RetrievalConfig]], - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, - ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: + llm_callback: Callable[[str, list[dict[str, Any]]], str], + retrieval_config: dict[str, RetrievalConfig] | None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, + ) -> tuple[list[dict[str, Any]], str, dict[str, Any]]: """Delegates to manager.process_turn_with_llm.""" return self._manager.process_turn_with_llm( self._actor_id, @@ -1132,11 +1133,11 @@ def process_turn_with_llm( async def process_turn_with_llm_async( self, user_input: str, - llm_callback: Callable[[str, List[Dict[str, Any]]], Awaitable[str]], - retrieval_config: Optional[Dict[str, RetrievalConfig]], - metadata: Optional[Dict[str, MetadataValue]] = None, - event_timestamp: Optional[datetime] = None, - ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: + llm_callback: Callable[[str, list[dict[str, Any]]], Awaitable[str]], + retrieval_config: dict[str, RetrievalConfig] | None, + metadata: dict[str, MetadataValue] | None = None, + event_timestamp: datetime | None = None, + ) -> tuple[list[dict[str, Any]], str, dict[str, Any]]: """Delegates to manager.process_turn_with_llm_async.""" return await self._manager.process_turn_with_llm_async( self._actor_id, @@ -1151,10 +1152,10 @@ async def process_turn_with_llm_async( def get_last_k_turns( self, k: int = 5, - branch_name: Optional[str] = None, - include_parent_branches: Optional[bool] = None, + branch_name: str | None = None, + include_parent_branches: bool | None = None, max_results: int = 100, - ) -> List[List[EventMessage]]: + ) -> list[list[EventMessage]]: """Delegates to manager.get_last_k_turns.""" return self._manager.get_last_k_turns( self._actor_id, self._session_id, k, branch_name, include_parent_branches, max_results @@ -1181,30 +1182,30 @@ def search_long_term_memories( query: str, namespace_prefix: str, top_k: int = 3, - strategy_id: Optional[str] = None, + strategy_id: str | None = None, max_results: int = 20, - ) -> List[MemoryRecord]: + ) -> list[MemoryRecord]: """Delegates to manager.search_long_term_memories.""" return self._manager.search_long_term_memories(query, namespace_prefix, top_k, strategy_id, max_results) def list_long_term_memory_records( - self, namespace_prefix: str, strategy_id: Optional[str] = None, max_results: int = 10 - ) -> List[MemoryRecord]: + self, namespace_prefix: str, strategy_id: str | None = None, max_results: int = 10 + ) -> list[MemoryRecord]: """Delegates to manager.list_long_term_memory_records.""" return self._manager.list_long_term_memory_records(namespace_prefix, strategy_id, max_results) - def list_actors(self) -> List[ActorSummary]: + def list_actors(self) -> list[ActorSummary]: """Delegates to manager.list_actors.""" return self._manager.list_actors() def list_events( self, - branch_name: Optional[str] = None, + branch_name: str | None = None, include_parent_branches: bool = False, - eventMetadata: Optional[List[EventMetadataFilter]] = None, + eventMetadata: list[EventMetadataFilter] | None = None, max_results: int = 100, include_payload: bool = True, - ) -> List[Event]: + ) -> list[Event]: """Delegates to manager.list_events.""" return self._manager.list_events( actor_id=self._actor_id, @@ -1216,7 +1217,7 @@ def list_events( max_results=max_results, ) - def list_branches(self) -> List[Branch]: + def list_branches(self) -> list[Branch]: """Delegates to manager.list_branches.""" return self._manager.list_branches(self._actor_id, self._session_id) @@ -1238,12 +1239,12 @@ def __init__(self, actor_id: str, session_manager: MemorySessionManager): self._session_manager = session_manager super().__init__(self._construct_session_dict()) - def _construct_session_dict(self) -> Dict[str, Any]: + def _construct_session_dict(self) -> dict[str, Any]: """Constructs a dictionary representing the actor.""" return { "actorId": self._id, } - def list_sessions(self) -> List[SessionSummary]: + def list_sessions(self) -> list[SessionSummary]: """Delegates to _session_manager.list_actor_sessions.""" return self._session_manager.list_actor_sessions(self._id) diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 2b84056..982ae61 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -11,8 +11,8 @@ import threading import time import uuid -from collections.abc import Sequence -from typing import Any, Callable, Dict, Optional +from collections.abc import Callable, Sequence +from typing import Any from starlette.applications import Starlette from starlette.middleware import Middleware @@ -78,7 +78,7 @@ class BedrockAgentCoreApp(Starlette): def __init__( self, debug: bool = False, - lifespan: Optional[Lifespan] = None, + lifespan: Lifespan | None = None, middleware: Sequence[Middleware] | None = None, ): """Initialize Bedrock AgentCore application. @@ -88,11 +88,11 @@ def __init__( lifespan: Optional lifespan context manager for startup/shutdown middleware: Optional sequence of Starlette Middleware objects (or Middleware(...) entries) """ - self.handlers: Dict[str, Callable] = {} - self._ping_handler: Optional[Callable] = None - self._active_tasks: Dict[int, Dict[str, Any]] = {} + self.handlers: dict[str, Callable] = {} + self._ping_handler: Callable | None = None + self._active_tasks: dict[int, dict[str, Any]] = {} self._task_counter_lock: threading.Lock = threading.Lock() - self._forced_ping_status: Optional[PingStatus] = None + self._forced_ping_status: PingStatus | None = None self._last_status_update_time: float = time.time() routes = [ @@ -199,7 +199,7 @@ def clear_forced_ping_status(self): """Clear forced status and resume automatic.""" self._forced_ping_status = None - def get_async_task_info(self) -> Dict[str, Any]: + def get_async_task_info(self) -> dict[str, Any]: """Get info about running async tasks.""" running_jobs = [] for t in self._active_tasks.values(): @@ -213,7 +213,7 @@ def get_async_task_info(self) -> Dict[str, Any]: return {"active_count": len(self._active_tasks), "running_jobs": running_jobs} - def add_async_task(self, name: str, metadata: Optional[Dict] = None) -> int: + def add_async_task(self, name: str, metadata: dict | None = None) -> int: """Register an async task for interactive health tracking. This method provides granular control over async task lifecycle, @@ -386,7 +386,7 @@ def _handle_ping(self, request): self.logger.exception("Ping endpoint failed") return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())}) - def run(self, port: int = 8080, host: Optional[str] = None, **kwargs): + def run(self, port: int = 8080, host: str | None = None, **kwargs): """Start the Bedrock AgentCore server. Args: @@ -430,7 +430,7 @@ async def _invoke_handler(self, handler, request_context, takes_context, payload self.logger.debug("Handler '%s' execution failed", handler_name) raise - def _handle_task_action(self, payload: dict) -> Optional[JSONResponse]: + def _handle_task_action(self, payload: dict) -> JSONResponse | None: """Handle task management actions if present in payload.""" action = payload.get("_agent_core_app_action") if not action: diff --git a/src/bedrock_agentcore/runtime/context.py b/src/bedrock_agentcore/runtime/context.py index b32fbe1..f8dec5b 100644 --- a/src/bedrock_agentcore/runtime/context.py +++ b/src/bedrock_agentcore/runtime/context.py @@ -4,7 +4,6 @@ """ from contextvars import ContextVar -from typing import Dict, Optional from pydantic import BaseModel, Field @@ -12,18 +11,18 @@ class RequestContext(BaseModel): """Request context containing metadata from HTTP requests.""" - session_id: Optional[str] = Field(None) - request_headers: Optional[Dict[str, str]] = Field(None) + session_id: str | None = Field(None) + request_headers: dict[str, str] | None = Field(None) class BedrockAgentCoreContext: """Unified context manager for Bedrock AgentCore.""" - _workload_access_token: ContextVar[Optional[str]] = ContextVar("workload_access_token") - _oauth2_callback_url: ContextVar[Optional[str]] = ContextVar("oauth2_callback_url") - _request_id: ContextVar[Optional[str]] = ContextVar("request_id") - _session_id: ContextVar[Optional[str]] = ContextVar("session_id") - _request_headers: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers") + _workload_access_token: ContextVar[str | None] = ContextVar("workload_access_token") + _oauth2_callback_url: ContextVar[str | None] = ContextVar("oauth2_callback_url") + _request_id: ContextVar[str | None] = ContextVar("request_id") + _session_id: ContextVar[str | None] = ContextVar("session_id") + _request_headers: ContextVar[dict[str, str] | None] = ContextVar("request_headers") @classmethod def set_workload_access_token(cls, token: str): @@ -31,7 +30,7 @@ def set_workload_access_token(cls, token: str): cls._workload_access_token.set(token) @classmethod - def get_workload_access_token(cls) -> Optional[str]: + def get_workload_access_token(cls) -> str | None: """Get the workload access token from the context.""" try: return cls._workload_access_token.get() @@ -44,7 +43,7 @@ def set_oauth2_callback_url(cls, workload_callback_url: str): cls._oauth2_callback_url.set(workload_callback_url) @classmethod - def get_oauth2_callback_url(cls) -> Optional[str]: + def get_oauth2_callback_url(cls) -> str | None: """Get the oauth2 callback url from the context.""" try: return cls._oauth2_callback_url.get() @@ -52,13 +51,13 @@ def get_oauth2_callback_url(cls) -> Optional[str]: return None @classmethod - def set_request_context(cls, request_id: str, session_id: Optional[str] = None): + def set_request_context(cls, request_id: str, session_id: str | None = None): """Set request-scoped identifiers.""" cls._request_id.set(request_id) cls._session_id.set(session_id) @classmethod - def get_request_id(cls) -> Optional[str]: + def get_request_id(cls) -> str | None: """Get current request ID.""" try: return cls._request_id.get() @@ -66,7 +65,7 @@ def get_request_id(cls) -> Optional[str]: return None @classmethod - def get_session_id(cls) -> Optional[str]: + def get_session_id(cls) -> str | None: """Get current session ID.""" try: return cls._session_id.get() @@ -74,12 +73,12 @@ def get_session_id(cls) -> Optional[str]: return None @classmethod - def set_request_headers(cls, headers: Dict[str, str]): + def set_request_headers(cls, headers: dict[str, str]): """Set request headers in the context.""" cls._request_headers.set(headers) @classmethod - def get_request_headers(cls) -> Optional[Dict[str, str]]: + def get_request_headers(cls) -> dict[str, str] | None: """Get request headers from the context.""" try: return cls._request_headers.get() diff --git a/src/bedrock_agentcore/runtime/utils.py b/src/bedrock_agentcore/runtime/utils.py index 351cdd0..387e37d 100644 --- a/src/bedrock_agentcore/runtime/utils.py +++ b/src/bedrock_agentcore/runtime/utils.py @@ -23,7 +23,7 @@ def convert_complex_objects(obj: Any, _depth: int = 0) -> Any: return {k: convert_complex_objects(v, _depth + 1) for k, v in obj.items()} # Handle lists and tuples recursively - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, list | tuple): return [convert_complex_objects(item, _depth + 1) for item in obj] # Handle sets (convert to list) diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index 80d30ac..638ee5e 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -5,7 +5,8 @@ import time import uuid from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from collections.abc import Callable +from typing import Any, Literal import boto3 from pydantic import BaseModel @@ -97,8 +98,8 @@ def create_api_key_credential_provider(self, req): return self.cp_client.create_api_key_credential_provider(**req) def get_workload_access_token( - self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None - ) -> Dict: + self, workload_name: str, user_token: str | None = None, user_id: str | None = None + ) -> dict: """Get a workload access token using workload name and optionally user token.""" if user_token: if user_id is not None: @@ -116,8 +117,8 @@ def get_workload_access_token( return resp def create_workload_identity( - self, name: Optional[str] = None, allowed_resource_oauth_2_return_urls: Optional[list[str]] = None - ) -> Dict: + self, name: str | None = None, allowed_resource_oauth_2_return_urls: list[str] | None = None + ) -> dict: """Create workload identity with optional name.""" self.logger.info("Creating workload identity...") if not name: @@ -126,7 +127,7 @@ def create_workload_identity( name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls or [] ) - def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict: + def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> dict: """Update an existing workload identity with allowed resource OAuth2 callback urls.""" self.logger.info( "Updating workload identity '%s' with callback urls: %s", name, allowed_resource_oauth_2_return_urls @@ -135,13 +136,13 @@ def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_ur name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls ) - def get_workload_identity(self, name: str) -> Dict: + def get_workload_identity(self, name: str) -> dict: """Retrieves information about a workload identity.""" self.logger.info("Fetching workload identity '%s'", name) return self.cp_client.get_workload_identity(name=name) def complete_resource_token_auth( - self, session_uri: str, user_identifier: Union[UserTokenIdentifier, UserIdIdentifier] + self, session_uri: str, user_identifier: UserTokenIdentifier | UserIdIdentifier ): """Confirms the user authentication session for obtaining OAuth2.0 tokens for a resource.""" self.logger.info("Completing 3LO OAuth2 flow...") @@ -160,15 +161,15 @@ async def get_token( self, *, provider_name: str, - scopes: Optional[List[str]] = None, + scopes: list[str] | None = None, agent_identity_token: str, - on_auth_url: Optional[Callable[[str], Any]] = None, + on_auth_url: Callable[[str], Any] | None = None, auth_flow: Literal["M2M", "USER_FEDERATION"], - callback_url: Optional[str] = None, + callback_url: str | None = None, force_authentication: bool = False, - token_poller: Optional[TokenPoller] = None, - custom_state: Optional[str] = None, - custom_parameters: Optional[Dict[str, str]] = None, + token_poller: TokenPoller | None = None, + custom_state: str | None = None, + custom_parameters: dict[str, str] | None = None, ) -> str: """Get an OAuth2 access token for the specified provider. diff --git a/src/bedrock_agentcore/tools/browser_client.py b/src/bedrock_agentcore/tools/browser_client.py index 4ebccbc..1180de7 100644 --- a/src/bedrock_agentcore/tools/browser_client.py +++ b/src/bedrock_agentcore/tools/browser_client.py @@ -10,8 +10,8 @@ import logging import secrets import uuid +from collections.abc import Generator from contextlib import contextmanager -from typing import Dict, Generator, Optional, Tuple from urllib.parse import urlparse import boto3 @@ -69,22 +69,22 @@ def __init__(self, region: str) -> None: self._session_id = None @property - def identifier(self) -> Optional[str]: + def identifier(self) -> str | None: """Get the current browser identifier.""" return self._identifier @identifier.setter - def identifier(self, value: Optional[str]): + def identifier(self, value: str | None): """Set the browser identifier.""" self._identifier = value @property - def session_id(self) -> Optional[str]: + def session_id(self) -> str | None: """Get the current session ID.""" return self._session_id @session_id.setter - def session_id(self, value: Optional[str]): + def session_id(self, value: str | None): """Set the session ID.""" self._session_id = value @@ -92,13 +92,13 @@ def create_browser( self, name: str, execution_role_arn: str, - network_configuration: Optional[Dict] = None, - description: Optional[str] = None, - recording: Optional[Dict] = None, - browser_signing: Optional[Dict] = None, - tags: Optional[Dict[str, str]] = None, - client_token: Optional[str] = None, - ) -> Dict: + network_configuration: dict | None = None, + description: str | None = None, + recording: dict | None = None, + browser_signing: dict | None = None, + tags: dict[str, str] | None = None, + client_token: str | None = None, + ) -> dict: """Create a custom browser with specific configuration. This is a control plane operation that provisions a new browser with @@ -183,7 +183,7 @@ def create_browser( response = self.control_plane_client.create_browser(**request_params) return response - def delete_browser(self, browser_id: str, client_token: Optional[str] = None) -> Dict: + def delete_browser(self, browser_id: str, client_token: str | None = None) -> dict: """Delete a custom browser. Args: @@ -208,7 +208,7 @@ def delete_browser(self, browser_id: str, client_token: Optional[str] = None) -> response = self.control_plane_client.delete_browser(**request_params) return response - def get_browser(self, browser_id: str) -> Dict: + def get_browser(self, browser_id: str) -> dict: """Get detailed information about a browser. Args: @@ -237,10 +237,10 @@ def get_browser(self, browser_id: str) -> Dict: def list_browsers( self, - browser_type: Optional[str] = None, + browser_type: str | None = None, max_results: int = 10, - next_token: Optional[str] = None, - ) -> Dict: + next_token: str | None = None, + ) -> dict: """List all browsers in the account. Args: @@ -272,10 +272,10 @@ def list_browsers( def start( self, - identifier: Optional[str] = DEFAULT_IDENTIFIER, - name: Optional[str] = None, - session_timeout_seconds: Optional[int] = DEFAULT_SESSION_TIMEOUT, - viewport: Optional[Dict[str, int]] = None, + identifier: str | None = DEFAULT_IDENTIFIER, + name: str | None = None, + session_timeout_seconds: int | None = DEFAULT_SESSION_TIMEOUT, + viewport: dict[str, int] | None = None, ) -> str: """Start a browser sandbox session. @@ -341,7 +341,7 @@ def stop(self) -> bool: self.session_id = None return True - def get_session(self, browser_id: Optional[str] = None, session_id: Optional[str] = None) -> Dict: + def get_session(self, browser_id: str | None = None, session_id: str | None = None) -> dict: """Get detailed information about a browser session. Args: @@ -377,11 +377,11 @@ def get_session(self, browser_id: Optional[str] = None, session_id: Optional[str def list_sessions( self, - browser_id: Optional[str] = None, - status: Optional[str] = None, + browser_id: str | None = None, + status: str | None = None, max_results: int = 10, - next_token: Optional[str] = None, - ) -> Dict: + next_token: str | None = None, + ) -> dict: """List browser sessions for a specific browser. Args: @@ -419,8 +419,8 @@ def list_sessions( def update_stream( self, stream_status: str, - browser_id: Optional[str] = None, - session_id: Optional[str] = None, + browser_id: str | None = None, + session_id: str | None = None, ) -> None: """Update the browser automation stream status. @@ -451,7 +451,7 @@ def update_stream( streamUpdate={"automationStreamUpdate": {"streamStatus": stream_status}}, ) - def generate_ws_headers(self) -> Tuple[str, Dict[str, str]]: + def generate_ws_headers(self) -> tuple[str, dict[str, str]]: """Generate the WebSocket headers needed for connecting to the browser sandbox. Returns: @@ -568,7 +568,7 @@ def release_control(self): @contextmanager def browser_session( - region: str, viewport: Optional[Dict[str, int]] = None, identifier: Optional[str] = None + region: str, viewport: dict[str, int] | None = None, identifier: str | None = None ) -> Generator[BrowserClient, None, None]: """Context manager for creating and managing a browser sandbox session. diff --git a/src/bedrock_agentcore/tools/code_interpreter_client.py b/src/bedrock_agentcore/tools/code_interpreter_client.py index 36b5415..ea04a4a 100644 --- a/src/bedrock_agentcore/tools/code_interpreter_client.py +++ b/src/bedrock_agentcore/tools/code_interpreter_client.py @@ -6,8 +6,8 @@ import logging import uuid +from collections.abc import Generator from contextlib import contextmanager -from typing import Dict, Generator, Optional import boto3 @@ -33,7 +33,7 @@ class CodeInterpreter: session_id (str, optional): The active session ID. """ - def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None: + def __init__(self, region: str, session: boto3.Session | None = None) -> None: """Initialize a Code Interpreter client for the specified AWS region. Args: @@ -64,22 +64,22 @@ def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None self._session_id = None @property - def identifier(self) -> Optional[str]: + def identifier(self) -> str | None: """Get the current code interpreter identifier.""" return self._identifier @identifier.setter - def identifier(self, value: Optional[str]): + def identifier(self, value: str | None): """Set the code interpreter identifier.""" self._identifier = value @property - def session_id(self) -> Optional[str]: + def session_id(self) -> str | None: """Get the current session ID.""" return self._session_id @session_id.setter - def session_id(self, value: Optional[str]): + def session_id(self, value: str | None): """Set the session ID.""" self._session_id = value @@ -87,11 +87,11 @@ def create_code_interpreter( self, name: str, execution_role_arn: str, - network_configuration: Optional[Dict] = None, - description: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - client_token: Optional[str] = None, - ) -> Dict: + network_configuration: dict | None = None, + description: str | None = None, + tags: dict[str, str] | None = None, + client_token: str | None = None, + ) -> dict: """Create a custom code interpreter with specific configuration. This is a control plane operation that provisions a new code interpreter @@ -157,7 +157,7 @@ def create_code_interpreter( response = self.control_plane_client.create_code_interpreter(**request_params) return response - def delete_code_interpreter(self, interpreter_id: str, client_token: Optional[str] = None) -> Dict: + def delete_code_interpreter(self, interpreter_id: str, client_token: str | None = None) -> dict: """Delete a custom code interpreter. Args: @@ -182,7 +182,7 @@ def delete_code_interpreter(self, interpreter_id: str, client_token: Optional[st response = self.control_plane_client.delete_code_interpreter(**request_params) return response - def get_code_interpreter(self, interpreter_id: str) -> Dict: + def get_code_interpreter(self, interpreter_id: str) -> dict: """Get detailed information about a code interpreter. Args: @@ -207,10 +207,10 @@ def get_code_interpreter(self, interpreter_id: str) -> Dict: def list_code_interpreters( self, - interpreter_type: Optional[str] = None, + interpreter_type: str | None = None, max_results: int = 10, - next_token: Optional[str] = None, - ) -> Dict: + next_token: str | None = None, + ) -> dict: """List all code interpreters in the account. Args: @@ -242,9 +242,9 @@ def list_code_interpreters( def start( self, - identifier: Optional[str] = DEFAULT_IDENTIFIER, - name: Optional[str] = None, - session_timeout_seconds: Optional[int] = DEFAULT_TIMEOUT, + identifier: str | None = DEFAULT_IDENTIFIER, + name: str | None = None, + session_timeout_seconds: int | None = DEFAULT_TIMEOUT, ) -> str: """Start a code interpreter sandbox session. @@ -302,7 +302,7 @@ def stop(self) -> bool: self.session_id = None return True - def get_session(self, interpreter_id: Optional[str] = None, session_id: Optional[str] = None) -> Dict: + def get_session(self, interpreter_id: str | None = None, session_id: str | None = None) -> dict: """Get detailed information about a code interpreter session. Args: @@ -335,11 +335,11 @@ def get_session(self, interpreter_id: Optional[str] = None, session_id: Optional def list_sessions( self, - interpreter_id: Optional[str] = None, - status: Optional[str] = None, + interpreter_id: str | None = None, + status: str | None = None, max_results: int = 10, - next_token: Optional[str] = None, - ) -> Dict: + next_token: str | None = None, + ) -> dict: """List code interpreter sessions for a specific interpreter. Args: @@ -374,7 +374,7 @@ def list_sessions( response = self.data_plane_client.list_code_interpreter_sessions(**request_params) return response - def invoke(self, method: str, params: Optional[Dict] = None): + def invoke(self, method: str, params: dict | None = None): r"""Invoke a method in the code interpreter sandbox. If no session is active, automatically starts a new session. @@ -407,7 +407,7 @@ def invoke(self, method: str, params: Optional[Dict] = None): @contextmanager def code_session( - region: str, session: Optional[boto3.Session] = None, identifier: Optional[str] = None + region: str, session: boto3.Session | None = None, identifier: str | None = None ) -> Generator[CodeInterpreter, None, None]: """Context manager for creating and managing a code interpreter session. diff --git a/src/bedrock_agentcore/tools/config.py b/src/bedrock_agentcore/tools/config.py index bf5312f..30ca929 100644 --- a/src/bedrock_agentcore/tools/config.py +++ b/src/bedrock_agentcore/tools/config.py @@ -5,7 +5,6 @@ """ from dataclasses import dataclass, field -from typing import Dict, List, Optional @dataclass @@ -17,10 +16,10 @@ class VpcConfig: subnets: List of subnet IDs """ - security_groups: List[str] - subnets: List[str] + security_groups: list[str] + subnets: list[str] - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" return {"securityGroups": self.security_groups, "subnets": self.subnets} @@ -35,7 +34,7 @@ class NetworkConfiguration: """ network_mode: str = "PUBLIC" - vpc_config: Optional[VpcConfig] = None + vpc_config: VpcConfig | None = None def __post_init__(self): """Validate configuration.""" @@ -45,7 +44,7 @@ def __post_init__(self): if self.network_mode == "VPC" and not self.vpc_config: raise ValueError("vpc_config is required when network_mode is 'VPC'") - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" config = {"networkMode": self.network_mode} if self.vpc_config: @@ -58,7 +57,7 @@ def public(cls) -> "NetworkConfiguration": return cls(network_mode="PUBLIC") @classmethod - def vpc(cls, security_groups: List[str], subnets: List[str]) -> "NetworkConfiguration": + def vpc(cls, security_groups: list[str], subnets: list[str]) -> "NetworkConfiguration": """Create a VPC network configuration. Args: @@ -81,9 +80,9 @@ class S3Location: """ bucket: str - key_prefix: Optional[str] = None + key_prefix: str | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" location = {"bucket": self.bucket} if self.key_prefix: @@ -101,9 +100,9 @@ class RecordingConfiguration: """ enabled: bool = True - s3_location: Optional[S3Location] = None + s3_location: S3Location | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" config = {"enabled": self.enabled} if self.s3_location: @@ -116,7 +115,7 @@ def disabled(cls) -> "RecordingConfiguration": return cls(enabled=False) @classmethod - def enabled_with_location(cls, bucket: str, key_prefix: Optional[str] = None) -> "RecordingConfiguration": + def enabled_with_location(cls, bucket: str, key_prefix: str | None = None) -> "RecordingConfiguration": """Create an enabled recording configuration with S3 location. Args: @@ -141,7 +140,7 @@ class BrowserSigningConfiguration: enabled: bool = True - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" return {"enabled": self.enabled} @@ -168,7 +167,7 @@ class ViewportConfiguration: width: int height: int - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary.""" return {"width": self.width, "height": self.height} @@ -217,12 +216,12 @@ class BrowserConfiguration: name: str execution_role_arn: str network_configuration: NetworkConfiguration - description: Optional[str] = None - recording: Optional[RecordingConfiguration] = None - browser_signing: Optional[BrowserSigningConfiguration] = None - tags: Optional[Dict[str, str]] = field(default_factory=dict) + description: str | None = None + recording: RecordingConfiguration | None = None + browser_signing: BrowserSigningConfiguration | None = None + tags: dict[str, str] | None = field(default_factory=dict) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary for create_browser.""" config = { "name": self.name, @@ -260,10 +259,10 @@ class CodeInterpreterConfiguration: name: str execution_role_arn: str network_configuration: NetworkConfiguration - description: Optional[str] = None - tags: Optional[Dict[str, str]] = field(default_factory=dict) + description: str | None = None + tags: dict[str, str] | None = field(default_factory=dict) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to API-compatible dictionary for create_code_interpreter.""" config = { "name": self.name, @@ -285,13 +284,13 @@ def create_browser_config( execution_role_arn: str, enable_web_bot_auth: bool = False, enable_recording: bool = False, - recording_bucket: Optional[str] = None, - recording_prefix: Optional[str] = None, + recording_bucket: str | None = None, + recording_prefix: str | None = None, use_vpc: bool = False, - security_groups: Optional[List[str]] = None, - subnets: Optional[List[str]] = None, - description: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, + security_groups: list[str] | None = None, + subnets: list[str] | None = None, + description: str | None = None, + tags: dict[str, str] | None = None, ) -> BrowserConfiguration: """Create a browser configuration with common options. diff --git a/tests/bedrock_agentcore/memory/test_session.py b/tests/bedrock_agentcore/memory/test_session.py index 549d346..34b1397 100644 --- a/tests/bedrock_agentcore/memory/test_session.py +++ b/tests/bedrock_agentcore/memory/test_session.py @@ -3,7 +3,7 @@ import asyncio import uuid from datetime import datetime, timezone -from typing import Any, Dict, List +from typing import Any from unittest import mock from unittest.mock import MagicMock, Mock, patch @@ -402,7 +402,7 @@ def test_process_turn_with_llm_success(self): mock_event = {"eventId": "event-123", "memoryId": "testMemory-1234567890"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): # Define LLM callback - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return f"Response to: {user_input} with {len(memories)} memories" # Test process_turn_with_llm with new RetrievalConfig API @@ -434,7 +434,7 @@ def test_process_turn_with_llm_no_retrieval(self): mock_event = {"eventId": "event-123", "memoryId": "testMemory-1234567890"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): # Define LLM callback - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return f"Response to: {user_input}" # Test process_turn_with_llm without retrieval (None retrieval_config) @@ -465,7 +465,7 @@ def test_process_turn_with_llm_async_method(self): mock_event = {"eventId": "event-123", "memoryId": "testMemory-1234567890"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): # Define async LLM callback - async def mock_async_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + async def mock_async_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return f"Async method response to: {user_input}" # Test process_turn_with_llm_async @@ -496,7 +496,7 @@ def test_process_turn_with_llm_callback_error(self): manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") # Define failing LLM callback - def failing_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def failing_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: raise Exception("LLM service error") # Test process_turn_with_llm with callback error @@ -523,7 +523,7 @@ def test_process_turn_with_llm_invalid_callback_return(self): manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") # Define callback that returns non-string - def invalid_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> int: + def invalid_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> int: return 123 # type: ignore # Test process_turn_with_llm with invalid return type @@ -1778,7 +1778,7 @@ def test_session_process_turn_with_llm_delegation(self): manager, "process_turn_with_llm", return_value=(mock_memories, mock_response, mock_event) ) as mock_process: - def mock_llm(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" memories, response, event = session.process_turn_with_llm( @@ -2045,7 +2045,7 @@ def test_process_turn_with_llm_custom_retrieval_config(self): mock_event = {"eventId": "event-123"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" # Test with custom retrieval config @@ -2589,7 +2589,7 @@ def test_memory_session_process_turn_with_llm_with_metadata(self): manager, "process_turn_with_llm", return_value=(mock_memories, mock_response, mock_event) ) as mock_process: - def mock_llm(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" memories, response, event = session.process_turn_with_llm( @@ -2619,7 +2619,7 @@ def test_process_turn_with_llm_with_metadata_parameter(self): mock_event = {"eventId": "event-123", "memoryId": "testMemory-1234567890"} with patch.object(manager, "add_turns", return_value=Event(mock_event)) as mock_add_turns: # Define LLM callback - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return f"Response to: {user_input} with {len(memories)} memories" # Test process_turn_with_llm with metadata @@ -2733,7 +2733,7 @@ def test_process_turn_with_llm_with_retrieval_query_fallback(self): mock_event = {"eventId": "event-123"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" # Test with retrieval_config but no retrieval_query (should use user_input) @@ -2865,7 +2865,7 @@ def test_process_turn_with_llm_with_relevance_score_filtering(self): mock_event = {"eventId": "event-123"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return f"Response with {len(memories)} memories" # Test with relevance_score filtering (should filter out low relevance) @@ -3094,7 +3094,7 @@ def test_process_turn_with_llm_no_relevance_score_config(self): mock_event = {"eventId": "event-123"} with patch.object(manager, "add_turns", return_value=Event(mock_event)): - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" # Test with RetrievalConfig that has a very low relevance_score (effectively no filtering) @@ -3716,7 +3716,7 @@ def test_process_turn_with_llm_no_retrieval_namespace(self): # Mock search_long_term_memories to ensure it's not called with patch.object(manager, "search_long_term_memories") as mock_search: - def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + def mock_llm_callback(user_input: str, memories: list[dict[str, Any]]) -> str: return "Response" # Test without retrieval_config (should not call search) diff --git a/tests/bedrock_agentcore/runtime/test_app.py b/tests/bedrock_agentcore/runtime/test_app.py index c0c2568..8e9e80b 100644 --- a/tests/bedrock_agentcore/runtime/test_app.py +++ b/tests/bedrock_agentcore/runtime/test_app.py @@ -1387,7 +1387,7 @@ def test_circular_references(self): # Should fallback to string representation or error object parsed = json.loads(result) - assert isinstance(parsed, (str, dict)) + assert isinstance(parsed, str | dict) # If it's a string, should contain some representation if isinstance(parsed, str): diff --git a/tests/bedrock_agentcore/runtime/test_utils.py b/tests/bedrock_agentcore/runtime/test_utils.py index 4f335a7..05250cb 100644 --- a/tests/bedrock_agentcore/runtime/test_utils.py +++ b/tests/bedrock_agentcore/runtime/test_utils.py @@ -1,7 +1,6 @@ """Tests for Bedrock AgentCore runtime utilities.""" from dataclasses import dataclass -from typing import List, Optional from pydantic import BaseModel @@ -64,7 +63,7 @@ def test_dataclasses(self): class TestDataClass: name: str value: int - items: List[str] + items: list[str] data = TestDataClass(name="test", value=100, items=["a", "b", "c"]) result = convert_complex_objects(data) @@ -123,7 +122,7 @@ class ConfigModel(BaseModel): @dataclass class ConfigData: version: str - features: List[str] + features: list[str] test_dict = { "config": ConfigModel(setting="test", enabled=True), @@ -202,7 +201,7 @@ class UserModel(BaseModel): @dataclass class UserProfile: bio: str - avatar_url: Optional[str] + avatar_url: str | None class PostModel(BaseModel): title: str diff --git a/tests_integ/async/interactive_async_strands.py b/tests_integ/async/interactive_async_strands.py index 43025ae..723c03d 100644 --- a/tests_integ/async/interactive_async_strands.py +++ b/tests_integ/async/interactive_async_strands.py @@ -24,7 +24,7 @@ import threading import time from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from typing import Any from strands import Agent, tool @@ -76,13 +76,13 @@ def __init__( total_seconds = duration_minutes * 60 self.base_processing_speed = self.total_items / total_seconds - def get_current_stage(self) -> Dict[str, Any]: + def get_current_stage(self) -> dict[str, Any]: """Get current processing stage info.""" if self.current_stage_index < len(self.PROCESSING_STAGES): return self.PROCESSING_STAGES[self.current_stage_index] return {"name": "completed", "weight": 0, "description": "Processing completed"} - def calculate_progress(self) -> Dict[str, Any]: + def calculate_progress(self) -> dict[str, Any]: """Calculate detailed progress information.""" current_stage = self.get_current_stage() @@ -330,7 +330,7 @@ def get_processing_progress(task_id: Optional[int] = None) -> str: `start_data_processing(dataset_size="medium", processing_type="data_analysis")`""" try: - with open(result_file, "r") as f: + with open(result_file) as f: progress = json.load(f) status = progress.get("status", "unknown") diff --git a/tests_integ/async/test_async_status_example.py b/tests_integ/async/test_async_status_example.py index 8cf7e0c..bdd934b 100644 --- a/tests_integ/async/test_async_status_example.py +++ b/tests_integ/async/test_async_status_example.py @@ -6,7 +6,7 @@ """ import time -from typing import Any, Dict +from typing import Any import requests @@ -45,7 +45,7 @@ def test_ping_endpoint(self): print(f" ❌ Error testing ping endpoint: {e}") return None - def test_rpc_action(self, action: str, expected_fields: list = None) -> Dict[Any, Any]: + def test_rpc_action(self, action: str, expected_fields: list = None) -> dict[Any, Any]: """Test a debug action via POST /invocations.""" print(f"🔍 Testing debug action: {action}") try: @@ -70,7 +70,7 @@ def test_rpc_action(self, action: str, expected_fields: list = None) -> Dict[Any print(f" ❌ Error testing debug action '{action}': {e}") return {} - def test_business_action(self, action: str, payload: dict = None) -> Dict[Any, Any]: + def test_business_action(self, action: str, payload: dict = None) -> dict[Any, Any]: """Test a regular business logic action.""" print(f"🔍 Testing business action: {action}") try: diff --git a/tests_integ/memory/test_devex.py b/tests_integ/memory/test_devex.py index 9a864a6..db1fd40 100644 --- a/tests_integ/memory/test_devex.py +++ b/tests_integ/memory/test_devex.py @@ -43,8 +43,8 @@ def test_complete_agent_workflow(client: MemoryClient, memory_id: str): logger.info("COMPLETE AGENT WORKFLOW TEST") logger.info("=" * 80) - actor_id = "customer-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "support-%s" % datetime.now().strftime("%Y%m%d%H%M%S") + actor_id = "customer-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) + session_id = "support-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) logger.info("\n1. Memory strategies already configured during creation") @@ -356,7 +356,7 @@ def save_with_retry(memory_id, actor_id, session_id, messages, branch=None, max_ logger.info("Waiting 30 seconds for extraction to trigger...") time.sleep(30) - namespace = "support/facts/%s" % session_id + namespace = f"support/facts/{session_id}" if client.wait_for_memories(memory_id, namespace, max_wait=180): logger.info("✓ Memories extracted and indexed successfully") @@ -411,8 +411,8 @@ def test_bedrock_integration(client: MemoryClient, memory_id: str): logger.info("Skipping Bedrock test - ensure AWS credentials are configured") return - actor_id = "bedrock-test-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "bedrock-session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") + actor_id = "bedrock-test-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) + session_id = "bedrock-session-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) # Create initial context logger.info("\n1. Creating initial conversation context...") @@ -439,7 +439,7 @@ def test_bedrock_integration(client: MemoryClient, memory_id: str): # Retrieve relevant memories logger.info("\n4. Retrieving relevant context...") - namespace = "support/facts/%s" % session_id + namespace = f"support/facts/{session_id}" memories = client.retrieve_memories(memory_id=memory_id, namespace=namespace, query=user_query, top_k=5) context = "" @@ -453,7 +453,7 @@ def test_bedrock_integration(client: MemoryClient, memory_id: str): messages = [] if context: messages.append( - {"role": "assistant", "content": "Here's what I know from our previous conversation:\n%s" % context} + {"role": "assistant", "content": f"Here's what I know from our previous conversation:\n{context}"} ) messages.append({"role": "user", "content": user_query}) @@ -694,7 +694,7 @@ def main(): logger.info("\nCreating test memory with strategies...") memory = client.create_memory( - name="DXTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), + name="DXTest_{}".format(datetime.now().strftime("%Y%m%d%H%M%S")), description="Developer experience evaluation", strategies=[ { diff --git a/tests_integ/memory/test_memory_client.py b/tests_integ/memory/test_memory_client.py index 6fb69d7..a907ff9 100644 --- a/tests_integ/memory/test_memory_client.py +++ b/tests_integ/memory/test_memory_client.py @@ -18,8 +18,8 @@ def test_list_events_api(client: MemoryClient, memory_id: str): logger.info("TESTING LIST_EVENTS PUBLIC API (Issue #1)") logger.info("=" * 80) - actor_id = "test-list-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") + actor_id = "test-list-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) + session_id = "session-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) # Create some events logger.info("\n1. Creating test events...") @@ -79,7 +79,7 @@ def test_strategy_polling_fix(client: MemoryClient): # Create memory without strategies logger.info("\n1. Creating memory without strategies...") memory = client.create_memory_and_wait( - name="PollingTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), + name="PollingTest_{}".format(datetime.now().strftime("%Y%m%d%H%M%S")), strategies=[], # No strategies initially event_expiry_days=7, ) @@ -135,8 +135,8 @@ def test_get_last_k_turns_fix(client: MemoryClient, memory_id: str): logger.info("TESTING GET_LAST_K_TURNS FIX (Issue #3)") logger.info("=" * 80) - actor_id = "restaurant-user-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "restaurant-session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") + actor_id = "restaurant-user-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) + session_id = "restaurant-session-{}".format(datetime.now().strftime("%Y%m%d%H%M%S")) # Create the exact conversation from the issue logger.info("\n1. Creating restaurant conversation...") @@ -357,7 +357,7 @@ def main(): logger.info("\n\nCreating memory for remaining tests...") # Explicitly define strategy with clear namespace pattern for testing memory = client.create_memory_and_wait( - name="RetrievalTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), + name="RetrievalTest_{}".format(datetime.now().strftime("%Y%m%d%H%M%S")), strategies=[ { "semanticMemoryStrategy": { diff --git a/tests_integ/runtime/base_test.py b/tests_integ/runtime/base_test.py index 5373fec..f5022c6 100644 --- a/tests_integ/runtime/base_test.py +++ b/tests_integ/runtime/base_test.py @@ -4,9 +4,10 @@ import threading import time from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager from subprocess import Popen -from typing import IO, Generator +from typing import IO logger = logging.getLogger("sdk-runtime-base-test")