1+ import types
12from typing import List
23
4+ import torch
35from cortex .stacks import build_cortex_auto_config
6+ from tensordict import TensorDict
7+ from tensordict .nn import TensorDictModule as TDM
8+ from torch import nn
49
510from metta .agent .components .actor import ActionProbsConfig , ActorHeadConfig
611from metta .agent .components .component_config import ComponentConfig
914from metta .agent .components .obs_enc import ObsPerceiverLatentConfig
1015from metta .agent .components .obs_shim import ObsShimTokensConfig
1116from metta .agent .components .obs_tokenizers import ObsAttrEmbedFourierConfig
12- from metta .agent .policy import PolicyArchitecture
17+ from metta .agent .policy import Policy , PolicyArchitecture
18+ from mettagrid .policy .policy_env_interface import PolicyEnvInterface
19+ from mettagrid .util .module import load_symbol
20+
21+
22+ def forward (self , td : TensorDict , action : torch .Tensor = None ) -> TensorDict :
23+ """Forward pass for the ViT policy with dynamics heads."""
24+ self .network ()(td )
25+ self .action_probs (td , action )
26+
27+ if "values" in td .keys ():
28+ td ["values" ] = td ["values" ].flatten ()
29+
30+ # Dynamics/Muesli predictions - only if modules were created by Dynamics loss
31+ if hasattr (self , "returns_pred" ) and self .returns_pred is not None :
32+ td ["pred_input" ] = torch .cat ([td ["core" ], td ["logits" ]], dim = - 1 )
33+ self .returns_pred (td )
34+ self .reward_pred (td )
35+
36+ # K-step dynamics unrolling for Muesli
37+ if action is not None and self .unroll_steps > 0 :
38+ _compute_unrolled_predictions (self , td , action )
39+
40+ return td
41+
42+
43+ def _compute_unrolled_predictions (self , td : TensorDict , actions : torch .Tensor ) -> None :
44+ """Compute K-step unrolled predictions and write to TensorDict.
45+
46+ Output shapes are (B*T, K, ...) to match TensorDict batch dimension.
47+ Only the first T_eff positions have valid predictions; rest are zero-padded.
48+ """
49+ K = self .unroll_steps
50+ hidden = td ["core" ] # (B*T, H)
51+
52+ B = int (td ["batch" ][0 ].item ())
53+ T = int (td ["bptt" ][0 ].item ())
54+ BT = B * T
55+
56+ if T <= K :
57+ return # Not enough timesteps
58+
59+ T_eff = T - K
60+
61+ # Reshape to (B, T, ...) for temporal indexing
62+ hidden_bt = hidden .view (B , T , - 1 )
63+ actions_bt = actions .view (B , T )
64+ current_h = hidden_bt [:, :T_eff ] # (B, T_eff, H)
65+
66+ unrolled_logits_list : list [torch .Tensor ] = []
67+ unrolled_rewards_list : list [torch .Tensor ] = []
68+ unrolled_returns_list : list [torch .Tensor ] = []
69+
70+ for k in range (K ):
71+ step_actions = actions_bt [:, k : k + T_eff ]
72+
73+ # Dynamics: (h_k, a_k) -> (h_{k+1}, r_k)
74+ next_h , r_pred_k = _dynamics_step (self , current_h , step_actions )
75+
76+ # Prediction: h_{k+1} -> (pi_{k+1}, v_{k+1})
77+ p_logits_k , v_pred_k = _prediction_step (self , next_h )
78+
79+ unrolled_logits_list .append (p_logits_k ) # (B, T_eff, A)
80+ unrolled_rewards_list .append (r_pred_k .squeeze (- 1 )) # (B, T_eff)
81+ unrolled_returns_list .append (v_pred_k .squeeze (- 1 )) # (B, T_eff)
82+ current_h = next_h
83+
84+ # Stack along K dimension: (B, T_eff, K, ...)
85+ stacked_logits = torch .stack (unrolled_logits_list , dim = 2 ) # (B, T_eff, K, A)
86+ stacked_rewards = torch .stack (unrolled_rewards_list , dim = 2 ) # (B, T_eff, K)
87+ stacked_returns = torch .stack (unrolled_returns_list , dim = 2 ) # (B, T_eff, K)
88+
89+ # Pad T_eff to T so we can reshape to (B*T, K, ...)
90+ A = stacked_logits .shape [- 1 ]
91+ padded_logits = torch .zeros (B , T , K , A , device = hidden .device , dtype = hidden .dtype )
92+ padded_rewards = torch .zeros (B , T , K , device = hidden .device , dtype = hidden .dtype )
93+ padded_returns = torch .zeros (B , T , K , device = hidden .device , dtype = hidden .dtype )
94+
95+ padded_logits [:, :T_eff ] = stacked_logits
96+ padded_rewards [:, :T_eff ] = stacked_rewards
97+ padded_returns [:, :T_eff ] = stacked_returns
98+
99+ # Reshape to (B*T, K, ...) to match TensorDict batch dimension
100+ td ["unrolled_logits" ] = padded_logits .view (BT , K , A )
101+ td ["unrolled_rewards" ] = padded_rewards .view (BT , K )
102+ td ["unrolled_returns" ] = padded_returns .view (BT , K )
103+
104+
105+ def _dynamics_step (self , hidden : torch .Tensor , action : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
106+ """Dynamics function: (h_t, a_t) -> (h_{t+1}, r_t)"""
107+ if action .dim () > 1 and action .shape [- 1 ] == 1 :
108+ action = action .squeeze (- 1 )
109+
110+ # One-hot encode actions
111+ if action .dtype in (torch .long , torch .int32 , torch .int64 ):
112+ action_emb = torch .nn .functional .one_hot (action .long (), num_classes = self .num_actions ).float ()
113+ else :
114+ action_emb = action
115+
116+ dyn_input = torch .cat ([hidden , action_emb ], dim = - 1 )
117+ output = self .dynamics_model (dyn_input )
118+
119+ next_hidden = output [..., :- 1 ]
120+ reward = output [..., - 1 :]
121+
122+ return next_hidden , reward
123+
124+
125+ def _prediction_step (self , hidden : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
126+ """Prediction function: h_t -> (pi_logits_t, v_t)"""
127+ td = TensorDict ({"core" : hidden }, batch_size = hidden .shape [0 ], device = hidden .device )
128+
129+ self .components ["actor_mlp" ](td )
130+ self .components ["actor_head" ](td )
131+ logits = td ["logits" ]
132+
133+ pred_input = torch .cat ([hidden , logits ], dim = - 1 )
134+ td ["pred_input" ] = pred_input
135+ self .returns_pred (td )
136+ returns = td ["returns_pred" ]
137+
138+ return logits , returns
13139
14140
15141class ViTDefaultConfig (PolicyArchitecture ):
@@ -26,6 +152,9 @@ class ViTDefaultConfig(PolicyArchitecture):
26152 pass_state_during_training : bool = False
27153 _critic_hidden = 512
28154
155+ # Dynamics/Muesli unroll steps - set > 0 to enable dynamics modules
156+ unroll_steps : int = 0
157+
29158 components : List [ComponentConfig ] = [
30159 ObsShimTokensConfig (in_key = "env_obs" , out_key = "obs_shim_tokens" , max_tokens = 48 ),
31160 ObsAttrEmbedFourierConfig (
@@ -73,7 +202,58 @@ class ViTDefaultConfig(PolicyArchitecture):
73202 out_features = 1 ,
74203 hidden_features = [_critic_hidden ],
75204 ),
76- ActorHeadConfig (in_key = "actor_hidden" , out_key = "logits" , input_dim = _actor_hidden ),
205+ ActorHeadConfig (in_key = "actor_hidden" , out_key = "logits" , input_dim = _actor_hidden , name = "actor_head" ),
77206 ]
78207
79208 action_probs_config : ActionProbsConfig = ActionProbsConfig (in_key = "logits" )
209+
210+ def make_policy (self , policy_env_info : PolicyEnvInterface ) -> Policy :
211+ # Ensure downstream components match core dimension
212+ # (self._latent_dim might have been overridden on the instance without updating all components)
213+ cortex = next (c for c in self .components if isinstance (c , CortexTDConfig ))
214+ core_dim = cortex .out_features or cortex .d_hidden
215+
216+ actor_mlp = next (c for c in self .components if c .name == "actor_mlp" )
217+ assert isinstance (actor_mlp , MLPConfig )
218+ if actor_mlp .in_features != core_dim :
219+ actor_mlp .in_features = core_dim
220+
221+ critic = next (c for c in self .components if c .name == "critic" )
222+ assert isinstance (critic , MLPConfig )
223+ if critic .in_features != core_dim :
224+ critic .in_features = core_dim
225+
226+ AgentClass = load_symbol (self .class_path )
227+ policy = AgentClass (policy_env_info , self )
228+ policy .num_actions = policy_env_info .action_space .n
229+ policy .unroll_steps = self .unroll_steps
230+
231+ # Only create dynamics modules if unroll_steps > 0
232+ if self .unroll_steps > 0 :
233+ latent_dim = core_dim
234+ num_actions = policy .num_actions
235+
236+ # Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
237+ dyn_input_dim = latent_dim + num_actions
238+ dyn_output_dim = latent_dim + 1
239+
240+ dynamics_net = nn .Sequential (
241+ nn .Linear (dyn_input_dim , 256 ),
242+ nn .SiLU (),
243+ nn .Linear (256 , dyn_output_dim ),
244+ )
245+ policy .dynamics_model = dynamics_net
246+
247+ # Returns/Reward Prediction Heads (for Muesli)
248+ pred_input_dim = latent_dim + num_actions
249+
250+ returns_module = nn .Linear (pred_input_dim , 1 )
251+ reward_module = nn .Linear (pred_input_dim , 1 )
252+
253+ policy .returns_pred = TDM (returns_module , in_keys = ["pred_input" ], out_keys = ["returns_pred" ])
254+ policy .reward_pred = TDM (reward_module , in_keys = ["pred_input" ], out_keys = ["reward_pred" ])
255+
256+ # Attach methods
257+ policy .forward = types .MethodType (forward , policy )
258+
259+ return policy
0 commit comments