@@ -27,7 +27,7 @@ def forward(self, td: TensorDict, action: torch.Tensor = None) -> TensorDict:
2727 if "values" in td .keys ():
2828 td ["values" ] = td ["values" ].flatten ()
2929
30- # Dynamics/Muesli predictions - only if modules were created by Dynamics loss
30+ # Dynamics/Muesli predictions - only if modules were created
3131 if hasattr (self , "returns_pred" ) and self .returns_pred is not None :
3232 td ["pred_input" ] = torch .cat ([td ["core" ], td ["logits" ]], dim = - 1 )
3333 self .returns_pred (td )
@@ -139,7 +139,12 @@ def _prediction_step(self, hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Te
139139
140140
141141class ViTDefaultConfig (PolicyArchitecture ):
142- """Speed-optimized ViT variant with lighter token embeddings and attention stack."""
142+ """Speed-optimized ViT variant with lighter token embeddings and attention stack.
143+
144+ The trunk uses Axon blocks (post-up experts with residual connections) for efficient
145+ feature processing. Configure trunk depth, layer normalization, and hidden dimension
146+ scaling independently.
147+ """
143148
144149 class_path : str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
145150
@@ -155,6 +160,12 @@ class ViTDefaultConfig(PolicyArchitecture):
155160 # Dynamics/Muesli unroll steps - set > 0 to enable dynamics modules
156161 unroll_steps : int = 0
157162
163+ # Trunk configuration
164+ # Number of Axon layers in the trunk (default: 16 for large model)
165+ trunk_num_resnet_layers : int = 1
166+ # Enable layer normalization after each trunk layer
167+ trunk_use_layer_norm : bool = True
168+
158169 components : List [ComponentConfig ] = [
159170 ObsShimTokensConfig (in_key = "env_obs" , out_key = "obs_shim_tokens" , max_tokens = 48 ),
160171 ObsAttrEmbedFourierConfig (
@@ -180,9 +191,9 @@ class ViTDefaultConfig(PolicyArchitecture):
180191 key_prefix = "vit_cortex_state" ,
181192 stack_cfg = build_cortex_auto_config (
182193 d_hidden = _latent_dim ,
183- num_layers = 1 ,
184- pattern = "L" ,
185- post_norm = False ,
194+ num_layers = trunk_num_resnet_layers ,
195+ pattern = "A" , # Axon blocks provide residual-like connections
196+ post_norm = trunk_use_layer_norm ,
186197 ),
187198 pass_state_during_training = pass_state_during_training ,
188199 ),
@@ -208,20 +219,8 @@ class ViTDefaultConfig(PolicyArchitecture):
208219 action_probs_config : ActionProbsConfig = ActionProbsConfig (in_key = "logits" )
209220
210221 def make_policy (self , policy_env_info : PolicyEnvInterface ) -> Policy :
211- # Ensure downstream components match core dimension
212- # (self._latent_dim might have been overridden on the instance without updating all components)
213- cortex = next (c for c in self .components if isinstance (c , CortexTDConfig ))
214- core_dim = cortex .out_features or cortex .d_hidden
215-
216- actor_mlp = next (c for c in self .components if c .name == "actor_mlp" )
217- assert isinstance (actor_mlp , MLPConfig )
218- if actor_mlp .in_features != core_dim :
219- actor_mlp .in_features = core_dim
220-
221- critic = next (c for c in self .components if c .name == "critic" )
222- assert isinstance (critic , MLPConfig )
223- if critic .in_features != core_dim :
224- critic .in_features = core_dim
222+ # Note: trunk configuration (num_layers, layer_norm, scaling) is applied
223+ # via the components list definition above, no runtime modification needed
225224
226225 AgentClass = load_symbol (self .class_path )
227226 policy = AgentClass (policy_env_info , self )
@@ -230,12 +229,12 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
230229
231230 # Only create dynamics modules if unroll_steps > 0
232231 if self .unroll_steps > 0 :
233- latent_dim = core_dim
234- num_actions = policy .num_actions
232+ latent_dim = int ( self . _latent_dim )
233+ num_actions = int ( policy .num_actions )
235234
236235 # Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
237- dyn_input_dim = latent_dim + num_actions
238- dyn_output_dim = latent_dim + 1
236+ dyn_input_dim = int ( latent_dim + num_actions )
237+ dyn_output_dim = int ( latent_dim + 1 )
239238
240239 dynamics_net = nn .Sequential (
241240 nn .Linear (dyn_input_dim , 256 ),
@@ -245,7 +244,8 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
245244 policy .dynamics_model = dynamics_net
246245
247246 # Returns/Reward Prediction Heads (for Muesli)
248- pred_input_dim = latent_dim + num_actions
247+ # Input: Core + Logits
248+ pred_input_dim = int (latent_dim + num_actions )
249249
250250 returns_module = nn .Linear (pred_input_dim , 1 )
251251 reward_module = nn .Linear (pred_input_dim , 1 )
0 commit comments