diff --git a/examples/07_mech_interp_hooked_transformer.py b/examples/07_mech_interp_hooked_transformer.py new file mode 100644 index 0000000..39a5e3e --- /dev/null +++ b/examples/07_mech_interp_hooked_transformer.py @@ -0,0 +1,145 @@ +"""Example of using ARES with TransformerLens for mechanistic interpretability. + +This example demonstrates how to: +1. Use HookedTransformer with ARES environments +2. Capture activations across an agent trajectory +3. Apply interventions to study model behavior + +Example usage: + + 1. Make sure you have mech_interp dependencies installed + `uv sync --group mech_interp` + 2. Run the example + `uv run -m examples.03_mech_interp_hooked_transformer` +""" + +import asyncio + +from transformer_lens import HookedTransformer + +import ares +from ares.contrib.mech_interp import ActivationCapture +from ares.contrib.mech_interp import HookedTransformerLLMClient +from ares.contrib.mech_interp.hook_utils import InterventionManager, create_zero_ablation_hook + +from . import utils + + +async def main(): + print("=" * 80) + print("ARES + TransformerLens Mechanistic Interpretability Example") + print("=" * 80) + + # Load a small model for demonstration + # For real work, you'd use a larger model like gpt2-medium or pythia-1.4b + print("\nLoading HookedTransformer model...") + model = HookedTransformer.from_pretrained( + "gpt2-small", + device="cpu", # Change to "cuda" if you have a GPU + ) + + # Create the LLM client with reduced token limit for gpt2-small's context window + # gpt2-small has max context of 1024 tokens, so we need to be conservative + client = HookedTransformerLLMClient( + model=model, + max_new_tokens=128, # Keep this small to avoid context overflow + ) + + # Example 1: Basic execution with activation capture + print("\n[Example 1] Running agent with activation capture...") + print("-" * 80) + + async with ares.make("sbv-mswea:0") as env: + # Set up activation capture + with ActivationCapture(model) as capture: + ts = await env.reset() + step_count = 0 + max_steps = 3 # Limit steps for demo + + while not ts.last() and step_count < max_steps: + # Capture activations for this step + capture.start_step() + + # Generate response + assert ts.observation is not None + action = await client(ts.observation) + + # End capture for this step + capture.end_step() + capture.record_step_metadata( + { + "step": step_count, + "action_preview": str(action.data[0].content)[:50], + } + ) + + utils.print_step(step_count, ts.observation, action) + + # Step environment + ts = await env.step(action) + step_count += 1 + + # Analyze captured activations + trajectory = capture.get_trajectory() + print(f"\nCaptured activations for {len(trajectory)} steps") + + # Example: Look at attention patterns in layer 0 + if len(trajectory) > 0: + attn_pattern = trajectory.get_activation(0, "blocks.0.attn.hook_pattern") + print(f"Layer 0 attention pattern shape: {attn_pattern.shape}") + print(" [batch, n_heads, query_pos, key_pos]") + + # Save trajectory for later analysis + print("\nSaving trajectory activations to ./mech_interp_demo/trajectory_001/") + trajectory.save("./mech_interp_demo/trajectory_001") + + # Example 2: Running with interventions + print("\n[Example 2] Running agent with attention head ablation...") + print("-" * 80) + + def create_zero_ablation_hook_with_log(*args, **kwargs): + hook_fn = create_zero_ablation_hook(*args, **kwargs) + def wrapped_hook_fn(*args, **kwargs): + print(f"Running zero ablation hook") + return hook_fn(*args, **kwargs) + return wrapped_hook_fn + + async with ares.make("sbv-mswea:0") as env: + # Set up intervention: ablate heads 0-2 in layer 0 + manager = InterventionManager(model) + manager.add_intervention( + hook_name="blocks.0.attn.hook_result", + hook_fn=create_zero_ablation_hook_with_log(heads=[0, 1, 2]), + description="Ablate attention heads 0-2 in layer 0", + ) + + print(manager.get_intervention_summary()) + + with manager: + ts = await env.reset() + step_count = 0 + max_steps = 2 # Limit steps for demo + + while not ts.last() and step_count < max_steps: + assert ts.observation is not None + action = await client(ts.observation) + + utils.print_step(step_count, ts.observation, action) + + ts = await env.step(action) + step_count += 1 + manager.increment_step() + + print("\n" + "=" * 80) + print("Demo complete!") + print("=" * 80) + print("\nNext steps for mechanistic interpretability research:") + print("1. Load saved activations: TrajectoryActivations.load('./mech_interp_demo/trajectory_001')") + print("2. Analyze attention patterns across the trajectory") + print("3. Use interventions to study causal effects") + print("4. Compare 'clean' vs 'corrupted' trajectories with path patching") + print("\nSee src/ares/contrib/mech_interp/README.md for more examples!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index b6b6cff..88aadbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ docs = [ "sphinx-rtd-theme>=2.0.0", "sphinx-autodoc-typehints>=1.25.0", ] +transformer-lens = ["transformer-lens>=2.10.0"] [project.urls] Homepage = "https://github.com/withmartian/ares" @@ -61,6 +62,7 @@ dev = [ "hatchling>=1.28.0", "twine>=6.2.0", ] + examples = [ "transformers>=4.57.3", "llama-cpp-python>=0.3.16", diff --git a/src/ares/contrib/mech_interp/README.md b/src/ares/contrib/mech_interp/README.md new file mode 100644 index 0000000..e0674be --- /dev/null +++ b/src/ares/contrib/mech_interp/README.md @@ -0,0 +1,481 @@ +# Mechanistic Interpretability for ARES + +This module provides deep integration between ARES and [TransformerLens](https://github.com/neelnanda-io/TransformerLens), enabling mechanistic interpretability research on code agents across long-horizon tasks. + +## Why ARES for Mechanistic Interpretability? + +Traditional mechanistic interpretability focuses on static, single-step analysis. But modern AI agents: +- Make decisions across many steps (50-100+ steps per episode) +- Maintain internal state that evolves over time +- Exhibit temporal dependencies and long-horizon planning + +ARES enables **trajectory-level mechanistic interpretability** by: +1. Capturing activations across entire agent episodes +2. Studying how internal representations evolve during multi-step reasoning +3. Identifying critical moments where interventions significantly alter episode-level outcomes +4. Seeing how activations differ across different agent frameworks for the same task +5. You tell us! + +## Quick Start + +### Installation + +```bash +# Install ARES with mech_interp group (includes TransformerLens) +uv add ares[mech-interp] +# or with pip +pip install ares[mech-interp] +``` + +### Basic Example + +```python +import asyncio +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import HookedTransformerLLMClient, ActivationCapture +from ares.environments import swebench_env + +async def main(): + # Load model + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model) + + # Run agent and capture activations + tasks = swebench_env.swebench_verified_tasks()[:1] + + async with swebench_env.SweBenchEnv(tasks=tasks) as env: + with ActivationCapture(model) as capture: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + ts = await env.step(action) + + # Analyze trajectory + trajectory = capture.get_trajectory() + print(f"Captured {len(trajectory)} steps") + trajectory.save("./activations/episode_001") + +asyncio.run(main()) +``` + +## Core Components + +### 1. HookedTransformerLLMClient + +An ARES-compatible LLM client that uses TransformerLens's `HookedTransformer` for inference. + +```python +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import HookedTransformerLLMClient + +model = HookedTransformer.from_pretrained("gpt2-medium") +client = HookedTransformerLLMClient( + model=model, + max_new_tokens=1024, + generation_kwargs={"temperature": 0.7} +) +``` + +**With Chat Templates:** + +```python +from transformers import AutoTokenizer +from ares.contrib.mech_interp import create_hooked_transformer_client_with_chat_template + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") +model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") + +client = create_hooked_transformer_client_with_chat_template( + model=model, + tokenizer=tokenizer, + max_new_tokens=2048, +) +``` + +### 2. ActivationCapture + +Captures activations across an agent trajectory for later analysis. + +```python +from ares.contrib.mech_interp import ActivationCapture + +with ActivationCapture(model) as capture: + # Filter which hooks to capture (optional) + # capture = ActivationCapture(model, hook_filter=lambda name: "attn" in name) + + async with env: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + + # Optionally record metadata + capture.record_step_metadata({"step_reward": ts.reward}) + + ts = await env.step(action) + +# Get trajectory +trajectory = capture.get_trajectory() + +# Access specific activations +attn_pattern_step_5 = trajectory.get_activation(5, "blocks.0.attn.hook_pattern") + +# Get activations across trajectory +all_layer0_resid = trajectory.get_activation_across_trajectory("blocks.0.hook_resid_post") + +# Save for later +trajectory.save("./activations/episode_001") + +# Load later +loaded = TrajectoryActivations.load("./activations/episode_001") +``` + +**Automatic Capture:** + +```python +from ares.contrib.mech_interp import automatic_activation_capture + +with automatic_activation_capture(model) as capture: + # Activations are captured automatically during each client call + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + + trajectory = capture.get_trajectory() +``` + +### 3. InterventionManager + +TODO: This needs to rework, can just use step nums with better identifying of interesting behavior + +Apply causal interventions during agent execution to study model behavior. + +```python +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import ( + HookedTransformerLLMClient, + FullyObservableState, +) + +async def has_readme(state: FullyObservableState) -> str | None: + # Function to only runs hooks when the environment contains a README + cmd_output = await state.container.exec_run("ls", workdir="/path/to/app") + if "README.md" in cmd_output.output and state.step_num > 1: + return "blocks.1.attn.hook_pattern" + else: + return None + +model = HookedTransformer.from_pretrained("gpt2-medium") +client = HookedTransformerLLMClient( + model=model, + fwd_hooks=[ + # The mean ablation hook is defined exactly the same as classic transformer_lens + ("blocks.1.attn.hook_pattern@step2", mean_ablation_hook), + ], + max_new_tokens=1024, + generation_kwargs={"temperature": 0.7} +) + + +# Run with interventions +async with env: + ts = await env.reset() + + while not ts.last(): + action = await client( + ts.observation, + # Should also include the env and timestep here so that + # we can reference the env in the has_readme function. + # Do not need to pass if hooks do not depend on state + env=env, + timestep=ts, + ) + ts = await env.step(action) +``` + +### 4. Hook Utilities + +Pre-built hooks for common interventions: + +**Zero Ablation:** +```python +from ares.contrib.mech_interp import create_zero_ablation_hook + +# Ablate specific positions +hook = create_zero_ablation_hook(positions=[10, 11, 12]) + +# Ablate specific attention heads +hook = create_zero_ablation_hook(heads=[0, 1]) +``` + +**Path Patching:** +```python +from ares.contrib.mech_interp import create_path_patching_hook + +# Run clean and corrupted inputs +clean_cache, _ = model.run_with_cache(clean_tokens) +corrupted_cache, _ = model.run_with_cache(corrupted_tokens) + +# Patch clean activations into corrupted run +hook = create_path_patching_hook( + clean_activation=clean_cache["blocks.0.hook_resid_post"], + positions=[5, 6, 7] +) +``` + +**Mean Ablation:** +```python +from ares.contrib.mech_interp import create_mean_ablation_hook + +# Compute mean activation over dataset +mean_cache = compute_mean_activations(model, dataset) + +hook = create_mean_ablation_hook( + mean_activation=mean_cache["blocks.0.hook_resid_post"], + positions=[10, 11, 12] +) +``` + +**Attention Knockout:** +```python +from ares.contrib.mech_interp import create_attention_knockout_hook + +# Prevent position 20 from attending to positions 0-10 +hook = create_attention_knockout_hook( + source_positions=[20], + target_positions=list(range(11)) +) +``` + +## Research Use Cases + +### 1. Attention Head Analysis Across Trajectories + +Study how attention patterns evolve as agents work through tasks: + +```python +with ActivationCapture(model) as capture: + # Run agent episode + ... + +trajectory = capture.get_trajectory() + +# Analyze attention patterns across all steps +for step in range(len(trajectory)): + attn = trajectory.get_activation(step, "blocks.5.attn.hook_pattern") + # Analyze attention to specific tokens, copy behavior, etc. +``` + +### 2. Identifying Critical Decision Points + +Find steps where small perturbations significantly alter outcomes: + +```python +baseline_trajectory = run_episode(env, client) + +# Test interventions at each step +critical_steps = [] +for step in range(len(baseline_trajectory)): + manager = InterventionManager(model) + manager.add_intervention( + hook_name="blocks.3.hook_resid_post", + hook_fn=create_zero_ablation_hook(positions=[10, 11]), + apply_at_steps=[step] + ) + + perturbed_reward = run_episode_with_intervention(env, client, manager) + + if abs(perturbed_reward - baseline_trajectory.reward) > threshold: + critical_steps.append(step) +``` + +### 3. Circuit Discovery in Multi-Step Reasoning + +Use path patching to identify circuits responsible for specific capabilities: + +```python +# Compare successful vs failed trajectories +success_cache = run_and_cache_episode(env, client, success_task) +failure_cache = run_and_cache_episode(env, client, failure_task) + +# Systematically patch components from success to failure +for layer in range(model.cfg.n_layers): + for step in range(len(failure_cache)): + hook = create_path_patching_hook( + clean_activation=success_cache.get_activation(step, f"blocks.{layer}.hook_resid_post"), + positions=None + ) + + # Test if this component recovers successful behavior + ... +``` + +### 4. Temporal Information Flow Analysis + +Track how information propagates through the model across steps: + +```python +# Capture activations with detailed metadata +with ActivationCapture(model) as capture: + async with env: + ts = await env.reset() + while not ts.last(): + capture.start_step() + action = await client(ts.observation) + capture.end_step() + + # Record what the agent was doing + capture.record_step_metadata({ + "action_type": classify_action(action), + "error_present": check_for_errors(ts.observation), + "file_context": extract_file_context(ts.observation), + }) + + ts = await env.step(action) + +trajectory = capture.get_trajectory() + +# Analyze when specific features activate relative to task events +# E.g., does the model activate "file-reading neurons" before bash commands? +``` + +### 5. Comparative Analysis: Different Models + +Compare how different models solve the same task: + +```python +models = [ + HookedTransformer.from_pretrained("gpt2-small"), + HookedTransformer.from_pretrained("gpt2-medium"), + HookedTransformer.from_pretrained("pythia-1.4b"), +] + +trajectories = [] +for model in models: + client = HookedTransformerLLMClient(model=model) + + with ActivationCapture(model) as capture: + # Run same task + trajectory = run_episode(env, client, capture) + trajectories.append(trajectory) + +# Compare attention patterns, activation magnitudes, etc. +compare_trajectories(trajectories) +``` + +## Advanced Examples + +### Example: Finding "Code Understanding" Circuits + +```python +import torch +from transformer_lens import HookedTransformer +from ares.contrib.mech_interp import * +from ares.environments import swebench_env + +async def find_code_understanding_circuits(): + model = HookedTransformer.from_pretrained("gpt2-medium") + client = HookedTransformerLLMClient(model=model) + + # 1. Collect baseline trajectory on a task + tasks = swebench_env.swebench_verified_tasks() + success_task = [t for t in tasks if is_success(t)][0] + + with ActivationCapture(model) as capture: + baseline_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client, capture) + + baseline_trajectory = capture.get_trajectory() + + # 2. For each layer and step, ablate and measure impact + impact_matrix = torch.zeros(model.cfg.n_layers, len(baseline_trajectory)) + + for layer in range(model.cfg.n_layers): + for step in range(len(baseline_trajectory)): + manager = InterventionManager(model) + manager.add_intervention( + hook_name=f"blocks.{layer}.hook_resid_post", + hook_fn=create_mean_ablation_hook(), + apply_at_steps=[step] + ) + + with manager: + perturbed_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client) + + impact_matrix[layer, step] = abs(baseline_reward - perturbed_reward) + + # 3. Identify critical (layer, step) pairs + critical_components = (impact_matrix > threshold).nonzero() + + print(f"Found {len(critical_components)} critical components") + + # 4. Drill down to attention heads in critical layers + for layer, step in critical_components[:10]: # Top 10 + for head in range(model.cfg.n_heads): + manager = InterventionManager(model) + manager.add_intervention( + hook_name=f"blocks.{layer}.attn.hook_result", + hook_fn=create_zero_ablation_hook(heads=[head]), + apply_at_steps=[step.item()] + ) + + with manager: + head_reward = await run_episode(swebench_env.SweBenchEnv([success_task]), client) + + if abs(baseline_reward - head_reward) > head_threshold: + print(f"Critical: Layer {layer}, Step {step}, Head {head}") + + # Visualize attention pattern + attn_pattern = baseline_trajectory.get_activation(step, f"blocks.{layer}.attn.hook_pattern") + visualize_attention(attn_pattern[:, head, :, :]) + +``` + +## Performance Tips + +**Memory Optimization:** +```python +# Only capture specific activations +ActivationCapture( + model, + hook_filter=lambda name: "attn.hook_pattern" in name or "hook_resid" in name +) + +# Clear trajectory periodically +if len(capture.step_activations) > 100: + trajectory = capture.get_trajectory() + trajectory.save(f"./checkpoints/step_{step}") + capture.clear() +``` + +**Speed Optimization:** +```python +# Use smaller models for initial exploration +model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") + +# Reduce max_new_tokens during ablation studies +client = HookedTransformerLLMClient(model=model, max_new_tokens=256) +``` + +## Resources + +- [TransformerLens Documentation](https://neelnanda-io.github.io/TransformerLens/) +- Example Notebook: Trajectory-Level Analysis *(coming soon)* +- [Blog Post: Beyond Static Mechanistic Interpretability](https://withmartian.com/post/beyond-static-mechanistic-interpretability-agentic-long-horizon-tasks-as-the-next-frontier) + +## Citation + +If you use this module in your research, please cite: + +```bibtex +@software{ares_mech_interp_2025, + title = {ARES Mechanistic Interpretability Module}, + author = {Martian}, + year = {2025}, + url = {https://github.com/withmartian/ares} +} +``` diff --git a/src/ares/contrib/mech_interp/__init__.py b/src/ares/contrib/mech_interp/__init__.py new file mode 100644 index 0000000..cf50f26 --- /dev/null +++ b/src/ares/contrib/mech_interp/__init__.py @@ -0,0 +1,16 @@ +"""Mechanistic interpretability utilities for ARES. + +This module provides tools for analyzing agent behavior using mechanistic interpretability +techniques, with deep integration with TransformerLens for studying model internals across +long-horizon agent trajectories. +""" + +from ares.contrib.mech_interp.activation_capture import ActivationCapture +from ares.contrib.mech_interp.activation_capture import TrajectoryActivations +from ares.contrib.mech_interp.hooked_transformer_client import HookedTransformerLLMClient + +__all__ = [ + "ActivationCapture", + "HookedTransformerLLMClient", + "TrajectoryActivations", +] diff --git a/src/ares/contrib/mech_interp/activation_capture.py b/src/ares/contrib/mech_interp/activation_capture.py new file mode 100644 index 0000000..50b6b4d --- /dev/null +++ b/src/ares/contrib/mech_interp/activation_capture.py @@ -0,0 +1,280 @@ +"""Utilities for capturing and analyzing activations across agent trajectories.""" + +from collections.abc import Callable +import dataclasses +import pathlib +from typing import Any + +import torch +from transformer_lens import ActivationCache +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint + + +@dataclasses.dataclass +class TrajectoryActivations: + """Container for activations captured across an agent trajectory. + + Attributes: + step_activations: List of ActivationCache objects, one per agent step. + step_metadata: Optional metadata for each step (e.g., observation, action). + model_name: Name of the model used. + """ + + step_activations: list[ActivationCache] + step_metadata: list[dict[str, Any]] = dataclasses.field(default_factory=list) + model_name: str = "unknown" + + def __len__(self) -> int: + """Return number of steps in the trajectory.""" + return len(self.step_activations) + + def save(self, path: str | pathlib.Path) -> None: + """Save trajectory activations to disk. + + Args: + path: Directory path to save activations. Will be created if it doesn't exist. + """ + path = pathlib.Path(path) + path.mkdir(parents=True, exist_ok=True) + + # Save each step's activations as plain dicts (ActivationCache can't be pickled) + for i, cache in enumerate(self.step_activations): + # Convert ActivationCache to dict for serialization + cache_dict = dict(cache.cache_dict.items()) + torch.save(cache_dict, path / f"step_{i:04d}.pt") + + # Save metadata + import json + + metadata = { + "model_name": self.model_name, + "num_steps": len(self), + "step_metadata": self.step_metadata, + } + with open(path / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + @classmethod + def load(cls, path: str | pathlib.Path) -> "TrajectoryActivations": + """Load trajectory activations from disk. + + Args: + path: Directory path containing saved activations. + + Returns: + TrajectoryActivations instance. + """ + import json + + path = pathlib.Path(path) + + # Load metadata + with open(path / "metadata.json") as f: + metadata = json.load(f) + + # Load activation dicts (we saved as plain dicts, not ActivationCache objects) + step_files = sorted(path.glob("step_*.pt")) + step_activations_dicts = [torch.load(f, weights_only=False) for f in step_files] + + # Convert back to ActivationCache objects if needed, or just use dicts + # For now, we'll keep them as ActivationCache objects for API compatibility + step_activations = [ + ActivationCache(cache_dict, model=None) for cache_dict in step_activations_dicts + ] + + return cls( + step_activations=step_activations, + step_metadata=metadata.get("step_metadata", []), + model_name=metadata.get("model_name", "unknown"), + ) + + def get_activation(self, step: int, hook_name: str) -> torch.Tensor: + """Get activation tensor for a specific step and hook. + + Args: + step: Step index in trajectory. + hook_name: Name of the hook point (e.g., "blocks.0.attn.hook_pattern"). + + Returns: + Activation tensor for the specified step and hook. + """ + return self.step_activations[step][hook_name] + + def get_activation_across_trajectory(self, hook_name: str) -> list[torch.Tensor]: + """Get activations for a specific hook across all trajectory steps. + + Args: + hook_name: Name of the hook point. + + Returns: + List of activation tensors, one per step. + """ + return [cache[hook_name] for cache in self.step_activations] + + +class ActivationCapture: + """Context manager for capturing activations during agent execution. + + This class provides a convenient way to capture activations from a HookedTransformer + during an ARES agent episode, enabling trajectory-level mechanistic interpretability + analysis. + + Example: + ```python + from transformer_lens import HookedTransformer + from ares.contrib.mech_interp import ActivationCapture, HookedTransformerLLMClient + + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model) + + # Capture activations during episode + with ActivationCapture(model) as capture: + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + capture.record_step_metadata({"action": response}) + ts = await env.step(response) + + # Analyze captured activations + trajectory = capture.get_trajectory() + print(f"Captured {len(trajectory)} steps") + + # Save for later analysis + trajectory.save("./activations/episode_001") + ``` + """ + + def __init__( + self, + model: HookedTransformer, + hook_filter: Callable[[str], bool] | None = None, + ): + """Initialize activation capture. + + Args: + model: HookedTransformer to capture activations from. + hook_filter: Optional function to filter which hooks to capture. + If None, captures all hooks. Example: lambda name: "attn" in name + """ + # TODO: Should we store logits and loss as well? By default or via flag? + self.model = model + self.hook_filter = hook_filter + self.step_activations: list[ActivationCache] = [] + self.step_metadata: list[dict[str, Any]] = [] + self._hook_handles: list[Any] = [] + + def __enter__(self) -> "ActivationCapture": + """Enter context manager and start capturing activations.""" + # Register hooks to capture activations + for name, hp in self.model.hook_dict.items(): + if self.hook_filter is None or self.hook_filter(name): + handle = hp.add_hook(self._make_capture_hook(name)) + self._hook_handles.append(handle) + + self._current_step_cache: dict[str, torch.Tensor] = {} + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager and clean up hooks.""" + # Remove all hooks + for handle in self._hook_handles: + if handle is not None: + handle.remove() + self._hook_handles.clear() + + def _make_capture_hook(self, hook_name: str) -> Callable: + """Create a hook function that captures activations.""" + + def capture_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + # Store a copy of the activation (detached from computation graph) + self._current_step_cache[hook_name] = activation.detach().cpu() + return activation + + return capture_hook + + def start_step(self) -> None: + """Mark the start of a new agent step.""" + self._current_step_cache = {} + + def end_step(self) -> None: + """Mark the end of an agent step and save captured activations.""" + if self._current_step_cache: + # Convert dict to ActivationCache + cache = ActivationCache( + self._current_step_cache.copy(), + self.model, + ) + self.step_activations.append(cache) + self._current_step_cache = {} + + def record_step_metadata(self, metadata: dict[str, Any]) -> None: + """Record metadata for the current/last step. + + Args: + metadata: Dictionary of metadata to associate with this step. + """ + self.step_metadata.append(metadata) + + def get_trajectory(self) -> TrajectoryActivations: + """Get the complete trajectory of captured activations. + + Returns: + TrajectoryActivations containing all captured steps. + """ + return TrajectoryActivations( + step_activations=self.step_activations.copy(), + step_metadata=self.step_metadata.copy(), + model_name=self.model.cfg.model_name, + ) + + def clear(self) -> None: + """Clear all captured activations and metadata.""" + self.step_activations.clear() + self.step_metadata.clear() + self._current_step_cache = {} + + +def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture: + """Create an ActivationCapture that automatically records steps during generation. + + This wraps the model's generate method to automatically call start_step() and + end_step() around each generation, making it seamless to use with ARES environments. + + Args: + model: HookedTransformer to capture activations from. + + Returns: + ActivationCapture instance with automatic step tracking. + + Example: + ```python + model = HookedTransformer.from_pretrained("gpt2-small") + + with automatic_activation_capture(model) as capture: + client = HookedTransformerLLMClient(model=model) + # Now activations are captured automatically during client calls + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + + trajectory = capture.get_trajectory() + ``` + """ + capture = ActivationCapture(model) + + # Wrap model.generate to auto-capture + original_generate = model.generate + + def wrapped_generate(*args, **kwargs): + capture.start_step() + result = original_generate(*args, **kwargs) + capture.end_step() + return result + + model.generate = wrapped_generate + + return capture diff --git a/src/ares/contrib/mech_interp/hook_utils.py b/src/ares/contrib/mech_interp/hook_utils.py new file mode 100644 index 0000000..514a251 --- /dev/null +++ b/src/ares/contrib/mech_interp/hook_utils.py @@ -0,0 +1,272 @@ +"""Hook utilities for mechanistic interpretability interventions.""" + +from collections.abc import Callable +import dataclasses +from typing import Any + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint + +from ares.containers import containers +from ares.environments.base import TimeStep + + +@dataclasses.dataclass +class FullyObservableState: + timestep: TimeStep | None + container: containers.Container | None + step_num: int + + +def create_zero_ablation_hook( + positions: list[int] | None = None, + heads: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that zeros out specific positions or attention heads. + + Args: + positions: Optional list of token positions to ablate. If None, ablates all. + heads: Optional list of attention head indices to ablate (for attention patterns). + + Returns: + Hook function that performs zero ablation. + + Example: + ```python + # Ablate positions 5-10 in layer 0 residual stream + hook = create_zero_ablation_hook(positions=list(range(5, 11))) + model.run_with_hooks( + tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def zero_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + if heads is not None: + # For attention patterns: [batch, head, query_pos, key_pos] + if len(ablated.shape) == 4: + ablated[:, heads, :, :] = 0.0 + # For attention outputs: [batch, pos, head_index, d_head] + elif len(ablated.shape) == 4: + ablated[:, :, heads, :] = 0.0 + elif positions is not None: + # For residual stream or other positional activations + ablated[:, positions, :] = 0.0 + else: + # Ablate everything + ablated = torch.zeros_like(ablated) + + return ablated + + return zero_ablation_hook + + +def create_path_patching_hook( + clean_activation: torch.Tensor, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook for activation patching (path patching). + + This replaces activations from a corrupted run with those from a clean run, + enabling causal analysis of information flow. + + Args: + clean_activation: Activation tensor from the clean run to patch in. + positions: Optional list of positions to patch. If None, patches all positions. + + Returns: + Hook function that performs activation patching. + + Example: + ```python + # First, run on clean input and cache activations + clean_cache, _ = model.run_with_cache(clean_tokens) + clean_resid = clean_cache["blocks.0.hook_resid_post"] + + # Then run on corrupted input with patching + hook = create_path_patching_hook(clean_resid, positions=[5, 6, 7]) + corrupted_logits = model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[("blocks.0.hook_resid_post", hook)] + ) + ``` + """ + + def path_patching_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + patched = activation.clone() + + if positions is not None: + patched[:, positions, :] = clean_activation[:, positions, :] + else: + patched = clean_activation.clone() + + return patched + + return path_patching_hook + + +def create_mean_ablation_hook( + mean_activation: torch.Tensor | None = None, + positions: list[int] | None = None, +) -> Callable[[torch.Tensor, HookPoint], torch.Tensor]: + """Create a hook that replaces activations with their mean. + + Mean ablation is often more realistic than zero ablation as it preserves + the scale of activations. + + Args: + mean_activation: Pre-computed mean activation. If None, computes mean on-the-fly. + positions: Optional list of positions to ablate. + + Returns: + Hook function that performs mean ablation. + """ + + def mean_ablation_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: # noqa: ARG001 + ablated = activation.clone() + + # Compute mean across batch and position dimensions if not provided + mean = activation.mean(dim=(0, 1), keepdim=True) if mean_activation is None else mean_activation + + if positions is not None: + ablated[:, positions, :] = mean.expand_as(ablated[:, positions, :]) + else: + ablated = mean.expand_as(ablated) + + return ablated + + return mean_ablation_hook + + +class InterventionManager: + """Manager for applying multiple interventions during agent execution. + + This class helps coordinate multiple hook-based interventions across an agent + trajectory, making it easy to study causal effects of different model components. + + Example: + ```python + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient(model=model, model_name="gpt2-small") + + # Create intervention manager + manager = InterventionManager(model) + + # Add interventions + manager.add_intervention( + hook_name="blocks.0.attn.hook_pattern", + hook_fn=create_zero_ablation_hook(heads=[0, 1]), + description="Ablate attention heads 0-1 in layer 0" + ) + + # Run with interventions active + with manager: + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + ``` + """ + + def __init__(self, model: HookedTransformer): + """Initialize intervention manager. + + Args: + model: HookedTransformer to apply interventions to. + """ + self.model = model + self.interventions: list[dict[str, Any]] = [] + self._active_handles: list[Any] = [] + + def add_intervention( + self, + hook_name: str, + hook_fn: Callable[[torch.Tensor, HookPoint], torch.Tensor], + description: str = "", + apply_at_steps: list[int] | None = None, + ) -> None: + """Add an intervention to apply during execution. + + Args: + hook_name: Name of the hook point (e.g., "blocks.0.hook_resid_post"). + hook_fn: Hook function to apply. + description: Human-readable description of the intervention. + apply_at_steps: Optional list of step indices when to apply this intervention. + If None, applies at all steps. + """ + self.interventions.append( + { + "hook_name": hook_name, + "hook_fn": hook_fn, + "description": description, + "apply_at_steps": apply_at_steps, + "step_count": 0, + } + ) + + def clear_interventions(self) -> None: + """Remove all interventions.""" + self.interventions.clear() + + def __enter__(self) -> "InterventionManager": + """Enter context manager and activate interventions.""" + for intervention in self.interventions: + hook_point = self.model.hook_dict[intervention["hook_name"]] + + # Wrap hook_fn to track step count and properly capture loop variables + def make_wrapped_hook(interv): + # Capture these from the intervention dict to avoid loop variable binding issues + hook_fn = interv["hook_fn"] + steps = interv["apply_at_steps"] + + def wrapped_hook(activation: torch.Tensor, hook: HookPoint) -> torch.Tensor: + # Check if we should apply this intervention at current step + if steps is None or interv["step_count"] in steps: + return hook_fn(activation, hook) + return activation + + return wrapped_hook + + handle = hook_point.add_hook(make_wrapped_hook(intervention)) + self._active_handles.append(handle) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager and deactivate interventions.""" + for handle in self._active_handles: + if handle is not None: + handle.remove() + self._active_handles.clear() + + def increment_step(self) -> None: + """Increment the step counter for all interventions. + + Call this between agent steps to track which step you're on. + """ + for intervention in self.interventions: + intervention["step_count"] += 1 + + def get_intervention_summary(self) -> str: + """Get a summary of all active interventions. + + Returns: + Human-readable summary string. + """ + if not self.interventions: + return "No interventions active" + + lines = ["Active interventions:"] + for i, interv in enumerate(self.interventions, 1): + desc = interv["description"] or "No description" + hook = interv["hook_name"] + steps = interv["apply_at_steps"] + step_str = f"steps {steps}" if steps else "all steps" + lines.append(f" {i}. {desc} ({hook}, {step_str})") + + return "\n".join(lines) diff --git a/src/ares/contrib/mech_interp/hooked_transformer_client.py b/src/ares/contrib/mech_interp/hooked_transformer_client.py new file mode 100644 index 0000000..3b00d89 --- /dev/null +++ b/src/ares/contrib/mech_interp/hooked_transformer_client.py @@ -0,0 +1,242 @@ +"""LLM client implementation using TransformerLens HookedTransformer.""" + +from collections.abc import Callable, Sequence +import dataclasses +import inspect +from typing import Any, Protocol, Union, runtime_checkable + +import torch +import transformer_lens + +from ares.contrib.mech_interp import hook_utils +from ares.environments import base as ares_env +from ares.environments import code_env +from ares import llms + + +HookNameFn = Callable[[str], str] + + +@runtime_checkable +class StateIdFn(Protocol): + """A function that returns the `name` param of `add_hook`/`run_with_hooks` if the hook should be applied given + the current state - otherwise `None` if it should not. + """ + async def __call__(self, state: hook_utils.FullyObservableState) -> str | Callable[[str], bool] | None: + ... + + +@dataclasses.dataclass +class HookedTransformerLLMClient: + """LLM client that uses TransformerLens HookedTransformer for inference. + + This client enables mechanistic interpretability research by providing access to + intermediate activations and allowing hook-based interventions during agent execution. + + Args: + model: A TransformerLens HookedTransformer instance. + max_new_tokens: Maximum number of tokens to generate per completion. + generation_kwargs: Additional keyword arguments passed to model.generate(). + format_messages_fn: Optional function to convert chat messages to model input. + If None, uses a simple concatenation of message contents. + + Example: + ```python + from transformer_lens import HookedTransformer + + model = HookedTransformer.from_pretrained("gpt2-small") + client = HookedTransformerLLMClient( + model=model, + max_new_tokens=512 + ) + + # Use with ARES environments + async with env: + ts = await env.reset() + while not ts.last(): + response = await client(ts.observation) + ts = await env.step(response) + ``` + """ + + model: transformer_lens.HookedTransformer + # QUESTION: Is it better to include the "state" object in the name/id fn or in the hook fn? + fwd_hooks: list[tuple[Union[str, HookNameFn, StateIdFn], transformer_lens.hook_points.HookFunction]] | None = None + # TODO: Identify better default max_new_tokens size + max_new_tokens: int = 2048 + generation_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + _format_messages_fn: Callable[[Sequence[Any]], str] | None = dataclasses.field(default=None, repr=False) + + @property + def format_messages_fn(self) -> Callable[[Sequence[Any]], str]: + """Get the message formatting function.""" + if self._format_messages_fn is None: + return self._default_format_messages + return self._format_messages_fn + + @staticmethod + def _default_format_messages(messages: Sequence[Any]) -> str: + """Default message formatter that concatenates all message contents.""" + formatted_parts = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + formatted_parts.append(f"{role.upper()}: {content}") + return "\n\n".join(formatted_parts) + + async def _call_with_hooks( + self, + input_ids: torch.Tensor, + state: hook_utils.FullyObservableState, + **gen_kwargs: Any, + ) -> torch.Tensor: + # self.model.reset_hooks(direction="fwd", including_permanent=False) + + if self.fwd_hooks is not None: + for name_or_id_fn, hook_fn in self.fwd_hooks: + # Check if this is a StateIdFn, or if it is the standard format to run in all states + if inspect.iscoroutinefunction(name_or_id_fn): + name = await name_or_id_fn(state) + else: + name = name_or_id_fn + + self.model.add_hook(name, hook_fn, is_permanent=False) # type: ignore + + with torch.no_grad(): + # TODO: Should we make use of the `__call__(..., return_type="logits | loss")` here instead? + # Generate completion + # Note: HookedTransformer.generate returns full sequence including input + outputs = self.model.generate( + input_ids, + **gen_kwargs, + ) + + # self.model.reset_hooks(direction="fwd", including_permanent=False) + + assert isinstance(outputs, torch.Tensor) # typing + return outputs + + async def __call__( + self, + request: llms.LLMRequest, + env: code_env.CodeEnvironment | None = None, + timestep: ares_env.TimeStep | None = None, + ) -> llms.LLMResponse: + """Generate a completion using the HookedTransformer. + + Args: + request: LLM request containing messages and optional temperature. + + Returns: + LLM response with chat completion and cost information. + """ + # Format messages into text + messages_list = list(request.messages) + input_text = self.format_messages_fn(messages_list) + + # Tokenize input + input_ids = self.model.to_tokens(input_text, prepend_bos=True) + num_input_tokens = input_ids.shape[-1] + + # TODO: Need to support various truncation methods + # Truncate if input + max_new_tokens would exceed model's context window + max_position = self.model.cfg.n_ctx + if num_input_tokens + self.max_new_tokens > max_position: + # Leave room for generation + max_input_tokens = max_position - self.max_new_tokens + input_ids = input_ids[:, :max_input_tokens] + num_input_tokens = input_ids.shape[-1] + + # Prepare generation kwargs + gen_kwargs = { + "max_new_tokens": self.max_new_tokens, + **self.generation_kwargs, + } + + # TODO: This should be more generic - why temperature specifically? + # Add temperature if specified + if request.temperature is not None: + gen_kwargs["temperature"] = request.temperature + + outputs = await self._call_with_hooks( + input_ids, + state=hook_utils.FullyObservableState( + timestep=timestep, + # TODO: Figure out typing here + container=env._container if env is not None else None, + step_num=0 # TODO: How to calculate?, + ), + **gen_kwargs, + ) + + # Extract only the generated tokens + num_output_tokens = outputs.shape[-1] - num_input_tokens + output_ids = outputs[0, num_input_tokens:] + + # Decode output + output_text = self.model.to_string(output_ids) + assert isinstance(output_text, str) # typing + + return llms.LLMResponse( + data=[llms.TextData(content=output_text)], + cost=0.0, # Local inference has no cost + usage=llms.Usage( + prompt_tokens=num_input_tokens, + generated_tokens=num_output_tokens, + ), + ) + + +def create_hooked_transformer_client_with_chat_template( + model: transformer_lens.HookedTransformer, + tokenizer: Any, + max_new_tokens: int = 2048, + **generation_kwargs: Any, +) -> HookedTransformerLLMClient: + """Create a HookedTransformerLLMClient with proper chat template formatting. + + This factory function creates a client that uses the tokenizer's chat template + (like Qwen2.5, Llama, etc.) to properly format messages. + + Args: + model: TransformerLens HookedTransformer instance. + tokenizer: HuggingFace tokenizer with apply_chat_template method. + max_new_tokens: Maximum tokens to generate. + **generation_kwargs: Additional arguments for generation. + + Returns: + Configured HookedTransformerLLMClient instance. + + Example: + ```python + from transformer_lens import HookedTransformer + from transformers import AutoTokenizer + + model = HookedTransformer.from_pretrained("gpt2-small") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + client = create_hooked_transformer_client_with_chat_template( + model=model, + tokenizer=tokenizer, + max_new_tokens=512, + temperature=0.7, + ) + ``` + """ + + def format_with_chat_template(messages: Sequence[Any]) -> str: + """Format messages using tokenizer's chat template.""" + # Apply chat template without tokenization (just get the text) + formatted = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + ) + return formatted + + return HookedTransformerLLMClient( + model=model, + max_new_tokens=max_new_tokens, + generation_kwargs=generation_kwargs, + _format_messages_fn=format_with_chat_template, + )