@@ -31,6 +31,55 @@ class EpisodeRolloutResult(BaseModel):
3131 replay_path : str | None
3232 steps : int
3333 max_steps : int
34+ agent_inventories : list [dict [str , int ]] | None = None # agent_id -> {resource_name -> amount}
35+ resource_names : list [str ] | None = None # list of resource names
36+
37+ def print_agent_stats (self , title : str = "Agent Summary" ) -> None :
38+ """Print agent inventories and stats as a table."""
39+ if not self .agent_inventories or not self .resource_names :
40+ return
41+
42+ # Build plain text table (avoids Rich truncation issues)
43+ columns = ["Agent" , "Reward" ] + self .resource_names
44+ widths = [max (6 , len (c )) for c in columns ]
45+
46+ # Header
47+ header = " " .join (c .rjust (widths [i ]) for i , c in enumerate (columns ))
48+ print (f"\n { title } " )
49+ print (header )
50+ print ("-" * len (header ))
51+
52+ # Agent rows
53+ resource_totals : dict [str , int ] = {}
54+ for agent_id , inventory in enumerate (self .agent_inventories ):
55+ row = [str (agent_id ), f"{ self .rewards [agent_id ]:.2f} " ]
56+ for resource in self .resource_names :
57+ amount = inventory .get (resource , 0 )
58+ resource_totals [resource ] = resource_totals .get (resource , 0 ) + amount
59+ row .append (str (amount ))
60+ print (" " .join (row [i ].rjust (widths [i ]) for i in range (len (row ))))
61+
62+ print ("-" * len (header ))
63+
64+ # Total row
65+ total_row = ["Total" , f"{ float (self .rewards .sum ()):.2f} " ]
66+ for resource in self .resource_names :
67+ total_row .append (str (resource_totals .get (resource , 0 )))
68+ print (" " .join (total_row [i ].rjust (widths [i ]) for i in range (len (total_row ))))
69+
70+ # Gained row (from stats)
71+ gained_row = ["Gained" , "" ]
72+ for resource in self .resource_names :
73+ gained = int (self .stats .get (f"{ resource } .gained" , 0 ))
74+ gained_row .append (str (gained ) if gained else "-" )
75+ print (" " .join (gained_row [i ].rjust (widths [i ]) for i in range (len (gained_row ))))
76+
77+ # Lost row (from stats)
78+ lost_row = ["Lost" , "" ]
79+ for resource in self .resource_names :
80+ lost = int (self .stats .get (f"{ resource } .lost" , 0 ))
81+ lost_row .append (str (lost ) if lost else "-" )
82+ print (" " .join (lost_row [i ].rjust (widths [i ]) for i in range (len (lost_row ))))
3483
3584
3685class MultiEpisodeRolloutResult (BaseModel ):
@@ -120,6 +169,12 @@ def multi_episode_rollout(
120169 all_replay_paths = episode_replay_writer .get_written_replay_urls ()
121170 replay_path = None if not all_replay_paths else list (all_replay_paths .values ())[0 ]
122171
172+ # Collect agent inventories
173+ agent_inventories = []
174+ for agent_id in range (env_cfg .game .num_agents ):
175+ agent = rollout ._sim .agent (agent_id )
176+ agent_inventories .append (dict (agent .inventory ))
177+
123178 result = EpisodeRolloutResult (
124179 assignments = assignments .copy (),
125180 rewards = np .array (rollout ._sim .episode_rewards , dtype = float ),
@@ -128,6 +183,8 @@ def multi_episode_rollout(
128183 replay_path = replay_path ,
129184 steps = rollout ._sim .current_step ,
130185 max_steps = rollout ._sim .config .game .max_steps ,
186+ agent_inventories = agent_inventories ,
187+ resource_names = env_cfg .game .resource_names ,
131188 )
132189
133190 episode_results .append (result )
0 commit comments