Skip to content

Commit 0a8c698

Browse files
committed
init
no more .mpt Merge remote-tracking branch 'origin/main' into richard-unifympt slim policy spec handler more concise cleanup simplify Merge remote-tracking branch 'origin/main' into richard-unifympt re-add fix policy spex Update packages/mettagrid/python/src/mettagrid/util/uri_resolvers/schemes.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt bundles Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt bundle Merge remote-tracking branch 'origin/richard-unifympt' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? ugh compat cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt cleanup tests Merge remote-tracking branch 'origin/main' into richard-unifympt more tests Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? Merge branch 'main' into richard-unifympt cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt no more .mpt remove all .mpt and lint cleanup local data path fixes mpt re-add re-add artifact lint Merge remote-tracking branch 'origin/main' into richard-unifympt more cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt diff cleanup ftt lint fix error Merge remote-tracking branch 'origin/main' into richard-unifympt more tests lint Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt checkpoint policy does save/load lint checkpoint moving catcus lint Merge branch 'main' into richard-unifympt fold-in [pyright 4] Get pyright to pass on app_backend (#4478) Merge remote-tracking branch 'origin/main' into richard-unifympt Fix command, add space (#4456) added space to --app:lib--tlsEmulation:off which makes it --app:lib --tlsEmulation:off now it runs Rename HyperUpdateRule to ScheduleRule (#4483) ## Summary - rename HyperUpdateRule to ScheduleRule and apply to TrainerConfig via target_path - update recipes and teacher scheduling to use ScheduleRule - report PPO stats using ppo_actor/ppo_critic hyperparam keys and update tests ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt Fix supervisor teacher behavior and legacy BC mode (#4484) ## Summary - gate PPO actor during supervisor teacher phase - fix supervisor/no-teacher behavior and add legacy BC (no gating, no PPO resume) - require supervisor policy URI for sliced_cloner_no_ppo ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Co-authored-by: Adam S <134907338+gustofied@users.noreply.github.com> Minor fixes to the slstm triton kernel, causing failures for certain kernel sizes (#4492) cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt fold in training environments and eval environments mismatched (#4487) I ran a direct config comparison using the training entrypoint (recipes/experiment/cogs_v_clips.train) with variants=["heart_chorus"] and compared the eval suite config it builds (difficulty standard + heart_chorus) for an overlapping mission: hello_world.oxygen_bottleneck. Findings: - Compass is ON in both training and eval (global_obs.compass=True). - Vibe count and change‑vibe settings match (152 vibes; change_vibe.number_of_vibes=152). - But the mission parameters differ between training and eval for the same mission name: - game.objects.carbon_extractor.max_uses: train 25 vs eval 100 - game.objects.oxygen_extractor.max_uses: train 5 vs eval 20 - game.objects.germanium_extractor.max_uses: train 5 vs eval 20 - game.objects.silicon_extractor.max_uses: train 120 vs eval 480 So the mismatch isn’t compass — it’s the mission definitions used by training vs eval. Training uses base missions (cogames.cogs_vs_clips.missions), while eval uses integrated eval missions (cogames.cogs_vs_clips.evals.integrated_evals) that have different extractor settings. Also: the eval suite used by recipes/experiment/cogs_v_clips.train does not include machina_1.open_world at all (it only evaluates the 7 HELLO_WORLD integrated evals). So training can be creating hearts on easier missions while your eval runs on machina_1 are a different environment entirely. Relevant files: - Compass default: packages/cogames/src/cogames/cogs_vs_clips/mission.py - Training entrypoint + eval suite wiring: recipes/experiment/cogs_v_clips.py - Eval mission definitions: packages/cogames/src/cogames/cogs_vs_clips/evals/integrated_evals.py If you want true parity, we should align which mission templates eval uses (and/or include machina_1.open_world in the eval suite). I can patch this if you want — tell me whether you prefer: 1. Eval suite uses the same mission templates as training (from missions.py), or 2. Training uses the integrated eval mission definitions, or 3. Add machina_1.open_world to the eval suite. ripping out Merge remote-tracking branch 'origin/main' into richard-unifympt simplify fix and lint choke
1 parent 81723ce commit 0a8c698

File tree

3 files changed

+215
-1
lines changed

3 files changed

+215
-1
lines changed

agent/src/metta/agent/policies/vit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ViTDefaultConfig(PolicyArchitecture):
5454

5555
def make_policy(self, policy_env_info: PolicyEnvInterface) -> Policy:
5656
# If the architecture spec already bundled a component list (common for saved
57-
# .mpt checkpoints), reuse it instead of regenerating with current defaults.
57+
# checkpoint bundles), reuse it instead of regenerating with current defaults.
5858
# This keeps restored policies aligned with the shapes they were trained with.
5959
if self.components:
6060
return super().make_policy(policy_env_info)

metta/rl/mpt_artifact.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
3+
import tempfile
4+
import zipfile
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
from typing import Any, Mapping, MutableMapping, Protocol
8+
9+
import torch
10+
from safetensors.torch import load as load_safetensors
11+
from safetensors.torch import save as save_safetensors
12+
13+
from mettagrid.policy.checkpoint_policy import architecture_from_spec, prepare_state_dict_for_save
14+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
15+
from mettagrid.util.file import local_copy, write_file
16+
from mettagrid.util.uri_resolvers.schemes import parse_uri
17+
18+
19+
class PolicyArchitectureProtocol(Protocol):
20+
def make_policy(self, policy_env_info: PolicyEnvInterface) -> Any: ...
21+
22+
def to_spec(self) -> str:
23+
"""Serialize this architecture to a string specification."""
24+
...
25+
26+
@classmethod
27+
def from_spec(cls, spec: str) -> "PolicyArchitectureProtocol":
28+
"""Deserialize an architecture from a string specification."""
29+
...
30+
31+
32+
@dataclass
33+
class MptArtifact:
34+
architecture: Any
35+
state_dict: MutableMapping[str, torch.Tensor]
36+
37+
def instantiate(
38+
self,
39+
policy_env_info: PolicyEnvInterface,
40+
device: str = "cpu",
41+
*,
42+
strict: bool = True,
43+
) -> Any:
44+
torch_device = torch.device(device)
45+
46+
policy = self.architecture.make_policy(policy_env_info)
47+
policy = policy.to(torch_device)
48+
49+
missing, unexpected = policy.load_state_dict(dict(self.state_dict), strict=strict)
50+
if strict and (missing or unexpected):
51+
raise RuntimeError(f"Strict loading failed. Missing: {missing}, Unexpected: {unexpected}")
52+
53+
if hasattr(policy, "initialize_to_environment"):
54+
policy.initialize_to_environment(policy_env_info, torch_device)
55+
56+
return policy
57+
58+
59+
def load_mpt(uri: str) -> MptArtifact:
60+
"""Load an .mpt checkpoint from a local path or s3:// URI."""
61+
with local_copy(uri) as local_path:
62+
return _load_local_mpt_file(local_path)
63+
64+
65+
def _load_local_mpt_file(path: Path) -> MptArtifact:
66+
if not path.exists():
67+
raise FileNotFoundError(f"MPT file not found: {path}")
68+
69+
with zipfile.ZipFile(path, mode="r") as archive:
70+
names = set(archive.namelist())
71+
72+
if "weights.safetensors" not in names:
73+
raise ValueError(f"Invalid .mpt file: {path} (missing weights)")
74+
75+
if "modelarchitecture.txt" in names:
76+
architecture_blob = archive.read("modelarchitecture.txt").decode("utf-8")
77+
else:
78+
raise ValueError(f"Invalid .mpt file: {path} (missing architecture)")
79+
architecture = architecture_from_spec(architecture_blob)
80+
81+
weights_blob = archive.read("weights.safetensors")
82+
state_dict = load_safetensors(weights_blob)
83+
if not isinstance(state_dict, MutableMapping):
84+
raise TypeError("Loaded safetensors state_dict is not a mutable mapping")
85+
86+
return MptArtifact(architecture=architecture, state_dict=state_dict)
87+
88+
89+
def save_mpt(
90+
uri: str | Path,
91+
*,
92+
architecture: Any,
93+
state_dict: Mapping[str, torch.Tensor],
94+
) -> str:
95+
"""Save an .mpt checkpoint to a URI or local path. Returns the saved URI."""
96+
parsed = parse_uri(str(uri), allow_none=False)
97+
98+
if parsed.scheme == "s3":
99+
with tempfile.NamedTemporaryFile(suffix=".mpt", delete=False) as tmp:
100+
tmp_path = Path(tmp.name)
101+
try:
102+
_save_mpt_file_locally(tmp_path, architecture=architecture, state_dict=state_dict)
103+
write_file(parsed.canonical, str(tmp_path))
104+
finally:
105+
tmp_path.unlink(missing_ok=True)
106+
return parsed.canonical
107+
else:
108+
output_path = parsed.local_path or Path(str(uri)).expanduser().resolve()
109+
_save_mpt_file_locally(output_path, architecture=architecture, state_dict=state_dict)
110+
return f"file://{output_path.resolve()}"
111+
112+
113+
def _save_mpt_file_locally(
114+
path: Path,
115+
*,
116+
architecture: Any,
117+
state_dict: Mapping[str, torch.Tensor],
118+
) -> None:
119+
path.parent.mkdir(parents=True, exist_ok=True)
120+
prepared_state = prepare_state_dict_for_save(state_dict)
121+
122+
with tempfile.NamedTemporaryFile(
123+
dir=path.parent,
124+
prefix=f".{path.name}.",
125+
suffix=".tmp",
126+
delete=False,
127+
) as temp_file:
128+
temp_path = Path(temp_file.name)
129+
130+
try:
131+
with zipfile.ZipFile(temp_path, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
132+
weights_blob = save_safetensors(dict(prepared_state))
133+
archive.writestr("weights.safetensors", weights_blob)
134+
archive.writestr("modelarchitecture.txt", architecture.to_spec())
135+
136+
temp_path.replace(path)
137+
except Exception:
138+
temp_path.unlink(missing_ok=True)
139+
raise

metta/rl/mpt_policy.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from metta.rl.mpt_artifact import load_mpt, save_mpt
7+
from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
8+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
9+
from mettagrid.util.uri_resolvers.schemes import parse_uri
10+
11+
12+
class MptPolicy(MultiAgentPolicy):
13+
"""Load a policy from an .mpt checkpoint file.
14+
15+
The .mpt format stores weights and architecture configuration. This allows
16+
loading trained policies without a build dependency on the training code.
17+
"""
18+
19+
short_names = ["mpt"]
20+
21+
def __init__(
22+
self,
23+
policy_env_info: PolicyEnvInterface,
24+
*,
25+
checkpoint_uri: str | None = None,
26+
device: str = "cpu",
27+
strict: bool = True,
28+
):
29+
super().__init__(policy_env_info, device=device)
30+
31+
self._policy = None
32+
self._architecture = None
33+
self._strict = strict
34+
self._device = device
35+
36+
if checkpoint_uri:
37+
self._load_from_checkpoint(checkpoint_uri, device=device)
38+
39+
def _load_from_checkpoint(self, checkpoint_uri: str, *, device: str) -> None:
40+
artifact = load_mpt(checkpoint_uri)
41+
self._architecture = artifact.architecture
42+
self._policy = artifact.instantiate(self._policy_env_info, device=device, strict=self._strict)
43+
self._policy.eval()
44+
45+
def load_policy_data(self, policy_data_path: str) -> None:
46+
self._load_from_checkpoint(policy_data_path, device=self._device)
47+
48+
def agent_policy(self, agent_id: int) -> AgentPolicy:
49+
if self._policy is None:
50+
raise RuntimeError("MptPolicy has not been initialized with checkpoint data")
51+
return self._policy.agent_policy(agent_id)
52+
53+
def eval(self) -> "MptPolicy":
54+
"""Ensure wrapped policy enters eval mode for rollout/play compatibility."""
55+
if self._policy is not None:
56+
self._policy.eval()
57+
return self
58+
59+
def save_policy(
60+
self,
61+
destination: str | Path,
62+
*,
63+
policy_architecture: Any | None = None,
64+
) -> str:
65+
"""Save the wrapped policy to a URI or local path."""
66+
architecture = policy_architecture or self._architecture
67+
if architecture is None:
68+
raise ValueError("policy_architecture is required to save policy")
69+
if self._policy is None:
70+
raise ValueError("Policy has not been loaded; cannot save")
71+
72+
save_mpt(str(destination), architecture=architecture, state_dict=self._policy.state_dict())
73+
74+
parsed = parse_uri(str(destination), allow_none=False)
75+
return parsed.canonical

0 commit comments

Comments
 (0)