diff --git a/CLAUDE.md b/CLAUDE.md index 0033052..2c45e4d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -152,6 +152,16 @@ Avoid over-specifying implementation details; focus on the "what" and "why", not Review existing RFCs before implementation to understand design decisions and constraints. +## Code Review Standards (RFC Conformance) + +When reviewing code against an RFC: + +1. **State the RFC's core invariant in one sentence** before reading code +2. **Trace data flow**: for key variables, verify input set == output set +3. **Check all paths**: normal, fallback, error β€” do they all satisfy the invariant? + +If you can't answer "yes" with line numbers, dig deeper. + ## Async Runtime Rules - **New runtime code must be async-first**: avoid introducing new blocking I/O in `agent/`, `llm/`, `memory/`, and `tools/`. diff --git a/README.md b/README.md index 19d1993..37e2a14 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,9 @@ See the full configuration template in `.env.example`. Key options: | `MAX_ITERATIONS` | Maximum agent iterations | `100` | | `MEMORY_COMPRESSION_THRESHOLD` | Compress when exceeded | `25000` | | `MEMORY_SHORT_TERM_SIZE` | Recent messages to keep | `100` | +| `COMPACT_USER_MESSAGE_MAX_TOKENS` | User message budget during compaction | `20000` | +| `TOOL_OUTPUT_TRUNCATION_POLICY` | Truncate tool outputs (`none|bytes|tokens`) | `tokens` | +| `CONTEXT_OVERFLOW_MAX_RETRIES` | Retries on context overflow | `3` | | `RETRY_MAX_ATTEMPTS` | Retry attempts for rate limits | `3` | | `LOG_LEVEL` | Logging level | `DEBUG` | diff --git a/agent/base.py b/agent/base.py index 94bdc97..25aaa95 100644 --- a/agent/base.py +++ b/agent/base.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional +from config import Config from llm import LLMMessage, LLMResponse, StopReason, ToolResult +from llm.retry import is_context_length_error from memory import MemoryManager from tools.base import BaseTool from tools.todo import TodoTool @@ -52,10 +54,6 @@ def __init__( # Initialize memory manager (uses Config directly) self.memory = MemoryManager(llm) - # Set up todo context provider for memory compression - # This injects current todo state into summaries instead of preserving all todo messages - self.memory.set_todo_context_provider(self._get_todo_context) - @abstractmethod def run(self, task: str) -> str: """Execute the agent on a task and return final answer.""" @@ -84,6 +82,42 @@ async def _call_llm( messages=messages, tools=tools, max_tokens=4096, **kwargs ) + async def _call_with_overflow_recovery( + self, + tools: Optional[List] = None, + spinner_message: str = "Thinking...", + **kwargs, + ) -> LLMResponse: + """Call LLM with context overflow recovery.""" + last_error: Optional[BaseException] = None + max_retries = max(0, Config.CONTEXT_OVERFLOW_MAX_RETRIES) + + for attempt in range(max_retries + 1): + context = self.memory.get_context_for_llm() + try: + return await self._call_llm( + messages=context, + tools=tools, + spinner_message=spinner_message, + **kwargs, + ) + except Exception as e: # noqa: BLE001 + if not is_context_length_error(e): + raise + last_error = e + removed = self.memory.remove_oldest_with_pair_integrity() + if removed is None: + break + logger.warning( + "Context length exceeded; removed oldest message and retrying (%s/%s)", + attempt + 1, + max_retries + 1, + ) + + if last_error: + raise last_error + raise RuntimeError("Context overflow recovery failed without an error.") + def _extract_text(self, response: LLMResponse) -> str: """Extract text from LLM response. @@ -95,17 +129,6 @@ def _extract_text(self, response: LLMResponse) -> str: """ return self.llm.extract_text(response) - def _get_todo_context(self) -> Optional[str]: - """Get current todo list state for memory compression. - - Returns formatted todo list if items exist, None otherwise. - This is used by MemoryManager to inject todo state into summaries. - """ - items = self.todo_list.get_current() - if not items: - return None - return self.todo_list.format_list() - async def _react_loop( self, messages: List[LLMMessage], @@ -135,11 +158,19 @@ async def _react_loop( context = self.memory.get_context_for_llm() if use_memory else messages # Call LLM with tools - response = await self._call_llm( - messages=context, - tools=tools, - spinner_message="Analyzing request...", - ) + if use_memory: + response = await self._call_with_overflow_recovery( + tools=tools, + spinner_message="Analyzing request...", + ) + else: + normalized = self.memory.ensure_call_outputs_present(context) + normalized = self.memory.remove_orphan_outputs(normalized) + response = await self._call_llm( + messages=normalized, + tools=tools, + spinner_message="Analyzing request...", + ) # Save assistant response using response.to_message() for proper format assistant_msg = response.to_message() diff --git a/config.py b/config.py index cd8352d..b46b28d 100644 --- a/config.py +++ b/config.py @@ -83,6 +83,39 @@ class Config: MEMORY_SHORT_TERM_MIN_SIZE = int(os.getenv("MEMORY_SHORT_TERM_MIN_SIZE", "6")) MEMORY_COMPRESSION_RATIO = float(os.getenv("MEMORY_COMPRESSION_RATIO", "0.3")) MEMORY_PRESERVE_SYSTEM_PROMPTS = True + TOOL_OUTPUT_TRUNCATION_POLICY = os.getenv("TOOL_OUTPUT_TRUNCATION_POLICY", "tokens").lower() + TOOL_OUTPUT_MAX_TOKENS = int(os.getenv("TOOL_OUTPUT_MAX_TOKENS", "5000")) + APPROX_CHARS_PER_TOKEN = int(os.getenv("APPROX_CHARS_PER_TOKEN", "4")) + TOOL_OUTPUT_MAX_BYTES = int( + os.getenv("TOOL_OUTPUT_MAX_BYTES", str(TOOL_OUTPUT_MAX_TOKENS * APPROX_CHARS_PER_TOKEN)) + ) + TOOL_OUTPUT_SERIALIZATION_BUFFER = float(os.getenv("TOOL_OUTPUT_SERIALIZATION_BUFFER", "1.2")) + COMPACT_USER_MESSAGE_MAX_TOKENS = int(os.getenv("COMPACT_USER_MESSAGE_MAX_TOKENS", "20000")) + CONTEXT_OVERFLOW_MAX_RETRIES = int(os.getenv("CONTEXT_OVERFLOW_MAX_RETRIES", "3")) + PROTECTED_TOOLS = [ + name.strip() + for name in os.getenv("PROTECTED_TOOLS", "manage_todo_list").split(",") + if name.strip() + ] + COMPACT_SUMMARIZATION_PROMPT = os.getenv( + "COMPACT_SUMMARIZATION_PROMPT", + """You are performing a CONTEXT CHECKPOINT COMPACTION. +Create a handoff summary for another LLM that will resume the task. + +Include: +- Current progress and key decisions made +- Important context, constraints, or user preferences +- What remains to be done (clear next steps) +- Any critical data needed to continue + +Be concise and focused on helping the next LLM seamlessly continue.""", + ) + COMPACT_SUMMARY_PREFIX = os.getenv( + "COMPACT_SUMMARY_PREFIX", + """Another language model started this task and produced +a summary. Use this to build on existing work and avoid duplication: +""", + ) # Logging Configuration # Note: Logging is now controlled via --verbose flag diff --git a/docs/configuration.md b/docs/configuration.md index 3749f5b..dcb18bc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,6 +79,17 @@ MEMORY_ENABLED=true MEMORY_COMPRESSION_THRESHOLD=25000 MEMORY_SHORT_TERM_SIZE=100 MEMORY_COMPRESSION_RATIO=0.3 +MEMORY_SHORT_TERM_MIN_SIZE=6 +COMPACT_USER_MESSAGE_MAX_TOKENS=20000 +CONTEXT_OVERFLOW_MAX_RETRIES=3 +TOOL_OUTPUT_TRUNCATION_POLICY=tokens +TOOL_OUTPUT_MAX_TOKENS=5000 +TOOL_OUTPUT_MAX_BYTES=20000 +TOOL_OUTPUT_SERIALIZATION_BUFFER=1.2 +APPROX_CHARS_PER_TOKEN=4 +PROTECTED_TOOLS=manage_todo_list +COMPACT_SUMMARIZATION_PROMPT="You are performing a CONTEXT CHECKPOINT COMPACTION..." +COMPACT_SUMMARY_PREFIX="Another language model started this task and produced..." ``` ## Retry Configuration diff --git a/docs/memory-management.md b/docs/memory-management.md index 00d3b91..c4b8218 100644 --- a/docs/memory-management.md +++ b/docs/memory-management.md @@ -26,6 +26,21 @@ The memory system addresses this by: - **Configurable**: Multiple strategies and settings - **Multi-Provider**: Works with Anthropic, OpenAI, Gemini +## Context Compaction Enhancements + +The memory system includes additional safeguards for large outputs and long-running sessions: + +- **Write-time tool output truncation**: Large tool outputs are truncated before being stored to protect context. + - Controls: `TOOL_OUTPUT_TRUNCATION_POLICY`, `TOOL_OUTPUT_MAX_TOKENS`, `TOOL_OUTPUT_MAX_BYTES`, + `TOOL_OUTPUT_SERIALIZATION_BUFFER`, `APPROX_CHARS_PER_TOKEN` +- **Context overflow recovery**: On `context_length_exceeded` errors, the agent removes the oldest messages + (maintaining tool call/result integrity) and retries. + - Control: `CONTEXT_OVERFLOW_MAX_RETRIES` +- **User message preservation**: Compaction keeps recent user messages up to a configurable token budget. + - Control: `COMPACT_USER_MESSAGE_MAX_TOKENS` +- **Protected tools**: Tool results like `manage_todo_list` are preserved during compaction. + - Control: `PROTECTED_TOOLS` + ## Quick Start ### 1. Enable Memory Management diff --git a/interactive.py b/interactive.py index fa96ec4..04b5913 100644 --- a/interactive.py +++ b/interactive.py @@ -43,6 +43,7 @@ def __init__(self, agent): "theme", "verbose", "compact", + "compact-output", "exit", "quit", ], @@ -101,7 +102,10 @@ def _show_help(self) -> None: f" [{colors.primary}]/verbose[/{colors.primary}] - Toggle verbose thinking display" ) terminal_ui.console.print( - f" [{colors.primary}]/compact[/{colors.primary}] - Toggle compact output mode" + f" [{colors.primary}]/compact[/{colors.primary}] - Compact memory now" + ) + terminal_ui.console.print( + f" [{colors.primary}]/compact-output[/{colors.primary}] - Toggle compact output mode" ) terminal_ui.console.print( f" [{colors.primary}]/exit[/{colors.primary}] - Exit interactive mode" @@ -264,6 +268,19 @@ def _toggle_compact(self) -> None: status = "enabled" if self.compact_mode else "disabled" terminal_ui.print_info(f"Compact mode {status}") + async def _compact_memory(self) -> None: + """Manually compact memory and report savings.""" + terminal_ui.print_info("Compacting memory...") + compressed = await self.agent.memory.compress() + if not compressed: + terminal_ui.print_info("No messages to compact.") + return + + terminal_ui.print_success( + f"Compaction complete: {compressed.original_tokens} β†’ {compressed.compressed_tokens} tokens " + f"({compressed.savings_percentage:.1f}% saved)" + ) + def _update_status_bar(self) -> None: """Update status bar with current stats.""" stats = self.agent.memory.get_stats() @@ -326,6 +343,9 @@ async def _handle_command(self, user_input: str) -> bool: self._toggle_verbose() elif command == "/compact": + await self._compact_memory() + + elif command == "/compact-output": self._toggle_compact() else: diff --git a/llm/retry.py b/llm/retry.py index 5252c66..15f4c59 100644 --- a/llm/retry.py +++ b/llm/retry.py @@ -26,6 +26,30 @@ def is_rate_limit_error(error: BaseException) -> bool: return any(indicator in error_str for indicator in rate_limit_indicators) +def is_context_length_error(error: BaseException) -> bool: + """Check if an error is a context length overflow error.""" + error_str = str(error).lower() + error_type = type(error).__name__ + + indicators = [ + "context_length_exceeded", + "context length", + "maximum context", + "max context", + "prompt is too long", + "input is too long", + "too many tokens", + "token limit", + "max_tokens", + "maximum tokens", + ] + + if "ContextLengthExceeded" in error_type or "TokenLimit" in error_type: + return True + + return any(indicator in error_str for indicator in indicators) + + def is_retryable_error(error: BaseException) -> bool: """Check if an error is retryable.""" if isinstance(error, asyncio.CancelledError): diff --git a/memory/compressor.py b/memory/compressor.py index b20498e..a309559 100644 --- a/memory/compressor.py +++ b/memory/compressor.py @@ -19,25 +19,23 @@ class WorkingMemoryCompressor: """Compresses conversation history using LLM summarization.""" # Tools that should NEVER be compressed - their state must be preserved - # Note: manage_todo_list is NOT protected because its state is managed externally - # by TodoList object. Instead, we inject current todo state into the summary. - PROTECTED_TOOLS: set[str] = set() + PROTECTED_TOOLS = {"manage_todo_list"} # Prefix for summary messages to identify them - SUMMARY_PREFIX = "[Previous conversation summary]\n" + SUMMARY_PREFIX = Config.COMPACT_SUMMARY_PREFIX + LEGACY_SUMMARY_PREFIX = "[Previous conversation summary]\n" + TURN_ABORTED_MARKER = "" - COMPRESSION_PROMPT = """You are a memory compression system. Summarize the following conversation messages while preserving: -1. Key decisions and outcomes -2. Important facts, data, and findings -3. Tool usage patterns and results -4. User intent and goals -5. Critical context needed for future interactions + COMPRESSION_PROMPT = ( + Config.COMPACT_SUMMARIZATION_PROMPT + + """ Original messages ({count} messages, ~{tokens} tokens): {messages} - Provide a concise but comprehensive summary that captures the essential information. Be specific and include concrete details. Target length: {target_tokens} tokens.""" +Provide a concise but comprehensive summary that captures the essential information. Be specific and include concrete details. Target length: {target_tokens} tokens.""" + ) def __init__(self, llm: "LiteLLMAdapter"): """Initialize compressor. @@ -46,13 +44,13 @@ def __init__(self, llm: "LiteLLMAdapter"): llm: LLM instance to use for summarization """ self.llm = llm + self.PROTECTED_TOOLS = set(Config.PROTECTED_TOOLS) async def compress( self, messages: List[LLMMessage], strategy: str = CompressionStrategy.SLIDING_WINDOW, target_tokens: Optional[int] = None, - todo_context: Optional[str] = None, ) -> CompressedMemory: """Compress messages using specified strategy. @@ -60,7 +58,6 @@ async def compress( messages: List of messages to compress strategy: Compression strategy to use target_tokens: Target token count for compressed output - todo_context: Optional current todo list state to inject into summary Returns: CompressedMemory object @@ -75,36 +72,35 @@ async def compress( # Select and apply compression strategy if strategy == CompressionStrategy.SLIDING_WINDOW: - return await self._compress_sliding_window(messages, target_tokens, todo_context) + return await self._compress_sliding_window(messages, target_tokens) elif strategy == CompressionStrategy.SELECTIVE: - return await self._compress_selective(messages, target_tokens, todo_context) + return await self._compress_selective(messages, target_tokens) elif strategy == CompressionStrategy.DELETION: return self._compress_deletion(messages) else: logger.warning(f"Unknown strategy {strategy}, using sliding window") - return await self._compress_sliding_window(messages, target_tokens, todo_context) + return await self._compress_sliding_window(messages, target_tokens) async def _compress_sliding_window( self, messages: List[LLMMessage], target_tokens: int, - todo_context: Optional[str] = None, ) -> CompressedMemory: """Compress using sliding window strategy. - Summarizes all messages into a single summary. If todo_context is provided, - it will be appended to the summary to preserve current task state. + Summarizes all messages into a single summary. Args: messages: Messages to compress target_tokens: Target token count - todo_context: Optional current todo list state to inject Returns: CompressedMemory object """ # Format messages for summarization - formatted = self._format_messages_for_summary(messages) + formatted = self._format_messages_for_summary( + [msg for msg in messages if not self.is_summary_message(msg)] + ) original_tokens = self._estimate_tokens(messages) # Create summarization prompt @@ -115,28 +111,22 @@ async def _compress_sliding_window( target_tokens=target_tokens, ) - # Extract system messages to preserve them - system_msgs = [m for m in messages if m.role == "system"] - # Call LLM to generate summary try: prompt = LLMMessage(role="user", content=prompt_text) response = await self.llm.call_async(messages=[prompt], max_tokens=target_tokens * 2) summary_text = self.llm.extract_text(response) - # Append todo context if available - if todo_context: - summary_text = f"{summary_text}\n\n[Current Tasks]\n{todo_context}" - - # Convert summary to a user message - summary_message = LLMMessage( - role="user", - content=f"{self.SUMMARY_PREFIX}{summary_text}", + result_messages = self.build_compacted_history( + initial_context=[m for m in messages if m.role == "system"], + user_messages=self.select_user_messages( + self.collect_user_messages(messages), Config.COMPACT_USER_MESSAGE_MAX_TOKENS + ), + summary_text=summary_text, + protected_messages=self.collect_protected_messages(messages), + orphaned_tool_calls=self.collect_orphaned_tool_calls(messages), ) - # System messages first, then summary - result_messages = system_msgs + [summary_message] - # Calculate compression metrics compressed_tokens = self._estimate_tokens(result_messages) compression_ratio = compressed_tokens / original_tokens if original_tokens > 0 else 0 @@ -154,7 +144,7 @@ async def _compress_sliding_window( # Fallback: keep system messages + first and last non-system message non_system = [m for m in messages if m.role != "system"] fallback_other = [non_system[0], non_system[-1]] if len(non_system) > 1 else non_system - fallback_messages = system_msgs + fallback_other + fallback_messages = [m for m in messages if m.role == "system"] + fallback_other return CompressedMemory( messages=fallback_messages, original_message_count=len(messages), @@ -168,48 +158,68 @@ async def _compress_selective( self, messages: List[LLMMessage], target_tokens: int, - todo_context: Optional[str] = None, ) -> CompressedMemory: """Compress using selective preservation strategy. Preserves important messages (tool calls, system prompts) and - summarizes the rest. If todo_context is provided, it will be - appended to the summary to preserve current task state. + summarizes the rest. Args: messages: Messages to compress target_tokens: Target token count - todo_context: Optional current todo list state to inject Returns: CompressedMemory object """ # Separate preserved vs compressible messages preserved, to_compress = self._separate_messages(messages) + original_tokens = self._estimate_tokens(messages) + + # Pre-collect all components that will be in final result per RFC structure: + # system + user + summary + protected + orphaned + # This ensures budget calculation matches actual output structure + system_msgs = [m for m in preserved if m.role == "system"] + user_messages = self.select_user_messages( + self.collect_user_messages(messages), + Config.COMPACT_USER_MESSAGE_MAX_TOKENS, + ) + protected_messages = self.collect_protected_messages(messages) + orphaned_tool_calls = self.collect_orphaned_tool_calls(messages) if not to_compress: - # Nothing to compress, just return preserved messages - # Ensure system messages are first - system_msgs = [m for m in preserved if m.role == "system"] - other_msgs = [m for m in preserved if m.role != "system"] - result_messages = system_msgs + other_msgs + # Nothing to compress, use RFC structure without summary + result_messages = self.build_compacted_history( + initial_context=system_msgs, + user_messages=user_messages, + summary_text="", + protected_messages=protected_messages, + orphaned_tool_calls=orphaned_tool_calls, + ) + compressed_tokens = self._estimate_tokens(result_messages) + compression_ratio = compressed_tokens / original_tokens if original_tokens > 0 else 1.0 return CompressedMemory( messages=result_messages, original_message_count=len(messages), - compressed_tokens=self._estimate_tokens(result_messages), - original_tokens=self._estimate_tokens(messages), - compression_ratio=1.0, + compressed_tokens=compressed_tokens, + original_tokens=original_tokens, + compression_ratio=compression_ratio, metadata={"strategy": "selective"}, ) - # Compress the compressible messages - original_tokens = self._estimate_tokens(messages) - preserved_tokens = self._estimate_tokens(preserved) - available_for_summary = target_tokens - preserved_tokens + # Calculate budget based on ACTUAL preserved components (not 'preserved' list + # which includes recent assistant messages that won't be in final output) + actual_preserved_tokens = ( + self._estimate_tokens(system_msgs) + + self._estimate_tokens(user_messages) + + self._estimate_tokens(protected_messages) + + self._estimate_tokens(orphaned_tool_calls) + ) + available_for_summary = target_tokens - actual_preserved_tokens if available_for_summary > 0: - # Generate summary for compressible messages - formatted = self._format_messages_for_summary(to_compress) + formatted = self._format_messages_for_summary( + [msg for msg in to_compress if not self.is_summary_message(msg)] + ) prompt_text = self.COMPRESSION_PROMPT.format( count=len(to_compress), tokens=self._estimate_tokens(to_compress), @@ -224,22 +234,16 @@ async def _compress_selective( ) summary_text = self.llm.extract_text(response) - # Append todo context if available - if todo_context: - summary_text = f"{summary_text}\n\n[Current Tasks]\n{todo_context}" - - # Convert summary to user message - summary_message = LLMMessage( - role="user", - content=f"{self.SUMMARY_PREFIX}{summary_text}", + result_messages = self.build_compacted_history( + initial_context=system_msgs, + user_messages=user_messages, + summary_text=summary_text, + protected_messages=protected_messages, + orphaned_tool_calls=orphaned_tool_calls, ) - # Ensure system messages come first, then summary, then other preserved - system_msgs = [m for m in preserved if m.role == "system"] - other_msgs = [m for m in preserved if m.role != "system"] - result_messages = system_msgs + [summary_message] + other_msgs - summary_tokens = self._estimate_tokens([summary_message]) - compressed_tokens = preserved_tokens + summary_tokens + # Calculate metrics based on actual result + compressed_tokens = self._estimate_tokens(result_messages) compression_ratio = ( compressed_tokens / original_tokens if original_tokens > 0 else 0 ) @@ -255,17 +259,22 @@ async def _compress_selective( except Exception as e: logger.error(f"Error during selective compression: {e}") - # Fallback: just preserve the important messages (no summary) - # Ensure system messages are first - system_msgs = [m for m in preserved if m.role == "system"] - other_msgs = [m for m in preserved if m.role != "system"] - result_messages = system_msgs + other_msgs + # Fallback: use RFC structure without summary (not raw 'preserved' list) + # This ensures consistent structure even when summary budget is exhausted + result_messages = self.build_compacted_history( + initial_context=system_msgs, + user_messages=user_messages, + summary_text="", + protected_messages=protected_messages, + orphaned_tool_calls=orphaned_tool_calls, + ) + compressed_tokens = self._estimate_tokens(result_messages) return CompressedMemory( messages=result_messages, original_message_count=len(messages), - compressed_tokens=preserved_tokens, + compressed_tokens=compressed_tokens, original_tokens=original_tokens, - compression_ratio=preserved_tokens / original_tokens if original_tokens > 0 else 1.0, + compression_ratio=compressed_tokens / original_tokens if original_tokens > 0 else 1.0, metadata={"strategy": "selective", "preserved_count": len(preserved)}, ) @@ -496,6 +505,111 @@ def _format_messages_for_summary(self, messages: List[LLMMessage]) -> str: return "\n\n".join(formatted) + def collect_user_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Collect user messages while excluding previous summaries and legacy tool_results.""" + return [ + msg + for msg in messages + if msg.role == "user" + and not self.is_summary_message(msg) + and not self.is_legacy_tool_result(msg) + ] + + def select_user_messages(self, messages: List[LLMMessage], max_tokens: int) -> List[LLMMessage]: + """Select user messages, prioritizing recent ones within a token budget.""" + if max_tokens <= 0: + return [] + + selected: List[LLMMessage] = [] + budget = max_tokens + + for msg in reversed(messages): + if self.is_turn_aborted_message(msg): + selected.append(msg) + continue + tokens = self._estimate_tokens([msg]) + if tokens <= budget: + selected.append(msg) + budget -= tokens + + return list(reversed(selected)) + + def collect_protected_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Collect protected tool call/output pairs to preserve.""" + tool_pairs, _ = self._find_tool_pairs(messages) + protected_pairs = self._find_protected_tool_pairs(messages, tool_pairs) + protected_indices = {idx for pair in protected_pairs for idx in pair} + return [msg for idx, msg in enumerate(messages) if idx in protected_indices] + + def collect_orphaned_tool_calls(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Collect orphaned tool calls (calls without matching results). + + These must be preserved in compaction to avoid losing pending tool calls. + Note: Uses set to deduplicate indices when same assistant has multiple orphan calls. + """ + _, orphaned_indices = self._find_tool_pairs(messages) + unique_indices = set(orphaned_indices) + return [msg for idx, msg in enumerate(messages) if idx in unique_indices] + + def build_compacted_history( + self, + initial_context: List[LLMMessage], + user_messages: List[LLMMessage], + summary_text: str, + protected_messages: List[LLMMessage], + orphaned_tool_calls: Optional[List[LLMMessage]] = None, + ) -> List[LLMMessage]: + """Build new history after compaction. + + Order: initial_context + user_messages + summary + protected + orphaned + Orphaned tool calls go at the end since they're waiting for results. + """ + result = initial_context + user_messages + result.append( + LLMMessage( + role="user", + content=f"{self.SUMMARY_PREFIX}{summary_text}", + ) + ) + result.extend(protected_messages) + if orphaned_tool_calls: + result.extend(orphaned_tool_calls) + return result + + def is_summary_message(self, message: LLMMessage) -> bool: + """Check if message is a previous summary.""" + if message.role != "user": + return False + if not isinstance(message.content, str): + return False + return message.content.startswith(self.SUMMARY_PREFIX) or message.content.startswith( + self.LEGACY_SUMMARY_PREFIX + ) + + def is_turn_aborted_message(self, message: LLMMessage) -> bool: + """Check if message contains a turn-aborted marker.""" + if message.role != "user": + return False + if not isinstance(message.content, str): + return False + return self.TURN_ABORTED_MARKER in message.content + + def is_legacy_tool_result(self, message: LLMMessage) -> bool: + """Check if message is a legacy tool_result (Anthropic format). + + Legacy tool_result messages have role='user' but content is a list + containing tool_result blocks. + """ + if message.role != "user": + return False + if not isinstance(message.content, list): + return False + return any( + (isinstance(block, dict) and block.get("type") == "tool_result") + or (hasattr(block, "type") and block.type == "tool_result") + for block in message.content + ) + def _extract_text_content(self, message: LLMMessage) -> str: """Extract text content from message for token estimation. diff --git a/memory/manager.py b/memory/manager.py index 804f26c..f2be96c 100644 --- a/memory/manager.py +++ b/memory/manager.py @@ -1,7 +1,7 @@ """Core memory manager that orchestrates all memory operations.""" import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from config import Config from llm.content_utils import content_has_tool_calls @@ -12,6 +12,7 @@ from .short_term import ShortTermMemory from .store import MemoryStore from .token_tracker import TokenTracker +from .truncate import truncate_tool_output from .types import CompressedMemory, CompressionStrategy logger = logging.getLogger(__name__) @@ -69,10 +70,6 @@ def __init__( self.last_compression_savings = 0 self.compression_count = 0 - # Optional callback to get current todo context for compression - # This allows injecting todo state into summaries without coupling to TodoList - self._todo_context_provider: Optional[Callable[[], Optional[str]]] = None - @classmethod async def from_session( cls, @@ -157,6 +154,9 @@ async def add_message(self, message: LLMMessage, actual_tokens: Dict[str, int] = self.system_messages.append(message) return + # Truncate large tool outputs before storing + message = self._maybe_truncate_tool_output(message) + # Count tokens (use actual if provided, otherwise estimate) if actual_tokens: # Use actual token counts from LLM response @@ -226,19 +226,7 @@ def get_context_for_llm(self) -> List[LLMMessage]: # 2. Add short-term memory (includes summary messages and recent messages) context.extend(self.short_term.get_messages()) - return context - - def set_todo_context_provider(self, provider: Callable[[], Optional[str]]) -> None: - """Set a callback to provide current todo context for compression. - - The provider should return a formatted string of current todo items, - or None if no todos exist. This context will be injected into - compression summaries to preserve task state. - - Args: - provider: Callable that returns current todo context string or None - """ - self._todo_context_provider = provider + return self._normalize_for_prompt(context) async def compress(self, strategy: str = None) -> Optional[CompressedMemory]: """Compress current short-term memory. @@ -266,17 +254,12 @@ async def compress(self, strategy: str = None) -> Optional[CompressedMemory]: logger.info(f"πŸ—œοΈ Compressing {message_count} messages using {strategy} strategy") try: - # Get todo context if provider is set - todo_context = None - if self._todo_context_provider: - todo_context = self._todo_context_provider() - # Perform compression + # Note: todo state is preserved via PROTECTED_TOOLS (manage_todo_list) compressed = await self.compressor.compress( messages, strategy=strategy, target_tokens=self._calculate_target_tokens(), - todo_context=todo_context, ) # Track compression results @@ -394,6 +377,212 @@ def _message_has_tool_calls(self, message: LLMMessage) -> bool: # Legacy/centralized check on content return content_has_tool_calls(message.content) + def _maybe_truncate_tool_output(self, message: LLMMessage) -> LLMMessage: + """Truncate tool output content if it exceeds configured limits.""" + if message.role != "tool": + return message + if not isinstance(message.content, str): + return message + + result = truncate_tool_output( + content=message.content, + policy=Config.TOOL_OUTPUT_TRUNCATION_POLICY, + max_tokens=Config.TOOL_OUTPUT_MAX_TOKENS, + max_bytes=Config.TOOL_OUTPUT_MAX_BYTES, + serialization_buffer=Config.TOOL_OUTPUT_SERIALIZATION_BUFFER, + approx_chars_per_token=Config.APPROX_CHARS_PER_TOKEN, + ) + + if not result.truncated: + return message + + return LLMMessage( + role=message.role, + content=result.content, + tool_calls=message.tool_calls, + tool_call_id=message.tool_call_id, + name=message.name, + ) + + def remove_oldest_with_pair_integrity(self) -> Optional[LLMMessage]: + """Remove oldest message and its corresponding tool pair (if any).""" + messages = self.short_term.get_messages() + if not messages: + return None + + oldest = messages[0] + call_ids = set(self._extract_tool_call_ids(oldest)) + + if not call_ids and oldest.role == "tool" and oldest.tool_call_id: + call_ids.add(oldest.tool_call_id) + + if not call_ids and self._has_legacy_tool_results(oldest): + call_ids.update(self._extract_tool_result_ids(oldest)) + + if call_ids: + filtered = self._remove_messages_by_tool_call_ids(messages, call_ids) + if filtered and filtered[0] == oldest: + filtered = filtered[1:] + else: + filtered = messages[1:] + + self.short_term.clear() + for msg in filtered: + self.short_term.add_message(msg) + + self.current_tokens = self._recalculate_current_tokens() + + return oldest + + def _remove_messages_by_tool_call_ids( + self, messages: List[LLMMessage], call_ids: set[str] + ) -> List[LLMMessage]: + """Remove all tool calls/results matching given call IDs.""" + filtered: List[LLMMessage] = [] + for msg in messages: + if msg.role == "assistant" and self._message_has_call_id(msg, call_ids): + continue + if msg.role == "tool" and msg.tool_call_id in call_ids: + continue + if self._has_legacy_tool_results(msg) and self._message_has_result_id(msg, call_ids): + continue + filtered.append(msg) + return filtered + + def _normalize_for_prompt(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Normalize messages to ensure tool call/output integrity before LLM call.""" + normalized = self.ensure_call_outputs_present(messages) + normalized = self.remove_orphan_outputs(normalized) + return normalized + + def ensure_call_outputs_present(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Add synthetic 'aborted' output for orphaned tool calls.""" + existing_outputs = set() + for msg in messages: + if msg.role == "tool" and msg.tool_call_id: + existing_outputs.add(msg.tool_call_id) + if self._has_legacy_tool_results(msg): + existing_outputs.update(self._extract_tool_result_ids(msg)) + + normalized: List[LLMMessage] = [] + for msg in messages: + normalized.append(msg) + for call_id, tool_name in self._extract_tool_call_id_pairs(msg): + if call_id in existing_outputs: + continue + normalized.append( + LLMMessage( + role="tool", + content="aborted", + tool_call_id=call_id, + name=tool_name or None, + ) + ) + existing_outputs.add(call_id) + + return normalized + + def remove_orphan_outputs(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Remove tool results without matching calls.""" + call_ids = set() + for msg in messages: + call_ids.update(self._extract_tool_call_ids(msg)) + + filtered: List[LLMMessage] = [] + for msg in messages: + if msg.role == "tool" and msg.tool_call_id and msg.tool_call_id not in call_ids: + continue + if self._has_legacy_tool_results(msg): + assert isinstance(msg.content, list) # for type checker + filtered_blocks = [ + block + for block in msg.content + if not self._is_orphan_tool_result_block(block, call_ids) + ] + if not filtered_blocks: + continue + # Note: content here is a list, which LLMMessage accepts via Any + filtered.append( + LLMMessage( + role=msg.role, + content=filtered_blocks, # type: ignore[arg-type] + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id, + name=msg.name, + ) + ) + continue + filtered.append(msg) + + return filtered + + def _extract_tool_call_id_pairs(self, message: LLMMessage) -> List[tuple[str, str]]: + pairs: List[tuple[str, str]] = [] + if message.role != "assistant": + return pairs + + if message.tool_calls: + for tc in message.tool_calls: + if isinstance(tc, dict): + call_id = tc.get("id") + tool_name = tc.get("function", {}).get("name", "") + else: + call_id = getattr(tc, "id", None) + tool_name = getattr(getattr(tc, "function", None), "name", "") if tc else "" + if call_id: + pairs.append((call_id, tool_name)) + return pairs + + if isinstance(message.content, list): + for block in message.content: + if isinstance(block, dict) and block.get("type") == "tool_use": + call_id = block.get("id") + tool_name = block.get("name", "") + elif hasattr(block, "type") and block.type == "tool_use": + call_id = getattr(block, "id", None) + tool_name = getattr(block, "name", "") + else: + continue + if call_id: + pairs.append((call_id, tool_name)) + + return pairs + + def _extract_tool_call_ids(self, message: LLMMessage) -> List[str]: + return [call_id for call_id, _ in self._extract_tool_call_id_pairs(message)] + + def _has_legacy_tool_results(self, message: LLMMessage) -> bool: + return message.role == "user" and isinstance(message.content, list) + + def _extract_tool_result_ids(self, message: LLMMessage) -> List[str]: + ids: List[str] = [] + if not self._has_legacy_tool_results(message): + return ids + assert isinstance(message.content, list) # for type checker + for block in message.content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + elif hasattr(block, "type") and block.type == "tool_result": + tool_use_id = getattr(block, "tool_use_id", None) + else: + continue + if tool_use_id: + ids.append(tool_use_id) + return ids + + def _message_has_call_id(self, message: LLMMessage, call_ids: set[str]) -> bool: + return any(call_id in call_ids for call_id in self._extract_tool_call_ids(message)) + + def _message_has_result_id(self, message: LLMMessage, call_ids: set[str]) -> bool: + return any(result_id in call_ids for result_id in self._extract_tool_result_ids(message)) + + def _is_orphan_tool_result_block(self, block, call_ids: set[str]) -> bool: + if isinstance(block, dict) and block.get("type") == "tool_result": + return block.get("tool_use_id") not in call_ids + if hasattr(block, "type") and block.type == "tool_result": + return getattr(block, "tool_use_id", None) not in call_ids + return False + def _calculate_target_tokens(self) -> int: """Calculate target token count for compression. diff --git a/memory/store.py b/memory/store.py index e141530..6feb092 100644 --- a/memory/store.py +++ b/memory/store.py @@ -11,6 +11,7 @@ import aiofiles.os import aiosqlite +from config import Config from llm.message_types import LLMMessage from memory.types import CompressedMemory from utils.runtime import get_db_path @@ -373,7 +374,7 @@ async def load_session(self, session_id: str) -> Optional[Dict[str, Any]]: msgs.append( LLMMessage( role="user", - content=f"[Previous conversation summary]\n{summary_data['summary']}", + content=f"{Config.COMPACT_SUMMARY_PREFIX}{summary_data['summary']}", ) ) msgs.extend( diff --git a/memory/truncate.py b/memory/truncate.py new file mode 100644 index 0000000..c04f1cf --- /dev/null +++ b/memory/truncate.py @@ -0,0 +1,139 @@ +"""Utilities for truncating large tool outputs before storing in memory.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TruncationResult: + """Result of a truncation attempt.""" + + content: str + truncated: bool + + +def truncate_tool_output( + content: str, + policy: str, + max_tokens: int, + max_bytes: int, + serialization_buffer: float, + approx_chars_per_token: int, +) -> TruncationResult: + """Truncate tool output according to policy. + + Args: + content: Tool output content. + policy: "none", "tokens", or "bytes". + max_tokens: Token limit for truncation (policy="tokens"). + max_bytes: Byte limit for truncation (policy="bytes"). + serialization_buffer: Multiplier to account for JSON overhead. + approx_chars_per_token: Approximate chars per token for estimation. + """ + if not content or policy == "none": + return TruncationResult(content=content, truncated=False) + + policy = policy.lower() + if policy == "tokens": + token_budget = max(0, _apply_buffer(max_tokens, serialization_buffer)) + token_count = _estimate_tokens(content, approx_chars_per_token) + if token_count <= token_budget: + return TruncationResult(content=content, truncated=False) + char_budget = max(0, token_budget * max(1, approx_chars_per_token)) + truncated_content = _truncate_with_split( + content, + char_budget, + removed_units=max(0, token_count - token_budget), + unit_label="tokens", + ) + return TruncationResult(content=truncated_content, truncated=True) + + if policy == "bytes": + byte_budget = max(0, _apply_buffer(max_bytes, serialization_buffer)) + content_bytes = content.encode("utf-8") + if len(content_bytes) <= byte_budget: + return TruncationResult(content=content, truncated=False) + truncated_content = _truncate_with_byte_split(content, content_bytes, byte_budget) + return TruncationResult(content=truncated_content, truncated=True) + + return TruncationResult(content=content, truncated=False) + + +def _estimate_tokens(content: str, approx_chars_per_token: int) -> int: + if not content: + return 0 + divisor = max(1, approx_chars_per_token) + return math.ceil(len(content) / divisor) + + +def _apply_buffer(value: int, buffer: float) -> int: + if value <= 0: + return 0 + return int(math.ceil(value * max(0.0, buffer))) + + +def _truncate_with_split( + content: str, + max_units: int, + removed_units: int, + unit_label: str, +) -> str: + if max_units <= 0: + return _format_marker(max(0, removed_units), unit_label) + + if len(content) <= max_units: + return content + + left_budget = max_units // 2 + right_budget = max_units - left_budget + + left = content[:left_budget] + right = content[-right_budget:] if right_budget > 0 else "" + marker = _format_marker(max(0, removed_units), unit_label) + truncated = f"{left}{marker}{right}" + + total_lines = content.count("\n") + 1 if content else 0 + if total_lines > 1: + return f"Total output lines: {total_lines}\n\n{truncated}" + + return truncated + + +def _truncate_with_byte_split(content: str, content_bytes: bytes, max_bytes: int) -> str: + if max_bytes <= 0: + return _format_marker(len(content), "chars") + + if len(content_bytes) <= max_bytes: + return content + + left_budget = max_bytes // 2 + right_budget = max_bytes - left_budget + + left_bytes = content_bytes[:left_budget] + right_bytes = content_bytes[-right_budget:] if right_budget > 0 else b"" + + left = _decode_prefix(left_bytes) + right = _decode_suffix(right_bytes) + removed_chars = max(0, len(content) - len(left) - len(right)) + marker = _format_marker(removed_chars, "chars") + truncated = f"{left}{marker}{right}" + + total_lines = content.count("\n") + 1 if content else 0 + if total_lines > 1: + return f"Total output lines: {total_lines}\n\n{truncated}" + + return truncated + + +def _decode_prefix(data: bytes) -> str: + return data.decode("utf-8", errors="ignore") + + +def _decode_suffix(data: bytes) -> str: + return data.decode("utf-8", errors="ignore") + + +def _format_marker(removed_units: int, unit_label: str) -> str: + return f"...{removed_units} {unit_label} truncated..." diff --git a/rfc/005-context-compact-system.md b/rfc/005-context-compact-system.md new file mode 100644 index 0000000..e90f5f7 --- /dev/null +++ b/rfc/005-context-compact-system.md @@ -0,0 +1,350 @@ +# RFC 005: Context Compact System (Intelligent History Management) + +- **Status**: Draft +- **Created**: 2026-01-28 +- **Author**: AgenticLoop Team + +## Abstract + +This RFC proposes enhancements to AgenticLoop's memory compression system, inspired by production-proven patterns from OpenAI's Codex agent. The goal is to provide **robust context management** through write-time truncation, automatic overflow recovery, and improved history compaction. + +## Motivation + +Current AgenticLoop memory system lacks critical features for production use: + +1. **Tool outputs can explode context**: A single `read_file` or `shell` command can return 100KB+ output +2. **No automatic recovery from context overflow**: Sessions fail on `context_length_exceeded` errors +3. **Tool pair integrity is fragile**: Removing messages can orphan tool calls/results, causing API errors +4. **No proactive truncation**: Large outputs stored verbatim until compression triggers + +## Goals + +- **Write-time truncation**: Truncate large tool outputs when added to history +- **Context overflow recovery**: Auto-recover from `context_length_exceeded` by removing oldest messages +- **Tool pair integrity**: Maintain call-output pairs when removing messages +- **User message preservation**: Keep original user messages during compaction +- **Backward compatibility**: Existing behavior continues to work + +## Non-Goals + +- **Remote compact API** β€” OpenAI-specific, we use LiteLLM for multi-provider support +- **Ghost snapshot / undo** β€” Future RFC consideration +- **Changing 4-role message model** β€” Already sufficient for our needs + +## Key Design Decisions + +### Decision 1: 4-Role Message Model is Sufficient + +AgenticLoop uses `LLMMessage` with 4 roles (`system`, `user`, `assistant`, `tool`), while Codex has 11+ `ResponseItem` variants. Analysis shows our model is sufficient: + +| Operation | How It Works | +|-----------|--------------| +| Truncation | Targets `role="tool"` messages only | +| Compact preservation | Identifies `role="user"` messages | +| Tool pair matching | Uses `tool_call_id` field matching | + +### Decision 2: Always Inline Compact (No Remote) + +Codex supports remote compact via OpenAI's `/responses/compact` endpoint. We choose inline-only because: +- Works with any LLM provider (via LiteLLM) +- More control over summarization +- No vendor lock-in + +### Decision 3: 50/50 Truncation Split with Serialization Buffer + +Like Codex, preserve both **beginning** and **end** of truncated content: +- 50% from start (context/setup) +- 50% from end (results/conclusions) +- **20% serialization buffer**: Apply `max_tokens * 1.2` to account for JSON overhead +- Marker format: `…N tokens truncated…` or `…N chars truncated…` +- Optional line count header: `Total output lines: X` + +> Note: Codex uses `budget / 2` for the split. We can adjust to 60/40 if testing shows better results. + +### Decision 4: Prioritize Recent User Messages + +During compact, user messages are preserved with recent-first priority (max 20K tokens total). + +### Decision 5: Two-Phase Normalization Before LLM Call + +Like Codex's `normalize.rs`, ensure tool pair integrity before sending to model: + +1. **ensure_call_outputs_present()**: Add synthetic `"aborted"` output for orphaned tool calls +2. **remove_orphan_outputs()**: Remove tool results without matching calls + +This prevents API errors from malformed conversation history. + +### Decision 6: Normalize on Every LLM Call + +Normalization should run in `for_prompt()` (before every LLM call), not just during compression: +1. Call `ensure_call_outputs_present()` to add synthetic outputs +2. Call `remove_orphan_outputs()` to clean up orphans +3. Filter out internal items (e.g., ghost snapshots) + +### Decision 7: Protected Tools Never Compressed + +Certain tool results must survive compaction (e.g., `manage_todo_list` for task tracking). These are identified by tool name and preserved in rebuilt history. + +### Decision 8: Preserve Turn Aborted Markers + +When collecting user messages for compact, also preserve `` markers that indicate interrupted turns. This maintains context about what was attempted but not completed. + +## AgenticLoop Existing Strengths (Keep) + +| Feature | Notes | +|---------|-------| +| **Multiple compression strategies** | `deletion`, `sliding_window`, `selective` | +| **Provider-specific token counting** | tiktoken for OpenAI (more accurate than Codex's ~4 chars/token) | +| **Tool pair detection** | `_find_tool_pairs()` handles both OpenAI and Anthropic formats | +| **Orphaned tool handling** | `orphaned_tool_use_indices` preserved during compression | +| **Configurable thresholds** | `MEMORY_COMPRESSION_THRESHOLD`, `MEMORY_SHORT_TERM_SIZE` | + +## Gap Analysis Summary + +| Feature | Codex | AgenticLoop Current | Proposed | +|---------|-------|---------------------|----------| +| Write-time truncation | βœ… `process_item()` | ❌ None | βœ… Phase 1 | +| Truncation policy config | βœ… Bytes/Tokens | ❌ None | βœ… Phase 1 | +| Context overflow recovery | βœ… Auto-retry | ❌ Fails | βœ… Phase 2 | +| Pair integrity on removal | βœ… `remove_corresponding_for()` | ❌ None | βœ… Phase 3 | +| User message truncation | βœ… 20K limit | ❌ None | βœ… Phase 4 | +| History rebuild | βœ… Structured | ⚠️ Strategy-dependent | βœ… Phase 5 | +| Protected tools | βœ… Yes | βœ… `manage_todo_list` | βœ… Keep | +| Tool pair detection | βœ… `call_id` | βœ… `_find_tool_pairs()` | βœ… Keep | +| Orphan output handling | βœ… `remove_orphan_outputs()` | ❌ None | βœ… Phase 3 | +| Orphan call handling | βœ… `ensure_call_outputs_present()` | βœ… `orphaned_tool_use_indices` | βœ… Keep | + +## Compact Flow Design + +### Trigger Conditions + +| Trigger | Current | Proposed | +|---------|---------|----------| +| Token threshold | βœ… `MEMORY_COMPRESSION_THRESHOLD` | Keep | +| Memory full | βœ… `MEMORY_SHORT_TERM_SIZE` | Keep | +| Context overflow error | ❌ Fails | βœ… Add | +| Manual `/compact` | ❌ None | ⚠️ P2 | + +### What Gets Discarded vs Preserved + +| Item Type | Action | Notes | +|-----------|--------|-------| +| System prompts | **Keep** | Initial context | +| User messages | **Keep** | Truncated to 20K tokens total | +| Previous summaries | **Discard** | Replaced by new summary | +| Turn aborted markers | **Keep** | Preserve `` context | +| Assistant messages | **Discard** | Replaced by summary | +| Tool calls | **Discard** | Not needed after summary | +| Tool results | **Discard** | Not needed after summary | +| Protected tool results | **Keep** | `manage_todo_list`, `read_file` with critical data | + +### Protected Tools List + +Tools whose results survive compaction (configurable): +- `manage_todo_list` β€” Task tracking state +- Future: any tool marked with `protected=True` + +### Rebuilt History Structure + +``` +[System Prompts] + [User Messages (truncated)] + [Summary] + [Protected Tools] +``` + +## Implementation Plan + +### Phase 1: Write-Time Truncation (P0) + +**Goal**: Truncate tool outputs at `add_message()` time. + +**Key interfaces**: +```python +# memory/truncate.py +def truncate_with_split(content: str, max_tokens: int) -> str: + """50/50 split: preserve beginning and end, remove middle.""" + +# memory/manager.py +def _maybe_truncate_tool_output(self, message: LLMMessage) -> LLMMessage: + """Truncate tool message if exceeds TOOL_OUTPUT_MAX_TOKENS.""" +``` + +**Acceptance**: +- Tool outputs > 5000 tokens truncated with marker +- Configurable via `TOOL_OUTPUT_TRUNCATION_POLICY` + +### Phase 2: Context Overflow Recovery (P0) + +**Goal**: Auto-recover from `context_length_exceeded` errors. + +**Key interfaces**: +```python +# llm/retry.py +def is_context_length_error(error: BaseException) -> bool: + """Detect context overflow from various providers.""" + +# agent/base.py +async def _call_with_overflow_recovery(self, messages, max_retries=3): + """Retry LLM call after removing oldest messages.""" +``` + +**Acceptance**: +- Context errors trigger automatic recovery (max 3 retries) +- Removed messages maintain tool pair integrity + +### Phase 3: Tool Pair Integrity on Removal (P1) + +**Goal**: When removing a message, also remove its counterpart. Add normalization before LLM calls. + +**Key interfaces**: +```python +# memory/manager.py +def remove_oldest_with_pair_integrity(self) -> Optional[LLMMessage]: + """Remove oldest message and its corresponding tool pair.""" + +def _remove_messages_by_tool_call_ids(self, call_ids: Set[str]) -> None: + """Remove all tool results matching given call IDs.""" + +def ensure_call_outputs_present(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Add synthetic 'aborted' output for orphaned tool calls.""" + +def remove_orphan_outputs(self, messages: List[LLMMessage]) -> List[LLMMessage]: + """Remove tool results without matching calls.""" +``` + +**Acceptance**: +- No orphaned tool calls or results after removal +- Normalization runs before every LLM call + +### Phase 4: User Message Truncation During Compact (P1) + +**Goal**: Limit total user message tokens during compaction. + +**Key interfaces**: +```python +# memory/compressor.py +def select_user_messages(messages: List[str], max_tokens: int = 20000) -> List[str]: + """Select user messages, prioritizing recent ones.""" +``` + +**Acceptance**: +- User messages capped at 20K tokens (recent-first) + +### Phase 5: Inline Compact History Rebuild (P1) + +**Goal**: Implement Codex-style history reconstruction. + +**Key interfaces**: +```python +# memory/compressor.py +def collect_user_messages(messages: List[LLMMessage]) -> List[str]: + """Extract user messages, excluding previous summaries.""" + +def build_compacted_history( + initial_context: List[LLMMessage], + user_messages: List[str], + summary_text: str, + protected_messages: List[LLMMessage], +) -> List[LLMMessage]: + """Build new history after compaction.""" + +def is_summary_message(message: str) -> bool: + """Check if message is a previous summary.""" +``` + +**Acceptance**: +- History = initial context + user messages + summary + protected tools +- Previous summaries excluded on re-compact + +### Phase 6: Manual `/compact` Command (P2, Optional) + +**Goal**: User-triggered compression via `/compact` command. + +**Acceptance**: +- Reports tokens saved and messages compressed + +## Phases at a Glance + +| Phase | Priority | Change | New Files | +|------:|----------|--------|-----------| +| 1 | P0 | Write-time truncation (50/50 split) | `memory/truncate.py` | +| 2 | P0 | Context overflow auto-recovery | - | +| 3 | P1 | Tool pair integrity on removal | - | +| 4 | P1 | User message truncation (20K limit) | - | +| 5 | P1 | Inline compact history rebuild | - | +| 6 | P2 | Manual `/compact` command | - | + +## Configuration + +```python +# Truncation +TOOL_OUTPUT_TRUNCATION_POLICY = "tokens" # none, bytes, tokens +TOOL_OUTPUT_MAX_TOKENS = 5000 +TOOL_OUTPUT_SERIALIZATION_BUFFER = 1.2 # 20% buffer for JSON overhead +APPROX_CHARS_PER_TOKEN = 4 + +# Compact +COMPACT_USER_MESSAGE_MAX_TOKENS = 20000 +CONTEXT_OVERFLOW_MAX_RETRIES = 3 + +# Protected Tools (results survive compaction) +PROTECTED_TOOLS = ["manage_todo_list"] + +# Prompts (customizable) +COMPACT_SUMMARIZATION_PROMPT = """You are performing a CONTEXT CHECKPOINT COMPACTION. +Create a handoff summary for another LLM that will resume the task. + +Include: +- Current progress and key decisions made +- Important context, constraints, or user preferences +- What remains to be done (clear next steps) +- Any critical data needed to continue + +Be concise and focused on helping the next LLM seamlessly continue.""" + +COMPACT_SUMMARY_PREFIX = """Another language model started this task and produced +a summary. Use this to build on existing work and avoid duplication:""" +``` + +## Risks and Mitigations + +| Risk | Mitigation | +|------|------------| +| Truncation loses important info | High limit (5000 tokens), preserve start+end | +| Aggressive overflow recovery | Max 3 retries, log removed content | +| Summary quality affects continuity | Use proven Codex prompts, allow customization | +| Long threads cause accuracy loss | Show warning after compaction (like Codex) | + +## Open Questions + +1. Should truncated content be logged for debugging? +2. Should summary use a smaller/faster model for cost savings? +3. Should we add `is_compact_summary` field to `LLMMessage`? (vs prefix detection) + +## Appendix: Compact Flow Diagram + +``` +TRIGGER + β”œβ”€β”€ Token threshold exceeded + β”œβ”€β”€ Short-term memory full + └── Context overflow error (with retry) + ↓ +COLLECT + β”œβ”€β”€ initial_context (system prompts) + β”œβ”€β”€ user_messages (filter out previous summaries) + └── protected_tools (manage_todo_list, etc.) + ↓ +SUMMARIZE + └── Call LLM with COMPACT_SUMMARIZATION_PROMPT + ↓ +REBUILD + └── initial_context + user_messages + summary + protected_tools + ↓ +REPLACE + └── Replace history, recompute token usage +``` + +## References + +- Codex compact: `codex-rs/core/src/compact.rs` +- Codex truncation: `codex-rs/core/src/truncate.rs` +- Codex normalization: `codex-rs/core/src/context_manager/normalize.rs` +- AgenticLoop memory: `memory/manager.py`, `memory/compressor.py` diff --git a/test/memory/test_compressor.py b/test/memory/test_compressor.py index d416bed..a54f851 100644 --- a/test/memory/test_compressor.py +++ b/test/memory/test_compressor.py @@ -1,5 +1,6 @@ """Unit tests for WorkingMemoryCompressor.""" +from config import Config from llm.base import LLMMessage from memory.compressor import WorkingMemoryCompressor from memory.types import CompressionStrategy @@ -46,10 +47,13 @@ async def test_sliding_window_strategy(self, mock_llm, simple_messages): assert result is not None assert len(result.messages) > 0 # Should have summary message - assert result.messages[0].role == "user" # Summary is a user message + assert any( + isinstance(msg.content, str) and msg.content.startswith(Config.COMPACT_SUMMARY_PREFIX) + for msg in result.messages + ) assert result.original_message_count == len(simple_messages) assert result.metadata["strategy"] == "sliding_window" - assert result.compressed_tokens < result.original_tokens + assert result.compressed_tokens > 0 async def test_deletion_strategy(self, mock_llm, simple_messages): """Test deletion compression strategy.""" @@ -194,87 +198,39 @@ async def test_tool_pairs_preserved_together( class TestProtectedTools: - """Test protected tool handling and todo context injection.""" + """Test protected tool handling.""" async def test_protected_tools_set_is_empty_by_default(self, mock_llm): - """Test that PROTECTED_TOOLS is empty - todo state is now injected via context.""" + """Test that manage_todo_list is protected by default.""" compressor = WorkingMemoryCompressor(mock_llm) - # PROTECTED_TOOLS should be empty because todo state is now injected - # via todo_context parameter instead of preserving tool messages - assert len(compressor.PROTECTED_TOOLS) == 0 + assert "manage_todo_list" in compressor.PROTECTED_TOOLS - async def test_todo_tool_messages_can_be_compressed( + async def test_todo_tool_messages_are_preserved( self, set_memory_config, mock_llm, protected_tool_messages ): - """Test that todo tool messages can now be compressed (state preserved via injection).""" + """Test that manage_todo_list tool messages are preserved.""" set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=0) # Don't preserve anything by default compressor = WorkingMemoryCompressor(mock_llm) - preserved, to_compress = compressor._separate_messages(protected_tool_messages) - - # Todo tool messages should now be compressible (not protected) - # Only system messages should be preserved when MEMORY_SHORT_TERM_MIN_SIZE=0 - assert len(to_compress) > 0 - - async def test_todo_context_injected_in_sliding_window(self, mock_llm, simple_messages): - """Test that todo context is injected into sliding window compression.""" - compressor = WorkingMemoryCompressor(mock_llm) - todo_context = "1. [pending] Fix bug\n2. [in_progress] Write tests" - - result = await compressor.compress( - simple_messages, - strategy="sliding_window", - target_tokens=500, - todo_context=todo_context, - ) - - # The summary should contain the todo context - summary_content = result.messages[-1].content if result.messages else "" - assert "[Current Tasks]" in summary_content - assert "Fix bug" in summary_content - - async def test_todo_context_injected_in_selective( - self, set_memory_config, mock_llm, tool_use_messages - ): - """Test that todo context is injected into selective compression.""" - set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=2) - compressor = WorkingMemoryCompressor(mock_llm) - todo_context = "1. [completed] Setup project" - - result = await compressor.compress( - tool_use_messages, - strategy="selective", - target_tokens=500, - todo_context=todo_context, - ) + preserved, _ = compressor._separate_messages(protected_tool_messages) - # Find the summary message and check for todo context - summary_found = False - for msg in result.messages: - content = str(msg.content) - if ( - "[Previous conversation summary]" in content - and "[Current Tasks]" in content - and "Setup project" in content - ): - summary_found = True - break - - assert summary_found, "Todo context should be in the summary" - - async def test_no_todo_context_when_none(self, mock_llm, simple_messages): - """Test that no todo section is added when todo_context is None.""" - compressor = WorkingMemoryCompressor(mock_llm) - - result = await compressor.compress( - simple_messages, - strategy="sliding_window", - target_tokens=500, - todo_context=None, - ) + # manage_todo_list tool_use/tool_result should be preserved + preserved_tool_use_ids = set() + preserved_tool_result_ids = set() + for msg in preserved: + if isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict): + if ( + block.get("type") == "tool_use" + and block.get("name") == "manage_todo_list" + ): + preserved_tool_use_ids.add(block.get("id")) + elif block.get("type") == "tool_result": + preserved_tool_result_ids.add(block.get("tool_use_id")) - summary_content = result.messages[-1].content if result.messages else "" - assert "[Current Tasks]" not in summary_content + assert preserved_tool_use_ids + assert preserved_tool_use_ids == preserved_tool_result_ids class TestMessageSeparation: @@ -362,6 +318,21 @@ async def test_tool_pair_preservation_rule( tool_id in preserved_tool_use_ids ), f"Tool result for {tool_id} is preserved but its use is not" + async def test_turn_aborted_marker_preserved(self, mock_llm): + """Turn-aborted markers should be preserved even with tight budgets.""" + compressor = WorkingMemoryCompressor(mock_llm) + + messages = [ + LLMMessage(role="user", content="Old message"), + LLMMessage(role="user", content=" interrupted"), + LLMMessage(role="user", content="Recent " * 20), + ] + + selected = compressor.select_user_messages(messages, max_tokens=1) + assert any( + isinstance(msg.content, str) and "" in msg.content for msg in selected + ) + class TestTokenEstimation: """Test token estimation logic.""" @@ -424,10 +395,36 @@ async def test_compression_ratio_calculation(self, mock_llm, simple_messages): simple_messages, strategy=CompressionStrategy.SLIDING_WINDOW, target_tokens=50 ) + # compression_ratio = compressed_tokens / original_tokens assert result.compression_ratio > 0 - assert result.compression_ratio <= 1.0 - # Compressed should be smaller than original - assert result.compressed_tokens <= result.original_tokens + assert result.original_tokens > 0 + assert result.compressed_tokens > 0 + expected_ratio = result.compressed_tokens / result.original_tokens + assert abs(result.compression_ratio - expected_ratio) < 0.001 + + async def test_compression_reduces_tokens(self, mock_llm): + """Test that compression reduces token count for large inputs.""" + compressor = WorkingMemoryCompressor(mock_llm) + + # Create a larger message set where compression should be effective + large_messages = [ + LLMMessage(role="user", content="Task: analyze this codebase"), + LLMMessage(role="assistant", content="I'll analyze the code. " * 50), + LLMMessage(role="user", content="Good, continue with the analysis"), + LLMMessage(role="assistant", content="Here are my findings. " * 50), + LLMMessage(role="user", content="What about the tests?"), + LLMMessage(role="assistant", content="The tests cover these areas. " * 50), + ] + + result = await compressor.compress( + large_messages, strategy=CompressionStrategy.SLIDING_WINDOW, target_tokens=100 + ) + + # For large inputs, compression should reduce tokens + assert result.compressed_tokens < result.original_tokens, ( + f"Compression should reduce tokens for large inputs: " + f"{result.compressed_tokens} >= {result.original_tokens}" + ) async def test_token_savings_calculation(self, mock_llm, simple_messages): """Test token savings calculation.""" @@ -437,8 +434,8 @@ async def test_token_savings_calculation(self, mock_llm, simple_messages): simple_messages, strategy=CompressionStrategy.SLIDING_WINDOW ) + # token_savings = original_tokens - compressed_tokens savings = result.token_savings - assert savings >= 0 assert savings == result.original_tokens - result.compressed_tokens async def test_savings_percentage_calculation(self, mock_llm, simple_messages): @@ -449,8 +446,149 @@ async def test_savings_percentage_calculation(self, mock_llm, simple_messages): simple_messages, strategy=CompressionStrategy.SLIDING_WINDOW ) - percentage = result.savings_percentage - assert 0 <= percentage <= 100 + # savings_percentage = (token_savings / original_tokens) * 100 + expected_percentage = ( + (result.token_savings / result.original_tokens) * 100 + if result.original_tokens > 0 + else 0 + ) + assert abs(result.savings_percentage - expected_percentage) < 0.01 + + +class TestCompressionEffectiveness: + """Tests that compaction produces meaningful reductions.""" + + async def test_compaction_reduces_token_count(self, mock_llm): + """Large assistant/tool outputs should compress to a smaller summary.""" + compressor = WorkingMemoryCompressor(mock_llm) + + messages = [ + LLMMessage(role="user", content="Short question 1"), + LLMMessage(role="assistant", content="A" * 8000), + LLMMessage(role="user", content="Short question 2"), + LLMMessage(role="assistant", content="B" * 8000), + LLMMessage(role="tool", content="C" * 8000, tool_call_id="call_1"), + ] + + result = await compressor.compress( + messages, strategy=CompressionStrategy.SLIDING_WINDOW, target_tokens=200 + ) + + # If you want to see numbers, run: pytest -k compaction_reduces_token_count -s + print( + f"original_tokens={result.original_tokens} " + f"compressed_tokens={result.compressed_tokens} " + f"savings_percentage={result.savings_percentage:.1f}" + ) + + assert result.compressed_tokens < result.original_tokens, ( + "Expected compression to reduce token count; " + f"original={result.original_tokens}, compressed={result.compressed_tokens}" + ) + assert result.savings_percentage > 50 + + +class TestUserMessageBudget: + """Tests user message truncation limits.""" + + async def test_user_messages_capped_by_budget(self, set_memory_config, mock_llm): + """User messages should be selected within the configured budget.""" + set_memory_config(COMPACT_USER_MESSAGE_MAX_TOKENS=30) + compressor = WorkingMemoryCompressor(mock_llm) + + messages = [LLMMessage(role="user", content=f"msg {i} " + "x" * 40) for i in range(5)] + + selected = compressor.select_user_messages(messages, Config.COMPACT_USER_MESSAGE_MAX_TOKENS) + + assert selected + assert selected[-1].content == messages[-1].content + assert compressor._estimate_tokens(selected) <= Config.COMPACT_USER_MESSAGE_MAX_TOKENS + + +class TestCompactionStructure: + """Tests that rebuilt history matches RFC structure.""" + + async def test_compaction_rebuild_structure(self, set_memory_config, mock_llm): + """Rebuilt history should follow system+user+summary+protected+orphaned.""" + set_memory_config(COMPACT_USER_MESSAGE_MAX_TOKENS=100000) + compressor = WorkingMemoryCompressor(mock_llm) + + system = LLMMessage(role="system", content="system") + user_one = LLMMessage(role="user", content="user one") + user_two = LLMMessage(role="user", content="user two") + + orphan_call = { + "id": "orphan_1", + "type": "function", + "function": {"name": "tool_x", "arguments": "{}"}, + } + orphan_assistant = LLMMessage(role="assistant", content=None, tool_calls=[orphan_call]) + + protected_call = { + "id": "todo_1", + "type": "function", + "function": {"name": "manage_todo_list", "arguments": "{}"}, + } + protected_assistant = LLMMessage( + role="assistant", content=None, tool_calls=[protected_call] + ) + protected_tool = LLMMessage(role="tool", content="ok", tool_call_id="todo_1") + + messages = [ + system, + user_one, + orphan_assistant, + protected_assistant, + protected_tool, + user_two, + ] + + result = await compressor.compress( + messages, strategy=CompressionStrategy.SLIDING_WINDOW, target_tokens=200 + ) + result_messages = result.messages + + assert result_messages[0].role == "system" + + summary_idx = next( + i + for i, msg in enumerate(result_messages) + if isinstance(msg.content, str) + and msg.content.startswith(Config.COMPACT_SUMMARY_PREFIX) + ) + + # User messages must appear before the summary. + for idx, msg in enumerate(result_messages): + if msg.role == "user" and isinstance(msg.content, str): + if msg.content.startswith(Config.COMPACT_SUMMARY_PREFIX): + continue + assert idx < summary_idx + + protected_assistant_idx = next( + i + for i, msg in enumerate(result_messages) + if msg.role == "assistant" + and msg.tool_calls + and any( + isinstance(tc, dict) and tc.get("function", {}).get("name") == "manage_todo_list" + for tc in msg.tool_calls + ) + ) + protected_tool_idx = next( + i + for i, msg in enumerate(result_messages) + if msg.role == "tool" and msg.tool_call_id == "todo_1" + ) + orphan_idx = next( + i + for i, msg in enumerate(result_messages) + if msg.role == "assistant" + and msg.tool_calls + and any(isinstance(tc, dict) and tc.get("id") == "orphan_1" for tc in msg.tool_calls) + ) + + assert summary_idx < protected_assistant_idx < protected_tool_idx + assert orphan_idx == len(result_messages) - 1 class TestCompressionErrors: diff --git a/test/memory/test_integration.py b/test/memory/test_integration.py index ddc13e9..25afd5d 100644 --- a/test/memory/test_integration.py +++ b/test/memory/test_integration.py @@ -4,6 +4,7 @@ especially focusing on edge cases and the tool_call/tool_result matching issue. """ +from config import Config from llm.base import LLMMessage from memory import MemoryManager from memory.types import CompressionStrategy @@ -142,7 +143,7 @@ async def test_interleaved_tool_calls(self, set_memory_config, mock_llm): async def test_orphaned_tool_use_detection(self, set_memory_config, mock_llm): """Test detection of orphaned tool_use (no matching result).""" - set_memory_config(MEMORY_SHORT_TERM_SIZE=5) + set_memory_config(MEMORY_SHORT_TERM_SIZE=20) manager = MemoryManager(mock_llm) # Add tool_use without result @@ -181,7 +182,7 @@ async def test_orphaned_tool_use_detection(self, set_memory_config, mock_llm): async def test_orphaned_tool_result_detection(self, set_memory_config, mock_llm): """Test detection of orphaned tool_result (no matching use).""" - set_memory_config(MEMORY_SHORT_TERM_SIZE=5) + set_memory_config(MEMORY_SHORT_TERM_SIZE=20) manager = MemoryManager(mock_llm) # Add tool_result without use (this shouldn't happen but let's test it) @@ -377,7 +378,7 @@ async def test_rapid_compression_cycles(self, set_memory_config, mock_llm): async def test_alternating_compression_strategies(self, set_memory_config, mock_llm): """Test using different compression strategies on same manager.""" - set_memory_config(MEMORY_SHORT_TERM_SIZE=5) + set_memory_config(MEMORY_SHORT_TERM_SIZE=20) manager = MemoryManager(mock_llm) # Add messages and compress with sliding window @@ -411,7 +412,7 @@ async def test_alternating_compression_strategies(self, set_memory_config, mock_ 1 for msg in context if isinstance(msg.content, str) - and msg.content.startswith("[Previous conversation summary]") + and msg.content.startswith(Config.COMPACT_SUMMARY_PREFIX) ) assert summary_count >= 1 # At least one summary should exist diff --git a/test/memory/test_memory_manager.py b/test/memory/test_memory_manager.py index 81b7d74..9c41896 100644 --- a/test/memory/test_memory_manager.py +++ b/test/memory/test_memory_manager.py @@ -1,5 +1,6 @@ """Unit tests for MemoryManager.""" +from config import Config from llm.base import LLMMessage from memory import MemoryManager from memory.types import CompressionStrategy @@ -103,8 +104,6 @@ async def test_compression_on_short_term_full(self, set_memory_config, mock_llm) # After 5 messages, compression should have been triggered and short-term cleared assert manager.compression_count == 1 assert manager.was_compressed_last_iteration - # After compression, short-term is cleared so it's not full - assert not manager.short_term.is_full() async def test_compression_on_hard_limit(self, set_memory_config, mock_llm): """Test compression triggers on hard limit (compression threshold).""" @@ -140,8 +139,7 @@ async def test_compression_creates_summary(self, set_memory_config, mock_llm, si # Check that summary message exists in short_term (at the front) context = manager.get_context_for_llm() has_summary = any( - isinstance(msg.content, str) - and msg.content.startswith("[Previous conversation summary]") + isinstance(msg.content, str) and msg.content.startswith(Config.COMPACT_SUMMARY_PREFIX) for msg in context ) assert has_summary, "Summary message should be present after compression" @@ -249,30 +247,16 @@ async def test_mismatched_tool_calls_detected( f"Detected mismatch - missing results: {missing_results}, missing uses: {missing_uses}" ) - async def test_todo_context_provider_integration( + async def test_protected_tool_messages_preserved( self, set_memory_config, mock_llm, protected_tool_messages ): - """Test that todo context provider is called during compression. - - Note: manage_todo_list is no longer in PROTECTED_TOOLS. Instead, todo state - is preserved via todo_context injection from MemoryManager's provider callback. - """ + """Test that protected tool messages survive compression.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=10, # Large enough to avoid auto-compression MEMORY_SHORT_TERM_MIN_SIZE=1, ) manager = MemoryManager(mock_llm) - # Set up todo context provider - todo_context_called = False - - def mock_todo_provider(): - nonlocal todo_context_called - todo_context_called = True - return "1. [pending] Test task" - - manager.set_todo_context_provider(mock_todo_provider) - # Add messages for msg in protected_tool_messages: await manager.add_message(msg) @@ -280,19 +264,27 @@ def mock_todo_provider(): # Manually trigger compression compressed = await manager.compress(strategy=CompressionStrategy.SELECTIVE) - # Verify compression happened and provider was called + # Verify compression happened assert compressed is not None - assert todo_context_called, "Todo context provider should be called during compression" - # Verify todo context is in the summary + # Verify protected tool pair is preserved context = manager.get_context_for_llm() - summary_has_todo = False + tool_use_ids = set() + tool_result_ids = set() for msg in context: - if isinstance(msg.content, str) and "[Current Tasks]" in msg.content: - summary_has_todo = True - break + if isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict): + if ( + block.get("type") == "tool_use" + and block.get("name") == "manage_todo_list" + ): + tool_use_ids.add(block.get("id")) + elif block.get("type") == "tool_result": + tool_result_ids.add(block.get("tool_use_id")) - assert summary_has_todo, "Todo context should be injected into compression summary" + assert tool_use_ids + assert tool_use_ids == tool_result_ids async def test_multiple_tool_pairs_in_sequence(self, set_memory_config, mock_llm): """Test multiple consecutive tool_use/tool_result pairs.""" @@ -356,6 +348,91 @@ async def test_multiple_tool_pairs_in_sequence(self, set_memory_config, mock_llm assert tool_use_ids == tool_result_ids +class TestPairIntegrityRemoval: + """Test pair integrity when removing oldest messages.""" + + async def test_remove_oldest_with_pair_integrity(self, mock_llm): + """Removing oldest tool call should also remove its tool result.""" + manager = MemoryManager(mock_llm) + + tool_call = { + "id": "call_1", + "type": "function", + "function": {"name": "tool_a", "arguments": "{}"}, + } + + await manager.add_message( + LLMMessage(role="assistant", content=None, tool_calls=[tool_call]) + ) + await manager.add_message(LLMMessage(role="tool", content="result", tool_call_id="call_1")) + await manager.add_message(LLMMessage(role="user", content="After tool")) + + removed = manager.remove_oldest_with_pair_integrity() + assert removed is not None + assert removed.role == "assistant" + + remaining = manager.short_term.get_messages() + assert len(remaining) == 1 + assert remaining[0].role == "user" + assert remaining[0].content == "After tool" + + +class TestNormalization: + """Test prompt normalization helpers.""" + + async def test_adds_synthetic_tool_output(self, mock_llm): + """Test synthetic outputs for orphaned tool calls.""" + manager = MemoryManager(mock_llm) + + tool_call = { + "id": "call_1", + "type": "function", + "function": {"name": "tool_a", "arguments": "{}"}, + } + + await manager.add_message( + LLMMessage(role="assistant", content=None, tool_calls=[tool_call]) + ) + + context = manager.get_context_for_llm() + assert any( + msg.role == "tool" and msg.tool_call_id == "call_1" and msg.content == "aborted" + for msg in context + ) + + async def test_removes_orphan_tool_outputs(self, mock_llm): + """Test removal of tool outputs without matching calls.""" + manager = MemoryManager(mock_llm) + + await manager.add_message( + LLMMessage(role="tool", content="result", tool_call_id="missing_call") + ) + + context = manager.get_context_for_llm() + assert not any(msg.role == "tool" for msg in context) + + +class TestTruncation: + """Test write-time tool output truncation.""" + + async def test_truncates_large_tool_output(self, set_memory_config, mock_llm): + """Tool output should be truncated at write time.""" + set_memory_config( + TOOL_OUTPUT_TRUNCATION_POLICY="tokens", + TOOL_OUTPUT_MAX_TOKENS=10, + TOOL_OUTPUT_SERIALIZATION_BUFFER=1.0, + APPROX_CHARS_PER_TOKEN=1, + ) + manager = MemoryManager(mock_llm) + + content = "x" * 30 + await manager.add_message(LLMMessage(role="tool", content=content, tool_call_id="call_1")) + + stored = manager.short_term.get_messages()[0] + assert stored.content != content + assert "...20 tokens truncated..." in stored.content + + class TestEdgeCases: """Test edge cases and error scenarios.""" diff --git a/test/memory/test_store.py b/test/memory/test_store.py index e369218..4b1e13b 100644 --- a/test/memory/test_store.py +++ b/test/memory/test_store.py @@ -7,6 +7,7 @@ import pytest +from config import Config from llm.base import LLMMessage from memory.store import MemoryStore from memory.types import CompressedMemory @@ -242,7 +243,7 @@ async def test_save_memory(self, store): messages=[ LLMMessage( role="user", - content="[Previous conversation summary]\nEarlier conversation summary", + content=f"{Config.COMPACT_SUMMARY_PREFIX}Earlier conversation summary", ) ], original_message_count=5, @@ -304,7 +305,7 @@ async def test_save_summary(self, store): summary = CompressedMemory( messages=[ LLMMessage( - role="user", content="[Previous conversation summary]\nThis is a summary" + role="user", content=f"{Config.COMPACT_SUMMARY_PREFIX}This is a summary" ), LLMMessage(role="user", content="Important message"), ], @@ -334,7 +335,7 @@ async def test_save_multiple_summaries(self, store): for i in range(3): summary = CompressedMemory( messages=[ - LLMMessage(role="user", content=f"[Previous conversation summary]\nSummary {i}") + LLMMessage(role="user", content=f"{Config.COMPACT_SUMMARY_PREFIX}Summary {i}") ], original_message_count=5, original_tokens=500, @@ -392,7 +393,7 @@ async def test_get_session_stats(self, store): await store.save_message(session_id, LLMMessage(role="assistant", content="Hi"), tokens=3) summary = CompressedMemory( - messages=[LLMMessage(role="user", content="[Previous conversation summary]\nSummary")], + messages=[LLMMessage(role="user", content=f"{Config.COMPACT_SUMMARY_PREFIX}Summary")], original_message_count=5, original_tokens=500, compressed_tokens=150, @@ -466,7 +467,7 @@ async def test_complete_session_lifecycle(self, store): messages=[ LLMMessage( role="user", - content=f"[Previous conversation summary]\nSummary of batch {i}", + content=f"{Config.COMPACT_SUMMARY_PREFIX}Summary of batch {i}", ) ], original_message_count=5, diff --git a/test/memory/test_truncate.py b/test/memory/test_truncate.py new file mode 100644 index 0000000..b607ab8 --- /dev/null +++ b/test/memory/test_truncate.py @@ -0,0 +1,40 @@ +"""Tests for tool output truncation utilities.""" + +from memory.truncate import truncate_tool_output + + +class TestToolOutputTruncation: + """Validate truncation behavior for different policies.""" + + def test_bytes_policy_truncates_and_adds_line_count(self): + """Bytes policy should truncate and include a line-count header.""" + content = "line1\nline2\nline3\nline4\nline5" + + result = truncate_tool_output( + content=content, + policy="bytes", + max_tokens=0, + max_bytes=20, + serialization_buffer=1.0, + approx_chars_per_token=4, + ) + + assert result.truncated + assert result.content.startswith("Total output lines:") + assert "chars truncated" in result.content + + def test_bytes_policy_no_truncation_when_within_budget(self): + """Bytes policy should keep content when within budget.""" + content = "short output" + + result = truncate_tool_output( + content=content, + policy="bytes", + max_tokens=0, + max_bytes=200, + serialization_buffer=1.0, + approx_chars_per_token=4, + ) + + assert not result.truncated + assert result.content == content diff --git a/test/test_context_overflow_recovery.py b/test/test_context_overflow_recovery.py new file mode 100644 index 0000000..a121537 --- /dev/null +++ b/test/test_context_overflow_recovery.py @@ -0,0 +1,54 @@ +"""Tests for context overflow recovery in the agent loop.""" + +from agent.base import BaseAgent +from config import Config +from llm.base import LLMMessage, LLMResponse, StopReason + + +class OverflowMockLLM: + """Mock LLM that raises a context length error on the first call.""" + + def __init__(self): + self.provider_name = "mock" + self.model = "mock-model" + self.call_count = 0 + + async def call_async(self, messages, tools=None, max_tokens=4096, **kwargs): + self.call_count += 1 + if self.call_count == 1: + raise Exception("context_length_exceeded") + return LLMResponse(content="ok", stop_reason=StopReason.STOP) + + def extract_text(self, response): + return response.content or "" + + @property + def supports_tools(self): + return True + + +class DummyAgent(BaseAgent): + """Minimal agent for testing overflow recovery.""" + + def run(self, task: str) -> str: + return "" + + +async def test_context_overflow_recovery_removes_oldest(monkeypatch): + """Context overflow should trigger removal and retry.""" + monkeypatch.setattr(Config, "CONTEXT_OVERFLOW_MAX_RETRIES", 1) + llm = OverflowMockLLM() + agent = DummyAgent(llm=llm, tools=[]) + + await agent.memory.add_message(LLMMessage(role="user", content="Message 1")) + await agent.memory.add_message(LLMMessage(role="user", content="Message 2")) + + before = agent.memory.short_term.count() + + response = await agent._call_with_overflow_recovery() + + after = agent.memory.short_term.count() + + assert response.content == "ok" + assert llm.call_count == 2 + assert after < before