Skip to content

Commit fbade4e

Browse files
committed
cp
1 parent c6a7d80 commit fbade4e

File tree

1 file changed

+22
-23
lines changed
  • agent/src/metta/agent/policies

1 file changed

+22
-23
lines changed

agent/src/metta/agent/policies/vit.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ def _prediction_step(self, hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Te
143143

144144

145145
class ViTDefaultConfig(PolicyArchitecture):
146-
"""Speed-optimized ViT variant with lighter token embeddings and attention stack."""
146+
"""Speed-optimized ViT variant with lighter token embeddings and attention stack.
147+
148+
The trunk uses Axon blocks (post-up experts with residual connections) for efficient
149+
feature processing. Configure trunk depth, layer normalization, and hidden dimension
150+
scaling independently.
151+
"""
147152

148153
class_path: str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
149154

@@ -156,6 +161,12 @@ class ViTDefaultConfig(PolicyArchitecture):
156161
pass_state_during_training: bool = False
157162
_critic_hidden = 512
158163

164+
# Trunk configuration
165+
# Number of Axon layers in the trunk (default: 16 for large model)
166+
trunk_num_resnet_layers: int = 1
167+
# Enable layer normalization after each trunk layer
168+
trunk_use_layer_norm: bool = True
169+
159170
components: List[ComponentConfig] = [
160171
ObsShimTokensConfig(in_key="env_obs", out_key="obs_shim_tokens", max_tokens=48),
161172
ObsAttrEmbedFourierConfig(
@@ -181,9 +192,9 @@ class ViTDefaultConfig(PolicyArchitecture):
181192
key_prefix="vit_cortex_state",
182193
stack_cfg=build_cortex_auto_config(
183194
d_hidden=_latent_dim,
184-
num_layers=1,
185-
pattern="L",
186-
post_norm=False,
195+
num_layers=trunk_num_resnet_layers,
196+
pattern="A", # Axon blocks provide residual-like connections
197+
post_norm=trunk_use_layer_norm,
187198
),
188199
pass_state_during_training=pass_state_during_training,
189200
),
@@ -209,32 +220,20 @@ class ViTDefaultConfig(PolicyArchitecture):
209220
action_probs_config: ActionProbsConfig = ActionProbsConfig(in_key="logits")
210221

211222
def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
212-
# Ensure downstream components match core dimension
213-
# (self._latent_dim might have been overridden on the instance without updating all components)
214-
cortex = next(c for c in self.components if isinstance(c, CortexTDConfig))
215-
core_dim = cortex.out_features or cortex.d_hidden
216-
217-
actor_mlp = next(c for c in self.components if c.name == "actor_mlp")
218-
assert isinstance(actor_mlp, MLPConfig)
219-
if actor_mlp.in_features != core_dim:
220-
actor_mlp.in_features = core_dim
221-
222-
critic = next(c for c in self.components if c.name == "critic")
223-
assert isinstance(critic, MLPConfig)
224-
if critic.in_features != core_dim:
225-
critic.in_features = core_dim
223+
# Note: trunk configuration (num_layers, layer_norm, scaling) is applied
224+
# via the components list definition above, no runtime modification needed
226225

227226
AgentClass = load_symbol(self.class_path)
228227
policy = AgentClass(policy_env_info, self)
229228
policy.num_actions = policy_env_info.action_space.n
230229

231230
# Dimensions
232-
latent_dim = core_dim
233-
num_actions = policy.num_actions
231+
latent_dim = int(self._latent_dim)
232+
num_actions = int(policy.num_actions)
234233

235234
# Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
236-
dyn_input_dim = latent_dim + num_actions
237-
dyn_output_dim = latent_dim + 1
235+
dyn_input_dim = int(latent_dim + num_actions)
236+
dyn_output_dim = int(latent_dim + 1)
238237

239238
# Simple MLP for dynamics
240239
dynamics_net = nn.Sequential(
@@ -246,7 +245,7 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
246245

247246
# Returns/Reward Prediction Heads (for Muesli)
248247
# Input: Core + Logits
249-
pred_input_dim = latent_dim + num_actions
248+
pred_input_dim = int(latent_dim + num_actions)
250249

251250
returns_module = nn.Linear(pred_input_dim, 1)
252251
reward_module = nn.Linear(pred_input_dim, 1)

0 commit comments

Comments
 (0)