@@ -143,7 +143,12 @@ def _prediction_step(self, hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Te
143143
144144
145145class ViTDefaultConfig (PolicyArchitecture ):
146- """Speed-optimized ViT variant with lighter token embeddings and attention stack."""
146+ """Speed-optimized ViT variant with lighter token embeddings and attention stack.
147+
148+ The trunk uses Axon blocks (post-up experts with residual connections) for efficient
149+ feature processing. Configure trunk depth, layer normalization, and hidden dimension
150+ scaling independently.
151+ """
147152
148153 class_path : str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
149154
@@ -156,6 +161,12 @@ class ViTDefaultConfig(PolicyArchitecture):
156161 pass_state_during_training : bool = False
157162 _critic_hidden = 512
158163
164+ # Trunk configuration
165+ # Number of Axon layers in the trunk (default: 16 for large model)
166+ trunk_num_resnet_layers : int = 1
167+ # Enable layer normalization after each trunk layer
168+ trunk_use_layer_norm : bool = True
169+
159170 components : List [ComponentConfig ] = [
160171 ObsShimTokensConfig (in_key = "env_obs" , out_key = "obs_shim_tokens" , max_tokens = 48 ),
161172 ObsAttrEmbedFourierConfig (
@@ -181,9 +192,9 @@ class ViTDefaultConfig(PolicyArchitecture):
181192 key_prefix = "vit_cortex_state" ,
182193 stack_cfg = build_cortex_auto_config (
183194 d_hidden = _latent_dim ,
184- num_layers = 1 ,
185- pattern = "L" ,
186- post_norm = False ,
195+ num_layers = trunk_num_resnet_layers ,
196+ pattern = "A" , # Axon blocks provide residual-like connections
197+ post_norm = trunk_use_layer_norm ,
187198 ),
188199 pass_state_during_training = pass_state_during_training ,
189200 ),
@@ -209,32 +220,20 @@ class ViTDefaultConfig(PolicyArchitecture):
209220 action_probs_config : ActionProbsConfig = ActionProbsConfig (in_key = "logits" )
210221
211222 def make_policy (self , policy_env_info : PolicyEnvInterface ) -> Policy :
212- # Ensure downstream components match core dimension
213- # (self._latent_dim might have been overridden on the instance without updating all components)
214- cortex = next (c for c in self .components if isinstance (c , CortexTDConfig ))
215- core_dim = cortex .out_features or cortex .d_hidden
216-
217- actor_mlp = next (c for c in self .components if c .name == "actor_mlp" )
218- assert isinstance (actor_mlp , MLPConfig )
219- if actor_mlp .in_features != core_dim :
220- actor_mlp .in_features = core_dim
221-
222- critic = next (c for c in self .components if c .name == "critic" )
223- assert isinstance (critic , MLPConfig )
224- if critic .in_features != core_dim :
225- critic .in_features = core_dim
223+ # Note: trunk configuration (num_layers, layer_norm, scaling) is applied
224+ # via the components list definition above, no runtime modification needed
226225
227226 AgentClass = load_symbol (self .class_path )
228227 policy = AgentClass (policy_env_info , self )
229228 policy .num_actions = policy_env_info .action_space .n
230229
231230 # Dimensions
232- latent_dim = core_dim
233- num_actions = policy .num_actions
231+ latent_dim = int ( self . _latent_dim )
232+ num_actions = int ( policy .num_actions )
234233
235234 # Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
236- dyn_input_dim = latent_dim + num_actions
237- dyn_output_dim = latent_dim + 1
235+ dyn_input_dim = int ( latent_dim + num_actions )
236+ dyn_output_dim = int ( latent_dim + 1 )
238237
239238 # Simple MLP for dynamics
240239 dynamics_net = nn .Sequential (
@@ -246,7 +245,7 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
246245
247246 # Returns/Reward Prediction Heads (for Muesli)
248247 # Input: Core + Logits
249- pred_input_dim = latent_dim + num_actions
248+ pred_input_dim = int ( latent_dim + num_actions )
250249
251250 returns_module = nn .Linear (pred_input_dim , 1 )
252251 reward_module = nn .Linear (pred_input_dim , 1 )
0 commit comments