Skip to content

Commit aec5381

Browse files
committed
cp
1 parent d4caa66 commit aec5381

File tree

11 files changed

+1287
-9
lines changed

11 files changed

+1287
-9
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""Diversity injection for RL policy networks.
2+
3+
When reward gradients vanish (stuck in local minima or flat regions), this module
4+
automatically expands exploration of nearby representational variants by injecting
5+
agent-specific random perturbations into the encoder output.
6+
7+
Key insight: when PPO loss → 0 (stuck), the diversity loss term automatically
8+
dominates, pushing α higher and increasing representational spread across agents.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import torch
14+
import torch.nn as nn
15+
from tensordict import TensorDict
16+
17+
from metta.agent.components.component_config import ComponentConfig
18+
19+
20+
class DiversityInjectionConfig(ComponentConfig):
21+
"""Configuration for diversity injection layer."""
22+
23+
in_key: str
24+
out_key: str
25+
name: str = "diversity_injection"
26+
27+
# Number of agent slots to support (should match max agents in training)
28+
num_agents: int = 256
29+
30+
# Low-rank approximation rank for memory efficiency
31+
# W = U @ V.T where U, V are (hidden_dim, rank)
32+
projection_rank: int = 32
33+
34+
# Initial value for log_alpha (α = exp(log_alpha))
35+
# -1.0 means α starts at ~0.37
36+
log_alpha_init: float = -1.0
37+
38+
# Maximum value for α to prevent explosion
39+
alpha_max: float = 5.0
40+
41+
# Whether to apply LayerNorm after injection for stability
42+
use_layer_norm: bool = True
43+
44+
# Key in TensorDict containing agent IDs (training_env_ids by default)
45+
agent_id_key: str = "training_env_ids"
46+
47+
def make_component(self, env=None) -> nn.Module:
48+
return DiversityInjection(config=self)
49+
50+
51+
class DiversityInjection(nn.Module):
52+
"""Applies agent-specific random perturbations to encoder output.
53+
54+
Architecture:
55+
obs → [shared encoder] → h → h + α * perturbation → [policy_head] → logits
56+
→ [value_head] → value
57+
58+
Where perturbation = W_rand[agent_id] @ h using low-rank factorization.
59+
"""
60+
61+
def __init__(self, config: DiversityInjectionConfig):
62+
super().__init__()
63+
self.config = config
64+
self.in_key = config.in_key
65+
self.out_key = config.out_key
66+
self.agent_id_key = config.agent_id_key
67+
self.alpha_max = config.alpha_max
68+
69+
# Learned scalar controlling perturbation strength
70+
self.log_alpha = nn.Parameter(torch.tensor(config.log_alpha_init))
71+
72+
# Lazy initialization - we don't know hidden_dim until first forward
73+
self._hidden_dim: int | None = None
74+
75+
# Register placeholder buffers (will be replaced on first forward)
76+
self.register_buffer("_projection_u", None)
77+
self.register_buffer("_projection_v", None)
78+
79+
self.layer_norm: nn.LayerNorm | None = None
80+
81+
def _initialize_projections(self, hidden_dim: int, device: torch.device, dtype: torch.dtype) -> None:
82+
"""Initialize random projection matrices on first forward pass."""
83+
if self._hidden_dim == hidden_dim and self._projection_u is not None:
84+
# Already initialized, just ensure device matches
85+
if self._projection_u.device != device:
86+
self._projection_u = self._projection_u.to(device)
87+
self._projection_v = self._projection_v.to(device)
88+
if self.layer_norm is not None:
89+
self.layer_norm = self.layer_norm.to(device)
90+
return
91+
92+
self._hidden_dim = hidden_dim
93+
rank = self.config.projection_rank
94+
num_agents = self.config.num_agents
95+
96+
# Create low-rank factorization: W = U @ V.T
97+
# Scale by 1/sqrt(rank) for stable initialization
98+
scale = 1.0 / (rank**0.5)
99+
100+
# Generate deterministic random projections per agent using seeded generators
101+
projection_u = torch.zeros(num_agents, hidden_dim, rank, dtype=dtype, device=device)
102+
projection_v = torch.zeros(num_agents, rank, hidden_dim, dtype=dtype, device=device)
103+
104+
for agent_idx in range(num_agents):
105+
gen = torch.Generator()
106+
gen.manual_seed(agent_idx * 31337) # Deterministic per-agent seed
107+
projection_u[agent_idx] = (
108+
torch.randn(hidden_dim, rank, generator=gen, dtype=dtype, device="cpu").to(device) * scale
109+
)
110+
projection_v[agent_idx] = (
111+
torch.randn(rank, hidden_dim, generator=gen, dtype=dtype, device="cpu").to(device) * scale
112+
)
113+
114+
# Update buffers in-place
115+
self._projection_u = projection_u
116+
self._projection_v = projection_v
117+
118+
# Initialize LayerNorm if enabled
119+
if self.config.use_layer_norm and self.layer_norm is None:
120+
self.layer_norm = nn.LayerNorm(hidden_dim).to(device)
121+
122+
@property
123+
def alpha(self) -> torch.Tensor:
124+
"""Current perturbation strength coefficient."""
125+
return self.log_alpha.exp().clamp(max=self.alpha_max)
126+
127+
def forward(self, td: TensorDict) -> TensorDict:
128+
h = td[self.in_key] # (batch, hidden_dim) or (batch, time, hidden_dim)
129+
130+
# Initialize on first forward
131+
self._initialize_projections(h.shape[-1], h.device, h.dtype)
132+
133+
# Get agent IDs - handle both (batch,) and (batch, time) shapes
134+
if self.agent_id_key in td.keys():
135+
agent_ids = td[self.agent_id_key]
136+
# Flatten to 1D if needed, take first element per batch item if (batch, time)
137+
if agent_ids.dim() > 1:
138+
agent_ids = agent_ids[:, 0] if agent_ids.shape[1] > 0 else agent_ids.squeeze(-1)
139+
agent_ids = agent_ids.long() % self.config.num_agents
140+
else:
141+
# Default to agent 0 if no agent IDs provided (e.g., during eval)
142+
agent_ids = torch.zeros(h.shape[0], dtype=torch.long, device=h.device)
143+
144+
# Compute perturbation using low-rank factorization
145+
# h @ U @ V.T = (h @ U) @ V.T
146+
original_shape = h.shape
147+
if h.dim() == 3:
148+
# (batch, time, hidden) -> (batch * time, hidden)
149+
batch, time, hidden = h.shape
150+
h_flat = h.reshape(batch * time, hidden)
151+
# Expand agent_ids to match flattened batch
152+
agent_ids = agent_ids.unsqueeze(1).expand(batch, time).reshape(batch * time)
153+
else:
154+
h_flat = h
155+
batch, time = h.shape[0], 1
156+
157+
# Gather projection matrices for each sample's agent
158+
# _projection_u: (num_agents, hidden_dim, rank)
159+
# _projection_v: (num_agents, rank, hidden_dim)
160+
u = self._projection_u[agent_ids] # (batch, hidden_dim, rank)
161+
v = self._projection_v[agent_ids] # (batch, rank, hidden_dim)
162+
163+
# Compute perturbation: h @ U @ V.T
164+
# (batch, hidden) @ (batch, hidden, rank) -> (batch, rank)
165+
intermediate = torch.einsum("bh,bhr->br", h_flat, u)
166+
# (batch, rank) @ (batch, rank, hidden) -> (batch, hidden)
167+
perturbation = torch.einsum("br,brh->bh", intermediate, v)
168+
169+
# Apply perturbation with learned coefficient
170+
alpha = self.alpha
171+
h_div = h_flat + alpha * perturbation
172+
173+
# Apply LayerNorm for stability when α is large
174+
if self.layer_norm is not None:
175+
h_div = self.layer_norm(h_div)
176+
177+
# Reshape back if needed
178+
if len(original_shape) == 3:
179+
h_div = h_div.reshape(original_shape)
180+
181+
td[self.out_key] = h_div
182+
183+
return td
184+
185+
def get_diversity_loss(self) -> torch.Tensor:
186+
"""Return diversity loss term: -log_alpha.
187+
188+
This encourages α to grow when other losses are small.
189+
"""
190+
return -self.log_alpha
191+
192+
def extra_repr(self) -> str:
193+
return (
194+
f"in_key={self.in_key}, out_key={self.out_key}, "
195+
f"num_agents={self.config.num_agents}, rank={self.config.projection_rank}, "
196+
f"alpha_max={self.alpha_max}, use_layer_norm={self.config.use_layer_norm}"
197+
)

