Skip to content

Commit 21c74a8

Browse files
committed
feat: support metta://policy/<short_name_or_class_path> URIs for built-in policies
1 parent a0360e5 commit 21c74a8

File tree

4 files changed

+83
-1
lines changed

4 files changed

+83
-1
lines changed

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,7 @@ Internal `metta/` folder dependencies are enforced by `import-linter`. Run `uv r
5454
```
5555

5656
See `common/src/metta/common/tool/README.md` for details.
57+
58+
## Git Hub Integration
59+
60+
Use graphite ("gt") to create PRs. Name the branch $user-short-issue-name

metta/rl/metta_scheme_resolver.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from metta.app_backend.clients.stats_client import StatsClient
88
from metta.app_backend.metta_repo import PolicyVersionWithName
99
from metta.common.util.constants import PROD_STATS_SERVER_URI
10+
from mettagrid.policy.loader import discover_and_register_policies
11+
from mettagrid.policy.policy_registry import get_policy_registry
1012
from mettagrid.util.uri_resolvers.base import MettaParsedScheme, SchemeResolver
1113
from mettagrid.util.uri_resolvers.schemes import resolve_uri
1214

@@ -112,7 +114,40 @@ def get_policy_version(self, uri: str) -> PolicyVersionWithName:
112114

113115
return policy_version
114116

117+
def _resolve_builtin_policy_class_path(self, identifier: str) -> str | None:
118+
"""Resolve a policy identifier to a full class path.
119+
120+
Supports:
121+
- Short names (e.g., "noop", "random") via the policy registry
122+
- Full class paths (e.g., "mettagrid.policy.noop.NoopPolicy")
123+
124+
Returns the full class path if valid, None otherwise.
125+
"""
126+
discover_and_register_policies()
127+
registry = get_policy_registry()
128+
129+
# First check if it's a registered short name
130+
if identifier in registry:
131+
return registry[identifier]
132+
133+
# Otherwise, check if it looks like a full class path (contains dots)
134+
if "." in identifier:
135+
return identifier
136+
137+
return None
138+
115139
def get_path_to_policy_spec_or_mpt(self, uri: str) -> str:
140+
parsed = self.parse(uri)
141+
path = parsed.path
142+
143+
# Check for built-in policy: metta://policy/<short_name_or_class_path>
144+
if path.startswith("policy/"):
145+
policy_identifier = path[len("policy/") :]
146+
class_path = self._resolve_builtin_policy_class_path(policy_identifier)
147+
if class_path:
148+
logger.info(f"Metta scheme resolver: {uri} resolved to builtin policy: {class_path}")
149+
return f"builtin://{class_path}"
150+
116151
policy_version = self.get_policy_version(uri)
117152
# By default we send you to the s3 path that contains the policy spec
118153
if policy_version.s3_path:

packages/mettagrid/python/src/mettagrid/util/uri_resolvers/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,23 @@ def checkpoint_info(self) -> tuple[str, int] | None:
7575
return _extract_run_and_epoch(self.path)
7676

7777

78-
ParsedScheme = Union[FileParsedScheme, S3ParsedScheme, MockParsedScheme, MettaParsedScheme]
78+
class BuiltinParsedScheme(BaseModel, frozen=True):
79+
"""Parsed scheme for built-in policies loaded by class path."""
80+
81+
scheme: Literal["builtin"] = "builtin"
82+
canonical: str
83+
class_path: str
84+
85+
@property
86+
def local_path(self) -> None:
87+
return None
88+
89+
@property
90+
def checkpoint_info(self) -> tuple[str, int] | None:
91+
return None
92+
93+
94+
ParsedScheme = Union[FileParsedScheme, S3ParsedScheme, MockParsedScheme, MettaParsedScheme, BuiltinParsedScheme]
7995

8096

8197
class SchemeResolver(ABC):

packages/mettagrid/python/src/mettagrid/util/uri_resolvers/schemes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from mettagrid.util.module import load_symbol
1111
from mettagrid.util.uri_resolvers.base import (
12+
BuiltinParsedScheme,
1213
CheckpointMetadata,
1314
FileParsedScheme,
1415
MockParsedScheme,
@@ -198,12 +199,34 @@ def parse(self, uri: str) -> MockParsedScheme:
198199
return MockParsedScheme(canonical=canonical, path=path)
199200

200201

202+
class BuiltinSchemeResolver(SchemeResolver):
203+
"""Resolves builtin:// URIs for built-in policies loaded by class path.
204+
205+
Supported formats:
206+
- builtin://full.class.path.PolicyClass
207+
"""
208+
209+
@property
210+
def scheme(self) -> str:
211+
return "builtin"
212+
213+
def parse(self, uri: str) -> BuiltinParsedScheme:
214+
if not uri.startswith("builtin://"):
215+
raise ValueError(f"Expected builtin:// URI, got: {uri}")
216+
class_path = uri[len("builtin://") :]
217+
if not class_path:
218+
raise ValueError("builtin:// URIs must include a class path")
219+
canonical = f"builtin://{class_path}"
220+
return BuiltinParsedScheme(canonical=canonical, class_path=class_path)
221+
222+
201223
_SCHEME_RESOLVERS: list[str] = [
202224
"mettagrid.util.uri_resolvers.schemes.FileSchemeResolver",
203225
"mettagrid.util.uri_resolvers.schemes.S3SchemeResolver",
204226
"mettagrid.util.uri_resolvers.schemes.HttpSchemeResolver",
205227
"mettagrid.util.uri_resolvers.schemes.HttpSchemeResolver",
206228
"mettagrid.util.uri_resolvers.schemes.MockSchemeResolver",
229+
"mettagrid.util.uri_resolvers.schemes.BuiltinSchemeResolver",
207230
"metta.rl.metta_scheme_resolver.MettaSchemeResolver",
208231
]
209232

@@ -272,6 +295,10 @@ def policy_spec_from_uri(
272295

273296
parsed = resolve_uri(uri)
274297

298+
# Handle built-in policies (e.g., builtin://mettagrid.policy.noop.NoopPolicy)
299+
if parsed.scheme == "builtin":
300+
return PolicySpec(class_path=parsed.class_path)
301+
275302
if parsed.canonical.endswith(".mpt"):
276303
checkpoint_path = str(parsed.local_path) if parsed.local_path else parsed.canonical
277304
return PolicySpec(

0 commit comments

Comments
 (0)