Skip to content

Commit f967d87

Browse files
committed
cp
1 parent 5213d93 commit f967d87

File tree

8 files changed

+32
-34
lines changed

8 files changed

+32
-34
lines changed

.bazelversion

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
latest

agent/src/metta/agent/policies/vit.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def forward(self, td: TensorDict, action: torch.Tensor = None) -> TensorDict:
2727
if "values" in td.keys():
2828
td["values"] = td["values"].flatten()
2929

30-
# Dynamics/Muesli predictions - only if modules were created by Dynamics loss
30+
# Dynamics/Muesli predictions - only if modules were created
3131
if hasattr(self, "returns_pred") and self.returns_pred is not None:
3232
td["pred_input"] = torch.cat([td["core"], td["logits"]], dim=-1)
3333
self.returns_pred(td)
@@ -139,7 +139,12 @@ def _prediction_step(self, hidden: torch.Tensor) -> tuple[torch.Tensor, torch.Te
139139

140140

141141
class ViTDefaultConfig(PolicyArchitecture):
142-
"""Speed-optimized ViT variant with lighter token embeddings and attention stack."""
142+
"""Speed-optimized ViT variant with lighter token embeddings and attention stack.
143+
144+
The trunk uses Axon blocks (post-up experts with residual connections) for efficient
145+
feature processing. Configure trunk depth, layer normalization, and hidden dimension
146+
scaling independently.
147+
"""
143148

144149
class_path: str = "metta.agent.policy_auto_builder.PolicyAutoBuilder"
145150

@@ -155,6 +160,12 @@ class ViTDefaultConfig(PolicyArchitecture):
155160
# Dynamics/Muesli unroll steps - set > 0 to enable dynamics modules
156161
unroll_steps: int = 0
157162

163+
# Trunk configuration
164+
# Number of Axon layers in the trunk (default: 16 for large model)
165+
trunk_num_resnet_layers: int = 1
166+
# Enable layer normalization after each trunk layer
167+
trunk_use_layer_norm: bool = True
168+
158169
components: List[ComponentConfig] = [
159170
ObsShimTokensConfig(in_key="env_obs", out_key="obs_shim_tokens", max_tokens=48),
160171
ObsAttrEmbedFourierConfig(
@@ -180,9 +191,9 @@ class ViTDefaultConfig(PolicyArchitecture):
180191
key_prefix="vit_cortex_state",
181192
stack_cfg=build_cortex_auto_config(
182193
d_hidden=_latent_dim,
183-
num_layers=1,
184-
pattern="L",
185-
post_norm=False,
194+
num_layers=trunk_num_resnet_layers,
195+
pattern="A", # Axon blocks provide residual-like connections
196+
post_norm=trunk_use_layer_norm,
186197
),
187198
pass_state_during_training=pass_state_during_training,
188199
),
@@ -208,20 +219,8 @@ class ViTDefaultConfig(PolicyArchitecture):
208219
action_probs_config: ActionProbsConfig = ActionProbsConfig(in_key="logits")
209220

210221
def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
211-
# Ensure downstream components match core dimension
212-
# (self._latent_dim might have been overridden on the instance without updating all components)
213-
cortex = next(c for c in self.components if isinstance(c, CortexTDConfig))
214-
core_dim = cortex.out_features or cortex.d_hidden
215-
216-
actor_mlp = next(c for c in self.components if c.name == "actor_mlp")
217-
assert isinstance(actor_mlp, MLPConfig)
218-
if actor_mlp.in_features != core_dim:
219-
actor_mlp.in_features = core_dim
220-
221-
critic = next(c for c in self.components if c.name == "critic")
222-
assert isinstance(critic, MLPConfig)
223-
if critic.in_features != core_dim:
224-
critic.in_features = core_dim
222+
# Note: trunk configuration (num_layers, layer_norm, scaling) is applied
223+
# via the components list definition above, no runtime modification needed
225224

226225
AgentClass = load_symbol(self.class_path)
227226
policy = AgentClass(policy_env_info, self)
@@ -230,12 +229,12 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
230229

231230
# Only create dynamics modules if unroll_steps > 0
232231
if self.unroll_steps > 0:
233-
latent_dim = core_dim
234-
num_actions = policy.num_actions
232+
latent_dim = int(self._latent_dim)
233+
num_actions = int(policy.num_actions)
235234

236235
# Dynamics Model: (Hidden + Action) -> (Hidden + Reward)
237-
dyn_input_dim = latent_dim + num_actions
238-
dyn_output_dim = latent_dim + 1
236+
dyn_input_dim = int(latent_dim + num_actions)
237+
dyn_output_dim = int(latent_dim + 1)
239238

240239
dynamics_net = nn.Sequential(
241240
nn.Linear(dyn_input_dim, 256),
@@ -245,7 +244,8 @@ def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
245244
policy.dynamics_model = dynamics_net
246245

247246
# Returns/Reward Prediction Heads (for Muesli)
248-
pred_input_dim = latent_dim + num_actions
247+
# Input: Core + Logits
248+
pred_input_dim = int(latent_dim + num_actions)
249249

250250
returns_module = nn.Linear(pred_input_dim, 1)
251251
reward_module = nn.Linear(pred_input_dim, 1)

agent/src/metta/agent/policy_auto_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def initialize_to_environment(
6363
self.to(device)
6464
if device.type == "cuda":
6565
self._configure_sdp()
66-
torch.backends.cuda.matmul.fp32_precision = "tf32" # type: ignore[attr-defined]
67-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[attr-defined]
66+
torch.set_float32_matmul_precision("high")
6867
logs = []
6968
for _, value in self.components.items():
7069
if hasattr(value, "initialize_to_environment"):

metta/rl/training/distributed_helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def _setup_torch_optimizations(self) -> None:
5454
"""Configure PyTorch for optimal performance."""
5555
# Keep TF32 fast paths enabled on compatible GPUs (using new API)
5656
if torch.cuda.is_available() and hasattr(torch.backends, "cuda"):
57-
torch.backends.cuda.matmul.fp32_precision = "tf32" # type: ignore[attr-defined]
58-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[attr-defined]
57+
torch.set_float32_matmul_precision("high")
5958
# Enable SDPA optimizations for better attention performance
6059
torch.backends.cuda.enable_flash_sdp(True)
6160
torch.backends.cuda.enable_mem_efficient_sdp(True)

metta/tools/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,7 @@ def _configure_torch_backends(self) -> None:
318318
return
319319

320320
try:
321-
torch.backends.cuda.matmul.fp32_precision = "tf32" # type: ignore[attr-defined]
322-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[attr-defined]
321+
torch.set_float32_matmul_precision("high")
323322
except Exception as exc: # pragma: no cover - diagnostic only
324323
logger.debug("Skipping CUDA matmul backend configuration: %s", exc)
325324

packages/cortex/evaluations/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def _enable_determinism() -> None:
345345
"""Force deterministic behavior where possible (CUDA/cuBLAS/torch)."""
346346
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
347347
torch.use_deterministic_algorithms(True)
348-
torch.backends.cuda.matmul.allow_tf32 = False # type: ignore[attr-defined]
348+
torch.backends.cuda.matmul.fp32_precision = "highest" # type: ignore[attr-defined]
349349
torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
350350
torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
351351

packages/cortex/src/cortex/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def configure_tf32_precision() -> None:
9191
if not torch.cuda.is_available():
9292
return
9393

94-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[attr-defined]
94+
# Use the official high-level API to avoid conflicts with torch.compile
95+
torch.set_float32_matmul_precision("high")
9596

9697

9798
__all__ = ["TRITON_AVAILABLE", "select_backend", "configure_tf32_precision"]

packages/pufferlib-core/src/pufferlib/pufferl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ class PuffeRL:
7777
def __init__(self, config, vecenv, policy, logger=None):
7878
# Backend perf optimization (using new API)
7979
if torch.cuda.is_available():
80-
torch.backends.cuda.matmul.fp32_precision = "tf32" # type: ignore[attr-defined]
81-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[attr-defined]
80+
torch.set_float32_matmul_precision("high")
8281
torch.backends.cudnn.deterministic = config["torch_deterministic"]
8382
torch.backends.cudnn.benchmark = True
8483

0 commit comments

Comments
 (0)