Skip to content

Commit a8d6ad8

Browse files
committed
init
no more .mpt Merge remote-tracking branch 'origin/main' into richard-unifympt slim policy spec handler more concise cleanup simplify Merge remote-tracking branch 'origin/main' into richard-unifympt re-add fix policy spex Update packages/mettagrid/python/src/mettagrid/util/uri_resolvers/schemes.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt bundles Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt bundle Merge remote-tracking branch 'origin/richard-unifympt' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? ugh compat cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt cleanup tests Merge remote-tracking branch 'origin/main' into richard-unifympt more tests Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt simplify? Merge branch 'main' into richard-unifympt cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt no more .mpt remove all .mpt and lint cleanup local data path fixes mpt re-add re-add artifact lint Merge remote-tracking branch 'origin/main' into richard-unifympt more cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt diff cleanup ftt lint fix error Merge remote-tracking branch 'origin/main' into richard-unifympt more tests lint Merge remote-tracking branch 'origin/main' into richard-unifympt Merge remote-tracking branch 'origin/main' into richard-unifympt checkpoint policy does save/load lint checkpoint moving catcus lint Merge branch 'main' into richard-unifympt fold-in [pyright 4] Get pyright to pass on app_backend (#4478) Merge remote-tracking branch 'origin/main' into richard-unifympt Fix command, add space (#4456) added space to --app:lib--tlsEmulation:off which makes it --app:lib --tlsEmulation:off now it runs Rename HyperUpdateRule to ScheduleRule (#4483) - rename HyperUpdateRule to ScheduleRule and apply to TrainerConfig via target_path - update recipes and teacher scheduling to use ScheduleRule - report PPO stats using ppo_actor/ppo_critic hyperparam keys and update tests - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Merge remote-tracking branch 'origin/main' into richard-unifympt Fix supervisor teacher behavior and legacy BC mode (#4484) - gate PPO actor during supervisor teacher phase - fix supervisor/no-teacher behavior and add legacy BC (no gating, no PPO resume) - require supervisor policy URI for sliced_cloner_no_ppo - not run (not requested) --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> Co-authored-by: Adam S <134907338+gustofied@users.noreply.github.com> Minor fixes to the slstm triton kernel, causing failures for certain kernel sizes (#4492) cleanup Merge remote-tracking branch 'origin/main' into richard-unifympt fold in training environments and eval environments mismatched (#4487) I ran a direct config comparison using the training entrypoint (recipes/experiment/cogs_v_clips.train) with variants=["heart_chorus"] and compared the eval suite config it builds (difficulty standard + heart_chorus) for an overlapping mission: hello_world.oxygen_bottleneck. Findings: - Compass is ON in both training and eval (global_obs.compass=True). - Vibe count and change‑vibe settings match (152 vibes; change_vibe.number_of_vibes=152). - But the mission parameters differ between training and eval for the same mission name: - game.objects.carbon_extractor.max_uses: train 25 vs eval 100 - game.objects.oxygen_extractor.max_uses: train 5 vs eval 20 - game.objects.germanium_extractor.max_uses: train 5 vs eval 20 - game.objects.silicon_extractor.max_uses: train 120 vs eval 480 So the mismatch isn’t compass — it’s the mission definitions used by training vs eval. Training uses base missions (cogames.cogs_vs_clips.missions), while eval uses integrated eval missions (cogames.cogs_vs_clips.evals.integrated_evals) that have different extractor settings. Also: the eval suite used by recipes/experiment/cogs_v_clips.train does not include machina_1.open_world at all (it only evaluates the 7 HELLO_WORLD integrated evals). So training can be creating hearts on easier missions while your eval runs on machina_1 are a different environment entirely. Relevant files: - Compass default: packages/cogames/src/cogames/cogs_vs_clips/mission.py - Training entrypoint + eval suite wiring: recipes/experiment/cogs_v_clips.py - Eval mission definitions: packages/cogames/src/cogames/cogs_vs_clips/evals/integrated_evals.py If you want true parity, we should align which mission templates eval uses (and/or include machina_1.open_world in the eval suite). I can patch this if you want — tell me whether you prefer: 1. Eval suite uses the same mission templates as training (from missions.py), or 2. Training uses the integrated eval mission definitions, or 3. Add machina_1.open_world to the eval suite. ripping out Merge remote-tracking branch 'origin/main' into richard-unifympt simplify fix and lint choke simplify submission zip creation use policy_spec for submission zips tighten checkpoint io helpers shorten checkpoint arg help inline checkpoint policy helpers restore policy spec docstring validate checkpoint data_path before download require checkpoint directory URIs expand policy spec s3 docstring
1 parent 7ed4f29 commit a8d6ad8

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

metta/rl/loss/sl_checkpointed_kickstarter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def _construct_checkpoint_uri(self, epoch: int) -> str:
148148
filename = checkpoint_filename(run_name, epoch)
149149

150150
if parsed.scheme == "file" and parsed.local_path:
151+
if parsed.local_path.is_file():
152+
raise ValueError("Provide a checkpoint directory, not policy_spec.json")
151153
path = parsed.local_path.parent / filename
152154
return f"file://{path}"
153155
elif parsed.scheme == "s3" and parsed.bucket and parsed.key:

metta/rl/metta_scheme_resolver.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from metta.app_backend.metta_repo import PolicyVersionWithName
88
from metta.common.util.constants import PROD_STATS_SERVER_URI
99
from mettagrid.util.uri_resolvers.base import MettaParsedScheme, SchemeResolver
10-
from mettagrid.util.uri_resolvers.schemes import resolve_uri
1110

1211
logger = logging.getLogger(__name__)
1312

@@ -118,12 +117,6 @@ def get_path_to_policy_spec_or_mpt(self, uri: str) -> str:
118117
logger.info(f"Metta scheme resolver: {uri} resolved to s3 policy spec: {policy_version.s3_path}")
119118
return policy_version.s3_path
120119

121-
# If that is missing (probably legacy policy), we send you to the mpt file, and will later assume
122-
# that the class to hydrate from is MptPolicy
123-
mpt_file_path = (policy_version.policy_spec or {}).get("init_kwargs", {}).get("checkpoint_uri")
124-
if not mpt_file_path:
125-
raise ValueError(f"Data not found for policy version {policy_version.id}")
126-
if not mpt_file_path.endswith(".mpt"):
127-
raise ValueError(f"Invalid mpt file path: {mpt_file_path}")
128-
logger.info(f"Metta scheme resolver: {uri} resolved to mpt checkpoint: {mpt_file_path}")
129-
return resolve_uri(mpt_file_path).canonical
120+
raise ValueError(
121+
f"Policy version {policy_version.id} has no s3_path; expected a policy spec submission zip in S3."
122+
)

packages/cogames/scripts/run_evaluation.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import matplotlib.pyplot as plt
3434
import numpy as np
3535
import torch
36+
from safetensors.torch import load as load_safetensors
3637

3738
from cogames.cogs_vs_clips.evals.diagnostic_evals import DIAGNOSTIC_EVALS
3839
from cogames.cogs_vs_clips.mission import Mission, MissionVariant, NumCogsVariant
@@ -85,16 +86,20 @@ def _get_policy_action_space(policy_path: str) -> Optional[int]:
8586
return None
8687

8788
try:
88-
if policy_path.endswith(".mpt"):
89-
from mettagrid.policy.mpt_artifact import load_mpt
89+
spec = policy_spec_from_uri(policy_path)
90+
if not spec.data_path:
91+
return None
92+
weights = load_safetensors(Path(spec.data_path).read_bytes())
9093

91-
artifact = load_mpt(policy_path)
92-
action_space = _action_space_from_state_dict(artifact.state_dict)
93-
else:
94-
from mettagrid.policy.checkpoint_policy import load_state_from_checkpoint_uri
94+
# Look for actor head weight to determine action space
95+
for key, tensor in weights.items():
96+
if "actor_head" in key and "weight" in key and len(tensor.shape) == 2:
97+
action_space = tensor.shape[0]
98+
_policy_action_space_cache[policy_path] = action_space
99+
logger.info(f"Detected policy action space: {action_space} actions")
100+
return action_space
95101

96-
_, state_dict = load_state_from_checkpoint_uri(policy_path, device="cpu")
97-
action_space = _action_space_from_state_dict(state_dict)
102+
return None
98103
except Exception as e:
99104
logger.warning(f"Failed to detect policy action space: {e}")
100105
return None

packages/mettagrid/python/src/mettagrid/util/uri_resolvers/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,30 @@
22

33
This package provides a pluggable URI resolution system for handling different resource schemes.
44

5+
Checkpoint URIs point at a checkpoint directory containing `policy_spec.json`.
6+
57
## Usage
68

79
```python
810
from mettagrid.util.uri_resolvers.schemes import parse_uri, resolve_uri, get_checkpoint_metadata
911

1012
# Parse a URI to get its components
11-
parsed = parse_uri("s3://bucket/path/to/file.mpt")
13+
parsed = parse_uri("s3://bucket/path/to/run:v5")
1214
print(parsed.scheme) # "s3"
1315
print(parsed.bucket) # "bucket"
14-
print(parsed.key) # "path/to/file.mpt"
16+
print(parsed.key) # "path/to/run:v5"
1517

1618
# Get checkpoint info (run_name, epoch) from parsed URI
1719
info = parsed.checkpoint_info # ("run_name", 5) or None
1820
if info:
1921
run_name, epoch = info
2022

2123
# Resolve a URI (normalizes and finds latest checkpoint if applicable)
22-
parsed = resolve_uri("file:///path/to/checkpoints")
23-
print(parsed.canonical) # "file:///path/to/checkpoints/run:v5.mpt"
24+
parsed = resolve_uri("file:///path/to/checkpoints:latest")
25+
print(parsed.canonical) # "file:///path/to/checkpoints/run:v5"
2426

2527
# Get full checkpoint metadata (resolves URI first)
26-
metadata = get_checkpoint_metadata("s3://bucket/checkpoints/my-run:v5.mpt")
28+
metadata = get_checkpoint_metadata("s3://bucket/checkpoints/my-run:v5")
2729
print(metadata.run_name) # "my-run"
2830
print(metadata.epoch) # 5
2931
print(metadata.uri) # resolved URI

0 commit comments

Comments
 (0)