Skip to content

Commit 5c3c4f9

Browse files
committed
cp
1 parent 0ecb7ae commit 5c3c4f9

File tree

10 files changed

+241
-40
lines changed

10 files changed

+241
-40
lines changed

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ MettaGrid::MettaGrid(const GameConfig& game_config, const py::list map, unsigned
8585

8686
init_action_handlers();
8787

88-
_init_grid(game_config, map);
88+
_init_grid(_game_config, map);
8989

9090
// Pre-compute goal_obs tokens for each agent
9191
if (_global_obs_config.goal_obs) {
@@ -136,6 +136,9 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
136136
object_type_names[type_id] = object_cfg->type_name;
137137
}
138138

139+
// Collect assemblers to initialize agent tracking after all agents are created
140+
std::vector<Assembler*> assemblers;
141+
139142
// Initialize objects from map
140143
for (GridCoord r = 0; r < height; r++) {
141144
for (GridCoord c = 0; c < width; c++) {
@@ -184,6 +187,7 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
184187
assembler->set_grid(_grid.get());
185188
assembler->set_current_timestep_ptr(&current_step);
186189
assembler->set_obs_encoder(_obs_encoder.get());
190+
assemblers.push_back(assembler);
187191
continue;
188192
}
189193

@@ -201,6 +205,7 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
201205
std::to_string(c) + ")");
202206
}
203207
}
208+
204209
}
205210

206211
void MettaGrid::_make_buffers(unsigned int num_agents) {
@@ -502,9 +507,10 @@ void MettaGrid::_compute_observation(GridCoord observer_row,
502507
int obs_r = r - static_cast<int>(observer_row) + static_cast<int>(obs_height_radius);
503508
int obs_c = c - static_cast<int>(observer_col) + static_cast<int>(obs_width_radius);
504509

505-
// Encode location and add tokens
510+
// Encode location and add tokens (pass agent_idx for agent-specific observations like per-agent cooldown)
506511
uint8_t location = PackedCoordinate::pack(static_cast<uint8_t>(obs_r), static_cast<uint8_t>(obs_c));
507-
attempted_tokens_written += _obs_encoder->encode_tokens(obj, obs_tokens, location);
512+
attempted_tokens_written +=
513+
_obs_encoder->encode_tokens(obj, obs_tokens, location, static_cast<unsigned int>(agent_idx));
508514
tokens_written = std::min(attempted_tokens_written, static_cast<size_t>(observation_view.shape(1)));
509515
}
510516

@@ -598,9 +604,29 @@ void MettaGrid::_step() {
598604
}
599605
}
600606

601-
// Check and apply damage for all agents
607+
// Apply cell effects to agents (AOE effects from nearby objects)
602608
for (auto* agent : _agents) {
603-
agent->check_and_apply_damage(_rng);
609+
const CellEffect& effect = _grid->effect_at(agent->location.r, agent->location.c);
610+
for (const auto& [item, delta] : effect.resource_deltas) {
611+
agent->inventory.update(item, delta);
612+
613+
// Track AOE stats
614+
const std::string& resource_name = _stats->resource_name(item);
615+
if (delta > 0) {
616+
agent->stats.add("aoe." + resource_name + ".gained", static_cast<float>(delta));
617+
_stats->add("aoe." + resource_name + ".gained", static_cast<float>(delta));
618+
} else if (delta < 0) {
619+
agent->stats.add("aoe." + resource_name + ".lost", static_cast<float>(-delta));
620+
_stats->add("aoe." + resource_name + ".lost", static_cast<float>(-delta));
621+
}
622+
agent->stats.add("aoe." + resource_name + ".delta", static_cast<float>(delta));
623+
_stats->add("aoe." + resource_name + ".delta", static_cast<float>(delta));
624+
}
625+
}
626+
627+
// Check and apply damage for all agents (randomized order for fairness)
628+
for (const auto& agent_idx : agent_indices) {
629+
_agents[agent_idx]->check_and_apply_damage(_rng);
604630
}
605631

