Skip to content

Commit 7523166

Browse files
committed
init
no more .mpt Merge remote-tracking branch 'origin/main' into richard-unifympt slim policy spec handler more concise cleanup simplify Merge remote-tracking branch 'origin/main' into richard-unifympt re-add fix policy spex Update packages/mettagrid/python/src/mettagrid/util/uri_resolvers/schemes.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt bundles Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt bundle Merge remote-tracking branch 'origin/richard-unifympt' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? ugh compat cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt cleanup tests Merge remote-tracking branch 'origin/main' into richard-unifympt more tests Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? Merge branch 'main' into richard-unifympt cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt no more .mpt remove all .mpt and lint cleanup local data path fixes mpt re-add re-add artifact lint Merge remote-tracking branch 'origin/main' into richard-unifympt more cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt diff cleanup ftt lint fix error Merge remote-tracking branch 'origin/main' into richard-unifympt more tests lint Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt checkpoint policy does save/load lint checkpoint moving catcus lint Merge branch 'main' into richard-unifympt fold-in [pyright 4] Get pyright to pass on app_backend (#4478) Merge remote-tracking branch 'origin/main' into richard-unifympt Fix command, add space (#4456) added space to --app:lib--tlsEmulation:off which makes it --app:lib --tlsEmulation:off now it runs Rename HyperUpdateRule to ScheduleRule (#4483) ## Summary - rename HyperUpdateRule to ScheduleRule and apply to TrainerConfig via target_path - update recipes and teacher scheduling to use ScheduleRule - report PPO stats using ppo_actor/ppo_critic hyperparam keys and update tests ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt Fix supervisor teacher behavior and legacy BC mode (#4484) ## Summary - gate PPO actor during supervisor teacher phase - fix supervisor/no-teacher behavior and add legacy BC (no gating, no PPO resume) - require supervisor policy URI for sliced_cloner_no_ppo ## Testing - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Co-authored-by: Adam S <134907338+gustofied@users.noreply.github.com> Minor fixes to the slstm triton kernel, causing failures for certain kernel sizes (#4492) cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt fold in training environments and eval environments mismatched (#4487) I ran a direct config comparison using the training entrypoint (recipes/experiment/cogs_v_clips.train) with variants=["heart_chorus"] and compared the eval suite config it builds (difficulty standard + heart_chorus) for an overlapping mission: hello_world.oxygen_bottleneck. Findings: - Compass is ON in both training and eval (global_obs.compass=True). - Vibe count and change‑vibe settings match (152 vibes; change_vibe.number_of_vibes=152). - But the mission parameters differ between training and eval for the same mission name: - game.objects.carbon_extractor.max_uses: train 25 vs eval 100 - game.objects.oxygen_extractor.max_uses: train 5 vs eval 20 - game.objects.germanium_extractor.max_uses: train 5 vs eval 20 - game.objects.silicon_extractor.max_uses: train 120 vs eval 480 So the mismatch isn’t compass — it’s the mission definitions used by training vs eval. Training uses base missions (cogames.cogs_vs_clips.missions), while eval uses integrated eval missions (cogames.cogs_vs_clips.evals.integrated_evals) that have different extractor settings. Also: the eval suite used by recipes/experiment/cogs_v_clips.train does not include machina_1.open_world at all (it only evaluates the 7 HELLO_WORLD integrated evals). So training can be creating hearts on easier missions while your eval runs on machina_1 are a different environment entirely. Relevant files: - Compass default: packages/cogames/src/cogames/cogs_vs_clips/mission.py - Training entrypoint + eval suite wiring: recipes/experiment/cogs_v_clips.py - Eval mission definitions: packages/cogames/src/cogames/cogs_vs_clips/evals/integrated_evals.py If you want true parity, we should align which mission templates eval uses (and/or include machina_1.open_world in the eval suite). I can patch this if you want — tell me whether you prefer: 1. Eval suite uses the same mission templates as training (from missions.py), or 2. Training uses the integrated eval mission definitions, or 3. Add machina_1.open_world to the eval suite. ripping out Merge remote-tracking branch 'origin/main' into richard-unifympt simplify fix and lint choke simplify submission zip creation use policy_spec for submission zips tighten checkpoint io helpers shorten checkpoint arg help inline checkpoint policy helpers restore policy spec docstring validate checkpoint data_path before download require checkpoint directory URIs expand policy spec s3 docstring
1 parent e0e1aaa commit 7523166

File tree

7 files changed

+53
-39
lines changed

7 files changed

+53
-39
lines changed

metta/rl/loss/sl_checkpointed_kickstarter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def _construct_checkpoint_uri(self, epoch: int) -> str:
159159
filename = checkpoint_filename(run_name, epoch)
160160

161161
if parsed.scheme == "file" and parsed.local_path:
162+
if parsed.local_path.is_file():
163+
raise ValueError("Provide a checkpoint directory, not policy_spec.json")
162164
path = parsed.local_path.parent / filename
163165
return f"file://{path}"
164166
elif parsed.scheme == "s3" and parsed.bucket and parsed.key:

metta/rl/metta_scheme_resolver.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from metta.app_backend.metta_repo import PolicyVersionWithName
88
from metta.common.util.constants import PROD_STATS_SERVER_URI
99
from mettagrid.util.uri_resolvers.base import MettaParsedScheme, SchemeResolver
10-
from mettagrid.util.uri_resolvers.schemes import resolve_uri
1110

1211
logger = logging.getLogger(__name__)
1312

@@ -118,12 +117,6 @@ def get_path_to_policy_spec_or_mpt(self, uri: str) -> str:
118117
logger.info(f"Metta scheme resolver: {uri} resolved to s3 policy spec: {policy_version.s3_path}")
119118
return policy_version.s3_path
120119

121-
# If that is missing (probably legacy policy), we send you to the mpt file, and will later assume
122-
# that the class to hydrate from is MptPolicy
123-
mpt_file_path = (policy_version.policy_spec or {}).get("init_kwargs", {}).get("checkpoint_uri")
124-
if not mpt_file_path:
125-
raise ValueError(f"Data not found for policy version {policy_version.id}")
126-
if not mpt_file_path.endswith(".mpt"):
127-
raise ValueError(f"Invalid mpt file path: {mpt_file_path}")
128-
logger.info(f"Metta scheme resolver: {uri} resolved to mpt checkpoint: {mpt_file_path}")
129-
return resolve_uri(mpt_file_path).canonical
120+
raise ValueError(
121+
f"Policy version {policy_version.id} has no s3_path; expected a policy spec submission zip in S3."
122+
)

metta/rl/training/evaluator.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import uuid
99
import zipfile
10+
from pathlib import Path
1011
from typing import Any, Optional
1112

1213
import torch
@@ -27,9 +28,9 @@
2728
from metta.tools.utils.auto_config import auto_replay_dir
2829
from mettagrid.base_config import Config
2930
from mettagrid.policy.policy import PolicySpec
30-
from mettagrid.policy.submission import POLICY_SPEC_FILENAME
31+
from mettagrid.policy.submission import POLICY_SPEC_FILENAME, SubmissionPolicySpec
3132
from mettagrid.util.file import write_data
32-
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri
33+
from mettagrid.util.uri_resolvers.schemes import policy_spec_from_uri, resolve_uri
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -140,21 +141,34 @@ def should_evaluate(self, epoch: int) -> bool:
140141
return epoch % interval == 0
141142

142143
def _create_submission_zip(self, policy_spec: PolicySpec) -> bytes:
143-
"""Create a submission zip containing policy_spec.json."""
144+
"""Create a submission zip containing policy_spec.json and optional weights."""
145+
submission_spec = SubmissionPolicySpec.model_validate(policy_spec.model_dump(mode="json"))
146+
data_path = submission_spec.data_path
147+
if data_path and Path(data_path).is_absolute():
148+
data_path = Path(data_path).name
149+
submission_spec.data_path = data_path
150+
spec_bytes = submission_spec.model_dump_json().encode("utf-8")
151+
144152
buffer = io.BytesIO()
145153
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
146-
zipf.writestr(POLICY_SPEC_FILENAME, policy_spec.model_dump_json())
154+
zipf.writestr(POLICY_SPEC_FILENAME, spec_bytes)
155+
if data_path:
156+
if not policy_spec.data_path:
157+
raise ValueError("policy_spec.data_path missing for submission")
158+
source_path = Path(policy_spec.data_path)
159+
if not source_path.is_absolute():
160+
raise ValueError("policy_spec.data_path must be absolute for submission")
161+
zipf.writestr(data_path, source_path.read_bytes())
147162
return buffer.getvalue()
148163

149-
def _upload_submission_zip(self, policy_spec: PolicySpec) -> str | None:
164+
def _upload_submission_zip(self, policy_spec: PolicySpec, policy_uri: str) -> str | None:
150165
"""Upload a submission zip to S3 and return the s3_path."""
151-
checkpoint_uri = policy_spec.init_kwargs.get("checkpoint_uri")
152-
if not checkpoint_uri or not checkpoint_uri.startswith("s3://"):
166+
if not policy_uri.startswith("s3://"):
153167
return None
154168

155-
submission_path = checkpoint_uri.replace(".mpt", "-submission.zip")
156-
zip_data = self._create_submission_zip(policy_spec)
157-
write_data(submission_path, zip_data, content_type="application/zip")
169+
checkpoint_dir = resolve_uri(policy_uri).canonical.rstrip("/")
170+
submission_path = f"{checkpoint_dir}/submission.zip"
171+
write_data(submission_path, self._create_submission_zip(policy_spec), content_type="application/zip")
158172
logger.info("Uploaded submission zip to %s", submission_path)
159173
return submission_path
160174

@@ -163,6 +177,7 @@ def _create_policy_version(
163177
*,
164178
stats_client: StatsClient,
165179
policy_spec: PolicySpec,
180+
policy_uri: str,
166181
epoch: int,
167182
agent_step: int,
168183
) -> uuid.UUID:
@@ -176,7 +191,7 @@ def _create_policy_version(
176191
)
177192

178193
# Upload submission zip to S3
179-
s3_path = self._upload_submission_zip(policy_spec)
194+
s3_path = self._upload_submission_zip(policy_spec, policy_uri)
180195

181196
# Create policy version
182197
policy_version_id = stats_client.create_policy_version(
@@ -209,6 +224,7 @@ def evaluate(
209224
policy_version_id = self._create_policy_version(
210225
stats_client=self._stats_client,
211226
policy_spec=policy_spec,
227+
policy_uri=policy_uri,
212228
epoch=epoch,
213229
agent_step=agent_step,
214230
)

metta/setup/shell.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ def help_configs() -> None:
2121
success("# Load configs with overrides:")
2222
info('cfg = load_cfg("train_job.yaml", ["training_env.curriculum=/env/mettagrid/arena/advanced"])')
2323
success("# Load checkpoints:")
24-
info('artifact = load_mpt("file://./train_dir/my_run/checkpoints/my_run:v12.mpt")')
25-
info('artifact = load_mpt("s3://bucket/path/my_run/checkpoints/my_run:v12.mpt")')
26-
info('policy = artifact.instantiate(policy_env_info, torch.device("cpu"))')
24+
info('spec = policy_spec_from_uri("file://./train_dir/my_run/checkpoints/my_run:v12")')
25+
info("policy = initialize_or_load_policy(policy_env_info, spec)")
2726
success("# Create checkpoint manager:")
2827
info('cm = CheckpointManager(run="my_run", run_dir="./train_dir")')
2928

packages/cogames/scripts/run_evaluation.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
uv run python packages/cogames/scripts/run_evaluation.py \
1515
--agent cogames.policy.nim_agents.agents.ThinkyAgentsMultiPolicy --cogs 1
1616
uv run python packages/cogames/scripts/run_evaluation.py \
17-
--agent cogames.policy.lstm.LSTMPolicy --checkpoint s3://bucket/path/model.mpt --cogs 1
17+
--agent cogames.policy.lstm.LSTMPolicy --checkpoint s3://bucket/path/run:v<N> --cogs 1
1818
uv run python packages/cogames/scripts/run_evaluation.py \
19-
--agent s3://bucket/path/model.mpt --cogs 1
19+
--agent s3://bucket/path/run:v<N> --cogs 1
2020
"""
2121

2222
import argparse
@@ -38,6 +38,7 @@
3838
import matplotlib.pyplot as plt
3939
import numpy as np
4040
import torch
41+
from safetensors.torch import load as load_safetensors
4142

4243
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
4344
from cogames.cogs_vs_clips.mission import Mission, MissionVariant, NumCogsVariant
@@ -87,12 +88,13 @@ def _get_policy_action_space(policy_path: str) -> Optional[int]:
8788
return None
8889

8990
try:
90-
from mettagrid.policy.mpt_artifact import load_mpt
91-
92-
artifact = load_mpt(policy_path)
91+
spec = policy_spec_from_uri(policy_path)
92+
if not spec.data_path:
93+
return None
94+
weights = load_safetensors(Path(spec.data_path).read_bytes())
9395

9496
# Look for actor head weight to determine action space
95-
for key, tensor in artifact.state_dict.items():
97+
for key, tensor in weights.items():
9698
if "actor_head" in key and "weight" in key and len(tensor.shape) == 2:
9799
action_space = tensor.shape[0]
98100
_policy_action_space_cache[policy_path] = action_space
@@ -1152,7 +1154,7 @@ def lookup_wrapper(_s: str, exp_name: str):
11521154
def main():
11531155
parser = argparse.ArgumentParser(description="Evaluate scripted or custom agents.")
11541156
parser.add_argument("--agent", nargs="*", default=None, help="Agent key, class path, or S3 URI")
1155-
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path (or S3 URI)")
1157+
parser.add_argument("--checkpoint", help="Checkpoint directory URI")
11561158
parser.add_argument("--experiments", nargs="*", default=None, help="Experiments to run")
11571159
parser.add_argument("--variants", nargs="*", default=None, help="Variants to apply")
11581160
parser.add_argument("--cogs", nargs="*", type=int, default=None, help="Agent counts to test")

packages/cogames/src/cogames/cli/policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def list_checkpoints():
5050

5151
def describe_policy_arg(with_proportion: bool):
5252
console.print("[bold cyan]-p [POLICY][/bold cyan] accepts two formats:\n")
53-
console.print("[bold]1. URI format[/bold] (for .mpt checkpoints):")
53+
console.print("[bold]1. URI format[/bold] (checkpoint bundle):")
5454
console.print(" - metta://policy/<name> or metta://policy/<uuid>")
55-
console.print(" - s3://bucket/path/to/checkpoint.mpt")
56-
console.print(" - file:///path/to/checkpoint.mpt or /path/to/checkpoint.mpt")
55+
console.print(" - s3://bucket/path/to/run:v<N>")
56+
console.print(" - file:///path/to/run:v<N> or /path/to/run:v<N>")
5757
console.print()
5858
console.print(
5959
"[bold]2. Key-value format[/bold]: "

packages/mettagrid/python/src/mettagrid/util/uri_resolvers/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,30 @@
22

33
This package provides a pluggable URI resolution system for handling different resource schemes.
44

5+
Checkpoint URIs point at a checkpoint directory containing `policy_spec.json`.
6+
57
## Usage
68

79
```python
810
from mettagrid.util.uri_resolvers.schemes import parse_uri, resolve_uri, get_checkpoint_metadata
911

1012
# Parse a URI to get its components
11-
parsed = parse_uri("s3://bucket/path/to/file.mpt")
13+
parsed = parse_uri("s3://bucket/path/to/run:v5")
1214
print(parsed.scheme) # "s3"
1315
print(parsed.bucket) # "bucket"
14-
print(parsed.key) # "path/to/file.mpt"
16+
print(parsed.key) # "path/to/run:v5"
1517

1618
# Get checkpoint info (run_name, epoch) from parsed URI
1719
info = parsed.checkpoint_info # ("run_name", 5) or None
1820
if info:
1921
run_name, epoch = info
2022

2123
# Resolve a URI (normalizes and finds latest checkpoint if applicable)
22-
parsed = resolve_uri("file:///path/to/checkpoints")
23-
print(parsed.canonical) # "file:///path/to/checkpoints/run:v5.mpt"
24+
parsed = resolve_uri("file:///path/to/checkpoints:latest")
25+
print(parsed.canonical) # "file:///path/to/checkpoints/run:v5"
2426

2527
# Get full checkpoint metadata (resolves URI first)
26-
metadata = get_checkpoint_metadata("s3://bucket/checkpoints/my-run:v5.mpt")
28+
metadata = get_checkpoint_metadata("s3://bucket/checkpoints/my-run:v5")
2729
print(metadata.run_name) # "my-run"
2830
print(metadata.epoch) # 5
2931
print(metadata.uri) # resolved URI

0 commit comments

Comments
 (0)