From 0490efdea27996ae0625aa0513f06936b5b02ffa Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Thu, 15 Jan 2026 17:28:33 -0800 Subject: [PATCH 1/5] Initial vibe coded mech interp implementation --- examples/03_mech_interp_hooked_transformer.py | 151 ++++++ pyproject.toml | 2 + src/ares/contrib/mech_interp/README.md | 491 ++++++++++++++++++ src/ares/contrib/mech_interp/__init__.py | 22 + .../contrib/mech_interp/activation_capture.py | 281 ++++++++++ src/ares/contrib/mech_interp/hook_utils.py | 307 +++++++++++ .../mech_interp/hooked_transformer_client.py | 215 ++++++++ 7 files changed, 1469 insertions(+) create mode 100644 examples/03_mech_interp_hooked_transformer.py create mode 100644 src/ares/contrib/mech_interp/README.md create mode 100644 src/ares/contrib/mech_interp/__init__.py create mode 100644 src/ares/contrib/mech_interp/activation_capture.py create mode 100644 src/ares/contrib/mech_interp/hook_utils.py create mode 100644 src/ares/contrib/mech_interp/hooked_transformer_client.py diff --git a/examples/03_mech_interp_hooked_transformer.py b/examples/03_mech_interp_hooked_transformer.py new file mode 100644 index 0000000..4e17b19 --- /dev/null +++ b/examples/03_mech_interp_hooked_transformer.py @@ -0,0 +1,151 @@ +"""Example of using ARES with TransformerLens for mechanistic interpretability. + +This example demonstrates how to: +1. Use HookedTransformer with ARES environments +2. Capture activations across an agent trajectory +3. Apply interventions to study model behavior + +Example usage: + + 1. Make sure you have mech_interp dependencies installed + `uv sync --group mech_interp` + 2. Run the example + `uv run -m examples.03_mech_interp_hooked_transformer` +""" + +import asyncio + +from transformer_lens import HookedTransformer + +from ares.contrib.mech_interp import ActivationCapture +from ares.contrib.mech_interp import HookedTransformerLLMClient +from ares.contrib.mech_interp import InterventionManager +from ares.contrib.mech_interp import create_zero_ablation_hook +from ares.environments import swebench_env + + +async def main(): + print("=" * 80) + print("ARES + TransformerLens Mechanistic Interpretability Example") + print("=" * 80) + + # Load a small model for demonstration + # For real work, you'd use a larger model like gpt2-medium or pythia-1.4b + print("\nLoading HookedTransformer model...") + model = HookedTransformer.from_pretrained( + "gpt2-small", + device="cpu", # Change to "cuda" if you have a GPU + ) + + # Create the LLM client with reduced token limit for gpt2-small's context window + # gpt2-small has max context of 1024 tokens, so we need to be conservative + client = HookedTransformerLLMClient( + model=model, + model_name="gpt2-small", + max_new_tokens=128, # Keep this small to avoid context overflow + ) + + # Load SWE-bench tasks + all_tasks = swebench_env.swebench_verified_tasks() + tasks = [all_tasks[0]] # Just one task for demo + + print(f"\nRunning on task: {tasks[0].instance_id}") + print(f"Repository: {tasks[0].repo}") + print("-" * 80) + + # Example 1: Basic execution with activation capture + print("\n[Example 1] Running agent with activation capture...") + print("-" * 80) + + async with swebench_env.SweBenchEnv(tasks=tasks) as env: + # Set up activation capture + with ActivationCapture(model) as capture: + ts = await env.reset() + step_count = 0 + max_steps = 3 # Limit steps for demo + + while not ts.last() and step_count < max_steps: + # Capture activations for this step + capture.start_step() + + # Generate response + assert ts.observation is not None + action = await client(ts.observation) + + # End capture for this step + capture.end_step() + capture.record_step_metadata( + { + "step": step_count, + "action_preview": str(action.chat_completion_response.choices[0].message.content)[:50], + } + ) + + assert action.chat_completion_response.usage is not None + print( + f"Step {step_count}: Generated {action.chat_completion_response.usage.completion_tokens} tokens" + ) + + # Step environment + ts = await env.step(action) + step_count += 1 + + # Analyze captured activations + trajectory = capture.get_trajectory() + print(f"\nCaptured activations for {len(trajectory)} steps") + + # Example: Look at attention patterns in layer 0 + if len(trajectory) > 0: + attn_pattern = trajectory.get_activation(0, "blocks.0.attn.hook_pattern") + print(f"Layer 0 attention pattern shape: {attn_pattern.shape}") + print(" [batch, n_heads, query_pos, key_pos]") + + # Save trajectory for later analysis + print("\nSaving trajectory activations to ./mech_interp_demo/trajectory_001/") + trajectory.save("./mech_interp_demo/trajectory_001") + + # Example 2: Running with interventions + print("\n[Example 2] Running agent with attention head ablation...") + print("-" * 80) + + async with swebench_env.SweBenchEnv(tasks=tasks) as env: + # Set up intervention: ablate heads 0-2 in layer 0 + manager = InterventionManager(model) + manager.add_intervention( + hook_name="blocks.0.attn.hook_result", + hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), + description="Ablate attention heads 0-2 in layer 0", + ) + + print(manager.get_intervention_summary()) + + with manager: + ts = await env.reset() + step_count = 0 + max_steps = 2 # Limit steps for demo + + while not ts.last() and step_count < max_steps: + assert ts.observation is not None + action = await client(ts.observation) + + assert action.chat_completion_response.usage is not None + num_tokens = action.chat_completion_response.usage.completion_tokens + print(f"Step {step_count} (with ablation): Generated {num_tokens} tokens") + + ts = await env.step(action) + step_count += 1 + manager.increment_step() + + print("\n" + "=" * 80) + print("Demo complete!") + print("=" * 80) + print("\nNext steps for mechanistic interpretability research:") + print("1. Load saved activations: TrajectoryActivations.load('./mech_interp_demo/trajectory_001')") + print("2. Analyze attention patterns across the trajectory") + print("3. Use interventions to study causal effects") + print("4. Compare 'clean' vs 'corrupted' trajectories with path patching") + print("\nSee src/ares/contrib/mech_interp/README.md for more examples!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index b6b6cff..88aadbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ docs = [ "sphinx-rtd-theme>=2.0.0", "sphinx-autodoc-typehints>=1.25.0", ] +transformer-lens = ["transformer-lens>=2.10.0"] [project.urls] Homepage = "https://github.com/withmartian/ares" @@ -61,6 +62,7 @@ dev = [ "hatchling>=1.28.0", "twine>=6.2.0", ] + examples = [ "transformers>=4.57.3", "llama-cpp-python>=0.3.16", diff --git a/src/ares/contrib/mech_interp/README.md b/src/ares/contrib/mech_interp/README.md new file mode 100644 index 0000000..12f2945 --- /dev/null +++ b/src/ares/contrib/mech_interp/README.md @@ -0,0 +1,491 @@ +# Mechanistic Interpretability for ARES + +This module provides deep integration between ARES and [TransformerLens](https://github.com/neelnanda-io/TransformerLens), enabling mechanistic interpretability research on code agents across long-horizon tasks. + +## Why ARES for Mechanistic Interpretability? + +Traditional mechanistic interpretability focuses on static, single-step analysis. But modern AI agents: +- Make decisions across many steps (50-100+ steps per episode) +- Maintain internal state that evolves over time +- Exhibit temporal dependencies and long-horizon planning + +ARES enables **trajectory-level mechanistic interpretability** by: +1. Capturing activations across entire agent episodes +2. Studying how internal representations evolve during multi-step reasoning +3. Identifying critical moments where interventions significantly alter outcomes +4. Analyzing information flow across dozens of inference steps + +## Quick Start + +### Installation + +```bash +# Install ARES with mech_interp group (includes TransformerLens) +uv sync --group mech_interp +``` + +### Basic Example + +```python +import asyncio +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import HookedTransformerLLMClient, ActivationCapture +from ares.environments import swebench_env + +async def main(): + # Load model + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + + # Run agent and capture activations + tasks = swebench_env.swebench_verified_tasks()[:1] + + async with swebench_env.SweBenchEnv(tasks=tasks) as env: + with ActivationCapture(model) as capture: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + ts = await env.step(action) + + # Analyze trajectory + trajectory = capture.get_trajectory() + print(f"Captured {len(trajectory)} steps") + trajectory.save("./activations/episode_001") + +asyncio.run(main()) +``` + +## Core Components + +### 1. HookedTransformerLLMClient + +An ARES-compatible LLM client that uses TransformerLens's `HookedTransformer` for inference. + +```python +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import HookedTransformerLLMClient + +model = HookedTransformer.from_pretrained("gpt2-medium") +client = HookedTransformerLLMClient( + model=model, + model_name="gpt2-medium", + max_new_tokens=1024, + generation_kwargs={"temperature": 0.7} +) +``` + +**With Chat Templates:** + +```python +from transformers import AutoTokenizer +from ares.contrib.mech_interp import create_hooked_transformer_client_with_chat_template + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") +model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") + +client = create_hooked_transformer_client_with_chat_template( + model=model, + tokenizer=tokenizer, + model_name="Qwen/Qwen2.5-3B-Instruct", + max_new_tokens=2048, +) +``` + +### 2. ActivationCapture + +Captures activations across an agent trajectory for later analysis. + +```python +from ares.contrib.mech_interp import ActivationCapture + +with ActivationCapture(model) as capture: + # Filter which hooks to capture (optional) + # capture = ActivationCapture(model, hook_filter=lambda name: "attn" in name) + + async with env: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + + # Optionally record metadata + capture.record_step_metadata({"step_reward": ts.reward}) + + ts = await env.step(action) + +# Get trajectory +trajectory = capture.get_trajectory() + +# Access specific activations +attn_pattern_step_5 = trajectory.get_activation(5, "blocks.0.attn.hook_pattern") + +# Get activations across trajectory +all_layer0_resid = trajectory.get_activation_across_trajectory("blocks.0.hook_resid_post") + +# Save for later +trajectory.save("./activations/episode_001") + +# Load later +loaded = TrajectoryActivations.load("./activations/episode_001") +``` + +**Automatic Capture:** + +```python +from ares.contrib.mech_interp import automatic_activation_capture + +with automatic_activation_capture(model) as capture: + # Activations are captured automatically during each client call + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + + trajectory = capture.get_trajectory() +``` + +### 3. InterventionManager + +Apply causal interventions during agent execution to study model behavior. + +```python +from ares.contrib.mech_interp import ( + InterventionManager, + create_zero_ablation_hook, + create_path_patching_hook, +) + +manager = InterventionManager(model) + +# Ablate attention heads +manager.add_intervention( + hook_name="blocks.0.attn.hook_result", + hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), + description="Ablate heads 0-2 in layer 0", + apply_at_steps=[5, 6, 7] # Optional: only at specific steps +) + +# Run with interventions +with manager: + async with env: + ts = await env.reset() + step_count = 0 + while not ts.last(): + action = await client(ts.observation) + ts = await env.step(action) + + step_count += 1 + manager.increment_step() # Track which step we're on +``` + +### 4. Hook Utilities + +Pre-built hooks for common interventions: + +**Zero Ablation:** +```python +from ares.contrib.mech_interp import create_zero_ablation_hook + +# Ablate specific positions +hook = create_zero_ablation_hook(positions=[10, 11, 12]) + +# Ablate specific attention heads +hook = create_zero_ablation_hook(heads=[0, 1]) +``` + +**Path Patching:** +```python +from ares.contrib.mech_interp import create_path_patching_hook + +# Run clean and corrupted inputs +clean_cache, _ = model.run_with_cache(clean_tokens) +corrupted_cache, _ = model.run_with_cache(corrupted_tokens) + +# Patch clean activations into corrupted run +hook = create_path_patching_hook( + clean_activation=clean_cache["blocks.0.hook_resid_post"], + positions=[5, 6, 7] +) +``` + +**Mean Ablation:** +```python +from ares.contrib.mech_interp import create_mean_ablation_hook + +# Compute mean activation over dataset +mean_cache = compute_mean_activations(model, dataset) + +hook = create_mean_ablation_hook( + mean_activation=mean_cache["blocks.0.hook_resid_post"], + positions=[10, 11, 12] +) +``` + +**Attention Knockout:** +```python +from ares.contrib.mech_interp import create_attention_knockout_hook + +# Prevent position 20 from attending to positions 0-10 +hook = create_attention_knockout_hook( + source_positions=[20], + target_positions=list(range(11)) +) +``` + +## Research Use Cases + +### 1. Attention Head Analysis Across Trajectories + +Study how attention patterns evolve as agents work through tasks: + +```python +with ActivationCapture(model) as capture: + # Run agent episode + ... + +trajectory = capture.get_trajectory() + +# Analyze attention patterns across all steps +for step in range(len(trajectory)): + attn = trajectory.get_activation(step, "blocks.5.attn.hook_pattern") + # Analyze attention to specific tokens, copy behavior, etc. +``` + +### 2. Identifying Critical Decision Points + +Find steps where small perturbations significantly alter outcomes: + +```python +baseline_trajectory = run_episode(env, client) + +# Test interventions at each step +critical_steps = [] +for step in range(len(baseline_trajectory)): + manager = InterventionManager(model) + manager.add_intervention( + hook_name="blocks.3.hook_resid_post", + hook_fn=create_zero_ablation_hook(positions=[10, 11]), + apply_at_steps=[step] + ) + + perturbed_reward = run_episode_with_intervention(env, client, manager) + + if abs(perturbed_reward - baseline_trajectory.reward) > threshold: + critical_steps.append(step) +``` + +### 3. Circuit Discovery in Multi-Step Reasoning + +Use path patching to identify circuits responsible for specific capabilities: + +```python +# Compare successful vs failed trajectories +success_cache = run_and_cache_episode(env, client, success_task) +failure_cache = run_and_cache_episode(env, client, failure_task) + +# Systematically patch components from success to failure +for layer in range(model.cfg.n_layers): + for step in range(len(failure_cache)): + hook = create_path_patching_hook( + clean_activation=success_cache.get_activation(step, f"blocks.{layer}.hook_resid_post"), + positions=None + ) + + # Test if this component recovers successful behavior + ... +``` + +### 4. Temporal Information Flow Analysis + +Track how information propagates through the model across steps: + +```python +# Capture activations with detailed metadata +with ActivationCapture(model) as capture: + async with env: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + + # Record what the agent was doing + capture.record_step_metadata({ + "action_type": classify_action(action), + "error_present": check_for_errors(ts.observation), + "file_context": extract_file_context(ts.observation), + }) + + ts = await env.step(action) + +trajectory = capture.get_trajectory() + +# Analyze when specific features activate relative to task events +# E.g., does the model activate "file-reading neurons" before bash commands? +``` + +### 5. Comparative Analysis: Different Models + +Compare how different models solve the same task: + +```python +models = [ + HookedTransformer.from_pretrained("gpt2-small"), + HookedTransformer.from_pretrained("gpt2-medium"), + HookedTransformer.from_pretrained("pythia-1.4b"), +] + +trajectories = [] +for model in models: + client = HookedTransformerLLMClient(model=model, model_name=model.cfg.model_name) + + with ActivationCapture(model) as capture: + # Run same task + trajectory = run_episode(env, client, capture) + trajectories.append(trajectory) + +# Compare attention patterns, activation magnitudes, etc. +compare_trajectories(trajectories) +``` + +## Advanced Examples + +### Example: Finding "Code Understanding" Circuits + +```python +import torch +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import * +from ares.environments import swebench_env + +async def find_code_understanding_circuits(): + model = HookedTransformer.from_pretrained("gpt2-medium") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-medium") + + # 1. Collect baseline trajectory on a task + tasks = swebench_env.swebench_verified_tasks() + success_task = [t for t in tasks if is_success(t)][0] + + with ActivationCapture(model) as capture: + baseline_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client, capture) + + baseline_trajectory = capture.get_trajectory() + + # 2. For each layer and step, ablate and measure impact + impact_matrix = torch.zeros(model.cfg.n_layers, len(baseline_trajectory)) + + for layer in range(model.cfg.n_layers): + for step in range(len(baseline_trajectory)): + manager = InterventionManager(model) + manager.add_intervention( + hook_name=f"blocks.{layer}.hook_resid_post", + hook_fn=create_mean_ablation_hook(), + apply_at_steps=[step] + ) + + with manager: + perturbed_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client) + + impact_matrix[layer, step] = abs(baseline_reward - perturbed_reward) + + # 3. Identify critical (layer, step) pairs + critical_components = (impact_matrix > threshold).nonzero() + + print(f"Found {len(critical_components)} critical components") + + # 4. Drill down to attention heads in critical layers + for layer, step in critical_components[:10]: # Top 10 + for head in range(model.cfg.n_heads): + manager = InterventionManager(model) + manager.add_intervention( + hook_name=f"blocks.{layer}.attn.hook_result", + hook_fn=create_zero_ablation_hook(heads=[head]), + apply_at_steps=[step.item()] + ) + + with manager: + head_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client) + + if abs(baseline_reward - head_reward) > head_threshold: + print(f"Critical: Layer {layer}, Step {step}, Head {head}") + + # Visualize attention pattern + attn_pattern = baseline_trajectory.get_activation(step, f"blocks.{layer}.attn.hook_pattern") + visualize_attention(attn_pattern[:, head, :, :]) + +# Run analysis +asyncio.run(find_code_understanding_circuits()) +``` + +## Best Practices + +1. **Start Small**: Use small models (gpt2-small, gpt2-medium) for initial exploration +2. **Limit Steps**: Use `max_steps` during development to avoid long-running experiments +3. **Save Often**: Save trajectories immediately after capture for later analysis +4. **Filter Hooks**: Use `hook_filter` in ActivationCapture to reduce memory usage +5. **GPU Management**: Move activations to CPU after capture: `.detach().cpu()` +6. **Batch Analysis**: Process multiple trajectories in batch when possible + +## Performance Tips + +**Memory Optimization:** +```python +# Only capture specific hooks +capture = ActivationCapture( + model, + hook_filter=lambda name: "attn.hook_pattern" in name or "hook_resid" in name +) + +# Clear trajectory periodically +if len(capture.step_activations) > 100: + trajectory = capture.get_trajectory() + trajectory.save(f"./checkpoints/step_{step}") + capture.clear() +``` + +**Speed Optimization:** +```python +# Use smaller models for initial exploration +model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") + +# Reduce max_new_tokens during ablation studies +client = HookedTransformerLLMClient(model=model, max_new_tokens=256) + +# Use torch.no_grad() (already done in client, but good to remember) +with torch.no_grad(): + ... +``` + +## Comparison to Traditional MI + +| Traditional MI | ARES + Mech Interp | +|----------------|---------------------| +| Single prompts | Multi-step trajectories (50-100+ steps) | +| Static analysis | Dynamic, evolving state | +| Local causality | Long-horizon dependencies | +| Fixed context | Context grows with episode | +| Immediate outcomes | Delayed rewards | + +## Resources + +- [TransformerLens Documentation](https://neelnanda-io.github.io/TransformerLens/) +- [Anthropic's Circuits Thread](https://transformer-circuits.pub/) +- [Example Notebook: Trajectory-Level Analysis](./notebooks/trajectory_analysis.ipynb) *(coming soon)* +- [Blog Post: Beyond Static MI](https://withmartian.com/post/beyond-static-mechanistic-interpretability-agentic-long-horizon-tasks-as-the-next-frontier) + +## Citation + +If you use this module in your research, please cite: + +```bibtex +@software{ares_mech_interp_2025, + title = {ARES Mechanistic Interpretability Module}, + author = {Martian}, + year = {2025}, + url = {https://github.com/anthropics/ares} +} +``` diff --git a/src/ares/contrib/mech_interp/__init__.py b/src/ares/contrib/mech_interp/__init__.py new file mode 100644 index 0000000..429bf28 --- /dev/null +++ b/src/ares/contrib/mech_interp/__init__.py @@ -0,0 +1,22 @@ +"""Mechanistic interpretability utilities for ARES. + +This module provides tools for analyzing agent behavior using mechanistic interpretability +techniques, with deep integration with TransformerLens for studying model internals across +long-horizon agent trajectories. +""" + +from ares.contrib.mech_interp.activation_capture import ActivationCapture +from ares.contrib.mech_interp.activation_capture import TrajectoryActivations +from ares.contrib.mech_interp.hook_utils import InterventionManager +from ares.contrib.mech_interp.hook_utils import create_path_patching_hook +from ares.contrib.mech_interp.hook_utils import create_zero_ablation_hook +from ares.contrib.mech_interp.hooked_transformer_client import HookedTransformerLLMClient + +__all__ = [ + "ActivationCapture", + "HookedTransformerLLMClient", + "InterventionManager", + "TrajectoryActivations", + "create_path_patching_hook", + "create_zero_ablation_hook", +] diff --git a/src/ares/contrib/mech_interp/activation_capture.py b/src/ares/contrib/mech_interp/activation_capture.py new file mode 100644 index 0000000..0d954a9 --- /dev/null +++ b/src/ares/contrib/mech_interp/activation_capture.py @@ -0,0 +1,281 @@ +"""Utilities for capturing and analyzing activations across agent trajectories.""" + +from collections.abc import Callable +import dataclasses +import pathlib +from typing import Any + +import torch +from transformer_lens import ActivationCache +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint + + +@dataclasses.dataclass +class TrajectoryActivations: + """Container for activations captured across an agent trajectory. + + Attributes: + step_activations: List of ActivationCache objects, one per agent step. + step_metadata: Optional metadata for each step (e.g., observation, action). + model_name: Name of the model used. + """ + + step_activations: list[ActivationCache] + step_metadata: list[dict[str, Any]] = dataclasses.field(default_factory=list) + model_name: str = "unknown" + + def __len__(self) -> int: + """Return number of steps in the trajectory.""" + return len(self.step_activations) + + def save(self, path: str | pathlib.Path) -> None: + """Save trajectory activations to disk. + + Args: + path: Directory path to save activations. Will be created if it doesn't exist. + """ + path = pathlib.Path(path) + path.mkdir(parents=True, exist_ok=True) + + # Save each step's activations as plain dicts (ActivationCache can't be pickled) + for i, cache in enumerate(self.step_activations): + # Convert ActivationCache to dict for serialization + cache_dict = dict(cache.cache_dict.items()) + torch.save(cache_dict, path / f"step_{i:04d}.pt") + + # Save metadata + import json + + metadata = { + "model_name": self.model_name, + "num_steps": len(self), + "step_metadata": self.step_metadata, + } + with open(path / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + @classmethod + def load(cls, path: str | pathlib.Path) -> "TrajectoryActivations": + """Load trajectory activations from disk. + + Args: + path: Directory path containing saved activations. + + Returns: + TrajectoryActivations instance. + """ + import json + + path = pathlib.Path(path) + + # Load metadata + with open(path / "metadata.json") as f: + metadata = json.load(f) + + # Load activation dicts (we saved as plain dicts, not ActivationCache objects) + step_files = sorted(path.glob("step_*.pt")) + step_activations_dicts = [torch.load(f, weights_only=False) for f in step_files] + + # Convert back to ActivationCache objects if needed, or just use dicts + # For now, we'll keep them as ActivationCache objects for API compatibility + step_activations = [ + ActivationCache(cache_dict, model=None) for cache_dict in step_activations_dicts + ] + + return cls( + step_activations=step_activations, + step_metadata=metadata.get("step_metadata", []), + model_name=metadata.get("model_name", "unknown"), + ) + + def get_activation(self, step: int, hook_name: str) -> torch.Tensor: + """Get activation tensor for a specific step and hook. + + Args: + step: Step index in trajectory. + hook_name: Name of the hook point (e.g., "blocks.0.attn.hook_pattern"). + + Returns: + Activation tensor for the specified step and hook. + """ + return self.step_activations[step][hook_name] + + def get_activation_across_trajectory(self, hook_name: str) -> list[torch.Tensor]: + """Get activations for a specific hook across all trajectory steps. + + Args: + hook_name: Name of the hook point. + + Returns: + List of activation tensors, one per step. + """ + return [cache[hook_name] for cache in self.step_activations] + + +class ActivationCapture: + """Context manager for capturing activations during agent execution. + + This class provides a convenient way to capture activations from a HookedTransformer + during an ARES agent episode, enabling trajectory-level mechanistic interpretability + analysis. + + Example: + ```python + from transformer_lens import HookedTransformer + from ares.contrib.mech_interp import ActivationCapture, HookedTransformerLLMClient + + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + + # Capture activations during episode + with ActivationCapture(model) as capture: + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + capture.record_step_metadata({"action": response}) + ts = await env.step(response) + + # Analyze captured activations + trajectory = capture.get_trajectory() + print(f"Captured {len(trajectory)} steps") + + # Save for later analysis + trajectory.save("./activations/episode_001") + ``` + """ + + def __init__( + self, + model: HookedTransformer, + hook_filter: Callable[[str], bool] | None = None, + ): + """Initialize activation capture. + + Args: + model: HookedTransformer to capture activations from. + hook_filter: Optional function to filter which hooks to capture. + If None, captures all hooks. Example: lambda name: "attn" in name + """ + self.model = model + self.hook_filter = hook_filter + self.step_activations: list[ActivationCache] = [] + self.step_metadata: list[dict[str, Any]] = [] + # TODO: What is the purpose of hook_handles? Prob should remove + self._hook_handles: list[Any] = [] + + # TODO: Does this need to be a context manager?? + def __enter__(self) -> "ActivationCapture": + """Enter context manager and start capturing activations.""" + # Register hooks to capture activations + for name, hp in self.model.hook_dict.items(): + if self.hook_filter is None or self.hook_filter(name): + handle = hp.add_hook(self._make_capture_hook(name)) + self._hook_handles.append(handle) + + self._current_step_cache: dict[str, torch.Tensor] = {} + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager and clean up hooks.""" + # Remove all hooks + for handle in self._hook_handles: + if handle is not None: + handle.remove() + self._hook_handles.clear() + + def _make_capture_hook(self, hook_name: str) -> Callable: + """Create a hook function that captures activations.""" + + def capture_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + # Store a copy of the activation (detached from computation graph) + self._current_step_cache[hook_name] = activation.detach().cpu() + return activation + + return capture_hook + + def start_step(self) -> None: + """Mark the start of a new agent step.""" + self._current_step_cache = {} + + def end_step(self) -> None: + """Mark the end of an agent step and save captured activations.""" + if self._current_step_cache: + # Convert dict to ActivationCache + cache = ActivationCache( + self._current_step_cache.copy(), + self.model, + ) + self.step_activations.append(cache) + self._current_step_cache = {} + + def record_step_metadata(self, metadata: dict[str, Any]) -> None: + """Record metadata for the current/last step. + + Args: + metadata: Dictionary of metadata to associate with this step. + """ + self.step_metadata.append(metadata) + + def get_trajectory(self) -> TrajectoryActivations: + """Get the complete trajectory of captured activations. + + Returns: + TrajectoryActivations containing all captured steps. + """ + return TrajectoryActivations( + step_activations=self.step_activations.copy(), + step_metadata=self.step_metadata.copy(), + model_name=self.model.cfg.model_name, + ) + + def clear(self) -> None: + """Clear all captured activations and metadata.""" + self.step_activations.clear() + self.step_metadata.clear() + self._current_step_cache = {} + + +def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture: + """Create an ActivationCapture that automatically records steps during generation. + + This wraps the model's generate method to automatically call start_step() and + end_step() around each generation, making it seamless to use with ARES environments. + + Args: + model: HookedTransformer to capture activations from. + + Returns: + ActivationCapture instance with automatic step tracking. + + Example: + ```python + model = HookedTransformer.from_pretrained("gpt2-small") + + with automatic_activation_capture(model) as capture: + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + # Now activations are captured automatically during client calls + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + + trajectory = capture.get_trajectory() + ``` + """ + capture = ActivationCapture(model) + + # Wrap model.generate to auto-capture + original_generate = model.generate + + def wrapped_generate(*args, **kwargs): + capture.start_step() + result = original_generate(*args, **kwargs) + capture.end_step() + return result + + model.generate = wrapped_generate + + return capture diff --git a/src/ares/contrib/mech_interp/hook_utils.py b/src/ares/contrib/mech_interp/hook_utils.py new file mode 100644 index 0000000..b950cf0 --- /dev/null +++ b/src/ares/contrib/mech_interp/hook_utils.py @@ -0,0 +1,307 @@ +"""Hook utilities for mechanistic interpretability interventions.""" + +from collections.abc import Callable +from typing import Any + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint + + +def create_zero_ablation_hook( + positions: list[int] | None = None, + heads: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that zeros out specific positions or attention heads. + + Args: + positions: Optional list of token positions to ablate. If None, ablates all. + heads: Optional list of attention head indices to ablate (for attention patterns). + + Returns: + Hook function that performs zero ablation. + + Example: + ```python + # Ablate positions 5-10 in layer 0 residual stream + hook = create_zero_ablation_hook(positions=list(range(5, 11))) + model.run_with_hooks( + tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def zero_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + if heads is not None: + # For attention patterns: [batch, head, query_pos, key_pos] + if len(ablated.shape) == 4: + ablated[:, heads, :, :] = 0.0 + # For attention outputs: [batch, pos, head_index, d_head] + elif len(ablated.shape) == 4: + ablated[:, :, heads, :] = 0.0 + elif positions is not None: + # For residual stream or other positional activations + ablated[:, positions, :] = 0.0 + else: + # Ablate everything + ablated = torch.zeros_like(ablated) + + return ablated + + return zero_ablation_hook + + +def create_path_patching_hook( + clean_activation: torch.Tensor, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook for activation patching (path patching). + + This replaces activations from a corrupted run with those from a clean run, + enabling causal analysis of information flow. + + Args: + clean_activation: Activation tensor from the clean run to patch in. + positions: Optional list of positions to patch. If None, patches all positions. + + Returns: + Hook function that performs activation patching. + + Example: + ```python + # First, run on clean input and cache activations + clean_cache, _ = model.run_with_cache(clean_tokens) + clean_resid = clean_cache["blocks.0.hook_resid_post"] + + # Then run on corrupted input with patching + hook = create_path_patching_hook(clean_resid, positions=[5, 6, 7]) + corrupted_logits = model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def path_patching_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + patched = activation.clone() + + if positions is not None: + patched[:, positions, :] = clean_activation[:, positions, :] + else: + patched = clean_activation.clone() + + return patched + + return path_patching_hook + + +def create_mean_ablation_hook( + mean_activation: torch.Tensor | None = None, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that replaces activations with their mean. + + Mean ablation is often more realistic than zero ablation as it preserves + the scale of activations. + + Args: + mean_activation: Pre-computed mean activation. If None, computes mean on-the-fly. + positions: Optional list of positions to ablate. + + Returns: + Hook function that performs mean ablation. + """ + + def mean_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + # Compute mean across batch and position dimensions if not provided + mean = activation.mean(dim=(0, 1), keepdim=True) if mean_activation is None else mean_activation + + if positions is not None: + ablated[:, positions, :] = mean.expand_as(ablated[:, positions, :]) + else: + ablated = mean.expand_as(ablated) + + return ablated + + return mean_ablation_hook + + +class InterventionManager: + """Manager for applying multiple interventions during agent execution. + + This class helps coordinate multiple hook-based interventions across an agent + trajectory, making it easy to study causal effects of different model components. + + Example: + ```python + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + + # Create intervention manager + manager = InterventionManager(model) + + # Add interventions + manager.add_intervention( + hook_name="blocks.0.attn.hook_pattern", + hook_fn=create_zero_ablation_hook(heads=[0, 1]), + description="Ablate attention heads 0-1 in layer 0" + ) + + # Run with interventions active + with manager: + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + ``` + """ + + def __init__(self, model: HookedTransformer): + """Initialize intervention manager. + + Args: + model: HookedTransformer to apply interventions to. + """ + self.model = model + self.interventions: list[dict[str, Any]] = [] + self._active_handles: list[Any] = [] + + def add_intervention( + self, + hook_name: str, + hook_fn: Callable[[torch.Tensor, HookPoint], torch.Tensor], + description: str = "", + apply_at_steps: list[int] | None = None, + ) -> None: + """Add an intervention to apply during execution. + + Args: + hook_name: Name of the hook point (e.g., "blocks.0.hook_resid_post"). + hook_fn: Hook function to apply. + description: Human-readable description of the intervention. + apply_at_steps: Optional list of step indices when to apply this intervention. + If None, applies at all steps. + """ + self.interventions.append( + { + "hook_name": hook_name, + "hook_fn": hook_fn, + "description": description, + "apply_at_steps": apply_at_steps, + "step_count": 0, + } + ) + + def clear_interventions(self) -> None: + """Remove all interventions.""" + self.interventions.clear() + + def __enter__(self) -> "InterventionManager": + """Enter context manager and activate interventions.""" + for intervention in self.interventions: + hook_point = self.model.hook_dict[intervention["hook_name"]] + + # Wrap hook_fn to track step count and properly capture loop variables + def make_wrapped_hook(interv): + # Capture these from the intervention dict to avoid loop variable binding issues + hook_fn = interv["hook_fn"] + steps = interv["apply_at_steps"] + + def wrapped_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: + # Check if we should apply this intervention at current step + if steps is None or interv["step_count"] in steps: + return hook_fn(activation, hook) + return activation + + return wrapped_hook + + handle = hook_point.add_hook(make_wrapped_hook(intervention)) + self._active_handles.append(handle) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager and deactivate interventions.""" + for handle in self._active_handles: + if handle is not None: + handle.remove() + self._active_handles.clear() + + def increment_step(self) -> None: + """Increment the step counter for all interventions. + + Call this between agent steps to track which step you're on. + """ + for intervention in self.interventions: + intervention["step_count"] += 1 + + def get_intervention_summary(self) -> str: + """Get a summary of all active interventions. + + Returns: + Human-readable summary string. + """ + if not self.interventions: + return "No interventions active" + + lines = ["Active interventions:"] + for i, interv in enumerate(self.interventions, 1): + desc = interv["description"] or "No description" + hook = interv["hook_name"] + steps = interv["apply_at_steps"] + step_str = f"steps {steps}" if steps else "all steps" + lines.append(f" {i}. {desc} ({hook}, {step_str})") + + return "\n".join(lines) + + +def create_attention_knockout_hook( + source_positions: list[int], + target_positions: list[int], +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that prevents attention from source to target positions. + + This is useful for studying information flow by blocking specific attention patterns. + + Args: + source_positions: Token positions that should not attend to targets. + target_positions: Token positions that should not be attended to. + + Returns: + Hook function that modifies attention patterns. + + Example: + ```python + # Prevent position 10 from attending to positions 0-5 + hook = create_attention_knockout_hook( + source_positions=[10], + target_positions=list(range(6)) + ) + model.run_with_hooks( + tokens, + fwd_hooks=[("blocks.0.attn.hook_pattern", hook)] + ) + ``` + """ + + def attention_knockout_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + # Attention pattern shape: [batch, head, query_pos, key_pos] + modified = activation.clone() + + # Set attention weights to zero (or very small value) + for source_pos in source_positions: + for target_pos in target_positions: + modified[:, :, source_pos, target_pos] = 0.0 + + # Renormalize attention patterns (they should sum to 1 across key dimension) + modified = modified / (modified.sum(dim=-1, keepdim=True) + 1e-10) + + return modified + + return attention_knockout_hook diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py new file mode 100644 index 0000000..d9aeb7f --- /dev/null +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -0,0 +1,215 @@ +"""LLM client implementation using TransformerLens HookedTransformer.""" + +from collections.abc import Callable, Sequence +import dataclasses +import time +from typing import Any +import uuid + +import openai.types.chat.chat_completion +import openai.types.chat.chat_completion_message +import openai.types.completion_usage +import torch +from transformer_lens import HookedTransformer + +from ares.llms import llm_clients + + +@dataclasses.dataclass +class HookedTransformerLLMClient: + """LLM client that uses TransformerLens HookedTransformer for inference. + + This client enables mechanistic interpretability research by providing access to + intermediate activations and allowing hook-based interventions during agent execution. + + Args: + model: A TransformerLens HookedTransformer instance. + model_name: Name of the model (for logging/identification). + max_new_tokens: Maximum number of tokens to generate per completion. + generation_kwargs: Additional keyword arguments passed to model.generate(). + format_messages_fn: Optional function to convert chat messages to model input. + If None, uses a simple concatenation of message contents. + + Example: + ```python + from transformer_lens import HookedTransformer + + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient( + model=model, + model_name="gpt2-small", + max_new_tokens=512 + ) + + # Use with ARES environments + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + ``` + """ + + model: HookedTransformer + model_name: str + # TODO: Identify better default max_new_tokens size + max_new_tokens: int = 2048 + generation_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + _format_messages_fn: Callable[[Sequence[Any]], str] | None = dataclasses.field(default=None, repr=False) + + @property + def format_messages_fn(self) -> Callable[[Sequence[Any]], str]: + """Get the message formatting function.""" + if self._format_messages_fn is None: + return self._default_format_messages + return self._format_messages_fn + + @staticmethod + def _default_format_messages(messages: Sequence[Any]) -> str: + """Default message formatter that concatenates all message contents.""" + formatted_parts = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + formatted_parts.append(f"{role.upper()}: {content}") + return "\n\n".join(formatted_parts) + + async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResponse: + """Generate a completion using the HookedTransformer. + + Args: + request: LLM request containing messages and optional temperature. + + Returns: + LLM response with chat completion and cost information. + """ + # Format messages into text + messages_list = list(request.messages) + input_text = self.format_messages_fn(messages_list) + + # Tokenize input + input_ids = self.model.to_tokens(input_text, prepend_bos=True) + num_input_tokens = input_ids.shape[-1] + + # TODO: Need to support various truncation methods + # Truncate if input + max_new_tokens would exceed model's context window + max_position = self.model.cfg.n_ctx + if num_input_tokens + self.max_new_tokens > max_position: + # Leave room for generation + max_input_tokens = max_position - self.max_new_tokens + input_ids = input_ids[:, :max_input_tokens] + num_input_tokens = input_ids.shape[-1] + + # Prepare generation kwargs + gen_kwargs = { + "max_new_tokens": self.max_new_tokens, + **self.generation_kwargs, + } + + # TODO: This should be more generic - why temperature specifically? + # Add temperature if specified + if request.temperature is not None: + gen_kwargs["temperature"] = request.temperature + + # TODO: We should make use of the `__call__(..., return_type="logits | loss")` here instead? + # Generate completion + # Note: HookedTransformer.generate returns full sequence including input + with torch.no_grad(): + outputs = self.model.generate( + input_ids, + **gen_kwargs, + ) + + # Extract only the generated tokens + num_output_tokens = outputs.shape[-1] - num_input_tokens + output_ids = outputs[0, num_input_tokens:] + + # Decode output + output_text = self.model.to_string(output_ids) + + # Construct OpenAI-compatible response + chat_completion = openai.types.chat.chat_completion.ChatCompletion( + id=str(uuid.uuid4()), + choices=[ + openai.types.chat.chat_completion.Choice( + message=openai.types.chat.chat_completion_message.ChatCompletionMessage( + content=output_text, + role="assistant", + ), + finish_reason="stop", + index=0, + ) + ], + created=int(time.time()), + model=self.model_name, + object="chat.completion", + usage=openai.types.completion_usage.CompletionUsage( + prompt_tokens=num_input_tokens, + completion_tokens=num_output_tokens, + total_tokens=num_input_tokens + num_output_tokens, + ), + ) + + return llm_clients.LLMResponse( + chat_completion_response=chat_completion, + cost=0.0, # Local inference has no cost + ) + + +def create_hooked_transformer_client_with_chat_template( + model: HookedTransformer, + tokenizer: Any, + model_name: str, + max_new_tokens: int = 2048, + **generation_kwargs: Any, +) -> HookedTransformerLLMClient: + """Create a HookedTransformerLLMClient with proper chat template formatting. + + This factory function creates a client that uses the tokenizer's chat template + (like Qwen2.5, Llama, etc.) to properly format messages. + + Args: + model: TransformerLens HookedTransformer instance. + tokenizer: HuggingFace tokenizer with apply_chat_template method. + model_name: Name of the model. + max_new_tokens: Maximum tokens to generate. + **generation_kwargs: Additional arguments for generation. + + Returns: + Configured HookedTransformerLLMClient instance. + + Example: + ```python + from transformer_lens import HookedTransformer + from transformers import AutoTokenizer + + model = HookedTransformer.from_pretrained("gpt2-small") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + client = create_hooked_transformer_client_with_chat_template( + model=model, + tokenizer=tokenizer, + model_name="gpt2-small", + max_new_tokens=512, + temperature=0.7, + ) + ``` + """ + + def format_with_chat_template(messages: Sequence[Any]) -> str: + """Format messages using tokenizer's chat template.""" + # Apply chat template without tokenization (just get the text) + formatted = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + ) + return formatted + + return HookedTransformerLLMClient( + model=model, + model_name=model_name, + max_new_tokens=max_new_tokens, + generation_kwargs=generation_kwargs, + _format_messages_fn=format_with_chat_template, + ) From 90c5138168860ccb89c29fed9039ffdc73e1dff5 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Fri, 23 Jan 2026 14:02:06 -0800 Subject: [PATCH 2/5] In progress edits/improvements --- examples/03_mech_interp_hooked_transformer.py | 1 - src/ares/contrib/mech_interp/README.md | 108 +++--- .../contrib/mech_interp/activation_capture.py | 7 +- src/ares/contrib/mech_interp/hook_utils.py | 310 +----------------- .../mech_interp/hooked_transformer_client.py | 92 ++++-- 5 files changed, 131 insertions(+), 387 deletions(-) diff --git a/examples/03_mech_interp_hooked_transformer.py b/examples/03_mech_interp_hooked_transformer.py index 4e17b19..bf56c21 100644 --- a/examples/03_mech_interp_hooked_transformer.py +++ b/examples/03_mech_interp_hooked_transformer.py @@ -41,7 +41,6 @@ async def main(): # gpt2-small has max context of 1024 tokens, so we need to be conservative client = HookedTransformerLLMClient( model=model, - model_name="gpt2-small", max_new_tokens=128, # Keep this small to avoid context overflow ) diff --git a/src/ares/contrib/mech_interp/README.md b/src/ares/contrib/mech_interp/README.md index 12f2945..e0674be 100644 --- a/src/ares/contrib/mech_interp/README.md +++ b/src/ares/contrib/mech_interp/README.md @@ -12,8 +12,9 @@ Traditional mechanistic interpretability focuses on static, single-step analysis ARES enables **trajectory-level mechanistic interpretability** by: 1. Capturing activations across entire agent episodes 2. Studying how internal representations evolve during multi-step reasoning -3. Identifying critical moments where interventions significantly alter outcomes -4. Analyzing information flow across dozens of inference steps +3. Identifying critical moments where interventions significantly alter episode-level outcomes +4. Seeing how activations differ across different agent frameworks for the same task +5. You tell us! ## Quick Start @@ -21,7 +22,9 @@ ARES enables **trajectory-level mechanistic interpretability** by: ```bash # Install ARES with mech_interp group (includes TransformerLens) -uv sync --group mech_interp +uv add ares[mech-interp] +# or with pip +pip install ares[mech-interp] ``` ### Basic Example @@ -35,7 +38,7 @@ from ares.environments import swebench_env async def main(): # Load model model = HookedTransformer.from_pretrained("gpt2-small") - client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + client = HookedTransformerLLMClient(model=model) # Run agent and capture activations tasks = swebench_env.swebench_verified_tasks()[:1] @@ -70,7 +73,6 @@ from ares.contrib.mech_interp import HookedTransformerLLMClient model = HookedTransformer.from_pretrained("gpt2-medium") client = HookedTransformerLLMClient( model=model, - model_name="gpt2-medium", max_new_tokens=1024, generation_kwargs={"temperature": 0.7} ) @@ -88,7 +90,6 @@ model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") client = create_hooked_transformer_client_with_chat_template( model=model, tokenizer=tokenizer, - model_name="Qwen/Qwen2.5-3B-Instruct", max_new_tokens=2048, ) ``` @@ -150,36 +151,51 @@ with automatic_activation_capture(model) as capture: ### 3. InterventionManager +TODO: This needs to rework, can just use step nums with better identifying of interesting behavior + Apply causal interventions during agent execution to study model behavior. ```python +from transformer_lens import HookedTransformer from ares.contrib.mech_interp import ( - InterventionManager, - create_zero_ablation_hook, - create_path_patching_hook, + HookedTransformerLLMClient, + FullyObservableState, ) -manager = InterventionManager(model) +async def has_readme(state: FullyObservableState) -> str | None: + # Function to only runs hooks when the environment contains a README + cmd_output = await state.container.exec_run("ls", workdir="/path/to/app") + if "README.md" in cmd_output.output and state.step_num > 1: + return "blocks.1.attn.hook_pattern" + else: + return None -# Ablate attention heads -manager.add_intervention( - hook_name="blocks.0.attn.hook_result", - hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), - description="Ablate heads 0-2 in layer 0", - apply_at_steps=[5, 6, 7] # Optional: only at specific steps +model = HookedTransformer.from_pretrained("gpt2-medium") +client = HookedTransformerLLMClient( + model=model, + fwd_hooks=[ + # The mean ablation hook is defined exactly the same as classic transformer_lens + ("blocks.1.attn.hook_pattern@step2", mean_ablation_hook), + ], + max_new_tokens=1024, + generation_kwargs={"temperature": 0.7} ) -# Run with interventions -with manager: - async with env: - ts = await env.reset() - step_count = 0 - while not ts.last(): - action = await client(ts.observation) - ts = await env.step(action) - step_count += 1 - manager.increment_step() # Track which step we're on +# Run with interventions +async with env: + ts = await env.reset() + + while not ts.last(): + action = await client( + ts.observation, + # Should also include the env and timestep here so that + # we can reference the env in the has_readme function. + # Do not need to pass if hooks do not depend on state + env=env, + timestep=ts, + ) + ts = await env.step(action) ``` ### 4. Hook Utilities @@ -341,7 +357,7 @@ models = [ trajectories = [] for model in models: - client = HookedTransformerLLMClient(model=model, model_name=model.cfg.model_name) + client = HookedTransformerLLMClient(model=model) with ActivationCapture(model) as capture: # Run same task @@ -364,7 +380,7 @@ from ares.environments import swebench_env async def find_code_understanding_circuits(): model = HookedTransformer.from_pretrained("gpt2-medium") - client = HookedTransformerLLMClient(model=model, model_name="gpt2-medium") + client = HookedTransformerLLMClient(model=model) # 1. Collect baseline trajectory on a task tasks = swebench_env.swebench_verified_tasks() @@ -417,25 +433,14 @@ async def find_code_understanding_circuits(): attn_pattern = baseline_trajectory.get_activation(step, f"blocks.{layer}.attn.hook_pattern") visualize_attention(attn_pattern[:, head, :, :]) -# Run analysis -asyncio.run(find_code_understanding_circuits()) ``` -## Best Practices - -1. **Start Small**: Use small models (gpt2-small, gpt2-medium) for initial exploration -2. **Limit Steps**: Use `max_steps` during development to avoid long-running experiments -3. **Save Often**: Save trajectories immediately after capture for later analysis -4. **Filter Hooks**: Use `hook_filter` in ActivationCapture to reduce memory usage -5. **GPU Management**: Move activations to CPU after capture: `.detach().cpu()` -6. **Batch Analysis**: Process multiple trajectories in batch when possible - ## Performance Tips **Memory Optimization:** ```python -# Only capture specific hooks -capture = ActivationCapture( +# Only capture specific activations +ActivationCapture( model, hook_filter=lambda name: "attn.hook_pattern" in name or "hook_resid" in name ) @@ -454,28 +459,13 @@ model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") # Reduce max_new_tokens during ablation studies client = HookedTransformerLLMClient(model=model, max_new_tokens=256) - -# Use torch.no_grad() (already done in client, but good to remember) -with torch.no_grad(): - ... ``` -## Comparison to Traditional MI - -| Traditional MI | ARES + Mech Interp | -|----------------|---------------------| -| Single prompts | Multi-step trajectories (50-100+ steps) | -| Static analysis | Dynamic, evolving state | -| Local causality | Long-horizon dependencies | -| Fixed context | Context grows with episode | -| Immediate outcomes | Delayed rewards | - ## Resources - [TransformerLens Documentation](https://neelnanda-io.github.io/TransformerLens/) -- [Anthropic's Circuits Thread](https://transformer-circuits.pub/) -- [Example Notebook: Trajectory-Level Analysis](./notebooks/trajectory_analysis.ipynb) *(coming soon)* -- [Blog Post: Beyond Static MI](https://withmartian.com/post/beyond-static-mechanistic-interpretability-agentic-long-horizon-tasks-as-the-next-frontier) +- Example Notebook: Trajectory-Level Analysis *(coming soon)* +- [Blog Post: Beyond Static Mechanistic Interpretability](https://withmartian.com/post/beyond-static-mechanistic-interpretability-agentic-long-horizon-tasks-as-the-next-frontier) ## Citation @@ -486,6 +476,6 @@ If you use this module in your research, please cite: title = {ARES Mechanistic Interpretability Module}, author = {Martian}, year = {2025}, - url = {https://github.com/anthropics/ares} + url = {https://github.com/withmartian/ares} } ``` diff --git a/src/ares/contrib/mech_interp/activation_capture.py b/src/ares/contrib/mech_interp/activation_capture.py index 0d954a9..50b6b4d 100644 --- a/src/ares/contrib/mech_interp/activation_capture.py +++ b/src/ares/contrib/mech_interp/activation_capture.py @@ -126,7 +126,7 @@ class ActivationCapture: from ares.contrib.mech_interp import ActivationCapture, HookedTransformerLLMClient model = HookedTransformer.from_pretrained("gpt2-small") - client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + client = HookedTransformerLLMClient(model=model) # Capture activations during episode with ActivationCapture(model) as capture: @@ -158,14 +158,13 @@ def __init__( hook_filter: Optional function to filter which hooks to capture. If None, captures all hooks. Example: lambda name: "attn" in name """ + # TODO: Should we store logits and loss as well? By default or via flag? self.model = model self.hook_filter = hook_filter self.step_activations: list[ActivationCache] = [] self.step_metadata: list[dict[str, Any]] = [] - # TODO: What is the purpose of hook_handles? Prob should remove self._hook_handles: list[Any] = [] - # TODO: Does this need to be a context manager?? def __enter__(self) -> "ActivationCapture": """Enter context manager and start capturing activations.""" # Register hooks to capture activations @@ -254,7 +253,7 @@ def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture: model = HookedTransformer.from_pretrained("gpt2-small") with automatic_activation_capture(model) as capture: - client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + client = HookedTransformerLLMClient(model=model) # Now activations are captured automatically during client calls async with env: ts = await env.reset() diff --git a/src/ares/contrib/mech_interp/hook_utils.py b/src/ares/contrib/mech_interp/hook_utils.py index b950cf0..ff95661 100644 --- a/src/ares/contrib/mech_interp/hook_utils.py +++ b/src/ares/contrib/mech_interp/hook_utils.py @@ -1,307 +1,13 @@ """Hook utilities for mechanistic interpretability interventions.""" -from collections.abc import Callable -from typing import Any +import dataclasses -import torch -from transformer_lens import HookedTransformer -from transformer_lens.hook_points import HookPoint +from ares.containers import containers +from ares.environments.base import TimeStep -def create_zero_ablation_hook( - positions: list[int] | None = None, - heads: list[int] | None = None, -) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: - """Create a hook that zeros out specific positions or attention heads. - - Args: - positions: Optional list of token positions to ablate. If None, ablates all. - heads: Optional list of attention head indices to ablate (for attention patterns). - - Returns: - Hook function that performs zero ablation. - - Example: - ```python - # Ablate positions 5-10 in layer 0 residual stream - hook = create_zero_ablation_hook(positions=list(range(5, 11))) - model.run_with_hooks( - tokens, - fwd_hooks=[("blocks.0.hook_resid_post", hook)] - ) - ``` - """ - - def zero_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 - ablated = activation.clone() - - if heads is not None: - # For attention patterns: [batch, head, query_pos, key_pos] - if len(ablated.shape) == 4: - ablated[:, heads, :, :] = 0.0 - # For attention outputs: [batch, pos, head_index, d_head] - elif len(ablated.shape) == 4: - ablated[:, :, heads, :] = 0.0 - elif positions is not None: - # For residual stream or other positional activations - ablated[:, positions, :] = 0.0 - else: - # Ablate everything - ablated = torch.zeros_like(ablated) - - return ablated - - return zero_ablation_hook - - -def create_path_patching_hook( - clean_activation: torch.Tensor, - positions: list[int] | None = None, -) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: - """Create a hook for activation patching (path patching). - - This replaces activations from a corrupted run with those from a clean run, - enabling causal analysis of information flow. - - Args: - clean_activation: Activation tensor from the clean run to patch in. - positions: Optional list of positions to patch. If None, patches all positions. - - Returns: - Hook function that performs activation patching. - - Example: - ```python - # First, run on clean input and cache activations - clean_cache, _ = model.run_with_cache(clean_tokens) - clean_resid = clean_cache["blocks.0.hook_resid_post"] - - # Then run on corrupted input with patching - hook = create_path_patching_hook(clean_resid, positions=[5, 6, 7]) - corrupted_logits = model.run_with_hooks( - corrupted_tokens, - fwd_hooks=[("blocks.0.hook_resid_post", hook)] - ) - ``` - """ - - def path_patching_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 - patched = activation.clone() - - if positions is not None: - patched[:, positions, :] = clean_activation[:, positions, :] - else: - patched = clean_activation.clone() - - return patched - - return path_patching_hook - - -def create_mean_ablation_hook( - mean_activation: torch.Tensor | None = None, - positions: list[int] | None = None, -) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: - """Create a hook that replaces activations with their mean. - - Mean ablation is often more realistic than zero ablation as it preserves - the scale of activations. - - Args: - mean_activation: Pre-computed mean activation. If None, computes mean on-the-fly. - positions: Optional list of positions to ablate. - - Returns: - Hook function that performs mean ablation. - """ - - def mean_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 - ablated = activation.clone() - - # Compute mean across batch and position dimensions if not provided - mean = activation.mean(dim=(0, 1), keepdim=True) if mean_activation is None else mean_activation - - if positions is not None: - ablated[:, positions, :] = mean.expand_as(ablated[:, positions, :]) - else: - ablated = mean.expand_as(ablated) - - return ablated - - return mean_ablation_hook - - -class InterventionManager: - """Manager for applying multiple interventions during agent execution. - - This class helps coordinate multiple hook-based interventions across an agent - trajectory, making it easy to study causal effects of different model components. - - Example: - ```python - model = HookedTransformer.from_pretrained("gpt2-small") - client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") - - # Create intervention manager - manager = InterventionManager(model) - - # Add interventions - manager.add_intervention( - hook_name="blocks.0.attn.hook_pattern", - hook_fn=create_zero_ablation_hook(heads=[0, 1]), - description="Ablate attention heads 0-1 in layer 0" - ) - - # Run with interventions active - with manager: - async with env: - ts = await env.reset() - while not ts.last(): - response = await client(ts.observation) - ts = await env.step(response) - ``` - """ - - def __init__(self, model: HookedTransformer): - """Initialize intervention manager. - - Args: - model: HookedTransformer to apply interventions to. - """ - self.model = model - self.interventions: list[dict[str, Any]] = [] - self._active_handles: list[Any] = [] - - def add_intervention( - self, - hook_name: str, - hook_fn: Callable[[torch.Tensor, HookPoint], torch.Tensor], - description: str = "", - apply_at_steps: list[int] | None = None, - ) -> None: - """Add an intervention to apply during execution. - - Args: - hook_name: Name of the hook point (e.g., "blocks.0.hook_resid_post"). - hook_fn: Hook function to apply. - description: Human-readable description of the intervention. - apply_at_steps: Optional list of step indices when to apply this intervention. - If None, applies at all steps. - """ - self.interventions.append( - { - "hook_name": hook_name, - "hook_fn": hook_fn, - "description": description, - "apply_at_steps": apply_at_steps, - "step_count": 0, - } - ) - - def clear_interventions(self) -> None: - """Remove all interventions.""" - self.interventions.clear() - - def __enter__(self) -> "InterventionManager": - """Enter context manager and activate interventions.""" - for intervention in self.interventions: - hook_point = self.model.hook_dict[intervention["hook_name"]] - - # Wrap hook_fn to track step count and properly capture loop variables - def make_wrapped_hook(interv): - # Capture these from the intervention dict to avoid loop variable binding issues - hook_fn = interv["hook_fn"] - steps = interv["apply_at_steps"] - - def wrapped_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: - # Check if we should apply this intervention at current step - if steps is None or interv["step_count"] in steps: - return hook_fn(activation, hook) - return activation - - return wrapped_hook - - handle = hook_point.add_hook(make_wrapped_hook(intervention)) - self._active_handles.append(handle) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Exit context manager and deactivate interventions.""" - for handle in self._active_handles: - if handle is not None: - handle.remove() - self._active_handles.clear() - - def increment_step(self) -> None: - """Increment the step counter for all interventions. - - Call this between agent steps to track which step you're on. - """ - for intervention in self.interventions: - intervention["step_count"] += 1 - - def get_intervention_summary(self) -> str: - """Get a summary of all active interventions. - - Returns: - Human-readable summary string. - """ - if not self.interventions: - return "No interventions active" - - lines = ["Active interventions:"] - for i, interv in enumerate(self.interventions, 1): - desc = interv["description"] or "No description" - hook = interv["hook_name"] - steps = interv["apply_at_steps"] - step_str = f"steps {steps}" if steps else "all steps" - lines.append(f" {i}. {desc} ({hook}, {step_str})") - - return "\n".join(lines) - - -def create_attention_knockout_hook( - source_positions: list[int], - target_positions: list[int], -) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: - """Create a hook that prevents attention from source to target positions. - - This is useful for studying information flow by blocking specific attention patterns. - - Args: - source_positions: Token positions that should not attend to targets. - target_positions: Token positions that should not be attended to. - - Returns: - Hook function that modifies attention patterns. - - Example: - ```python - # Prevent position 10 from attending to positions 0-5 - hook = create_attention_knockout_hook( - source_positions=[10], - target_positions=list(range(6)) - ) - model.run_with_hooks( - tokens, - fwd_hooks=[("blocks.0.attn.hook_pattern", hook)] - ) - ``` - """ - - def attention_knockout_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 - # Attention pattern shape: [batch, head, query_pos, key_pos] - modified = activation.clone() - - # Set attention weights to zero (or very small value) - for source_pos in source_positions: - for target_pos in target_positions: - modified[:, :, source_pos, target_pos] = 0.0 - - # Renormalize attention patterns (they should sum to 1 across key dimension) - modified = modified / (modified.sum(dim=-1, keepdim=True) + 1e-10) - - return modified - - return attention_knockout_hook +@dataclasses.dataclass +class FullyObservableState: + timestep: TimeStep | None + container: containers.Container | None + step_num: int diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py index d9aeb7f..35ad752 100644 --- a/src/ares/contrib/mech_interp/hooked_transformer_client.py +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -2,19 +2,34 @@ from collections.abc import Callable, Sequence import dataclasses +import inspect import time -from typing import Any +from typing import Any, Protocol, Union, runtime_checkable import uuid import openai.types.chat.chat_completion import openai.types.chat.chat_completion_message import openai.types.completion_usage import torch -from transformer_lens import HookedTransformer +import transformer_lens +from ares.contrib.mech_interp import hook_utils +from ares.environments import base as environments from ares.llms import llm_clients +HookNameFn = Callable[[str], str] + + +@runtime_checkable +class StateIdFn(Protocol): + """A function that returns the `name` param of `add_hook`/`run_with_hooks` if the hook should be applied given + the current state - otherwise `None` if it should not. + """ + async def __call__(self, state: hook_utils.FullyObservableState) -> str | Callable[[str], bool] | None: + ... + + @dataclasses.dataclass class HookedTransformerLLMClient: """LLM client that uses TransformerLens HookedTransformer for inference. @@ -24,7 +39,6 @@ class HookedTransformerLLMClient: Args: model: A TransformerLens HookedTransformer instance. - model_name: Name of the model (for logging/identification). max_new_tokens: Maximum number of tokens to generate per completion. generation_kwargs: Additional keyword arguments passed to model.generate(). format_messages_fn: Optional function to convert chat messages to model input. @@ -37,7 +51,6 @@ class HookedTransformerLLMClient: model = HookedTransformer.from_pretrained("gpt2-small") client = HookedTransformerLLMClient( model=model, - model_name="gpt2-small", max_new_tokens=512 ) @@ -50,8 +63,9 @@ class HookedTransformerLLMClient: ``` """ - model: HookedTransformer - model_name: str + model: transformer_lens.HookedTransformer + # QUESTION: Is it better to include the "state" object in the name/id fn or in the hook fn? + fwd_hooks: list[tuple[Union[str, HookNameFn, StateIdFn], transformer_lens.hook_points.HookFunction]] | None = None # TODO: Identify better default max_new_tokens size max_new_tokens: int = 2048 generation_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -74,7 +88,44 @@ def _default_format_messages(messages: Sequence[Any]) -> str: formatted_parts.append(f"{role.upper()}: {content}") return "\n\n".join(formatted_parts) - async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResponse: + async def _call_with_hooks( + self, + input_ids: torch.Tensor, + state: hook_utils.FullyObservableState, + **gen_kwargs: Any, + ) -> torch.Tensor: + self.model.reset_hooks(direction="fwd", including_permanent=False) + + if self.fwd_hooks is not None: + for name_or_id_fn, hook_fn in self.fwd_hooks: + # Check if this is a StateIdFn, or if it is the standard format to run in all states + if inspect.iscoroutinefunction(name_or_id_fn): + name = await name_or_id_fn(state) + else: + name = name_or_id_fn + + self.model.add_hook(name, hook_fn, is_permanent=False) # type: ignore + + with torch.no_grad(): + # TODO: Should we make use of the `__call__(..., return_type="logits | loss")` here instead? + # Generate completion + # Note: HookedTransformer.generate returns full sequence including input + outputs = self.model.generate( + input_ids, + **gen_kwargs, + ) + + self.model.reset_hooks(direction="fwd", including_permanent=False) + + assert isinstance(outputs, torch.Tensor) # typing + return outputs + + async def __call__( + self, + request: llm_clients.LLMRequest, + env: environments.base.CodeBaseEnv | None = None, + timestep: environments.base.TimeStep | None = None, + ) -> llm_clients.LLMResponse: """Generate a completion using the HookedTransformer. Args: @@ -111,14 +162,16 @@ async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResp if request.temperature is not None: gen_kwargs["temperature"] = request.temperature - # TODO: We should make use of the `__call__(..., return_type="logits | loss")` here instead? - # Generate completion - # Note: HookedTransformer.generate returns full sequence including input - with torch.no_grad(): - outputs = self.model.generate( - input_ids, - **gen_kwargs, - ) + outputs = await self._call_with_hooks( + input_ids, + state=hook_utils.FullyObservableState( + timestep=timestep, + # TODO: Figure out typing here + container=env.container if env is not None else None, + step_num=0 # TODO: How to calculate?, + ), + **gen_kwargs, + ) # Extract only the generated tokens num_output_tokens = outputs.shape[-1] - num_input_tokens @@ -126,6 +179,7 @@ async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResp # Decode output output_text = self.model.to_string(output_ids) + assert isinstance(output_text, str) # typing # Construct OpenAI-compatible response chat_completion = openai.types.chat.chat_completion.ChatCompletion( @@ -141,7 +195,7 @@ async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResp ) ], created=int(time.time()), - model=self.model_name, + model=self.model.cfg.model_name, object="chat.completion", usage=openai.types.completion_usage.CompletionUsage( prompt_tokens=num_input_tokens, @@ -157,9 +211,8 @@ async def __call__(self, request: llm_clients.LLMRequest) -> llm_clients.LLMResp def create_hooked_transformer_client_with_chat_template( - model: HookedTransformer, + model: transformer_lens.HookedTransformer, tokenizer: Any, - model_name: str, max_new_tokens: int = 2048, **generation_kwargs: Any, ) -> HookedTransformerLLMClient: @@ -171,7 +224,6 @@ def create_hooked_transformer_client_with_chat_template( Args: model: TransformerLens HookedTransformer instance. tokenizer: HuggingFace tokenizer with apply_chat_template method. - model_name: Name of the model. max_new_tokens: Maximum tokens to generate. **generation_kwargs: Additional arguments for generation. @@ -189,7 +241,6 @@ def create_hooked_transformer_client_with_chat_template( client = create_hooked_transformer_client_with_chat_template( model=model, tokenizer=tokenizer, - model_name="gpt2-small", max_new_tokens=512, temperature=0.7, ) @@ -208,7 +259,6 @@ def format_with_chat_template(messages: Sequence[Any]) -> str: return HookedTransformerLLMClient( model=model, - model_name=model_name, max_new_tokens=max_new_tokens, generation_kwargs=generation_kwargs, _format_messages_fn=format_with_chat_template, From d9ae1adeebeb6c9d7cf6cafb49e9ed7ddf74beb8 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Fri, 30 Jan 2026 11:32:02 -0800 Subject: [PATCH 3/5] Moved to new example number --- ...hooked_transformer.py => 07_mech_interp_hooked_transformer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{03_mech_interp_hooked_transformer.py => 07_mech_interp_hooked_transformer.py} (100%) diff --git a/examples/03_mech_interp_hooked_transformer.py b/examples/07_mech_interp_hooked_transformer.py similarity index 100% rename from examples/03_mech_interp_hooked_transformer.py rename to examples/07_mech_interp_hooked_transformer.py From c442c54308e06409222c796b11aaa7c7970e28e8 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Fri, 30 Jan 2026 11:54:48 -0800 Subject: [PATCH 4/5] Merge conflict updates --- examples/07_mech_interp_hooked_transformer.py | 86 ++++++++----------- src/ares/contrib/mech_interp/__init__.py | 6 -- .../mech_interp/hooked_transformer_client.py | 49 +++-------- 3 files changed, 51 insertions(+), 90 deletions(-) diff --git a/examples/07_mech_interp_hooked_transformer.py b/examples/07_mech_interp_hooked_transformer.py index bf56c21..cbc1446 100644 --- a/examples/07_mech_interp_hooked_transformer.py +++ b/examples/07_mech_interp_hooked_transformer.py @@ -17,11 +17,9 @@ from transformer_lens import HookedTransformer +import ares from ares.contrib.mech_interp import ActivationCapture from ares.contrib.mech_interp import HookedTransformerLLMClient -from ares.contrib.mech_interp import InterventionManager -from ares.contrib.mech_interp import create_zero_ablation_hook -from ares.environments import swebench_env async def main(): @@ -44,19 +42,11 @@ async def main(): max_new_tokens=128, # Keep this small to avoid context overflow ) - # Load SWE-bench tasks - all_tasks = swebench_env.swebench_verified_tasks() - tasks = [all_tasks[0]] # Just one task for demo - - print(f"\nRunning on task: {tasks[0].instance_id}") - print(f"Repository: {tasks[0].repo}") - print("-" * 80) - # Example 1: Basic execution with activation capture print("\n[Example 1] Running agent with activation capture...") print("-" * 80) - async with swebench_env.SweBenchEnv(tasks=tasks) as env: + async with ares.make("sbv-mswea:0") as env: # Set up activation capture with ActivationCapture(model) as capture: ts = await env.reset() @@ -76,14 +66,14 @@ async def main(): capture.record_step_metadata( { "step": step_count, - "action_preview": str(action.chat_completion_response.choices[0].message.content)[:50], + "action_preview": str(action.data[0].content)[:50], } ) - assert action.chat_completion_response.usage is not None - print( - f"Step {step_count}: Generated {action.chat_completion_response.usage.completion_tokens} tokens" - ) + if action.usage is not None: + print( + f"Step {step_count}: Generated {action.usage.generated_tokens} tokens" + ) # Step environment ts = await env.step(action) @@ -103,37 +93,37 @@ async def main(): print("\nSaving trajectory activations to ./mech_interp_demo/trajectory_001/") trajectory.save("./mech_interp_demo/trajectory_001") - # Example 2: Running with interventions - print("\n[Example 2] Running agent with attention head ablation...") - print("-" * 80) - - async with swebench_env.SweBenchEnv(tasks=tasks) as env: - # Set up intervention: ablate heads 0-2 in layer 0 - manager = InterventionManager(model) - manager.add_intervention( - hook_name="blocks.0.attn.hook_result", - hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), - description="Ablate attention heads 0-2 in layer 0", - ) - - print(manager.get_intervention_summary()) - - with manager: - ts = await env.reset() - step_count = 0 - max_steps = 2 # Limit steps for demo - - while not ts.last() and step_count < max_steps: - assert ts.observation is not None - action = await client(ts.observation) - - assert action.chat_completion_response.usage is not None - num_tokens = action.chat_completion_response.usage.completion_tokens - print(f"Step {step_count} (with ablation): Generated {num_tokens} tokens") - - ts = await env.step(action) - step_count += 1 - manager.increment_step() + # # Example 2: Running with interventions + # print("\n[Example 2] Running agent with attention head ablation...") + # print("-" * 80) + + # async with swebench_env.SweBenchEnv(tasks=tasks) as env: + # # Set up intervention: ablate heads 0-2 in layer 0 + # manager = InterventionManager(model) + # manager.add_intervention( + # hook_name="blocks.0.attn.hook_result", + # hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), + # description="Ablate attention heads 0-2 in layer 0", + # ) + + # print(manager.get_intervention_summary()) + + # with manager: + # ts = await env.reset() + # step_count = 0 + # max_steps = 2 # Limit steps for demo + + # while not ts.last() and step_count < max_steps: + # assert ts.observation is not None + # action = await client(ts.observation) + + # assert action.chat_completion_response.usage is not None + # num_tokens = action.chat_completion_response.usage.completion_tokens + # print(f"Step {step_count} (with ablation): Generated {num_tokens} tokens") + + # ts = await env.step(action) + # step_count += 1 + # manager.increment_step() print("\n" + "=" * 80) print("Demo complete!") diff --git a/src/ares/contrib/mech_interp/__init__.py b/src/ares/contrib/mech_interp/__init__.py index 429bf28..cf50f26 100644 --- a/src/ares/contrib/mech_interp/__init__.py +++ b/src/ares/contrib/mech_interp/__init__.py @@ -7,16 +7,10 @@ from ares.contrib.mech_interp.activation_capture import ActivationCapture from ares.contrib.mech_interp.activation_capture import TrajectoryActivations -from ares.contrib.mech_interp.hook_utils import InterventionManager -from ares.contrib.mech_interp.hook_utils import create_path_patching_hook -from ares.contrib.mech_interp.hook_utils import create_zero_ablation_hook from ares.contrib.mech_interp.hooked_transformer_client import HookedTransformerLLMClient __all__ = [ "ActivationCapture", "HookedTransformerLLMClient", - "InterventionManager", "TrajectoryActivations", - "create_path_patching_hook", - "create_zero_ablation_hook", ] diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py index 35ad752..dce4b0b 100644 --- a/src/ares/contrib/mech_interp/hooked_transformer_client.py +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -3,19 +3,15 @@ from collections.abc import Callable, Sequence import dataclasses import inspect -import time from typing import Any, Protocol, Union, runtime_checkable -import uuid -import openai.types.chat.chat_completion -import openai.types.chat.chat_completion_message -import openai.types.completion_usage import torch import transformer_lens from ares.contrib.mech_interp import hook_utils -from ares.environments import base as environments -from ares.llms import llm_clients +from ares.environments import base as ares_env +from ares.environments import code_env +from ares import llms HookNameFn = Callable[[str], str] @@ -122,10 +118,10 @@ async def _call_with_hooks( async def __call__( self, - request: llm_clients.LLMRequest, - env: environments.base.CodeBaseEnv | None = None, - timestep: environments.base.TimeStep | None = None, - ) -> llm_clients.LLMResponse: + request: llms.LLMRequest, + env: code_env.CodeEnvironment | None = None, + timestep: ares_env.TimeStep | None = None, + ) -> llms.LLMResponse: """Generate a completion using the HookedTransformer. Args: @@ -167,7 +163,7 @@ async def __call__( state=hook_utils.FullyObservableState( timestep=timestep, # TODO: Figure out typing here - container=env.container if env is not None else None, + container=env._container if env is not None else None, step_num=0 # TODO: How to calculate?, ), **gen_kwargs, @@ -181,34 +177,15 @@ async def __call__( output_text = self.model.to_string(output_ids) assert isinstance(output_text, str) # typing - # Construct OpenAI-compatible response - chat_completion = openai.types.chat.chat_completion.ChatCompletion( - id=str(uuid.uuid4()), - choices=[ - openai.types.chat.chat_completion.Choice( - message=openai.types.chat.chat_completion_message.ChatCompletionMessage( - content=output_text, - role="assistant", - ), - finish_reason="stop", - index=0, - ) - ], - created=int(time.time()), - model=self.model.cfg.model_name, - object="chat.completion", - usage=openai.types.completion_usage.CompletionUsage( + return llms.LLMResponse( + data=[llms.TextData(content=output_text)], + cost=0.0, # Local inference has no cost + usage=llms.Usage( prompt_tokens=num_input_tokens, - completion_tokens=num_output_tokens, - total_tokens=num_input_tokens + num_output_tokens, + generated_tokens=num_output_tokens, ), ) - return llm_clients.LLMResponse( - chat_completion_response=chat_completion, - cost=0.0, # Local inference has no cost - ) - def create_hooked_transformer_client_with_chat_template( model: transformer_lens.HookedTransformer, From 78396e5087a041ad7f191fe783dbc5d38d226973 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Fri, 30 Jan 2026 12:37:29 -0800 Subject: [PATCH 5/5] In progress updates --- examples/07_mech_interp_hooked_transformer.py | 75 ++--- src/ares/contrib/mech_interp/hook_utils.py | 259 ++++++++++++++++++ .../mech_interp/hooked_transformer_client.py | 4 +- 3 files changed, 301 insertions(+), 37 deletions(-) diff --git a/examples/07_mech_interp_hooked_transformer.py b/examples/07_mech_interp_hooked_transformer.py index cbc1446..39a5e3e 100644 --- a/examples/07_mech_interp_hooked_transformer.py +++ b/examples/07_mech_interp_hooked_transformer.py @@ -20,6 +20,9 @@ import ares from ares.contrib.mech_interp import ActivationCapture from ares.contrib.mech_interp import HookedTransformerLLMClient +from ares.contrib.mech_interp.hook_utils import InterventionManager, create_zero_ablation_hook + +from . import utils async def main(): @@ -70,10 +73,7 @@ async def main(): } ) - if action.usage is not None: - print( - f"Step {step_count}: Generated {action.usage.generated_tokens} tokens" - ) + utils.print_step(step_count, ts.observation, action) # Step environment ts = await env.step(action) @@ -93,37 +93,42 @@ async def main(): print("\nSaving trajectory activations to ./mech_interp_demo/trajectory_001/") trajectory.save("./mech_interp_demo/trajectory_001") - # # Example 2: Running with interventions - # print("\n[Example 2] Running agent with attention head ablation...") - # print("-" * 80) - - # async with swebench_env.SweBenchEnv(tasks=tasks) as env: - # # Set up intervention: ablate heads 0-2 in layer 0 - # manager = InterventionManager(model) - # manager.add_intervention( - # hook_name="blocks.0.attn.hook_result", - # hook_fn=create_zero_ablation_hook(heads=[0, 1, 2]), - # description="Ablate attention heads 0-2 in layer 0", - # ) - - # print(manager.get_intervention_summary()) - - # with manager: - # ts = await env.reset() - # step_count = 0 - # max_steps = 2 # Limit steps for demo - - # while not ts.last() and step_count < max_steps: - # assert ts.observation is not None - # action = await client(ts.observation) - - # assert action.chat_completion_response.usage is not None - # num_tokens = action.chat_completion_response.usage.completion_tokens - # print(f"Step {step_count} (with ablation): Generated {num_tokens} tokens") - - # ts = await env.step(action) - # step_count += 1 - # manager.increment_step() + # Example 2: Running with interventions + print("\n[Example 2] Running agent with attention head ablation...") + print("-" * 80) + + def create_zero_ablation_hook_with_log(*args, **kwargs): + hook_fn = create_zero_ablation_hook(*args, **kwargs) + def wrapped_hook_fn(*args, **kwargs): + print(f"Running zero ablation hook") + return hook_fn(*args, **kwargs) + return wrapped_hook_fn + + async with ares.make("sbv-mswea:0") as env: + # Set up intervention: ablate heads 0-2 in layer 0 + manager = InterventionManager(model) + manager.add_intervention( + hook_name="blocks.0.attn.hook_result", + hook_fn=create_zero_ablation_hook_with_log(heads=[0, 1, 2]), + description="Ablate attention heads 0-2 in layer 0", + ) + + print(manager.get_intervention_summary()) + + with manager: + ts = await env.reset() + step_count = 0 + max_steps = 2 # Limit steps for demo + + while not ts.last() and step_count < max_steps: + assert ts.observation is not None + action = await client(ts.observation) + + utils.print_step(step_count, ts.observation, action) + + ts = await env.step(action) + step_count += 1 + manager.increment_step() print("\n" + "=" * 80) print("Demo complete!") diff --git a/src/ares/contrib/mech_interp/hook_utils.py b/src/ares/contrib/mech_interp/hook_utils.py index ff95661..514a251 100644 --- a/src/ares/contrib/mech_interp/hook_utils.py +++ b/src/ares/contrib/mech_interp/hook_utils.py @@ -1,6 +1,12 @@ """Hook utilities for mechanistic interpretability interventions.""" +from collections.abc import Callable import dataclasses +from typing import Any + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint from ares.containers import containers from ares.environments.base import TimeStep @@ -11,3 +17,256 @@ class FullyObservableState: timestep: TimeStep | None container: containers.Container | None step_num: int + + +def create_zero_ablation_hook( + positions: list[int] | None = None, + heads: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that zeros out specific positions or attention heads. + + Args: + positions: Optional list of token positions to ablate. If None, ablates all. + heads: Optional list of attention head indices to ablate (for attention patterns). + + Returns: + Hook function that performs zero ablation. + + Example: + ```python + # Ablate positions 5-10 in layer 0 residual stream + hook = create_zero_ablation_hook(positions=list(range(5, 11))) + model.run_with_hooks( + tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def zero_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + if heads is not None: + # For attention patterns: [batch, head, query_pos, key_pos] + if len(ablated.shape) == 4: + ablated[:, heads, :, :] = 0.0 + # For attention outputs: [batch, pos, head_index, d_head] + elif len(ablated.shape) == 4: + ablated[:, :, heads, :] = 0.0 + elif positions is not None: + # For residual stream or other positional activations + ablated[:, positions, :] = 0.0 + else: + # Ablate everything + ablated = torch.zeros_like(ablated) + + return ablated + + return zero_ablation_hook + + +def create_path_patching_hook( + clean_activation: torch.Tensor, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook for activation patching (path patching). + + This replaces activations from a corrupted run with those from a clean run, + enabling causal analysis of information flow. + + Args: + clean_activation: Activation tensor from the clean run to patch in. + positions: Optional list of positions to patch. If None, patches all positions. + + Returns: + Hook function that performs activation patching. + + Example: + ```python + # First, run on clean input and cache activations + clean_cache, _ = model.run_with_cache(clean_tokens) + clean_resid = clean_cache["blocks.0.hook_resid_post"] + + # Then run on corrupted input with patching + hook = create_path_patching_hook(clean_resid, positions=[5, 6, 7]) + corrupted_logits = model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def path_patching_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + patched = activation.clone() + + if positions is not None: + patched[:, positions, :] = clean_activation[:, positions, :] + else: + patched = clean_activation.clone() + + return patched + + return path_patching_hook + + +def create_mean_ablation_hook( + mean_activation: torch.Tensor | None = None, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that replaces activations with their mean. + + Mean ablation is often more realistic than zero ablation as it preserves + the scale of activations. + + Args: + mean_activation: Pre-computed mean activation. If None, computes mean on-the-fly. + positions: Optional list of positions to ablate. + + Returns: + Hook function that performs mean ablation. + """ + + def mean_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + # Compute mean across batch and position dimensions if not provided + mean = activation.mean(dim=(0, 1), keepdim=True) if mean_activation is None else mean_activation + + if positions is not None: + ablated[:, positions, :] = mean.expand_as(ablated[:, positions, :]) + else: + ablated = mean.expand_as(ablated) + + return ablated + + return mean_ablation_hook + + +class InterventionManager: + """Manager for applying multiple interventions during agent execution. + + This class helps coordinate multiple hook-based interventions across an agent + trajectory, making it easy to study causal effects of different model components. + + Example: + ```python + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + + # Create intervention manager + manager = InterventionManager(model) + + # Add interventions + manager.add_intervention( + hook_name="blocks.0.attn.hook_pattern", + hook_fn=create_zero_ablation_hook(heads=[0, 1]), + description="Ablate attention heads 0-1 in layer 0" + ) + + # Run with interventions active + with manager: + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + ``` + """ + + def __init__(self, model: HookedTransformer): + """Initialize intervention manager. + + Args: + model: HookedTransformer to apply interventions to. + """ + self.model = model + self.interventions: list[dict[str, Any]] = [] + self._active_handles: list[Any] = [] + + def add_intervention( + self, + hook_name: str, + hook_fn: Callable[[torch.Tensor, HookPoint], torch.Tensor], + description: str = "", + apply_at_steps: list[int] | None = None, + ) -> None: + """Add an intervention to apply during execution. + + Args: + hook_name: Name of the hook point (e.g., "blocks.0.hook_resid_post"). + hook_fn: Hook function to apply. + description: Human-readable description of the intervention. + apply_at_steps: Optional list of step indices when to apply this intervention. + If None, applies at all steps. + """ + self.interventions.append( + { + "hook_name": hook_name, + "hook_fn": hook_fn, + "description": description, + "apply_at_steps": apply_at_steps, + "step_count": 0, + } + ) + + def clear_interventions(self) -> None: + """Remove all interventions.""" + self.interventions.clear() + + def __enter__(self) -> "InterventionManager": + """Enter context manager and activate interventions.""" + for intervention in self.interventions: + hook_point = self.model.hook_dict[intervention["hook_name"]] + + # Wrap hook_fn to track step count and properly capture loop variables + def make_wrapped_hook(interv): + # Capture these from the intervention dict to avoid loop variable binding issues + hook_fn = interv["hook_fn"] + steps = interv["apply_at_steps"] + + def wrapped_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: + # Check if we should apply this intervention at current step + if steps is None or interv["step_count"] in steps: + return hook_fn(activation, hook) + return activation + + return wrapped_hook + + handle = hook_point.add_hook(make_wrapped_hook(intervention)) + self._active_handles.append(handle) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager and deactivate interventions.""" + for handle in self._active_handles: + if handle is not None: + handle.remove() + self._active_handles.clear() + + def increment_step(self) -> None: + """Increment the step counter for all interventions. + + Call this between agent steps to track which step you're on. + """ + for intervention in self.interventions: + intervention["step_count"] += 1 + + def get_intervention_summary(self) -> str: + """Get a summary of all active interventions. + + Returns: + Human-readable summary string. + """ + if not self.interventions: + return "No interventions active" + + lines = ["Active interventions:"] + for i, interv in enumerate(self.interventions, 1): + desc = interv["description"] or "No description" + hook = interv["hook_name"] + steps = interv["apply_at_steps"] + step_str = f"steps {steps}" if steps else "all steps" + lines.append(f" {i}. {desc} ({hook}, {step_str})") + + return "\n".join(lines) diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py index dce4b0b..3b00d89 100644 --- a/src/ares/contrib/mech_interp/hooked_transformer_client.py +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -90,7 +90,7 @@ async def _call_with_hooks( state: hook_utils.FullyObservableState, **gen_kwargs: Any, ) -> torch.Tensor: - self.model.reset_hooks(direction="fwd", including_permanent=False) + # self.model.reset_hooks(direction="fwd", including_permanent=False) if self.fwd_hooks is not None: for name_or_id_fn, hook_fn in self.fwd_hooks: @@ -111,7 +111,7 @@ async def _call_with_hooks( **gen_kwargs, ) - self.model.reset_hooks(direction="fwd", including_permanent=False) + # self.model.reset_hooks(direction="fwd", including_permanent=False) assert isinstance(outputs, torch.Tensor) # typing return outputs