agent/src/metta/agent/policies/vit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from metta.agent.components.actor import ActionProbsConfig, ActorHeadConfig
1111
from metta.agent.components.component_config import ComponentConfig
1212
from metta.agent.components.cortex import CortexTDConfig
13+
from metta.agent.components.diversity_injection import DiversityInjectionConfig
1314
from metta.agent.components.misc import MLPConfig
1415
from metta.agent.components.obs_enc import ObsPerceiverLatentConfig
1516
from metta.agent.components.obs_shim import ObsShimTokensConfig
@@ -168,6 +169,10 @@ class ViTDefaultConfig(PolicyArchitecture):
168169
# Whether to torch.compile the trunk (Cortex stack)
169170
core_compile: bool = False
170171

172+
# Diversity injection - auto-expands exploration when gradients vanish
173+
# Enable with losses.diversity.enabled=True losses.diversity.diversity_coef=0.01
174+
use_diversity_injection: bool = False
175+
171176
components: List[ComponentConfig] = [
172177
ObsShimTokensConfig(in_key="env_obs", out_key="obs_shim_tokens", max_tokens=48),
173178
ObsAttrEmbedFourierConfig(
@@ -233,6 +238,21 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
233238
compile_blocks=self.core_compile,
234239
)
235240

241+
# Conditionally add diversity injection after Cortex
242+
if self.use_diversity_injection:
243+
# Find Cortex index and insert diversity injection after it
244+
cortex_idx = next(i for i, c in enumerate(self.components) if isinstance(c, CortexTDConfig))
245+
# Check if already inserted
246+
if not any(isinstance(c, DiversityInjectionConfig) for c in self.components):
247+
self.components.insert(
248+
cortex_idx + 1,
249+
DiversityInjectionConfig(
250+
in_key="core",
251+
out_key="core", # in-place replacement
252+
name="diversity_injection",
253+
),
254+
)
255+
236256
AgentClass = load_symbol(self.class_path)
237257
if not isinstance(AgentClass, type):
238258
raise TypeError(f"Loaded symbol {self.class_path} is not a class")

metta/rl/loss/cmpo.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,24 @@ def update_target_network(self) -> None:
109109
return
110110

111111
with torch.no_grad():
112-
for target_param, online_param in zip(
113-
self.target_policy.parameters(),
114-
self.policy.parameters(),
115-
strict=False,
116-
):
117-
target_param.data = (
118-
self.cfg.target_ema_decay * target_param.data + (1 - self.cfg.target_ema_decay) * online_param.data
119-
)
112+
target_state = self.target_policy.state_dict()
113+
online_state = self.policy.state_dict()
114+
115+
for name, online_param in online_state.items():
116+
if name in target_state:
117+
target_param = target_state[name]
118+
if target_param.shape == online_param.shape:
119+
target_state[name] = (
120+
self.cfg.target_ema_decay * target_param + (1 - self.cfg.target_ema_decay) * online_param
121+
)
122+
else:
123+
# Shape mismatch (e.g., lazy init resize) - copy directly
124+
target_state[name] = online_param.clone()
125+
else:
126+
# New parameter - add it
127+
target_state[name] = online_param.clone()
128+
129+
self.target_policy.load_state_dict(target_state)
120130

121131
def compute_cmpo_policy(
122132
self,

metta/rl/loss/diversity.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Diversity loss for encouraging representational exploration.
2+
3+
This loss works in conjunction with DiversityInjection component to automatically
4+
increase exploration when policy gradients vanish. The key insight is that when
5+
PPO loss → 0 (stuck in local minima), the diversity loss term dominates, pushing
6+
α higher and increasing representational spread across agents.
7+
"""
8+
9+
from typing import Any
10+
11+
import torch
12+
from pydantic import Field
13+
from tensordict import TensorDict
14+
from torch import Tensor
15+
from torchrl.data import Composite
16+
17+
from metta.agent.policy import Policy
18+
from metta.rl.loss.loss import Loss, LossConfig
19+
from metta.rl.training import ComponentContext, TrainingEnvironment
20+
21+
22+
class DiversityLossConfig(LossConfig):
23+
"""Configuration for diversity loss."""
24+
25+
# Coefficient for diversity loss term (-log_alpha)
26+
# Start small (~0.01) and tune as needed
27+
diversity_coef: float = Field(default=0.01, ge=0)
28+
29+
# Name of the DiversityInjection component in the policy
30+
# Used to find the log_alpha parameter
31+
diversity_component_name: str = "diversity_injection"
32+
33+
def create(
34+
self,
35+
policy: Policy,
36+
trainer_cfg: Any,
37+
env: TrainingEnvironment,
38+
device: torch.device,
39+
instance_name: str,
40+
) -> "DiversityLoss":
41+
return DiversityLoss(policy, trainer_cfg, env, device, instance_name, self)
42+
43+
44+
class DiversityLoss(Loss):
45+
"""Diversity loss that encourages exploration when policy gradients vanish.
46+
47+
Loss = -diversity_coef * log_alpha
48+
49+
When α is small (low diversity), log_alpha is negative, so -log_alpha is positive
50+
and this loss encourages α to grow. When PPO loss is meaningful, its gradients
51+
dominate and α stays controlled. When stuck (PPO loss ≈ 0), diversity loss
52+
dominates and α grows, increasing representational spread.
53+
"""
54+
55+
def __init__(
56+
self,
57+
policy: Policy,
58+
trainer_cfg: Any,
59+
env: TrainingEnvironment,
60+
device: torch.device,
61+
instance_name: str,
62+
cfg: DiversityLossConfig,
63+
):
64+
super().__init__(policy, trainer_cfg, env, device, instance_name, cfg)
65+
self._diversity_component = None
66+
self._find_diversity_component()
67+
68+
def _find_diversity_component(self) -> None:
69+
"""Find the DiversityInjection component in the policy."""
70+
if hasattr(self.policy, "components"):
71+
component_name = self.cfg.diversity_component_name
72+
if component_name in self.policy.components:
73+
self._diversity_component = self.policy.components[component_name]
74+
else:
75+
# Try to find any DiversityInjection component
76+
from metta.agent.components.diversity_injection import DiversityInjection
77+
78+
for _, component in self.policy.components.items():
79+
if isinstance(component, DiversityInjection):
80+
self._diversity_component = component
81+
break
82+
83+
def get_experience_spec(self) -> Composite:
84+
"""Diversity loss doesn't require additional experience fields."""
85+
return Composite()
86+
87+
def run_rollout(self, td: TensorDict, context: ComponentContext) -> None:
88+
"""No-op during rollout - diversity loss only affects training."""
89+
pass
90+
91+
def run_train(
92+
self, shared_loss_data: TensorDict, context: ComponentContext, mb_idx: int
93+
) -> tuple[Tensor, TensorDict, bool]:
94+
"""Compute diversity loss from the DiversityInjection component."""
95+
if self._diversity_component is None:
96+
# No diversity component found, return zero loss
97+
zero_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
98+
return zero_loss, shared_loss_data, False
99+
100+
# Get diversity loss from component
101+
diversity_loss = self._diversity_component.get_diversity_loss()
102+
weighted_loss = self.cfg.diversity_coef * diversity_loss
103+
104+
# Track metrics
105+
alpha = self._diversity_component.alpha
106+
self._track("diversity_loss", weighted_loss)
107+
self._track("diversity_alpha", alpha)
108+
self._track("diversity_log_alpha", self._diversity_component.log_alpha)
109+
110+
return weighted_loss, shared_loss_data, False
111+
112+
def _track(self, key: str, value: Tensor) -> None:
113+
"""Track a metric value."""
114+
if value.numel() == 1:
115+
self.loss_tracker[key].append(float(value.item()))
116+
else:
117+
self.loss_tracker[key].append(float(value.mean().item()))

metta/rl/loss/dynamics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class DynamicsConfig(LossConfig):
1616
returns_step_look_ahead: int = Field(default=1)
17-
unroll_steps: int = Field(default=0)
17+
unroll_steps: int = Field(default=2)
1818
returns_pred_coef: float = Field(default=1.0, ge=0, le=1.0)
1919
reward_pred_coef: float = Field(default=1.0, ge=0, le=1.0)
2020

0 commit comments

Comments
 (0)