Skip to content

Commit 1ef135b

Browse files
committed
cp
1 parent 0f2d96f commit 1ef135b

File tree

5 files changed

+80
-7
lines changed

5 files changed

+80
-7
lines changed

packages/mettagrid/cpp/include/mettagrid/objects/assembler.hpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <algorithm>
55
#include <cassert>
66
#include <cmath>
7+
#include <iostream>
78
#include <stdexcept>
89
#include <string>
910
#include <unordered_map>
@@ -250,6 +251,12 @@ class Assembler : public GridObject, public Usable {
250251
// Chest search distance - if > 0, assembler can use inventories from chests within this distance
251252
unsigned int chest_search_distance;
252253

254+
// Per-agent cooldown duration - number of timesteps before an agent can use this assembler again
255+
unsigned int agent_cooldown;
256+
257+
// Per-agent cooldown tracking - maps agent_id to timestep when cooldown ends
258+
std::vector<unsigned int> agent_cooldown_ends;
259+
253260
Assembler(GridCoord r, GridCoord c, const AssemblerConfig& cfg, StatsTracker* stats)
254261
: protocols(build_protocol_map(cfg.protocols)),
255262
unclip_protocols(),
@@ -266,6 +273,7 @@ class Assembler : public GridObject, public Usable {
266273
obs_encoder(nullptr),
267274
allow_partial_usage(cfg.allow_partial_usage),
268275
chest_search_distance(cfg.chest_search_distance),
276+
agent_cooldown(cfg.agent_cooldown),
269277
clipper_ptr(nullptr) {
270278
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe);
271279
}
@@ -287,9 +295,36 @@ class Assembler : public GridObject, public Usable {
287295
}
288296

289297
// Initialize the per-agent tracking array (call after knowing num_agents)
290-
// Note: This is a no-op for now since we removed per-agent cooldown, but kept for API compat
291-
void init_agent_tracking(unsigned int /*num_agents*/) {
292-
// No-op - per-agent cooldown was removed
298+
void init_agent_tracking(unsigned int num_agents) {
299+
if (agent_cooldown > 0) {
300+
agent_cooldown_ends.resize(num_agents, 0);
301+
}
302+
}
303+
304+
// Check if a specific agent is on cooldown
305+
bool is_agent_on_cooldown(unsigned int agent_id) const {
306+
if (agent_cooldown == 0 || agent_id >= agent_cooldown_ends.size()) {
307+
return false;
308+
}
309+
return current_timestep_ptr && agent_cooldown_ends[agent_id] > *current_timestep_ptr;
310+
}
311+
312+
// Get remaining cooldown for a specific agent
313+
unsigned int get_agent_cooldown_remaining(unsigned int agent_id) const {
314+
if (agent_cooldown == 0 || agent_id >= agent_cooldown_ends.size() || !current_timestep_ptr) {
315+
return 0;
316+
}
317+
if (agent_cooldown_ends[agent_id] <= *current_timestep_ptr) {
318+
return 0;
319+
}
320+
return agent_cooldown_ends[agent_id] - *current_timestep_ptr;
321+
}
322+
323+
// Set cooldown for a specific agent
324+
void set_agent_cooldown(unsigned int agent_id) {
325+
if (agent_cooldown > 0 && agent_id < agent_cooldown_ends.size() && current_timestep_ptr) {
326+
agent_cooldown_ends[agent_id] = *current_timestep_ptr + agent_cooldown;
327+
}
293328
}
294329

295330
// Get the remaining cooldown duration in ticks (0 when ready for use)
@@ -439,6 +474,11 @@ class Assembler : public GridObject, public Usable {
439474
return false;
440475
}
441476

477+
// Check per-agent cooldown
478+
if (is_agent_on_cooldown(actor.agent_id)) {
479+
return false;
480+
}
481+
442482
// Check if on cooldown and whether partial usage is allowed
443483
unsigned int remaining = cooldown_remaining();
444484
if (remaining > 0 && !allow_partial_usage) {
@@ -497,6 +537,9 @@ class Assembler : public GridObject, public Usable {
497537
cooldown_duration = static_cast<unsigned int>(protocol_to_use.cooldown);
498538
cooldown_end_timestep = *current_timestep_ptr + cooldown_duration;
499539

540+
// Set per-agent cooldown
541+
set_agent_cooldown(actor.agent_id);
542+
500543
// If we were clipped and successfully used an unclip protocol, become unclipped. Also, don't count this as a use.
501544
if (is_clipped) {
502545
become_unclipped();
@@ -515,6 +558,14 @@ class Assembler : public GridObject, public Usable {
515558
features.push_back({ObservationFeature::CooldownRemaining, static_cast<ObservationType>(remaining)});
516559
}
517560

561+
// Add per-agent cooldown remaining if the observer is a valid agent
562+
if (observer_agent_id != UINT_MAX && agent_cooldown > 0) {
563+
unsigned int agent_remaining = std::min(get_agent_cooldown_remaining(observer_agent_id), 255u);
564+
if (agent_remaining > 0) {
565+
features.push_back({ObservationFeature::AgentCooldownRemaining, static_cast<ObservationType>(agent_remaining)});
566+
}
567+
}
568+
518569
// Add clipped status to observations if clipped
519570
if (is_clipped) {
520571
features.push_back({ObservationFeature::Clipped, static_cast<ObservationType>(1)});

packages/mettagrid/cpp/include/mettagrid/objects/assembler_config.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ struct AssemblerConfig : public GridObjectConfig {
1717
AssemblerConfig(TypeId type_id, const std::string& type_name, ObservationType initial_vibe = 0)
1818
: GridObjectConfig(type_id, type_name, initial_vibe),
1919
allow_partial_usage(false),
20-
max_uses(0), // 0 means unlimited uses
21-
clip_immune(false), // Not immune by default
22-
start_clipped(false), // Not clipped at start by default
23-
chest_search_distance(0) {} // 0 means chests are not searched
20+
max_uses(0), // 0 means unlimited uses
21+
clip_immune(false), // Not immune by default
22+
start_clipped(false), // Not clipped at start by default
23+
chest_search_distance(0), // 0 means chests are not searched
24+
agent_cooldown(0) {} // 0 means no per-agent cooldown
2425

2526
// List of protocols - GroupVibe keys will be calculated from each protocol's vibes vector
2627
std::vector<std::shared_ptr<Protocol>> protocols;
@@ -40,6 +41,10 @@ struct AssemblerConfig : public GridObjectConfig {
4041
// Distance is measured as Chebyshev distance (max of row and column differences)
4142
// 0 means chests are not searched
4243
unsigned int chest_search_distance;
44+
45+
// Per-agent cooldown duration - number of timesteps before an agent can use this assembler again
46+
// 0 means no per-agent cooldown
47+
unsigned int agent_cooldown;
4348
};
4449

4550
namespace py = pybind11;
@@ -59,6 +64,7 @@ inline void bind_assembler_config(py::module& m) {
5964
.def_readwrite("clip_immune", &AssemblerConfig::clip_immune)
6065
.def_readwrite("start_clipped", &AssemblerConfig::start_clipped)
6166
.def_readwrite("chest_search_distance", &AssemblerConfig::chest_search_distance)
67+
.def_readwrite("agent_cooldown", &AssemblerConfig::agent_cooldown)
6268
.def_readwrite("initial_vibe", &AssemblerConfig::initial_vibe);
6369
}
6470

packages/mettagrid/python/src/mettagrid/config/mettagrid_c_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def convert_to_cpp_game_config(game_config: GameConfig):
292292
cpp_assembler_config.clip_immune = object_config.clip_immune
293293
cpp_assembler_config.start_clipped = object_config.start_clipped
294294
cpp_assembler_config.chest_search_distance = object_config.chest_search_distance
295+
cpp_assembler_config.agent_cooldown = object_config.agent_cooldown
295296
# Key by map_name so map grid (which uses map_name) resolves directly.
296297
objects_cpp_params[object_config.map_name or object_type] = cpp_assembler_config
297298
elif isinstance(object_config, ChestConfig):

packages/mettagrid/python/src/mettagrid/config/mettagrid_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,11 @@ class AssemblerConfig(GridObjectConfig):
386386
ge=0,
387387
description="Distance within which assembler can use inventories from chests",
388388
)
389+
agent_cooldown: int = Field(
390+
default=0,
391+
ge=0,
392+
description="Per-agent cooldown duration in timesteps before they can use this assembler again",
393+
)
389394

390395

391396
class ChestConfig(GridObjectConfig):

packages/mettagrid/python/src/mettagrid/simulator/multi_episode/rollout.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class EpisodeRolloutResult(BaseModel):
3131
replay_path: str | None
3232
steps: int
3333
max_steps: int
34+
agent_inventories: list[dict[str, int]] | None = None # agent_id -> {resource_name -> amount}
35+
resource_names: list[str] | None = None # list of resource names
3436

3537
def print_agent_stats(self, title: str = "Agent Summary") -> None:
3638
"""Print agent inventories and stats as a table."""
@@ -167,6 +169,12 @@ def multi_episode_rollout(
167169
all_replay_paths = episode_replay_writer.get_written_replay_urls()
168170
replay_path = None if not all_replay_paths else list(all_replay_paths.values())[0]
169171

172+
# Collect agent inventories
173+
agent_inventories = []
174+
for agent_id in range(env_cfg.game.num_agents):
175+
agent = rollout._sim.agent(agent_id)
176+
agent_inventories.append(dict(agent.inventory))
177+
170178
result = EpisodeRolloutResult(
171179
assignments=assignments.copy(),
172180
rewards=np.array(rollout._sim.episode_rewards, dtype=float),
@@ -175,6 +183,8 @@ def multi_episode_rollout(
175183
replay_path=replay_path,
176184
steps=rollout._sim.current_step,
177185
max_steps=rollout._sim.config.game.max_steps,
186+
agent_inventories=agent_inventories,
187+
resource_names=env_cfg.game.resource_names,
178188
)
179189

180190
episode_results.append(result)

0 commit comments

Comments
 (0)