|
23 | 23 | from concurrent.futures import ThreadPoolExecutor, as_completed |
24 | 24 | from dataclasses import asdict, dataclass |
25 | 25 | from pathlib import Path |
26 | | -from typing import Dict, List, Optional |
| 26 | +from typing import Dict, List, Mapping, Optional |
27 | 27 |
|
28 | 28 | import matplotlib |
29 | 29 |
|
@@ -75,32 +75,42 @@ def _get_policy_action_space(policy_path: str) -> Optional[int]: |
75 | 75 |
|
76 | 76 | Returns the number of actions the policy was trained with, or None if detection fails. |
77 | 77 | """ |
78 | | - if not policy_path.endswith(".mpt"): |
| 78 | + if not policy_path: |
79 | 79 | return None |
80 | 80 | if policy_path in _policy_action_space_cache: |
81 | 81 | return _policy_action_space_cache[policy_path] |
82 | | - |
83 | | - if not is_s3_uri(policy_path): |
84 | | - return None |
| 82 | + if "://" not in policy_path: |
| 83 | + candidate = Path(policy_path).expanduser() |
| 84 | + if not candidate.exists() and not policy_path.endswith(".mpt"): |
| 85 | + return None |
85 | 86 |
|
86 | 87 | try: |
87 | | - from mettagrid.policy.mpt_artifact import load_mpt |
88 | | - |
89 | | - artifact = load_mpt(policy_path) |
| 88 | + if policy_path.endswith(".mpt"): |
| 89 | + from mettagrid.policy.mpt_artifact import load_mpt |
90 | 90 |
|
91 | | - # Look for actor head weight to determine action space |
92 | | - for key, tensor in artifact.state_dict.items(): |
93 | | - if "actor_head" in key and "weight" in key and len(tensor.shape) == 2: |
94 | | - action_space = tensor.shape[0] |
95 | | - _policy_action_space_cache[policy_path] = action_space |
96 | | - logger.info(f"Detected policy action space: {action_space} actions") |
97 | | - return action_space |
| 91 | + artifact = load_mpt(policy_path) |
| 92 | + action_space = _action_space_from_state_dict(artifact.state_dict) |
| 93 | + else: |
| 94 | + from mettagrid.policy.checkpoint_policy import load_state_from_checkpoint_uri |
98 | 95 |
|
99 | | - return None |
| 96 | + _, state_dict = load_state_from_checkpoint_uri(policy_path, device="cpu") |
| 97 | + action_space = _action_space_from_state_dict(state_dict) |
100 | 98 | except Exception as e: |
101 | 99 | logger.warning(f"Failed to detect policy action space: {e}") |
102 | 100 | return None |
103 | 101 |
|
| 102 | + if action_space is not None: |
| 103 | + _policy_action_space_cache[policy_path] = action_space |
| 104 | + logger.info(f"Detected policy action space: {action_space} actions") |
| 105 | + return action_space |
| 106 | + |
| 107 | + |
| 108 | +def _action_space_from_state_dict(state_dict: Mapping[str, torch.Tensor]) -> Optional[int]: |
| 109 | + for key, tensor in state_dict.items(): |
| 110 | + if "actor_head" in key and "weight" in key and len(tensor.shape) == 2: |
| 111 | + return int(tensor.shape[0]) |
| 112 | + return None |
| 113 | + |
104 | 114 |
|
105 | 115 | def _configure_env_for_action_space(env_cfg, num_actions: int) -> None: |
106 | 116 | """Configure environment vibes to match a specific action space. |
@@ -281,10 +291,9 @@ def _run_case( |
281 | 291 | _ensure_vibe_supports_gear(env_config) |
282 | 292 |
|
283 | 293 | # Auto-detect policy action space and configure environment to match |
284 | | - if is_s3_uri(agent_config.policy_path): |
285 | | - policy_action_space = _get_policy_action_space(agent_config.policy_path) |
286 | | - if policy_action_space is not None: |
287 | | - _configure_env_for_action_space(env_config, policy_action_space) |
| 294 | + policy_action_space = _get_policy_action_space(agent_config.policy_path) |
| 295 | + if policy_action_space is not None: |
| 296 | + _configure_env_for_action_space(env_config, policy_action_space) |
288 | 297 |
|
289 | 298 | if variant is None or getattr(variant, "max_steps_override", None) is None: |
290 | 299 | env_config.game.max_steps = max_steps |
|
0 commit comments