Skip to content

Commit 9337a11

Browse files
committed
Fix Grid::remove_object to validate object ID before clearing grid cell
Prevents inconsistent state where grid cell is cleared but object may still exist in memory if obj.id >= objects.size(). Now returns false instead of true when the ID is invalid.
1 parent ff5c7e0 commit 9337a11

File tree

10 files changed

+19
-74
lines changed

10 files changed

+19
-74
lines changed

packages/mettagrid/cpp/include/mettagrid/core/grid.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,16 @@ class Grid {
104104
return false; // Object not at expected location
105105
}
106106

107+
// Validate object ID first
108+
if (obj.id >= objects.size()) {
109+
return false;
110+
}
111+
107112
// Clear the grid cell
108113
grid[obj.location.r][obj.location.c] = nullptr;
109114

110115
// Release the object (unique_ptr becomes null but slot remains)
111-
if (obj.id < objects.size()) {
112-
objects[obj.id].reset();
113-
}
116+
objects[obj.id].reset();
114117
return true;
115118
}
116119

packages/mettagrid/cpp/include/mettagrid/core/grid_object.hpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,13 @@ class GridObject : public HasVibe {
8484

8585
virtual ~GridObject() = default;
8686

87-
void init(TypeId object_type_id,
88-
const std::string& object_type_name,
89-
const GridLocation& object_location,
90-
const std::vector<int>& tags,
91-
ObservationType object_vibe = 0,
92-
const DemolishConfig* demolish = nullptr) {
93-
this->type_id = object_type_id;
94-
this->type_name = object_type_name;
87+
void init(const GridObjectConfig& cfg, const GridLocation& object_location) {
88+
this->type_id = cfg.type_id;
89+
this->type_name = cfg.type_name;
9590
this->location = object_location;
96-
this->tag_ids = tags;
97-
this->vibe = object_vibe;
98-
this->demolish_config = demolish;
91+
this->tag_ids = cfg.tag_ids;
92+
this->vibe = cfg.initial_vibe;
93+
this->demolish_config = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
9994
}
10095

10196
// Called when this object is demolished. Override for cleanup.

packages/mettagrid/cpp/include/mettagrid/objects/assembler.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,7 @@ class Assembler : public GridObject, public Usable {
266266
allow_partial_usage(cfg.allow_partial_usage),
267267
chest_search_distance(cfg.chest_search_distance),
268268
clipper_ptr(nullptr) {
269-
const DemolishConfig* demolish = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
270-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe, demolish);
269+
GridObject::init(cfg, GridLocation(r, c));
271270
}
272271
virtual ~Assembler() = default;
273272

packages/mettagrid/cpp/include/mettagrid/objects/chest.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ class Chest : public GridObject, public Usable, public HasInventory {
7878
vibe_transfers(cfg.vibe_transfers),
7979
stats_tracker(stats_tracker),
8080
grid(nullptr) {
81-
const DemolishConfig* demolish = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
82-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe, demolish);
81+
GridObject::init(cfg, GridLocation(r, c));
8382
// Set initial inventory for all configured resources (ignore limits for initial setup)
8483
for (const auto& [resource, amount] : cfg.initial_inventory) {
8584
if (amount > 0) {

packages/mettagrid/cpp/include/mettagrid/objects/protocol.hpp

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#include <pybind11/pybind11.h>
55
#include <pybind11/stl.h>
66

7-
#include <algorithm>
8-
#include <cmath>
97
#include <memory>
108
#include <unordered_map>
119

@@ -18,40 +16,13 @@ class Protocol {
1816
std::unordered_map<InventoryItem, InventoryQuantity> input_resources;
1917
std::unordered_map<InventoryItem, InventoryQuantity> output_resources;
2018
unsigned short cooldown;
21-
float slope; // Linear component: adds slope * activation_count to the multiplier
22-
float exponent; // Exponential component: multiplies by (1 + exponent)^activation_count
23-
24-
// Track activation count for inflation calculation
25-
mutable unsigned int activation_count;
2619

2720
Protocol(unsigned short min_agents = 0,
2821
const std::vector<ObservationType>& vibes = {},
2922
const std::unordered_map<InventoryItem, InventoryQuantity>& inputs = {},
3023
const std::unordered_map<InventoryItem, InventoryQuantity>& outputs = {},
31-
unsigned short cooldown = 0,
32-
float slope = 0.0f,
33-
float exponent = 0.0f)
34-
: min_agents(min_agents),
35-
vibes(vibes),
36-
input_resources(inputs),
37-
output_resources(outputs),
38-
cooldown(cooldown),
39-
slope(slope),
40-
exponent(exponent),
41-
activation_count(0) {}
42-
43-
// Calculate cost multiplier based on activation count
44-
// Formula: max(0, 1 + slope * n) * (1 + exponent)^n
45-
// - slope < 0: discounting (starts at 1, decreases linearly, floors at 0)
46-
// - slope > 0: linear cost increase
47-
// - exponent > 0: exponential cost increase
48-
// - exponent < 0: exponential cost decrease
49-
float get_cost_multiplier() const {
50-
float n = static_cast<float>(activation_count);
51-
float linear_component = std::max(0.0f, 1.0f + slope * n);
52-
float exponential_component = std::pow(1.0f + exponent, n);
53-
return linear_component * exponential_component;
54-
}
24+
unsigned short cooldown = 0)
25+
: min_agents(min_agents), vibes(vibes), input_resources(inputs), output_resources(outputs), cooldown(cooldown) {}
5526
};
5627

5728
inline void bind_protocol(py::module& m) {
@@ -61,10 +32,7 @@ inline void bind_protocol(py::module& m) {
6132
.def_readwrite("vibes", &Protocol::vibes)
6233
.def_readwrite("input_resources", &Protocol::input_resources)
6334
.def_readwrite("output_resources", &Protocol::output_resources)
64-
.def_readwrite("cooldown", &Protocol::cooldown)
65-
.def_readwrite("slope", &Protocol::slope)
66-
.def_readwrite("exponent", &Protocol::exponent)
67-
.def_readwrite("activation_count", &Protocol::activation_count);
35+
.def_readwrite("cooldown", &Protocol::cooldown);
6836
}
6937

7038
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_OBJECTS_PROTOCOL_HPP_

packages/mettagrid/cpp/include/mettagrid/objects/wall.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ struct WallConfig : public GridObjectConfig {
2020
class Wall : public GridObject {
2121
public:
2222
Wall(GridCoord r, GridCoord c, const WallConfig& cfg) {
23-
const DemolishConfig* demolish = cfg.demolish.has_value() ? &cfg.demolish.value() : nullptr;
24-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe, demolish);
23+
GridObject::init(cfg, GridLocation(r, c));
2524
}
2625

2726
std::vector<PartialObservationToken> obs_features() const override {

packages/mettagrid/cpp/src/mettagrid/objects/agent.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Agent::Agent(GridCoord r,
4141
}
4242
}
4343
populate_initial_inventory(config.initial_inventory);
44-
GridObject::init(config.type_id, config.type_name, GridLocation(r, c), config.tag_ids, config.initial_vibe);
44+
GridObject::init(config, GridLocation(r, c));
4545
}
4646

4747
void Agent::init(RewardType* reward_ptr) {
-2.35 MB
Binary file not shown.

packages/mettagrid/python/src/mettagrid/config/mettagrid_c_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ def convert_to_cpp_game_config(game_config: GameConfig):
289289
cpp_protocol.input_resources = input_res
290290
cpp_protocol.output_resources = output_res
291291
cpp_protocol.cooldown = protocol_config.cooldown
292-
cpp_protocol.slope = protocol_config.slope
293-
cpp_protocol.exponent = protocol_config.exponent
294292
protocols.append(cpp_protocol)
295293

296294
# Convert tag names to IDs
@@ -571,8 +569,6 @@ def process_action_config(action_name: str, action_config):
571569
output_res[key] = val
572570
cpp_protocol.output_resources = output_res
573571
cpp_protocol.cooldown = protocol_config.cooldown
574-
cpp_protocol.slope = protocol_config.slope
575-
cpp_protocol.exponent = protocol_config.exponent
576572
clipper_protocols.append(cpp_protocol)
577573
clipper_config = CppClipperConfig()
578574
clipper_config.unclipping_protocols = clipper_protocols

packages/mettagrid/python/src/mettagrid/config/mettagrid_config.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,20 +422,6 @@ class ProtocolConfig(Config):
422422
input_resources: dict[str, int] = Field(default_factory=dict)
423423
output_resources: dict[str, int] = Field(default_factory=dict)
424424
cooldown: int = Field(ge=0, default=0)
425-
slope: float = Field(
426-
default=0.0,
427-
description=(
428-
"Linear component of cost multiplier. "
429-
"Negative = discounting (starts at 1x, decreases linearly). "
430-
"Positive = linear cost increase."
431-
),
432-
)
433-
exponent: float = Field(
434-
default=0.0,
435-
description=(
436-
"Exponential component of cost multiplier. Cost multiplier = max(0, 1 + slope * n) * (1 + exponent)^n."
437-
),
438-
)
439425

440426

441427
class AssemblerConfig(GridObjectConfig):

0 commit comments

Comments
 (0)