Skip to content

Commit 77390d8

Browse files
committed
Fix action space detection for checkpoint bundles
1 parent 2b42d85 commit 77390d8

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

packages/cogames/scripts/run_evaluation.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from concurrent.futures import ThreadPoolExecutor, as_completed
2424
from dataclasses import asdict, dataclass
2525
from pathlib import Path
26-
from typing import Dict, List, Optional
26+
from typing import Dict, List, Mapping, Optional
2727

2828
import matplotlib
2929

@@ -75,32 +75,42 @@ def _get_policy_action_space(policy_path: str) -> Optional[int]:
7575
7676
Returns the number of actions the policy was trained with, or None if detection fails.
7777
"""
78-
if not policy_path.endswith(".mpt"):
78+
if not policy_path:
7979
return None
8080
if policy_path in _policy_action_space_cache:
8181
return _policy_action_space_cache[policy_path]
82-
83-
if not is_s3_uri(policy_path):
84-
return None
82+
if "://" not in policy_path:
83+
candidate = Path(policy_path).expanduser()
84+
if not candidate.exists() and not policy_path.endswith(".mpt"):
85+
return None
8586

8687
try:
87-
from mettagrid.policy.mpt_artifact import load_mpt
88-
89-
artifact = load_mpt(policy_path)
88+
if policy_path.endswith(".mpt"):
89+
from mettagrid.policy.mpt_artifact import load_mpt
9090

91-
# Look for actor head weight to determine action space
92-
for key, tensor in artifact.state_dict.items():
93-
if "actor_head" in key and "weight" in key and len(tensor.shape) == 2:
94-
action_space = tensor.shape[0]
95-
_policy_action_space_cache[policy_path] = action_space
96-
logger.info(f"Detected policy action space: {action_space} actions")
97-
return action_space
91+
artifact = load_mpt(policy_path)
92+
action_space = _action_space_from_state_dict(artifact.state_dict)
93+
else:
94+
from mettagrid.policy.checkpoint_policy import load_state_from_checkpoint_uri
9895

99-
return None
96+
_, state_dict = load_state_from_checkpoint_uri(policy_path, device="cpu")
97+
action_space = _action_space_from_state_dict(state_dict)
10098
except Exception as e:
10199
logger.warning(f"Failed to detect policy action space: {e}")
102100
return None
103101

102+
if action_space is not None:
103+
_policy_action_space_cache[policy_path] = action_space
104+
logger.info(f"Detected policy action space: {action_space} actions")
105+
return action_space
106+
107+
108+
def _action_space_from_state_dict(state_dict: Mapping[str, torch.Tensor]) -> Optional[int]:
109+
for key, tensor in state_dict.items():
110+
if "actor_head" in key and "weight" in key and len(tensor.shape) == 2:
111+
return int(tensor.shape[0])
112+
return None
113+
104114

105115
def _configure_env_for_action_space(env_cfg, num_actions: int) -> None:
106116
"""Configure environment vibes to match a specific action space.
@@ -281,10 +291,9 @@ def _run_case(
281291
_ensure_vibe_supports_gear(env_config)
282292

283293
# Auto-detect policy action space and configure environment to match
284-
if is_s3_uri(agent_config.policy_path):
285-
policy_action_space = _get_policy_action_space(agent_config.policy_path)
286-
if policy_action_space is not None:
287-
_configure_env_for_action_space(env_config, policy_action_space)
294+
policy_action_space = _get_policy_action_space(agent_config.policy_path)
295+
if policy_action_space is not None:
296+
_configure_env_for_action_space(env_config, policy_action_space)
288297

289298
if variant is None or getattr(variant, "max_steps_override", None) is None:
290299
env_config.game.max_steps = max_steps

0 commit comments

Comments
 (0)