Skip to content

Commit 69496e9

Browse files
authored
purge some uninteresting stats from wandb logging (#4546)
# Summary Reduce training/environment stats volume and simplify StatsReporter output. This trims low‑value metrics, removes weight/dormant‑neuron reporting, and narrows rolling averages to a small set of required env metrics. # Why We’re intentionally cutting noisy or redundant metrics to lower logging overhead and make dashboards easier to read, while keeping core signals intact. # What changed ## Environment stats production (mettagrid) - Removed label completion tracking/emission from `StatsTracker`. - Reward estimate diffs and env timing stats are still produced in the env but are now filtered out before logging. ## Stats filtering (metta/rl/stats.py) - `filter_movement_metrics` now drops: - `env_attributes/*` - `env_reward_estimates/*` - `env_timing_per_epoch/*` - `env_timing_cumulative/*` - Removed stale `env_label_completions/*` filter entry (metric no longer produced in this branch). - Core movement metrics kept: only the four direction counters. ## StatsReporter (metta/rl/training/stats_reporter.py) - Removed weight metrics + dormant‑neuron analysis. - Removed `parameters/*` payload (learning rate, epoch steps, minibatches, schedulefree extras). - Hyperparameters now come only from optimizer param groups (lr, schedulefree scheduled_lr, lr_max). - Rolling averages computed only for `default_zero_metrics` instead of all env metrics. ## W&B logger (metta/rl/training/wandb_logger.py) - Removed direct logging of `latest_losses_stats`; losses now only flow through StatsReporter. - If StatsReporter is disabled, loss metrics will not appear. ## System monitor (mettagrid) - Dropped static counters: cpu_count, cpu_count_logical/physical, memory_total_mb, gpu_count. # Behavior / compatibility notes - W&B/env metric output is substantially reduced. Any dashboards/scripts that reference removed metrics will need updates. - Loss metrics are no longer emitted by WandbLogger when StatsReporter is disabled. # Testing Not run (stats/logging changes only). # Files touched - metta/rl/stats.py - metta/rl/training/stats_reporter.py - metta/rl/training/wandb_logger.py - packages/mettagrid/python/src/mettagrid/envs/stats_tracker.py - packages/mettagrid/python/src/mettagrid/profiling/system_monitor.py - tests/rl/test_stats_reporter_defaults.py [Asana Task](https://app.asana.com/1/1209016784099267/project/1210348820405981/task/1212600739220124)
1 parent 4e9733f commit 69496e9

File tree

6 files changed

+32
-180
lines changed

6 files changed

+32
-180
lines changed

metta/rl/stats.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def accumulate_rollout_stats(
5050

5151

5252
def filter_movement_metrics(stats: dict[str, Any]) -> dict[str, Any]:
53-
"""Filter movement metrics to only keep core values, removing derived stats."""
53+
"""Filter verbose environment metrics while keeping core values."""
5454
filtered = {}
5555

5656
# Core movement metrics we want to keep (without any suffix)
@@ -60,14 +60,18 @@ def filter_movement_metrics(stats: dict[str, Any]) -> dict[str, Any]:
6060
"env_agent/movement.direction.down",
6161
"env_agent/movement.direction.left",
6262
"env_agent/movement.direction.right",
63-
"env_agent/movement.sequential_rotations",
64-
"env_agent/movement.rotation.to_up",
65-
"env_agent/movement.rotation.to_down",
66-
"env_agent/movement.rotation.to_left",
67-
"env_agent/movement.rotation.to_right",
6863
}
64+
noisy_prefixes = (
65+
"env_reward_estimates/",
66+
"env_timing_per_epoch/",
67+
"env_timing_cumulative/",
68+
)
6969

7070
for key, value in stats.items():
71+
if key.startswith("env_attributes/"):
72+
continue
73+
if key.startswith(noisy_prefixes):
74+
continue
7175
# Check if this is a core metric (exact match)
7276
if key in core_metrics:
7377
filtered[key] = value

metta/rl/training/stats_reporter.py

Lines changed: 20 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
import numpy as np
99
import torch
10-
import torch.nn as nn
1110
from pydantic import Field
1211

1312
from metta.common.wandb.context import WandbRun
14-
from metta.rl.model_analysis import compute_dormant_neuron_stats
1513
from metta.rl.stats import accumulate_rollout_stats, compute_timing_stats, process_training_stats
1614
from metta.rl.training.component import TrainerComponent
17-
from metta.rl.utils import should_run
1815
from mettagrid.base_config import Config
1916

2017
logger = logging.getLogger(__name__)
@@ -43,11 +40,9 @@ def _to_scalar(value: Any) -> Optional[float]:
4340
def build_wandb_payload(
4441
processed_stats: dict[str, Any],
4542
timing_info: dict[str, Any],
46-
weight_stats: dict[str, Any],
4743
grad_stats: dict[str, float],
4844
system_stats: dict[str, Any],
4945
memory_stats: dict[str, Any],
50-
parameters: dict[str, Any],
5146
hyperparameters: dict[str, Any],
5247
*,
5348
agent_step: int,
@@ -87,12 +82,9 @@ def _update(items: dict[str, Any], *, prefix: str = "") -> None:
8782
_update(experience_stats, prefix="experience/")
8883

8984
_update(processed_stats.get("environment_stats", {}))
90-
_update(parameters, prefix="parameters/")
9185
_update(hyperparameters, prefix="hyperparameters/")
92-
9386
_update(system_stats)
9487
_update({f"trainer_memory/{k}": v for k, v in memory_stats.items()})
95-
_update(weight_stats)
9688
_update(grad_stats)
9789
_update(timing_info.get("timing_stats", {}))
9890

@@ -107,10 +99,6 @@ class StatsReporterConfig(Config):
10799
grad_mean_variance_interval: int = 50
108100
interval: int = 1
109101
"""How often to report stats (in epochs)"""
110-
analyze_weights_interval: int = 0
111-
"""How often to compute weight metrics (0 disables)."""
112-
dormant_neuron_threshold: float = 1e-6
113-
"""Threshold for considering a neuron dormant based on mean absolute weight magnitude."""
114102
rolling_window: int = Field(default=5, ge=1, description="Number of epochs for metric rolling averages")
115103
default_zero_metrics: tuple[str, ...] = Field(
116104
default_factory=lambda: ("env_game/assembler.heart.created",),
@@ -355,27 +343,16 @@ def _build_wandb_payload(
355343
timing_info = compute_timing_stats(timer=timer, agent_step=agent_step)
356344
self._normalize_steps_per_second(timing_info, agent_step)
357345

358-
weight_stats = self._collect_weight_stats(policy=policy, epoch=epoch)
359-
dormant_stats = self._compute_dormant_neuron_stats(policy=policy)
360-
if dormant_stats:
361-
weight_stats.update(dormant_stats)
362346
system_stats = self._collect_system_stats()
363347
memory_stats = self._collect_memory_stats()
364-
parameters = self._collect_parameters(
365-
experience=experience,
366-
optimizer=optimizer,
367-
timing_info=timing_info,
368-
)
369-
hyperparameters = self._collect_hyperparameters(trainer_cfg=trainer_cfg, parameters=parameters)
348+
hyperparameters = self._collect_hyperparameters(optimizer=optimizer)
370349

371350
return build_wandb_payload(
372351
processed_stats=processed,
373352
timing_info=timing_info,
374-
weight_stats=weight_stats,
375353
grad_stats=self._state.grad_stats,
376354
system_stats=system_stats,
377355
memory_stats=memory_stats,
378-
parameters=parameters,
379356
hyperparameters=hyperparameters,
380357
agent_step=agent_step,
381358
epoch=epoch,
@@ -386,7 +363,10 @@ def _augment_with_rolling_averages(self, processed: dict[str, Any]) -> None:
386363
if not isinstance(env_stats, dict):
387364
return
388365

389-
tracked_keys = set(env_stats.keys()) | set(self._state.rolling_stats.keys())
366+
tracked_keys = set(self._config.default_zero_metrics)
367+
for key in list(self._state.rolling_stats.keys()):
368+
if key not in tracked_keys:
369+
del self._state.rolling_stats[key]
390370
window = self._config.rolling_window
391371

392372
for key in tracked_keys:
@@ -402,9 +382,7 @@ def _augment_with_rolling_averages(self, processed: dict[str, Any]) -> None:
402382
continue
403383
history.append(scalar)
404384
env_stats.setdefault(key, scalar)
405-
# Skip creating .avg versions for env_per_label metrics
406-
if not (key.startswith("env_per_label_rewards/") or key.startswith("env_per_label_chest_deposits/")):
407-
env_stats[f"{key}.avg"] = sum(history) / len(history)
385+
env_stats[f"{key}.avg"] = sum(history) / len(history)
408386

409387
def _normalize_steps_per_second(self, timing_info: dict[str, Any], agent_step: int) -> None:
410388
"""Adjust SPS to account for agent steps accumulated before a resume."""
@@ -433,43 +411,6 @@ def _normalize_steps_per_second(self, timing_info: dict[str, Any], agent_step: i
433411
if isinstance(timing_stats, dict):
434412
timing_stats["timing_cumulative/sps"] = sps
435413

436-
def _collect_weight_stats(self, *, policy: Any, epoch: int) -> dict[str, float]:
437-
interval = self._config.analyze_weights_interval
438-
if not interval:
439-
policy_config = getattr(policy, "config", None)
440-
interval = getattr(policy_config, "analyze_weights_interval", 0) if policy_config else 0
441-
442-
if not interval or not should_run(epoch, interval):
443-
return {}
444-
445-
if not hasattr(policy, "compute_weight_metrics"):
446-
return {}
447-
448-
weight_stats: dict[str, float] = {}
449-
try:
450-
for metrics in policy.compute_weight_metrics():
451-
name = metrics.get("name", "unknown")
452-
for key, value in metrics.items():
453-
if key == "name":
454-
continue
455-
scalar = _to_scalar(value)
456-
if scalar is None:
457-
continue
458-
weight_stats[f"weights/{key}/{name}"] = scalar
459-
except Exception as exc: # pragma: no cover - safeguard against model-specific failures
460-
logger.warning("Failed to compute weight metrics: %s", exc, exc_info=True)
461-
return weight_stats
462-
463-
def _compute_dormant_neuron_stats(self, *, policy: Any) -> dict[str, float]:
464-
if not isinstance(policy, nn.Module):
465-
return {}
466-
threshold = getattr(self._config, "dormant_neuron_threshold", 1e-6)
467-
try:
468-
return compute_dormant_neuron_stats(policy, threshold=threshold)
469-
except Exception as exc: # pragma: no cover - safeguard against model-specific failures
470-
logger.debug("Failed to compute dormant neuron stats: %s", exc, exc_info=True)
471-
return {}
472-
473414
def _collect_system_stats(self) -> dict[str, Any]:
474415
system_monitor = getattr(self.context, "system_monitor", None)
475416
if system_monitor is None:
@@ -490,79 +431,19 @@ def _collect_memory_stats(self) -> dict[str, Any]:
490431
logger.debug("Memory monitor stats failed: %s", exc, exc_info=True)
491432
return {}
492433

493-
def _collect_parameters(
494-
self,
495-
*,
496-
experience: Any,
497-
optimizer: torch.optim.Optimizer,
498-
timing_info: dict[str, Any],
499-
) -> dict[str, Any]:
500-
learning_rate = getattr(self.context.config.optimizer, "learning_rate", 0)
501-
if optimizer and optimizer.param_groups:
502-
learning_rate = optimizer.param_groups[0].get("lr", learning_rate)
503-
504-
parameters: dict[str, Any] = {
505-
"learning_rate": learning_rate,
506-
"epoch_steps": timing_info.get("epoch_steps", 0),
507-
"num_minibatches": getattr(experience, "num_minibatches", 0),
508-
}
509-
510-
# Add ScheduleFree optimizer information
511-
if optimizer and optimizer.param_groups:
512-
param_group = optimizer.param_groups[0]
513-
is_schedulefree = "train_mode" in param_group
514-
515-
if is_schedulefree:
516-
scheduled_lr = param_group.get("scheduled_lr")
517-
if scheduled_lr is not None:
518-
parameters["schedulefree_scheduled_lr"] = scheduled_lr
519-
lr_max = param_group.get("lr_max")
520-
if lr_max is not None:
521-
parameters["schedulefree_lr_max"] = lr_max
522-
523-
return parameters
524-
525-
def _collect_hyperparameters(
526-
self,
527-
*,
528-
trainer_cfg: Any,
529-
parameters: dict[str, Any],
530-
) -> dict[str, Any]:
434+
def _collect_hyperparameters(self, *, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
531435
hyperparameters: dict[str, Any] = {}
532-
if "learning_rate" in parameters:
533-
hyperparameters["learning_rate"] = parameters["learning_rate"]
534-
535-
optimizer_cfg = getattr(trainer_cfg, "optimizer", None)
536-
if optimizer_cfg:
537-
hyperparameters["optimizer_type"] = optimizer_cfg.type
538-
if "schedulefree" in optimizer_cfg.type:
539-
warmup_steps = getattr(optimizer_cfg, "warmup_steps", None)
540-
if warmup_steps is not None:
541-
hyperparameters["schedulefree_warmup_steps"] = warmup_steps
542-
543-
losses = getattr(trainer_cfg, "losses", None)
544-
loss_configs = getattr(losses, "loss_configs", {}) if losses else {}
545-
if isinstance(loss_configs, dict):
546-
ppo_actor_cfg = loss_configs.get("ppo_actor")
547-
if ppo_actor_cfg is not None:
548-
for attr in ("clip_coef", "ent_coef", "norm_adv", "target_kl"):
549-
value = getattr(ppo_actor_cfg, attr, None)
550-
if value is None:
551-
continue
552-
hyperparameters[f"ppo_actor_{attr}"] = value
553-
554-
ppo_critic_cfg = loss_configs.get("ppo_critic")
555-
if ppo_critic_cfg is not None:
556-
for attr in (
557-
"vf_coef",
558-
"vf_clip_coef",
559-
"clip_vloss",
560-
"critic_update",
561-
"aux_coef",
562-
"beta",
563-
):
564-
value = getattr(ppo_critic_cfg, attr, None)
565-
if value is None:
566-
continue
567-
hyperparameters[f"ppo_critic_{attr}"] = value
436+
param_groups = optimizer.param_groups
437+
if not param_groups:
438+
return hyperparameters
439+
param_group = param_groups[0]
440+
learning_rate = param_group.get("lr")
441+
if learning_rate is not None:
442+
hyperparameters["learning_rate"] = learning_rate
443+
scheduled_lr = param_group.get("scheduled_lr")
444+
if scheduled_lr is not None:
445+
hyperparameters["schedulefree_scheduled_lr"] = scheduled_lr
446+
lr_max = param_group.get("lr_max")
447+
if lr_max is not None:
448+
hyperparameters["schedulefree_lr_max"] = lr_max
568449
return hyperparameters

metta/rl/training/wandb_logger.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ def on_epoch_end(self, epoch: int) -> None: # noqa: D401 - documented in base c
5454
for k, v in elapsed.items():
5555
self._prev_elapsed[k] = float(v)
5656

57-
for key, value in context.latest_losses_stats.items():
58-
metric_key = key if "/" in key else f"loss/{key}"
59-
payload[metric_key] = float(value)
60-
6157
self._wandb_run.log(payload)
6258

6359
def on_training_complete(self) -> None: # noqa: D401

packages/mettagrid/python/src/mettagrid/envs/stats_tracker.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(self, stats_writer: StatsWriter):
1515
self._stats_writer = stats_writer
1616
self._episode_start_ts = datetime.datetime.now()
1717
self._episode_end_ts = None
18-
self._label_completions = {"completed_tasks": [], "completion_rates": {}}
1918
self._per_label_rewards = {}
2019
self._per_label_chest_deposits = {} # Track chest deposits per label
2120

@@ -55,11 +54,6 @@ def on_episode_end(self) -> None:
5554
config.game.reward_estimates["worst_case_optimal_reward"] - mean_reward
5655
)
5756

58-
self._update_label_completions()
59-
60-
# only plot label completions once we have a full moving average window, to prevent initial bias
61-
if len(self._label_completions["completed_tasks"]) >= 50:
62-
infos["label_completions"] = self._label_completions["completion_rates"]
6357
self._per_label_rewards[config.label] = mean_reward
6458
infos["per_label_rewards"] = self._per_label_rewards
6559

@@ -147,21 +141,3 @@ def _add_timing_info(self) -> None:
147141

148142
def on_close(self) -> None:
149143
self._stats_writer.close()
150-
151-
def _update_label_completions(self, moving_avg_window: int = 500) -> None:
152-
"""Update label completions."""
153-
label = self._sim.config.label
154-
155-
# keep track of a list of the last 500 labels
156-
if len(self._label_completions["completed_tasks"]) >= moving_avg_window:
157-
self._label_completions["completed_tasks"].pop(0)
158-
self._label_completions["completed_tasks"].append(label)
159-
160-
# moving average of the completion rates
161-
self._label_completions["completion_rates"] = {t: 0 for t in set(self._label_completions["completed_tasks"])}
162-
for t in self._label_completions["completed_tasks"]:
163-
self._label_completions["completion_rates"][t] += 1
164-
self._label_completions["completion_rates"] = {
165-
t: self._label_completions["completion_rates"][t] / len(self._label_completions["completed_tasks"])
166-
for t in self._label_completions["completion_rates"]
167-
}

packages/mettagrid/python/src/mettagrid/profiling/system_monitor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,10 @@ def _initialize_default_metrics(self):
103103
self._metric_collectors = {
104104
# CPU metrics
105105
"cpu_percent": lambda: psutil.cpu_percent(interval=0),
106-
"cpu_count": lambda: psutil.cpu_count(),
107-
"cpu_count_logical": lambda: psutil.cpu_count(logical=True),
108-
"cpu_count_physical": lambda: psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True),
109106
# Memory metrics
110107
"memory_percent": lambda: psutil.virtual_memory().percent,
111108
"memory_available_mb": lambda: psutil.virtual_memory().available / (1024 * 1024),
112109
"memory_used_mb": lambda: psutil.virtual_memory().used / (1024 * 1024),
113-
"memory_total_mb": lambda: psutil.virtual_memory().total / (1024 * 1024),
114110
# Process-specific metrics
115111
"process_memory_mb": lambda: psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024),
116112
"process_cpu_percent": lambda: self._process.cpu_percent(),
@@ -154,7 +150,6 @@ def _initialize_default_metrics(self):
154150
# Add aggregate metrics (rename to make it clear they're aggregates)
155151
self._metric_collectors.update(
156152
{
157-
"gpu_count": lambda: gpu_count,
158153
"gpu_utilization_avg": self._get_gpu_utilization_cuda,
159154
"gpu_memory_percent_avg": self._get_gpu_memory_percent_cuda,
160155
"gpu_memory_used_mb_total": self._get_gpu_memory_used_mb_cuda,

tests/rl/test_stats_reporter_defaults.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _reporter(existing_heart: float | None = None) -> StatsReporter:
3434
stopwatch=timer,
3535
experience=SimpleNamespace(stats=lambda: {}),
3636
policy=None,
37-
optimizer=SimpleNamespace(param_groups=[]),
37+
optimizer=None,
3838
epoch=0,
3939
agent_step=0,
4040
run_name=None,
@@ -61,7 +61,7 @@ def test_heart_metric_zero_fill_and_preserve(existing: float | None, expected: f
6161
agent_step=0,
6262
epoch=0,
6363
timer=reporter.context.stopwatch,
64-
optimizer=reporter.context.optimizer,
64+
optimizer=SimpleNamespace(param_groups=[]),
6565
)
6666

6767
assert payload["env_game/assembler.heart.created"] == expected

0 commit comments

Comments
 (0)