Skip to content

Commit ce97bc2

Browse files
committed
Update stats reporter PPO hyperparam keys
1 parent 9fe14c9 commit ce97bc2

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

metta/rl/training/stats_reporter.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -529,17 +529,20 @@ def _collect_hyperparameters(
529529

530530
losses = getattr(trainer_cfg, "losses", None)
531531
loss_configs = getattr(losses, "loss_configs", {}) if losses else {}
532-
ppo_cfg = loss_configs.get("ppo") if isinstance(loss_configs, dict) else None
533-
if ppo_cfg is not None:
534-
for attr in (
535-
"clip_coef",
536-
"vf_clip_coef",
537-
"ent_coef",
538-
"l2_reg_loss_coef",
539-
"l2_init_loss_coef",
540-
):
541-
value = getattr(ppo_cfg, attr, None)
542-
if value is None:
543-
continue
544-
hyperparameters[f"ppo_{attr}"] = value
532+
if isinstance(loss_configs, dict):
533+
ppo_actor_cfg = loss_configs.get("ppo_actor")
534+
if ppo_actor_cfg is not None:
535+
for attr in ("clip_coef", "ent_coef", "norm_adv", "target_kl"):
536+
value = getattr(ppo_actor_cfg, attr, None)
537+
if value is None:
538+
continue
539+
hyperparameters[f"ppo_actor_{attr}"] = value
540+
541+
ppo_critic_cfg = loss_configs.get("ppo_critic")
542+
if ppo_critic_cfg is not None:
543+
for attr in ("vf_coef", "vf_clip_coef", "clip_vloss"):
544+
value = getattr(ppo_critic_cfg, attr, None)
545+
if value is None:
546+
continue
547+
hyperparameters[f"ppo_critic_{attr}"] = value
545548
return hyperparameters

tests/rl/test_stats_reporter_defaults.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def _reporter(existing_heart: float | None = None) -> StatsReporter:
2424
reporter._context = SimpleNamespace( # type: ignore[attr-defined, assignment]
2525
config=SimpleNamespace(
2626
optimizer=SimpleNamespace(learning_rate=0.0, type="adam"),
27-
losses=SimpleNamespace(ppo=SimpleNamespace(enabled=False)),
27+
losses=SimpleNamespace(
28+
ppo_actor=SimpleNamespace(enabled=False),
29+
ppo_critic=SimpleNamespace(enabled=False),
30+
),
2831
),
2932
stopwatch=timer,
3033
experience=SimpleNamespace(stats=lambda: {}),

0 commit comments

Comments
 (0)