Skip to content

Commit f262323

Browse files
committed
cp
1 parent cb175f8 commit f262323

File tree

8 files changed

+236
-36
lines changed

8 files changed

+236
-36
lines changed

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ MettaGrid::MettaGrid(const GameConfig& game_config, const py::list map, unsigned
8787

8888
init_action_handlers();
8989

90-
_init_grid(game_config, map);
90+
_init_grid(_game_config, map);
9191

9292
// Set runtime context for Build handler (needs obs_encoder and agents count)
9393
if (_build_handler) {
@@ -143,6 +143,9 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
143143
object_type_names[type_id] = object_cfg->type_name;
144144
}
145145

146+
// Collect assemblers to initialize agent tracking after all agents are created
147+
std::vector<Assembler*> assemblers;
148+
146149
// Initialize objects from map
147150
for (GridCoord r = 0; r < height; r++) {
148151
for (GridCoord c = 0; c < width; c++) {
@@ -191,6 +194,7 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
191194
assembler->set_grid(_grid.get());
192195
assembler->set_current_timestep_ptr(&current_step);
193196
assembler->set_obs_encoder(_obs_encoder.get());
197+
assemblers.push_back(assembler);
194198
continue;
195199
}
196200

@@ -208,6 +212,7 @@ void MettaGrid::_init_grid(const GameConfig& game_config, const py::list& map) {
208212
std::to_string(c) + ")");
209213
}
210214
}
215+
211216
}
212217

213218
void MettaGrid::_make_buffers(unsigned int num_agents) {
@@ -527,9 +532,10 @@ void MettaGrid::_compute_observation(GridCoord observer_row,
527532
int obs_r = r - static_cast<int>(observer_row) + static_cast<int>(obs_height_radius);
528533
int obs_c = c - static_cast<int>(observer_col) + static_cast<int>(obs_width_radius);
529534

530-
// Encode location and add tokens
535+
// Encode location and add tokens (pass agent_idx for agent-specific observations like per-agent cooldown)
531536
uint8_t location = PackedCoordinate::pack(static_cast<uint8_t>(obs_r), static_cast<uint8_t>(obs_c));
532-
attempted_tokens_written += _obs_encoder->encode_tokens(obj, obs_tokens, location);
537+
attempted_tokens_written +=
538+
_obs_encoder->encode_tokens(obj, obs_tokens, location, static_cast<unsigned int>(agent_idx));
533539
tokens_written = std::min(attempted_tokens_written, static_cast<size_t>(observation_view.shape(1)));
534540
}
535541

@@ -623,9 +629,29 @@ void MettaGrid::_step() {
623629
}
624630
}
625631

626-
// Check and apply damage for all agents
632+
// Apply cell effects to agents (AOE effects from nearby objects)
627633
for (auto* agent : _agents) {
628-
agent->check_and_apply_damage(_rng);
634+
const CellEffect& effect = _grid->effect_at(agent->location.r, agent->location.c);
635+
for (const auto& [item, delta] : effect.resource_deltas) {
636+
agent->inventory.update(item, delta);
637+
638+
// Track AOE stats
639+
const std::string& resource_name = _stats->resource_name(item);
640+
if (delta > 0) {
641+
agent->stats.add("aoe." + resource_name + ".gained", static_cast<float>(delta));
642+
_stats->add("aoe." + resource_name + ".gained", static_cast<float>(delta));
643+
} else if (delta < 0) {
644+
agent->stats.add("aoe." + resource_name + ".lost", static_cast<float>(-delta));
645+
_stats->add("aoe." + resource_name + ".lost", static_cast<float>(-delta));
646+
}
647+
agent->stats.add("aoe." + resource_name + ".delta", static_cast<float>(delta));
648+
_stats->add("aoe." + resource_name + ".delta", static_cast<float>(delta));
649+
}
650+
}
651+
652+
// Check and apply damage for all agents (randomized order for fairness)
653+
for (const auto& agent_idx : agent_indices) {
654+
_agents[agent_idx]->check_and_apply_damage(_rng);
629655
}
630656

