Skip to content

Commit 1b4184e

Browse files
committed
remove legacy mpt artifact files
1 parent 8229709 commit 1b4184e

File tree

4 files changed

+2
-244
lines changed

4 files changed

+2
-244
lines changed

metta/rl/mpt_artifact.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,13 @@ def instantiate(
3838
self,
3939
policy_env_info: PolicyEnvInterface,
4040
device: str = "cpu",
41-
*,
42-
strict: bool = True,
4341
) -> Any:
4442
torch_device = torch.device(device)
4543

4644
policy = self.architecture.make_policy(policy_env_info)
4745
policy = policy.to(torch_device)
4846

49-
missing, unexpected = policy.load_state_dict(dict(self.state_dict), strict=strict)
50-
if strict and (missing or unexpected):
51-
raise RuntimeError(f"Strict loading failed. Missing: {missing}, Unexpected: {unexpected}")
47+
policy.load_state_dict(dict(self.state_dict))
5248

5349
if hasattr(policy, "initialize_to_environment"):
5450
policy.initialize_to_environment(policy_env_info, torch_device)

metta/rl/mpt_policy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@ def __init__(
2424
*,
2525
checkpoint_uri: str | None = None,
2626
device: str = "cpu",
27-
strict: bool = True,
2827
):
2928
super().__init__(policy_env_info, device=device)
3029

3130
self._policy = None
3231
self._architecture = None
33-
self._strict = strict
3432
self._device = device
3533

3634
if checkpoint_uri:
@@ -39,7 +37,7 @@ def __init__(
3937
def _load_from_checkpoint(self, checkpoint_uri: str, *, device: str) -> None:
4038
artifact = load_mpt(checkpoint_uri)
4139
self._architecture = artifact.architecture
42-
self._policy = artifact.instantiate(self._policy_env_info, device=device, strict=self._strict)
40+
self._policy = artifact.instantiate(self._policy_env_info, device=device)
4341
self._policy.eval()
4442

4543
def load_policy_data(self, policy_data_path: str) -> None:

packages/mettagrid/python/src/mettagrid/policy/mpt_artifact.py

Lines changed: 0 additions & 179 deletions
This file was deleted.

packages/mettagrid/python/src/mettagrid/policy/mpt_policy.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)