606632
// Apply global systems
@@ -1009,9 +1035,19 @@ PYBIND11_MODULE(mettagrid_c, m) {
10091035
.def_readwrite("cost", &DemolishConfig::cost)
10101036
.def_readwrite("scrap", &DemolishConfig::scrap);
10111037

1038+
// Bind AOEEffectConfig for AOE effects on any object
1039+
py::class_<AOEEffectConfig>(m, "AOEEffectConfig")
1040+
.def(py::init<>())
1041+
.def(py::init<unsigned int, const std::unordered_map<InventoryItem, InventoryDelta>&>(),
1042+
py::arg("range") = 1,
1043+
py::arg("resource_deltas") = std::unordered_map<InventoryItem, InventoryDelta>())
1044+
.def_readwrite("range", &AOEEffectConfig::range)
1045+
.def_readwrite("resource_deltas", &AOEEffectConfig::resource_deltas);
1046+
10121047
// Expose this so we can cast python WallConfig / AgentConfig to a common GridConfig cpp object.
10131048
py::class_<GridObjectConfig, std::shared_ptr<GridObjectConfig>>(m, "GridObjectConfig")
1014-
.def_readwrite("demolish", &GridObjectConfig::demolish);
1049+
.def_readwrite("demolish", &GridObjectConfig::demolish)
1050+
.def_readwrite("aoe", &GridObjectConfig::aoe);
10151051

10161052
bind_wall_config(m);
10171053

packages/mettagrid/cpp/include/mettagrid/actions/attack.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Attack : public ActionHandler {
152152
actor.stats.incr(_action_prefix(actor_group) + "demolish." + target_type_name);
153153
actor.stats.incr("demolish." + target_type_name);
154154

155-
// Call on_demolish before removing
155+
// Call on_demolish before removing (handles AOE cleanup)
156156
target->on_demolish();
157157

158158
// Give scrap resources to actor

packages/mettagrid/cpp/include/mettagrid/core/grid.hpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <algorithm>
55
#include <memory>
6+
#include <unordered_map>
67
#include <vector>
78

89
#include "core/grid_object.hpp"
@@ -13,6 +14,26 @@ using std::unique_ptr;
1314
using std::vector;
1415
using GridType = std::vector<std::vector<GridObject*>>;
1516

17+
// Accumulated resource effects at a single grid cell
18+
struct CellEffect {
19+
std::unordered_map<InventoryItem, InventoryDelta> resource_deltas;
20+
21+
void add(InventoryItem item, InventoryDelta delta) {
22+
resource_deltas[item] += delta;
23+
if (resource_deltas[item] == 0) {
24+
resource_deltas.erase(item);
25+
}
26+
}
27+
28+
void subtract(InventoryItem item, InventoryDelta delta) {
29+
add(item, -delta);
30+
}
31+
32+
bool empty() const {
33+
return resource_deltas.empty();
34+
}
35+
};
36+
1637
class Grid {
1738
public:
1839
const GridCoord height;
@@ -21,29 +42,20 @@ class Grid {
2142

2243
private:
2344
GridType grid;
45+
std::vector<std::vector<CellEffect>> _effects;
2446

2547
public:
2648
Grid(GridCoord height, GridCoord width)
27-
: height(height),
28-
width(width),
29-
objects(), // Initialize objects in member init list
30-
grid() { // Initialize grid in member init list
49+
: height(height), width(width), objects(), grid(), _effects(height, std::vector<CellEffect>(width)) {
3150
grid.resize(height, std::vector<GridObject*>(width, nullptr));
3251

3352
// Reserve space for objects to avoid frequent reallocations
34-
// Assume ~50% of grid cells will contain objects
3553
size_t estimated_objects = static_cast<size_t>(height) * width / 2;
36-
37-
// Cap preallocation at ~100MB of pointer memory
3854
constexpr size_t MAX_PREALLOCATED_OBJECTS = 12'500'000;
3955
size_t reserved_objects = std::min(estimated_objects, MAX_PREALLOCATED_OBJECTS);
40-
4156
objects.reserve(reserved_objects);
4257

4358
// GridObjectId "0" is reserved to mean empty space (GridObject pointer = nullptr).
44-
// By pushing nullptr at index 0, we ensure that:
45-
// 1. Grid initialization with zeros automatically represents empty spaces
46-
// 2. Object IDs match their index in the objects vector (no off-by-one adjustments)
4759
objects.push_back(nullptr);
4860
}
4961

@@ -64,6 +76,11 @@ class Grid {
6476
obj->id = static_cast<GridObjectId>(this->objects.size());
6577
this->objects.push_back(std::unique_ptr<GridObject>(obj));
6678
this->grid[obj->location.r][obj->location.c] = obj;
79+
80+
// Register AOE effects if configured
81+
obj->aoe.init(this);
82+
obj->aoe.register_effects(obj->location.r, obj->location.c);
83+
6784
return true;
6885
}
6986

@@ -104,6 +121,9 @@ class Grid {
104121
return false; // Object not at expected location
105122
}
106123

124+
// Call on_demolish for cleanup (includes AOE unregistration)
125+
obj.on_demolish();
126+
107127
// Clear the grid cell
108128
grid[obj.location.r][obj.location.c] = nullptr;
109129

@@ -129,6 +149,41 @@ class Grid {
129149
inline bool is_empty(GridCoord row, GridCoord col) const {
130150
return grid[row][col] == nullptr;
131151
}
152+
153+
// Get the effect at a cell
154+
const CellEffect& effect_at(GridCoord r, GridCoord c) const {
155+
return _effects[r][c];
156+
}
157+
158+
// Apply AOE effects from a source at (center_r, center_c) with given radius
159+
// If adding=true, adds effects; if false, removes them
160+
void apply_aoe(GridCoord center_r,
161+
GridCoord center_c,
162+
unsigned int radius,
163+
const std::unordered_map<InventoryItem, InventoryDelta>& resource_deltas,
164+
bool adding) {
165+
int r_start = std::max(0, static_cast<int>(center_r) - static_cast<int>(radius));
166+
int r_end = std::min(static_cast<int>(height), static_cast<int>(center_r) + static_cast<int>(radius) + 1);
167+
int c_start = std::max(0, static_cast<int>(center_c) - static_cast<int>(radius));
168+
int c_end = std::min(static_cast<int>(width), static_cast<int>(center_c) + static_cast<int>(radius) + 1);
169+
170+
for (int r = r_start; r < r_end; ++r) {
171+
for (int c = c_start; c < c_end; ++c) {
172+
// Manhattan distance for AOE range
173+
int dr = std::abs(r - static_cast<int>(center_r));
174+
int dc = std::abs(c - static_cast<int>(center_c));
175+
if (dr + dc <= static_cast<int>(radius)) {
176+
for (const auto& [item, delta] : resource_deltas) {
177+
if (adding) {
178+
_effects[r][c].add(item, delta);
179+
} else {
180+
_effects[r][c].subtract(item, delta);
181+
}
182+
}
183+
}
184+
}
185+
}
186+
}
132187
};
133188

134189
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CORE_GRID_HPP_

packages/mettagrid/cpp/include/mettagrid/core/grid_object.hpp

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "objects/constants.hpp"
1414
#include "objects/has_vibe.hpp"
1515

16+
// Forward declaration
17+
class Grid;
18+
1619
using TypeId = ObservationType;
1720
using ObservationCoord = ObservationType;
1821
using Vibe = ObservationType;
@@ -61,19 +64,74 @@ struct DemolishConfig {
6164
: cost(cost), scrap(scrap) {}
6265
};
6366

67+
// Configuration for Area of Effect (AOE) resource effects
68+
struct AOEEffectConfig {
69+
unsigned int range = 1; // Radius of effect (Manhattan distance)
70+
std::unordered_map<InventoryItem, InventoryDelta> resource_deltas; // Per-tick resource changes
71+
72+
AOEEffectConfig() = default;
73+
AOEEffectConfig(unsigned int range, const std::unordered_map<InventoryItem, InventoryDelta>& resource_deltas)
74+
: range(range), resource_deltas(resource_deltas) {}
75+
};
76+
6477
struct GridObjectConfig {
6578
TypeId type_id;
6679
std::string type_name;
6780
std::vector<int> tag_ids;
6881
ObservationType initial_vibe;
6982
std::optional<DemolishConfig> demolish; // If set, object can be demolished
83+
std::optional<AOEEffectConfig> aoe; // If set, object emits AOE effects
7084

7185
GridObjectConfig(TypeId type_id, const std::string& type_name, ObservationType initial_vibe = 0)
72-
: type_id(type_id), type_name(type_name), tag_ids({}), initial_vibe(initial_vibe), demolish(std::nullopt) {}
86+
: type_id(type_id),
87+
type_name(type_name),
88+
tag_ids({}),
89+
initial_vibe(initial_vibe),
90+
demolish(std::nullopt),
91+
aoe(std::nullopt) {}
7392

7493
virtual ~GridObjectConfig() = default;
7594
};
7695

96+
// Helper class for managing AOE effects on grid objects
97+
class AOEHelper {
98+
public:
99+
AOEHelper() = default;
100+
101+
// Initialize with grid reference
102+
void init(Grid* grid) {
103+
_grid = grid;
104+
}
105+
106+
// Set the AOE config (call from object constructor)
107+
void set_config(const AOEEffectConfig* config) {
108+
_config = config;
109+
}
110+
111+
// Check if this helper has AOE configured
112+
bool has_aoe() const {
113+
return _config != nullptr && _grid != nullptr;
114+
}
115+
116+
// Register AOE effects at the given location
117+
void register_effects(GridCoord r, GridCoord c);
118+
119+
// Unregister AOE effects (call on demolish or removal)
120+
void unregister_effects();
121+
122+
// Get the config
123+
const AOEEffectConfig* config() const {
124+
return _config;
125+
}
126+
127+
private:
128+
const AOEEffectConfig* _config = nullptr;
129+
Grid* _grid = nullptr;
130+
bool _registered = false;
131+
GridCoord _location_r = 0;
132+
GridCoord _location_c = 0;
133+
};
134+
77135
class GridObject : public HasVibe {
78136
public:
79137
GridObjectId id{};
@@ -82,6 +140,7 @@ class GridObject : public HasVibe {
82140
std::string type_name;
83141
std::vector<int> tag_ids;
84142
const DemolishConfig* demolish_config = nullptr; // Optional demolish config for buildings
143+
AOEHelper aoe; // AOE effect helper
85144

86145
virtual ~GridObject() = default;
87146

@@ -90,24 +149,34 @@ class GridObject : public HasVibe {
90149
const GridLocation& object_location,
91150
const std::vector<int>& tags,
92151
ObservationType object_vibe = 0,
93-
const DemolishConfig* demolish = nullptr) {
152+
const DemolishConfig* demolish = nullptr,
153+
const std::optional<AOEEffectConfig>& aoe_config = std::nullopt) {
94154
this->type_id = object_type_id;
95155
this->type_name = object_type_name;
96156
this->location = object_location;
97157
this->tag_ids = tags;
98158
this->vibe = object_vibe;
99159
this->demolish_config = demolish;
160+
if (aoe_config.has_value()) {
161+
_aoe_config = aoe_config.value();
162+
this->aoe.set_config(&_aoe_config.value());
163+
}
100164
}
101165

102166
// Called when this object is demolished. Override for cleanup.
103-
virtual void on_demolish() {}
167+
virtual void on_demolish() {
168+
aoe.unregister_effects();
169+
}
104170

105171
// observer_agent_id: The agent observing this object (UINT_MAX means no specific observer)
106172
// Used by Assembler to report agent-specific cooldowns
107173
virtual std::vector<PartialObservationToken> obs_features(unsigned int observer_agent_id = UINT_MAX) const {
108174
(void)observer_agent_id; // Unused in base class
109175
return {}; // Default: no observable features
110176
}
177+
178+
private:
179+
std::optional<AOEEffectConfig> _aoe_config;
111180
};
112181

113182
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CORE_GRID_OBJECT_HPP_

packages/mettagrid/cpp/include/mettagrid/objects/assembler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class Assembler : public GridObject, public Usable {
276276
agent_cooldown(cfg.agent_cooldown),
277277
clipper_ptr(nullptr) {
278278
const DemolishConfig* demolish = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
279-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe, demolish);
279+
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe, demolish, cfg.aoe);
280280
}
281281
virtual ~Assembler() = default;
282282

0 commit comments

Comments
 (0)