Skip to content

Commit 7c95ff5

Browse files
committed
init
no more .mpt Merge remote-tracking branch 'origin/main' into richard-unifympt slim policy spec handler more concise cleanup simplify Merge remote-tracking branch 'origin/main' into richard-unifympt re-add fix policy spex Update packages/mettagrid/python/src/mettagrid/util/uri_resolvers/schemes.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt bundles Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt bundle Merge remote-tracking branch 'origin/richard-unifympt' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? ugh compat cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt cleanup tests Merge remote-tracking branch 'origin/main' into richard-unifympt more tests Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? Merge branch 'main' into richard-unifympt cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt no more .mpt remove all .mpt and lint cleanup local data path fixes mpt re-add re-add artifact lint Merge remote-tracking branch 'origin/main' into richard-unifympt more cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt diff cleanup ftt lint fix error Merge remote-tracking branch 'origin/main' into richard-unifympt more tests lint Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt checkpoint policy does save/load lint checkpoint moving catcus lint Merge branch 'main' into richard-unifympt fold-in [pyright 4] Get pyright to pass on app_backend (#4478) Merge remote-tracking branch 'origin/main' into richard-unifympt Fix command, add space (#4456) added space to --app:lib--tlsEmulation:off which makes it --app:lib --tlsEmulation:off now it runs Rename HyperUpdateRule to ScheduleRule (#4483) ## Summary - rename HyperUpdateRule to ScheduleRule and apply to TrainerConfig via target_path - update recipes and teacher scheduling to use ScheduleRule - report PPO stats using ppo_actor/ppo_critic hyperparam keys and update tests ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt Fix supervisor teacher behavior and legacy BC mode (#4484) ## Summary - gate PPO actor during supervisor teacher phase - fix supervisor/no-teacher behavior and add legacy BC (no gating, no PPO resume) - require supervisor policy URI for sliced_cloner_no_ppo ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Co-authored-by: Adam S <134907338+gustofied@users.noreply.github.com> Minor fixes to the slstm triton kernel, causing failures for certain kernel sizes (#4492) cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt fold in training environments and eval environments mismatched (#4487) I ran a direct config comparison using the training entrypoint (recipes/experiment/cogs_v_clips.train) with variants=["heart_chorus"] and compared the eval suite config it builds (difficulty standard + heart_chorus) for an overlapping mission: hello_world.oxygen_bottleneck. Findings: - Compass is ON in both training and eval (global_obs.compass=True). - Vibe count and change‑vibe settings match (152 vibes; change_vibe.number_of_vibes=152). - But the mission parameters differ between training and eval for the same mission name: - game.objects.carbon_extractor.max_uses: train 25 vs eval 100 - game.objects.oxygen_extractor.max_uses: train 5 vs eval 20 - game.objects.germanium_extractor.max_uses: train 5 vs eval 20 - game.objects.silicon_extractor.max_uses: train 120 vs eval 480 So the mismatch isn’t compass — it’s the mission definitions used by training vs eval. Training uses base missions (cogames.cogs_vs_clips.missions), while eval uses integrated eval missions (cogames.cogs_vs_clips.evals.integrated_evals) that have different extractor settings. Also: the eval suite used by recipes/experiment/cogs_v_clips.train does not include machina_1.open_world at all (it only evaluates the 7 HELLO_WORLD integrated evals). So training can be creating hearts on easier missions while your eval runs on machina_1 are a different environment entirely. Relevant files: - Compass default: packages/cogames/src/cogames/cogs_vs_clips/mission.py - Training entrypoint + eval suite wiring: recipes/experiment/cogs_v_clips.py - Eval mission definitions: packages/cogames/src/cogames/cogs_vs_clips/evals/integrated_evals.py If you want true parity, we should align which mission templates eval uses (and/or include machina_1.open_world in the eval suite). I can patch this if you want — tell me whether you prefer: 1. Eval suite uses the same mission templates as training (from missions.py), or 2. Training uses the integrated eval mission definitions, or 3. Add machina_1.open_world to the eval suite. ripping out Merge remote-tracking branch 'origin/main' into richard-unifympt simplify fix and lint choke simplify submission zip creation use policy_spec for submission zips tighten checkpoint io helpers shorten checkpoint arg help inline checkpoint policy helpers restore policy spec docstring validate checkpoint data_path before download require checkpoint directory URIs expand policy spec s3 docstring
1 parent d0b2e7f commit 7c95ff5

File tree

9 files changed

+269
-69
lines changed

9 files changed

+269
-69
lines changed

metta/rl/checkpoint_manager.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from metta.rl.system_config import SystemConfig
1010
from metta.rl.training.optimizer import is_schedulefree_optimizer
1111
from metta.tools.utils.auto_config import auto_policy_storage_decision
12-
from mettagrid.policy.mpt_artifact import save_mpt
13-
from mettagrid.util.uri_resolvers.schemes import checkpoint_filename, resolve_uri
12+
from mettagrid.policy.checkpoint_policy import WEIGHTS_FILENAME, CheckpointPolicy
13+
from mettagrid.policy.submission import POLICY_SPEC_FILENAME
14+
from mettagrid.util.file import write_data
15+
from mettagrid.util.uri_resolvers.schemes import resolve_uri
1416

1517
logger = logging.getLogger(__name__)
1618

@@ -80,16 +82,30 @@ def try_resolve(uri: str) -> tuple[str, int] | None:
8082
return max(candidates, key=lambda x: x[1])[0]
8183

8284
def save_policy_checkpoint(self, state_dict: dict, architecture, epoch: int) -> str:
83-
filename = checkpoint_filename(self.run_name, epoch)
8485
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
85-
86-
local_uri = save_mpt(self.checkpoint_dir / filename, architecture=architecture, state_dict=state_dict)
86+
checkpoint_dir = CheckpointPolicy.write_checkpoint_dir(
87+
base_dir=self.checkpoint_dir,
88+
run_name=self.run_name,
89+
epoch=epoch,
90+
architecture=architecture,
91+
state_dict=state_dict,
92+
)
8793

8894
if self._remote_prefix:
89-
remote_uri = save_mpt(f"{self.output_uri}/{filename}", architecture=architecture, state_dict=state_dict)
90-
logger.debug("Policy checkpoint saved remotely to %s", remote_uri)
91-
return remote_uri
92-
95+
remote_dir = f"{self.output_uri.rstrip('/')}/{checkpoint_dir.name}"
96+
write_data(
97+
f"{remote_dir}/{WEIGHTS_FILENAME}",
98+
(checkpoint_dir / WEIGHTS_FILENAME).read_bytes(),
99+
)
100+
write_data(
101+
f"{remote_dir}/{POLICY_SPEC_FILENAME}",
102+
(checkpoint_dir / POLICY_SPEC_FILENAME).read_bytes(),
103+
content_type="application/json",
104+
)
105+
logger.debug("Policy checkpoint saved remotely to %s", remote_dir)
106+
return remote_dir
107+
108+
local_uri = checkpoint_dir.as_uri()
93109
logger.debug("Policy checkpoint saved locally to %s", local_uri)
94110
return local_uri
95111

metta/rl/loss/eer_kickstarter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from metta.rl.advantage import compute_advantage
1111
from metta.rl.loss.loss import Loss, LossConfig
1212
from metta.rl.training import ComponentContext
13+
from mettagrid.policy.checkpoint_policy import CheckpointPolicy
1314
from mettagrid.policy.loader import initialize_or_load_policy
14-
from mettagrid.policy.mpt_policy import MptPolicy
1515
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri
1616

1717
if TYPE_CHECKING:
@@ -59,8 +59,8 @@ def __init__(
5959
raise RuntimeError("Environment metadata is required to instantiate teacher policy")
6060
teacher_spec = policy_spec_from_uri(self.cfg.teacher_uri, device=str(self.device))
6161
self.teacher_policy = initialize_or_load_policy(policy_env_info, teacher_spec)
62-
if isinstance(self.teacher_policy, MptPolicy):
63-
self.teacher_policy = self.teacher_policy._policy
62+
if isinstance(self.teacher_policy, CheckpointPolicy):
63+
self.teacher_policy = self.teacher_policy.wrapped_policy
6464

6565
def get_experience_spec(self) -> Composite:
6666
act_space = self.env.single_action_space

metta/rl/loss/kickstarter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from metta.agent.policy import Policy
1111
from metta.rl.loss.loss import Loss, LossConfig
1212
from metta.rl.training import ComponentContext
13+
from mettagrid.policy.checkpoint_policy import CheckpointPolicy
1314
from mettagrid.policy.loader import initialize_or_load_policy
14-
from mettagrid.policy.mpt_policy import MptPolicy
1515
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri
1616

1717
if TYPE_CHECKING:
@@ -60,8 +60,8 @@ def __init__(
6060
raise RuntimeError("Environment metadata is required to instantiate teacher policy")
6161
teacher_spec = policy_spec_from_uri(self.cfg.teacher_uri, device=str(self.device))
6262
self.teacher_policy = initialize_or_load_policy(policy_env_info, teacher_spec)
63-
if isinstance(self.teacher_policy, MptPolicy):
64-
self.teacher_policy = self.teacher_policy._policy
63+
if isinstance(self.teacher_policy, CheckpointPolicy):
64+
self.teacher_policy = self.teacher_policy.wrapped_policy
6565

6666
def get_experience_spec(self) -> Composite:
6767
# Get action space size for logits shape

metta/rl/loss/logit_kickstarter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from metta.agent.policy import Policy
1111
from metta.rl.loss.loss import Loss, LossConfig
1212
from metta.rl.training import ComponentContext
13+
from mettagrid.policy.checkpoint_policy import CheckpointPolicy
1314
from mettagrid.policy.loader import initialize_or_load_policy
14-
from mettagrid.policy.mpt_policy import MptPolicy
1515
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri
1616

1717
if TYPE_CHECKING:
@@ -67,8 +67,8 @@ def __init__(
6767

6868
teacher_spec = policy_spec_from_uri(self.cfg.teacher_uri, device=self.device)
6969
self.teacher_policy = initialize_or_load_policy(base_policy_env_info, teacher_spec)
70-
if isinstance(self.teacher_policy, MptPolicy):
71-
self.teacher_policy = self.teacher_policy._policy
70+
if isinstance(self.teacher_policy, CheckpointPolicy):
71+
self.teacher_policy = self.teacher_policy.wrapped_policy
7272

7373
def get_experience_spec(self) -> Composite:
7474
# Get action space size for logits shape

metta/rl/loss/sliced_kickstarter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from metta.agent.policy import Policy
1111
from metta.rl.loss.loss import Loss, LossConfig
1212
from metta.rl.training import ComponentContext
13+
from mettagrid.policy.checkpoint_policy import CheckpointPolicy
1314
from mettagrid.policy.loader import initialize_or_load_policy
14-
from mettagrid.policy.mpt_policy import MptPolicy
1515
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri
1616

1717
if TYPE_CHECKING:
@@ -64,8 +64,8 @@ def __init__(
6464

6565
teacher_spec = policy_spec_from_uri(self.cfg.teacher_uri, device=self.device)
6666
self.teacher_policy = initialize_or_load_policy(base_policy_env_info, teacher_spec)
67-
if isinstance(self.teacher_policy, MptPolicy):
68-
self.teacher_policy = self.teacher_policy._policy
67+
if isinstance(self.teacher_policy, CheckpointPolicy):
68+
self.teacher_policy = self.teacher_policy.wrapped_policy
6969

7070
def get_experience_spec(self) -> Composite:
7171
# Get action space size for logits shape

metta/rl/training/checkpointer.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
"""Policy checkpoint management component."""
22

33
import logging
4+
from pathlib import Path
45
from typing import Optional
56

67
import torch
78
from pydantic import Field
9+
from safetensors.torch import load as load_safetensors
810

911
from metta.agent.policy import Policy, PolicyArchitecture
1012
from metta.rl.checkpoint_manager import CheckpointManager
1113
from metta.rl.training import DistributedHelper, TrainerComponent
1214
from mettagrid.base_config import Config
13-
from mettagrid.policy.mpt_artifact import MptArtifact, load_mpt
15+
from mettagrid.policy.checkpoint_policy import CheckpointPolicy
16+
from mettagrid.policy.loader import initialize_or_load_policy
1417
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
15-
from mettagrid.util.uri_resolvers.schemes import resolve_uri
18+
from mettagrid.util.module import load_symbol
19+
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri, resolve_uri
1620

1721
logger = logging.getLogger(__name__)
1822

@@ -55,21 +59,30 @@ def load_or_create_policy(
5559
candidate_uri = policy_uri or self._checkpoint_manager.get_latest_checkpoint()
5660
load_device = torch.device(self._distributed.config.device)
5761

62+
def load_state_from_checkpoint_uri(uri: str) -> tuple[str, dict[str, torch.Tensor]]:
63+
spec = policy_spec_from_uri(uri, device=str(load_device))
64+
architecture_spec = spec.init_kwargs.get("architecture_spec")
65+
if not architecture_spec:
66+
raise ValueError("policy_spec.json missing init_kwargs.architecture_spec")
67+
if not spec.data_path:
68+
raise ValueError("policy_spec.json missing data_path")
69+
state_dict = load_safetensors(Path(spec.data_path).read_bytes())
70+
return architecture_spec, dict(state_dict)
71+
5872
if self._distributed.is_distributed:
5973
normalized_uri = None
6074
if self._distributed.is_master() and candidate_uri:
6175
normalized_uri = resolve_uri(candidate_uri).canonical
6276
normalized_uri = self._distributed.broadcast_from_master(normalized_uri)
6377

6478
if normalized_uri:
65-
artifact: MptArtifact | None = None
79+
loaded: tuple[str, dict[str, torch.Tensor]] | None = None
6680
if self._distributed.is_master():
67-
artifact = load_mpt(normalized_uri)
68-
81+
loaded = load_state_from_checkpoint_uri(normalized_uri)
6982
state_dict = self._distributed.broadcast_from_master(
70-
{k: v.cpu() for k, v in artifact.state_dict.items()} if artifact else None
83+
{k: v.cpu() for k, v in loaded[1].items()} if loaded else None
7184
)
72-
arch = self._distributed.broadcast_from_master(artifact.architecture if artifact else None)
85+
architecture_spec = self._distributed.broadcast_from_master(loaded[0] if loaded else None)
7386
action_count = self._distributed.broadcast_from_master(
7487
len(policy_env_info.actions.actions()) if self._distributed.is_master() else None
7588
)
@@ -78,6 +91,10 @@ def load_or_create_policy(
7891
if local_action_count != action_count:
7992
raise ValueError(f"Action space mismatch: master={action_count}, rank={local_action_count}")
8093

94+
if architecture_spec is None:
95+
raise ValueError("Missing architecture_spec from master")
96+
class_path = architecture_spec.split("(", 1)[0].strip()
97+
arch = load_symbol(class_path).from_spec(architecture_spec)
8198
policy = arch.make_policy(policy_env_info).to(load_device)
8299
if hasattr(policy, "initialize_to_environment"):
83100
policy.initialize_to_environment(policy_env_info, load_device)
@@ -91,8 +108,10 @@ def load_or_create_policy(
91108
return policy
92109

93110
if candidate_uri:
94-
artifact = load_mpt(candidate_uri)
95-
policy = artifact.instantiate(policy_env_info, self._distributed.config.device)
111+
spec = policy_spec_from_uri(candidate_uri, device=str(load_device))
112+
policy = initialize_or_load_policy(policy_env_info, spec, device_override=str(load_device))
113+
if isinstance(policy, CheckpointPolicy):
114+
policy = policy.wrapped_policy
96115
self._latest_policy_uri = resolve_uri(candidate_uri).canonical
97116
logger.info("Loaded policy from %s", candidate_uri)
98117
return policy
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import Any, Mapping
5+
6+
import torch
7+
from safetensors.torch import load as load_safetensors
8+
from safetensors.torch import save as save_safetensors
9+
10+
from mettagrid.policy.policy import AgentPolicy, MultiAgentPolicy
11+
from mettagrid.policy.policy_env_interface import PolicyEnvInterface
12+
from mettagrid.policy.submission import POLICY_SPEC_FILENAME, SubmissionPolicySpec
13+
from mettagrid.util.module import load_symbol
14+
from mettagrid.util.uri_resolvers.schemes import checkpoint_filename
15+
16+
WEIGHTS_FILENAME = "weights.safetensors"
17+
18+
19+
def prepare_state_dict_for_save(state_dict: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
20+
result: dict[str, torch.Tensor] = {}
21+
seen_storage: set[int] = set()
22+
for key, tensor in state_dict.items():
23+
if not isinstance(tensor, torch.Tensor):
24+
raise TypeError(f"State dict entry '{key}' is not a torch.Tensor")
25+
value = tensor.detach().cpu()
26+
data_ptr = value.data_ptr()
27+
if data_ptr in seen_storage:
28+
value = value.clone()
29+
else:
30+
seen_storage.add(data_ptr)
31+
result[key] = value
32+
return result
33+
34+
35+
def _resolve_policy_data_path(path: Path) -> Path:
36+
if path.is_dir():
37+
spec_path = path / POLICY_SPEC_FILENAME
38+
if not spec_path.exists():
39+
raise FileNotFoundError(f"{POLICY_SPEC_FILENAME} not found in checkpoint directory: {path}")
40+
submission_spec = SubmissionPolicySpec.model_validate_json(spec_path.read_text())
41+
if not submission_spec.data_path:
42+
raise ValueError(f"{POLICY_SPEC_FILENAME} missing data_path in {path}")
43+
weights_path = path / submission_spec.data_path
44+
if not weights_path.exists():
45+
raise FileNotFoundError(f"Policy data path does not exist: {weights_path}")
46+
return weights_path
47+
48+
if path.is_file() and path.name != POLICY_SPEC_FILENAME:
49+
return path
50+
51+
raise FileNotFoundError(f"Policy data path does not exist: {path}")
52+
53+
54+
def write_policy_spec(checkpoint_dir: Path, architecture_spec: str) -> None:
55+
spec = SubmissionPolicySpec(
56+
class_path="mettagrid.policy.checkpoint_policy.CheckpointPolicy",
57+
data_path=WEIGHTS_FILENAME,
58+
init_kwargs={"architecture_spec": architecture_spec},
59+
)
60+
(checkpoint_dir / POLICY_SPEC_FILENAME).write_text(spec.model_dump_json())
61+
62+
63+
class CheckpointPolicy(MultiAgentPolicy):
64+
short_names = ["checkpoint"]
65+
66+
def __init__(
67+
self,
68+
policy_env_info: PolicyEnvInterface,
69+
*,
70+
architecture_spec: str,
71+
device: str = "cpu",
72+
strict: bool = True,
73+
):
74+
super().__init__(policy_env_info, device=device)
75+
self._strict = strict
76+
self._device = torch.device(device)
77+
self._policy_env_info = policy_env_info
78+
self._architecture_spec = architecture_spec
79+
class_path = architecture_spec.split("(", 1)[0].strip()
80+
self._architecture = load_symbol(class_path).from_spec(architecture_spec)
81+
self._policy = self._architecture.make_policy(policy_env_info).to(self._device)
82+
self._policy.eval()
83+
84+
def load_policy_data(self, policy_data_path: str) -> None:
85+
weights_blob = _resolve_policy_data_path(Path(policy_data_path).expanduser()).read_bytes()
86+
state_dict = load_safetensors(weights_blob)
87+
missing, unexpected = self._policy.load_state_dict(dict(state_dict), strict=self._strict)
88+
if self._strict and (missing or unexpected):
89+
raise RuntimeError(f"Strict loading failed. Missing: {missing}, Unexpected: {unexpected}")
90+
if hasattr(self._policy, "initialize_to_environment"):
91+
self._policy.initialize_to_environment(self._policy_env_info, self._device)
92+
self._policy.eval()
93+
94+
def save_policy_data(self, policy_data_path: str) -> None:
95+
target_dir = Path(policy_data_path).expanduser()
96+
target_dir.mkdir(parents=True, exist_ok=True)
97+
(target_dir / WEIGHTS_FILENAME).write_bytes(
98+
save_safetensors(prepare_state_dict_for_save(self._policy.state_dict()))
99+
)
100+
write_policy_spec(target_dir, self._architecture_spec)
101+
102+
@staticmethod
103+
def write_checkpoint_dir(
104+
*,
105+
base_dir: Path,
106+
run_name: str,
107+
epoch: int,
108+
architecture: Any,
109+
state_dict: Mapping[str, torch.Tensor],
110+
) -> Path:
111+
architecture_spec = architecture if isinstance(architecture, str) else architecture.to_spec()
112+
checkpoint_dir = (base_dir / checkpoint_filename(run_name, epoch)).expanduser().resolve()
113+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
114+
(checkpoint_dir / WEIGHTS_FILENAME).write_bytes(
115+
save_safetensors(prepare_state_dict_for_save(state_dict))
116+
)
117+
write_policy_spec(checkpoint_dir, architecture_spec)
118+
return checkpoint_dir
119+
120+
def agent_policy(self, agent_id: int) -> AgentPolicy:
121+
return self._policy.agent_policy(agent_id)
122+
123+
def eval(self) -> "CheckpointPolicy":
124+
self._policy.eval()
125+
return self
126+
127+
@property
128+
def wrapped_policy(self) -> Any:
129+
return self._policy

0 commit comments

Comments
 (0)