Skip to content

Commit 88e0519

Browse files
committed
Add rollout agent inventory stats printing
Adds print_agent_stats() to EpisodeRolloutResult that displays a table of agent inventories and rewards at the end of each episode.
1 parent aba948e commit 88e0519

File tree

1 file changed

+57
-0
lines changed
  • packages/mettagrid/python/src/mettagrid/simulator/multi_episode

1 file changed

+57
-0
lines changed

packages/mettagrid/python/src/mettagrid/simulator/multi_episode/rollout.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,55 @@ class EpisodeRolloutResult(BaseModel):
3131
replay_path: str | None
3232
steps: int
3333
max_steps: int
34+
agent_inventories: list[dict[str, int]] | None = None # agent_id -> {resource_name -> amount}
35+
resource_names: list[str] | None = None # list of resource names
36+
37+
def print_agent_stats(self, title: str = "Agent Summary") -> None:
38+
"""Print agent inventories and stats as a table."""
39+
if not self.agent_inventories or not self.resource_names:
40+
return
41+
42+
# Build plain text table (avoids Rich truncation issues)
43+
columns = ["Agent", "Reward"] + self.resource_names
44+
widths = [max(6, len(c)) for c in columns]
45+
46+
# Header
47+
header = " ".join(c.rjust(widths[i]) for i, c in enumerate(columns))
48+
print(f"\n{title}")
49+
print(header)
50+
print("-" * len(header))
51+
52+
# Agent rows
53+
resource_totals: dict[str, int] = {}
54+
for agent_id, inventory in enumerate(self.agent_inventories):
55+
row = [str(agent_id), f"{self.rewards[agent_id]:.2f}"]
56+
for resource in self.resource_names:
57+
amount = inventory.get(resource, 0)
58+
resource_totals[resource] = resource_totals.get(resource, 0) + amount
59+
row.append(str(amount))
60+
print(" ".join(row[i].rjust(widths[i]) for i in range(len(row))))
61+
62+
print("-" * len(header))
63+
64+
# Total row
65+
total_row = ["Total", f"{float(self.rewards.sum()):.2f}"]
66+
for resource in self.resource_names:
67+
total_row.append(str(resource_totals.get(resource, 0)))
68+
print(" ".join(total_row[i].rjust(widths[i]) for i in range(len(total_row))))
69+
70+
# Gained row (from stats)
71+
gained_row = ["Gained", ""]
72+
for resource in self.resource_names:
73+
gained = int(self.stats.get(f"{resource}.gained", 0))
74+
gained_row.append(str(gained) if gained else "-")
75+
print(" ".join(gained_row[i].rjust(widths[i]) for i in range(len(gained_row))))
76+
77+
# Lost row (from stats)
78+
lost_row = ["Lost", ""]
79+
for resource in self.resource_names:
80+
lost = int(self.stats.get(f"{resource}.lost", 0))
81+
lost_row.append(str(lost) if lost else "-")
82+
print(" ".join(lost_row[i].rjust(widths[i]) for i in range(len(lost_row))))
3483

3584

3685
class MultiEpisodeRolloutResult(BaseModel):
@@ -120,6 +169,12 @@ def multi_episode_rollout(
120169
all_replay_paths = episode_replay_writer.get_written_replay_urls()
121170
replay_path = None if not all_replay_paths else list(all_replay_paths.values())[0]
122171

172+
# Collect agent inventories
173+
agent_inventories = []
174+
for agent_id in range(env_cfg.game.num_agents):
175+
agent = rollout._sim.agent(agent_id)
176+
agent_inventories.append(dict(agent.inventory))
177+
123178
result = EpisodeRolloutResult(
124179
assignments=assignments.copy(),
125180
rewards=np.array(rollout._sim.episode_rewards, dtype=float),
@@ -128,6 +183,8 @@ def multi_episode_rollout(
128183
replay_path=replay_path,
129184
steps=rollout._sim.current_step,
130185
max_steps=rollout._sim.config.game.max_steps,
186+
agent_inventories=agent_inventories,
187+
resource_names=env_cfg.game.resource_names,
131188
)
132189

133190
episode_results.append(result)

0 commit comments

Comments
 (0)