Skip to content

Commit 867b705

Browse files
committed
cp
1 parent a0936ef commit 867b705

File tree

11 files changed

+1553
-18
lines changed

11 files changed

+1553
-18
lines changed

agent/src/metta/agent/policies/fast_dynamics.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
def forward(self, td: TensorDict, action: torch.Tensor = None) -> TensorDict:
2828
"""Forward pass for the FastDynamics policy."""
29-
self.network(td)
29+
self.network()(td)
3030
self.action_probs(td, action)
3131

3232
td["pred_input"] = torch.cat([td["core"], td["logits"]], dim=-1)
@@ -39,8 +39,6 @@ def forward(self, td: TensorDict, action: torch.Tensor = None) -> TensorDict:
3939
class FastDynamicsConfig(PolicyArchitecture):
4040
class_path: str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
4141

42-
class_path: str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
43-
4442
_latent_dim = 64
4543
_token_embed_dim = 8
4644
_fourier_freqs = 3

agent/src/metta/agent/policies/vit.py

Lines changed: 182 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import types
12
from typing import List
23

4+
import torch
35
from cortex.stacks import build_cortex_auto_config
6+
from tensordict import TensorDict
7+
from tensordict.nn import TensorDictModule as TDM
8+
from torch import nn
49

510
from metta.agent.components.actor import ActionProbsConfig, ActorHeadConfig
611
from metta.agent.components.component_config import ComponentConfig
@@ -9,7 +14,128 @@
914
from metta.agent.components.obs_enc import ObsPerceiverLatentConfig
1015
from metta.agent.components.obs_shim import ObsShimTokensConfig
1116
from metta.agent.components.obs_tokenizers import ObsAttrEmbedFourierConfig
12-
from metta.agent.policy import PolicyArchitecture
17+
from metta.agent.policy import Policy, PolicyArchitecture
18+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
19+
from mettagrid.util.module import load_symbol
20+
21+
22+
def forward(self, td: TensorDict, action: torch.Tensor = None) -> TensorDict:
23+
"""Forward pass for the ViT policy with dynamics heads."""
24+
self.network()(td)
25+
self.action_probs(td, action)
26+
27+
if "values" in td.keys():
28+
td["values"] = td["values"].flatten()
29+
30+
# Dynamics/Muesli predictions - only if modules were created by Dynamics loss
31+
if hasattr(self, "returns_pred") and self.returns_pred is not None:
32+
td["pred_input"] = torch.cat([td["core"], td["logits"]], dim=-1)
33+
self.returns_pred(td)
34+
self.reward_pred(td)
35+
36+
# K-step dynamics unrolling for Muesli
37+
if action is not None and self.unroll_steps > 0:
38+
_compute_unrolled_predictions(self, td, action)
39+
40+
return td
41+
42+
43+
def _compute_unrolled_predictions(self, td: TensorDict, actions: torch.Tensor) -> None:
44+
"""Compute K-step unrolled predictions and write to TensorDict.
45+
46+
Output shapes are (B*T, K, ...) to match TensorDict batch dimension.
47+
Only the first T_eff positions have valid predictions; rest are zero-padded.
48+
"""
49+
K = self.unroll_steps
50+
hidden = td["core"] # (B*T, H)
51+
52+
B = int(td["batch"][0].item())
53+
T = int(td["bptt"][0].item())
54+
BT = B * T
55+
56+
if T <= K:
57+
return # Not enough timesteps
58+
59+
T_eff = T - K
60+
61+
# Reshape to (B, T, ...) for temporal indexing
62+
hidden_bt = hidden.view(B, T, -1)
63+
actions_bt = actions.view(B, T)
64+
current_h = hidden_bt[:, :T_eff] # (B, T_eff, H)
65+
66+
unrolled_logits_list: list[torch.Tensor] = []
67+
unrolled_rewards_list: list[torch.Tensor] = []
68+
unrolled_returns_list: list[torch.Tensor] = []
69+
70+
for k in range(K):
71+
step_actions = actions_bt[:, k : k + T_eff]
72+
73+
# Dynamics: (h_k, a_k) -> (h_{k+1}, r_k)
74+
next_h, r_pred_k = _dynamics_step(self, current_h, step_actions)
75+
76+
# Prediction: h_{k+1} -> (pi_{k+1}, v_{k+1})
77+
p_logits_k, v_pred_k = _prediction_step(self, next_h)
78+
79+
unrolled_logits_list.append(p_logits_k) # (B, T_eff, A)
80+
unrolled_rewards_list.append(r_pred_k.squeeze(-1)) # (B, T_eff)
81+
unrolled_returns_list.append(v_pred_k.squeeze(-1)) # (B, T_eff)
82+
current_h = next_h
83+
84+
# Stack along K dimension: (B, T_eff, K, ...)
85+
stacked_logits = torch.stack(unrolled_logits_list, dim=2) # (B, T_eff, K, A)
86+
stacked_rewards = torch.stack(unrolled_rewards_list, dim=2) # (B, T_eff, K)
87+
stacked_returns = torch.stack(unrolled_returns_list, dim=2) # (B, T_eff, K)
88+
89+
# Pad T_eff to T so we can reshape to (B*T, K, ...)
90+
A = stacked_logits.shape[-1]
91+
padded_logits = torch.zeros(B, T, K, A, device=hidden.device, dtype=hidden.dtype)
92+
padded_rewards = torch.zeros(B, T, K, device=hidden.device, dtype=hidden.dtype)
93+
padded_returns = torch.zeros(B, T, K, device=hidden.device, dtype=hidden.dtype)
94+
95+
padded_logits[:, :T_eff] = stacked_logits
96+
padded_rewards[:, :T_eff] = stacked_rewards
97+
padded_returns[:, :T_eff] = stacked_returns
98+
99+
# Reshape to (B*T, K, ...) to match TensorDict batch dimension
100+
td["unrolled_logits"] = padded_logits.view(BT, K, A)
101+
td["unrolled_rewards"] = padded_rewards.view(BT, K)
102+
td["unrolled_returns"] = padded_returns.view(BT, K)
103+
104+
105+
def _dynamics_step(self, hidden: torch.Tensor, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
106+
"""Dynamics function: (h_t, a_t) -> (h_{t+1}, r_t)"""
107+
if action.dim() > 1 and action.shape[-1] == 1:
108+
action = action.squeeze(-1)
109+
110+
# One-hot encode actions
111+
if action.dtype in (torch.long, torch.int32, torch.int64):
112+
action_emb = torch.nn.functional.one_hot(action.long(), num_classes=self.num_actions).float()
113+
else:
114+
action_emb = action
115+
116+
dyn_input = torch.cat([hidden, action_emb], dim=-1)
117+
output = self.dynamics_model(dyn_input)
118+
119+
next_hidden = output[..., :-1]
120+
reward = output[..., -1:]
121+
122+
return next_hidden, reward
123+
124+
125+
def _prediction_step(self, hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
126+
"""Prediction function: h_t -> (pi_logits_t, v_t)"""
127+
td = TensorDict({"core": hidden}, batch_size=hidden.shape[0], device=hidden.device)
128+
129+
self.components["actor_mlp"](td)
130+
self.components["actor_head"](td)
131+
logits = td["logits"]
132+
133+
pred_input = torch.cat([hidden, logits], dim=-1)
134+
td["pred_input"] = pred_input
135+
self.returns_pred(td)
136+
returns = td["returns_pred"]
137+
138+
return logits, returns
13139

14140

15141
class ViTDefaultConfig(PolicyArchitecture):
@@ -26,6 +152,9 @@ class ViTDefaultConfig(PolicyArchitecture):
26152
pass_state_during_training: bool = False
27153
_critic_hidden = 512
28154

155+
# Dynamics/Muesli unroll steps - set > 0 to enable dynamics modules
156+
unroll_steps: int = 0
157+
29158
components: List[ComponentConfig] = [
30159
ObsShimTokensConfig(in_key="env_obs", out_key="obs_shim_tokens", max_tokens=48),
31160
ObsAttrEmbedFourierConfig(
@@ -73,7 +202,58 @@ class ViTDefaultConfig(PolicyArchitecture):
73202
out_features=1,
74203
hidden_features=[_critic_hidden],
75204
),
76-
ActorHeadConfig(in_key="actor_hidden", out_key="logits", input_dim=_actor_hidden),
205+
ActorHeadConfig(in_key="actor_hidden", out_key="logits", input_dim=_actor_hidden, name="actor_head"),
77206
]
78207

79208
action_probs_config: ActionProbsConfig = ActionProbsConfig(in_key="logits")
209+
210+
def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
211+
# Ensure downstream components match core dimension
212+
# (self._latent_dim might have been overridden on the instance without updating all components)
213+
cortex = next(c for c in self.components if isinstance(c, CortexTDConfig))
214+
core_dim = cortex.out_features or cortex.d_hidden
215+
216+
actor_mlp = next(c for c in self.components if c.name == "actor_mlp")
217+
assert isinstance(actor_mlp, MLPConfig)
218+
if actor_mlp.in_features != core_dim:
219+
actor_mlp.in_features = core_dim
220+
221+
critic = next(c for c in self.components if c.name == "critic")
222+
assert isinstance(critic, MLPConfig)
223+
if critic.in_features != core_dim:
224+
critic.in_features = core_dim
225+
226+
AgentClass = load_symbol(self.class_path)
227+
policy = AgentClass(policy_env_info, self)
228+
policy.num_actions = policy_env_info.action_space.n
229+
policy.unroll_steps = self.unroll_steps
230+
231+
# Only create dynamics modules if unroll_steps > 0
232+
if self.unroll_steps > 0:
233+
latent_dim = core_dim
234+
num_actions = policy.num_actions
235+
236+
# Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
237+
dyn_input_dim = latent_dim + num_actions
238+
dyn_output_dim = latent_dim + 1
239+
240+
dynamics_net = nn.Sequential(
241+
nn.Linear(dyn_input_dim, 256),
242+
nn.SiLU(),
243+
nn.Linear(256, dyn_output_dim),
244+
)
245+
policy.dynamics_model = dynamics_net
246+
247+
# Returns/Reward Prediction Heads (for Muesli)
248+
pred_input_dim = latent_dim + num_actions
249+
250+
returns_module = nn.Linear(pred_input_dim, 1)
251+
reward_module = nn.Linear(pred_input_dim, 1)
252+
253+
policy.returns_pred = TDM(returns_module, in_keys=["pred_input"], out_keys=["returns_pred"])
254+
policy.reward_pred = TDM(reward_module, in_keys=["pred_input"], out_keys=["reward_pred"])
255+
256+
# Attach methods
257+
policy.forward = types.MethodType(forward, policy)
258+
259+
return policy

0 commit comments

Comments
 (0)