631657
// Apply global systems
@@ -1034,9 +1060,19 @@ PYBIND11_MODULE(mettagrid_c, m) {
10341060
.def_readwrite("cost", &DemolishConfig::cost)
10351061
.def_readwrite("scrap", &DemolishConfig::scrap);
10361062

1063+
// Bind AOEEffectConfig for AOE effects on any object
1064+
py::class_<AOEEffectConfig>(m, "AOEEffectConfig")
1065+
.def(py::init<>())
1066+
.def(py::init<unsigned int, const std::unordered_map<InventoryItem, InventoryDelta>&>(),
1067+
py::arg("range") = 1,
1068+
py::arg("resource_deltas") = std::unordered_map<InventoryItem, InventoryDelta>())
1069+
.def_readwrite("range", &AOEEffectConfig::range)
1070+
.def_readwrite("resource_deltas", &AOEEffectConfig::resource_deltas);
1071+
10371072
// Expose this so we can cast python WallConfig / AgentConfig to a common GridConfig cpp object.
10381073
py::class_<GridObjectConfig, std::shared_ptr<GridObjectConfig>>(m, "GridObjectConfig")
1039-
.def_readwrite("demolish", &GridObjectConfig::demolish);
1074+
.def_readwrite("demolish", &GridObjectConfig::demolish)
1075+
.def_readwrite("aoe", &GridObjectConfig::aoe);
10401076

10411077
bind_wall_config(m);
10421078

packages/mettagrid/cpp/include/mettagrid/actions/attack.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class Attack : public ActionHandler {
155155
actor.stats.incr(_action_prefix(actor_group) + "demolish." + target_type_name);
156156
actor.stats.incr("demolish." + target_type_name);
157157

158-
// Call on_demolish before removing
158+
// Call on_demolish before removing (handles AOE cleanup)
159159
target->on_demolish();
160160

161161
// Give scrap resources to actor

packages/mettagrid/cpp/include/mettagrid/core/grid.hpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <algorithm>
55
#include <memory>
6+
#include <unordered_map>
67
#include <vector>
78

89
#include "core/grid_object.hpp"
@@ -13,6 +14,26 @@ using std::unique_ptr;
1314
using std::vector;
1415
using GridType = std::vector<std::vector<GridObject*>>;
1516

17+
// Accumulated resource effects at a single grid cell
18+
struct CellEffect {
19+
std::unordered_map<InventoryItem, InventoryDelta> resource_deltas;
20+
21+
void add(InventoryItem item, InventoryDelta delta) {
22+
resource_deltas[item] += delta;
23+
if (resource_deltas[item] == 0) {
24+
resource_deltas.erase(item);
25+
}
26+
}
27+
28+
void subtract(InventoryItem item, InventoryDelta delta) {
29+
add(item, -delta);
30+
}
31+
32+
bool empty() const {
33+
return resource_deltas.empty();
34+
}
35+
};
36+
1637
class Grid {
1738
public:
1839
const GridCoord height;
@@ -21,29 +42,20 @@ class Grid {
2142

2243
private:
2344
GridType grid;
45+
std::vector<std::vector<CellEffect>> _effects;
2446

2547
public:
2648
Grid(GridCoord height, GridCoord width)
27-
: height(height),
28-
width(width),
29-
objects(), // Initialize objects in member init list
30-
grid() { // Initialize grid in member init list
49+
: height(height), width(width), objects(), grid(), _effects(height, std::vector<CellEffect>(width)) {
3150
grid.resize(height, std::vector<GridObject*>(width, nullptr));
3251

3352
// Reserve space for objects to avoid frequent reallocations
34-
// Assume ~50% of grid cells will contain objects
3553
size_t estimated_objects = static_cast<size_t>(height) * width / 2;
36-
37-
// Cap preallocation at ~100MB of pointer memory
3854
constexpr size_t MAX_PREALLOCATED_OBJECTS = 12'500'000;
3955
size_t reserved_objects = std::min(estimated_objects, MAX_PREALLOCATED_OBJECTS);
40-
4156
objects.reserve(reserved_objects);
4257

4358
// GridObjectId "0" is reserved to mean empty space (GridObject pointer = nullptr).
44-
// By pushing nullptr at index 0, we ensure that:
45-
// 1. Grid initialization with zeros automatically represents empty spaces
46-
// 2. Object IDs match their index in the objects vector (no off-by-one adjustments)
4759
objects.push_back(nullptr);
4860
}
4961

@@ -64,6 +76,11 @@ class Grid {
6476
obj->id = static_cast<GridObjectId>(this->objects.size());
6577
this->objects.push_back(std::unique_ptr<GridObject>(obj));
6678
this->grid[obj->location.r][obj->location.c] = obj;
79+
80+
// Register AOE effects if configured
81+
obj->aoe.init(this);
82+
obj->aoe.register_effects(obj->location.r, obj->location.c);
83+
6784
return true;
6885
}
6986

@@ -109,6 +126,9 @@ class Grid {
109126
return false;
110127
}
111128

129+
// Call on_demolish for cleanup (includes AOE unregistration)
130+
obj.on_demolish();
131+
112132
// Clear the grid cell
113133
grid[obj.location.r][obj.location.c] = nullptr;
114134

@@ -132,6 +152,41 @@ class Grid {
132152
inline bool is_empty(GridCoord row, GridCoord col) const {
133153
return grid[row][col] == nullptr;
134154
}
155+
156+
// Get the effect at a cell
157+
const CellEffect& effect_at(GridCoord r, GridCoord c) const {
158+
return _effects[r][c];
159+
}
160+
161+
// Apply AOE effects from a source at (center_r, center_c) with given radius
162+
// If adding=true, adds effects; if false, removes them
163+
void apply_aoe(GridCoord center_r,
164+
GridCoord center_c,
165+
unsigned int radius,
166+
const std::unordered_map<InventoryItem, InventoryDelta>& resource_deltas,
167+
bool adding) {
168+
int r_start = std::max(0, static_cast<int>(center_r) - static_cast<int>(radius));
169+
int r_end = std::min(static_cast<int>(height), static_cast<int>(center_r) + static_cast<int>(radius) + 1);
170+
int c_start = std::max(0, static_cast<int>(center_c) - static_cast<int>(radius));
171+
int c_end = std::min(static_cast<int>(width), static_cast<int>(center_c) + static_cast<int>(radius) + 1);
172+
173+
for (int r = r_start; r < r_end; ++r) {
174+
for (int c = c_start; c < c_end; ++c) {
175+
// Manhattan distance for AOE range
176+
int dr = std::abs(r - static_cast<int>(center_r));
177+
int dc = std::abs(c - static_cast<int>(center_c));
178+
if (dr + dc <= static_cast<int>(radius)) {
179+
for (const auto& [item, delta] : resource_deltas) {
180+
if (adding) {
181+
_effects[r][c].add(item, delta);
182+
} else {
183+
_effects[r][c].subtract(item, delta);
184+
}
185+
}
186+
}
187+
}
188+
}
189+
}
135190
};
136191

137192
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CORE_GRID_HPP_

packages/mettagrid/cpp/include/mettagrid/core/grid_object.hpp

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "objects/constants.hpp"
1414
#include "objects/has_vibe.hpp"
1515

16+
// Forward declaration
17+
class Grid;
18+
1619
using TypeId = ObservationType;
1720
using ObservationCoord = ObservationType;
1821
using Vibe = ObservationType;
@@ -61,19 +64,74 @@ struct DemolishConfig {
6164
: cost(cost), scrap(scrap) {}
6265
};
6366

67+
// Configuration for Area of Effect (AOE) resource effects
68+
struct AOEEffectConfig {
69+
unsigned int range = 1; // Radius of effect (Manhattan distance)
70+
std::unordered_map<InventoryItem, InventoryDelta> resource_deltas; // Per-tick resource changes
71+
72+
AOEEffectConfig() = default;
73+
AOEEffectConfig(unsigned int range, const std::unordered_map<InventoryItem, InventoryDelta>& resource_deltas)
74+
: range(range), resource_deltas(resource_deltas) {}
75+
};
76+
6477
struct GridObjectConfig {
6578
TypeId type_id;
6679
std::string type_name;
6780
std::vector<int> tag_ids;
6881
ObservationType initial_vibe;
6982
std::optional<DemolishConfig> demolish; // If set, object can be demolished
83+
std::optional<AOEEffectConfig> aoe; // If set, object emits AOE effects
7084

7185
GridObjectConfig(TypeId type_id, const std::string& type_name, ObservationType initial_vibe = 0)
72-
: type_id(type_id), type_name(type_name), tag_ids({}), initial_vibe(initial_vibe), demolish(std::nullopt) {}
86+
: type_id(type_id),
87+
type_name(type_name),
88+
tag_ids({}),
89+
initial_vibe(initial_vibe),
90+
demolish(std::nullopt),
91+
aoe(std::nullopt) {}
7392

7493
virtual ~GridObjectConfig() = default;
7594
};
7695

96+
// Helper class for managing AOE effects on grid objects
97+
class AOEHelper {
98+
public:
99+
AOEHelper() = default;
100+
101+
// Initialize with grid reference
102+
void init(Grid* grid) {
103+
_grid = grid;
104+
}
105+
106+
// Set the AOE config (call from object constructor)
107+
void set_config(const AOEEffectConfig* config) {
108+
_config = config;
109+
}
110+
111+
// Check if this helper has AOE configured
112+
bool has_aoe() const {
113+
return _config != nullptr && _grid != nullptr;
114+
}
115+
116+
// Register AOE effects at the given location
117+
void register_effects(GridCoord r, GridCoord c);
118+
119+
// Unregister AOE effects (call on demolish or removal)
120+
void unregister_effects();
121+
122+
// Get the config
123+
const AOEEffectConfig* config() const {
124+
return _config;
125+
}
126+
127+
private:
128+
const AOEEffectConfig* _config = nullptr;
129+
Grid* _grid = nullptr;
130+
bool _registered = false;
131+
GridCoord _location_r = 0;
132+
GridCoord _location_c = 0;
133+
};
134+
77135
class GridObject : public HasVibe {
78136
public:
79137
GridObjectId id{};
@@ -82,6 +140,7 @@ class GridObject : public HasVibe {
82140
std::string type_name;
83141
std::vector<int> tag_ids;
84142
const DemolishConfig* demolish_config = nullptr; // Optional demolish config for buildings
143+
AOEHelper aoe; // AOE effect helper
85144

86145
virtual ~GridObject() = default;
87146

@@ -92,17 +151,26 @@ class GridObject : public HasVibe {
92151
this->tag_ids = cfg.tag_ids;
93152
this->vibe = cfg.initial_vibe;
94153
this->demolish_config = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
154+
if (cfg.aoe.has_value()) {
155+
_aoe_config = cfg.aoe.value();
156+
this->aoe.set_config(&_aoe_config.value());
157+
}
95158
}
96159

97160
// Called when this object is demolished. Override for cleanup.
98-
virtual void on_demolish() {}
161+
virtual void on_demolish() {
162+
aoe.unregister_effects();
163+
}
99164

100165
// observer_agent_id: The agent observing this object (UINT_MAX means no specific observer)
101166
// Used by Assembler to report agent-specific cooldowns
102167
virtual std::vector<PartialObservationToken> obs_features(unsigned int observer_agent_id = UINT_MAX) const {
103168
(void)observer_agent_id; // Unused in base class
104169
return {}; // Default: no observable features
105170
}
171+
172+
private:
173+
std::optional<AOEEffectConfig> _aoe_config;
106174
};
107175

108176
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CORE_GRID_OBJECT_HPP_

packages/mettagrid/cpp/include/mettagrid/objects/chest.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "objects/usable.hpp"
1818
#include "systems/observation_encoder.hpp"
1919
#include "systems/stats_tracker.hpp"
20+
2021
class Chest : public GridObject, public Usable, public HasInventory {
2122
private:
2223
// a reference to the game stats tracker
@@ -99,6 +100,7 @@ class Chest : public GridObject, public Usable, public HasInventory {
99100
this->obs_encoder = encoder;
100101
}
101102

103+
private:
102104
// Implement pure virtual method from Usable
103105
virtual bool onUse(Agent& actor, ActionArg /*arg*/) override {
104106
if (!grid) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "core/grid_object.hpp"
2+
3+
#include "core/grid.hpp"
4+
5+
void AOEHelper::register_effects(GridCoord r, GridCoord c) {
6+
if (has_aoe()) {
7+
_grid->apply_aoe(r, c, _config->range, _config->resource_deltas, true);
8+
_registered = true;
9+
_location_r = r;
10+
_location_c = c;
11+
}
12+
}
13+
14+
void AOEHelper::unregister_effects() {
15+
if (_registered && has_aoe()) {
16+
_grid->apply_aoe(_location_r, _location_c, _config->range, _config->resource_deltas, false);
17+
_registered = false;
18+
}
19+
}

0 commit comments

Comments
 (0)