Skip to content

Commit f7d72ce

Browse files
committed
Fix dangling references in action handlers
- Remove unused _config reference from Move class - Use _game_config member instead of parameter reference for action handler initialization - Fixes potential UB from dangling references/pointers to config objects
1 parent 6edbcf9 commit f7d72ce

File tree

10 files changed

+70
-29
lines changed

10 files changed

+70
-29
lines changed

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "actions/move.hpp"
1919
#include "actions/move_config.hpp"
2020
#include "actions/noop.hpp"
21+
#include "actions/transfer.hpp"
2122
#include "config/observation_features.hpp"
2223
#include "core/grid.hpp"
2324
#include "core/types.hpp"
@@ -82,7 +83,7 @@ MettaGrid::MettaGrid(const GameConfig& game_config, const py::list map, unsigned
8283

8384
_action_success.resize(num_agents);
8485

85-
init_action_handlers(game_config);
86+
init_action_handlers();
8687

8788
_init_grid(game_config, map);
8889

@@ -238,11 +239,11 @@ void MettaGrid::_init_buffers(unsigned int num_agents) {
238239
_compute_observations(executed_actions);
239240
}
240241

241-
void MettaGrid::init_action_handlers(const GameConfig& game_config) {
242+
void MettaGrid::init_action_handlers() {
242243
_max_action_priority = 0;
243244

244245
// Noop
245-
auto noop = std::make_unique<Noop>(*game_config.actions.at("noop"));
246+
auto noop = std::make_unique<Noop>(*_game_config.actions.at("noop"));
246247
noop->init(_grid.get(), &_rng);
247248
if (noop->priority > _max_action_priority) _max_action_priority = noop->priority;
248249
for (const auto& action : noop->actions()) {
@@ -251,8 +252,8 @@ void MettaGrid::init_action_handlers(const GameConfig& game_config) {
251252
_action_handler_impl.push_back(std::move(noop));
252253

253254
// Move
254-
auto move_config = std::static_pointer_cast<const MoveActionConfig>(game_config.actions.at("move"));
255-
auto move = std::make_unique<Move>(*move_config, &game_config);
255+
auto move_config = std::static_pointer_cast<const MoveActionConfig>(_game_config.actions.at("move"));
256+
auto move = std::make_unique<Move>(*move_config, &_game_config);
256257
move->init(_grid.get(), &_rng);
257258
if (move->priority > _max_action_priority) _max_action_priority = move->priority;
258259
for (const auto& action : move->actions()) {
@@ -263,25 +264,36 @@ void MettaGrid::init_action_handlers(const GameConfig& game_config) {
263264
_action_handler_impl.push_back(std::move(move));
264265

265266
// Attack
266-
auto attack_config = std::static_pointer_cast<const AttackActionConfig>(game_config.actions.at("attack"));
267-
auto attack = std::make_unique<Attack>(*attack_config, &game_config);
267+
auto attack_config = std::static_pointer_cast<const AttackActionConfig>(_game_config.actions.at("attack"));
268+
auto attack = std::make_unique<Attack>(*attack_config, &_game_config);
268269
attack->init(_grid.get(), &_rng);
269270
if (attack->priority > _max_action_priority) _max_action_priority = attack->priority;
270271
for (const auto& action : attack->actions()) {
271272
_action_handlers.push_back(action);
272273
}
273274

274-
// Register Attack handler with Move handler for vibe-triggered attacks
275+
// Transfer
276+
auto transfer_config = std::static_pointer_cast<const TransferActionConfig>(_game_config.actions.at("transfer"));
277+
auto transfer = std::make_unique<Transfer>(*transfer_config, &_game_config);
278+
transfer->init(_grid.get(), &_rng);
279+
if (transfer->priority > _max_action_priority) _max_action_priority = transfer->priority;
280+
for (const auto& action : transfer->actions()) {
281+
_action_handlers.push_back(action);
282+
}
283+
284+
// Register vibe-triggered action handlers with Move
275285
std::unordered_map<std::string, ActionHandler*> handlers;
276286
handlers["attack"] = attack.get();
287+
handlers["transfer"] = transfer.get();
277288
move_ptr->set_action_handlers(handlers);
278289

279290
_action_handler_impl.push_back(std::move(attack));
291+
_action_handler_impl.push_back(std::move(transfer));
280292

281293
// ChangeVibe
282294
auto change_vibe_config =
283-
std::static_pointer_cast<const ChangeVibeActionConfig>(game_config.actions.at("change_vibe"));
284-
auto change_vibe = std::make_unique<ChangeVibe>(*change_vibe_config, &game_config);
295+
std::static_pointer_cast<const ChangeVibeActionConfig>(_game_config.actions.at("change_vibe"));
296+
auto change_vibe = std::make_unique<ChangeVibe>(*change_vibe_config, &_game_config);
285297
change_vibe->init(_grid.get(), &_rng);
286298
if (change_vibe->priority > _max_action_priority) _max_action_priority = change_vibe->priority;
287299
for (const auto& action : change_vibe->actions()) {
@@ -979,6 +991,8 @@ PYBIND11_MODULE(mettagrid_c, m) {
979991
bind_attack_action_config(m);
980992
bind_change_vibe_action_config(m);
981993
bind_move_action_config(m);
994+
bind_vibe_transfer_effect(m);
995+
bind_transfer_action_config(m);
982996
bind_global_obs_config(m);
983997
bind_clipper_config(m);
984998
bind_game_config(m);

packages/mettagrid/cpp/bindings/mettagrid_c.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class METTAGRID_API MettaGrid {
147147
// Global systems
148148
std::unique_ptr<Clipper> _clipper;
149149

150-
void init_action_handlers(const GameConfig& game_config);
150+
void init_action_handlers();
151151
void add_agent(Agent* agent);
152152
void _init_grid(const GameConfig& game_config, const py::list& map);
153153
void _make_buffers(unsigned int num_agents);

packages/mettagrid/cpp/include/mettagrid/actions/move.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "actions/attack.hpp"
1111
#include "actions/move_config.hpp"
1212
#include "actions/orientation.hpp"
13+
#include "actions/transfer.hpp"
1314
#include "core/grid_object.hpp"
1415
#include "core/types.hpp"
1516
#include "objects/agent.hpp"
@@ -58,6 +59,13 @@ class Move : public ActionHandler {
5859
_vibe_handlers[vibe] = handler;
5960
}
6061
}
62+
} else if (name == "transfer") {
63+
Transfer* transfer = dynamic_cast<Transfer*>(handler);
64+
if (transfer) {
65+
for (ObservationType vibe : transfer->get_vibes()) {
66+
_vibe_handlers[vibe] = handler;
67+
}
68+
}
6169
}
6270
}
6371
}
@@ -100,6 +108,10 @@ class Move : public ActionHandler {
100108
if (attack_handler && attack_handler->try_attack(actor, target_object)) {
101109
return true;
102110
}
111+
Transfer* transfer_handler = dynamic_cast<Transfer*>(handler);
112+
if (transfer_handler && transfer_handler->try_transfer(actor, target_object)) {
113+
return true;
114+
}
103115
}
104116

105117
// If location is empty, move

packages/mettagrid/cpp/include/mettagrid/objects/agent_config.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ struct AgentConfig : public GridObjectConfig {
4040
const std::unordered_map<InventoryItem, InventoryQuantity>& initial_inventory = {},
4141
const std::unordered_map<InventoryItem, InventoryQuantity>& inventory_regen_amounts = {},
4242
const std::vector<InventoryItem>& diversity_tracked_resources = {},
43+
ObservationType initial_vibe = 0,
4344
const DamageConfig& damage_config = DamageConfig())
44-
: GridObjectConfig(type_id, type_name, 0),
45+
: GridObjectConfig(type_id, type_name, initial_vibe),
4546
group_id(group_id),
4647
group_name(group_name),
4748
freeze_duration(freeze_duration),
@@ -90,6 +91,7 @@ inline void bind_agent_config(py::module& m) {
9091
const std::unordered_map<InventoryItem, InventoryQuantity>&,
9192
const std::unordered_map<InventoryItem, InventoryQuantity>&,
9293
const std::vector<InventoryItem>&,
94+
ObservationType,
9395
const DamageConfig&>(),
9496
py::arg("type_id"),
9597
py::arg("type_name") = "agent",
@@ -102,6 +104,7 @@ inline void bind_agent_config(py::module& m) {
102104
py::arg("initial_inventory") = std::unordered_map<InventoryItem, InventoryQuantity>(),
103105
py::arg("inventory_regen_amounts") = std::unordered_map<InventoryItem, InventoryQuantity>(),
104106
py::arg("diversity_tracked_resources") = std::vector<InventoryItem>(),
107+
py::arg("initial_vibe") = 0,
105108
py::arg("damage_config") = DamageConfig())
106109
.def_readwrite("type_id", &AgentConfig::type_id)
107110
.def_readwrite("type_name", &AgentConfig::type_name)
@@ -115,6 +118,7 @@ inline void bind_agent_config(py::module& m) {
115118
.def_readwrite("initial_inventory", &AgentConfig::initial_inventory)
116119
.def_readwrite("inventory_regen_amounts", &AgentConfig::inventory_regen_amounts)
117120
.def_readwrite("diversity_tracked_resources", &AgentConfig::diversity_tracked_resources)
121+
.def_readwrite("initial_vibe", &AgentConfig::initial_vibe)
118122
.def_readwrite("damage_config", &AgentConfig::damage_config);
119123
}
120124

packages/mettagrid/cpp/src/mettagrid/objects/agent.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ void Agent::on_inventory_change(InventoryItem item, InventoryDelta delta) {
8989

9090
void Agent::update_inventory_diversity_stats(InventoryItem item, InventoryQuantity amount) {
9191
const size_t index = static_cast<size_t>(item);
92-
if (index >= diversity_tracked_mask.size() || diversity_tracked_mask[index] == 0) {
92+
if (index >= this->diversity_tracked_mask.size() || this->diversity_tracked_mask[index] == 0) {
9393
return;
9494
}
9595

96-
const bool had = tracked_resource_presence[index] != 0;
96+
const bool had = this->tracked_resource_presence[index] != 0;
9797
const bool has = amount > 0;
9898

9999
if (had != has) {
100-
tracked_resource_presence[index] = has ? 1 : 0;
101-
tracked_resource_diversity += has ? 1 : static_cast<std::size_t>(-1);
102-
this->stats.set("inventory.diversity", static_cast<float>(tracked_resource_diversity));
100+
this->tracked_resource_presence[index] = has ? 1 : 0;
101+
this->tracked_resource_diversity += has ? 1 : static_cast<std::size_t>(-1);
102+
this->stats.set("inventory.diversity", static_cast<float>(this->tracked_resource_diversity));
103103
}
104104
}
105105

packages/mettagrid/python/src/mettagrid/config/mettagrid_c_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from mettagrid.mettagrid_c import LimitDef as CppLimitDef
2222
from mettagrid.mettagrid_c import MoveActionConfig as CppMoveActionConfig
2323
from mettagrid.mettagrid_c import Protocol as CppProtocol
24+
from mettagrid.mettagrid_c import TransferActionConfig as CppTransferActionConfig
25+
from mettagrid.mettagrid_c import VibeTransferEffect as CppVibeTransferEffect
2426
from mettagrid.mettagrid_c import WallConfig as CppWallConfig
2527

2628

@@ -224,6 +226,7 @@ def convert_to_cpp_game_config(mettagrid_config: dict | GameConfig):
224226
initial_inventory=initial_inventory,
225227
inventory_regen_amounts=inventory_regen_amounts,
226228
diversity_tracked_resources=diversity_tracked_resources,
229+
initial_vibe=agent_props.get("initial_vibe", 0),
227230
damage_config=cpp_damage_config,
228231
)
229232
cpp_agent_config.tag_ids = tag_ids
@@ -449,7 +452,6 @@ def process_action_config(action_name: str, action_config):
449452
action_params["vibes"] = [vibe_name_to_id[vibe] for vibe in actions_config.attack.vibes if vibe in vibe_name_to_id]
450453
actions_cpp_params["attack"] = CppAttackActionConfig(**action_params)
451454

452-
<<<<<<< HEAD
453455
# Process transfer - vibes are derived from vibe_transfers keys in C++
454456
transfer_cfg = actions_config.transfer
455457
vibe_transfers_cpp = {}
@@ -470,8 +472,6 @@ def process_action_config(action_name: str, action_config):
470472
enabled=transfer_cfg.enabled,
471473
)
472474

473-
=======
474-
>>>>>>> d8324275fe (Remove Transfer action, keep only vibe-triggered Attack on Move)
475475
# Process change_vibe - always add to map
476476
action_params = process_action_config("change_vibe", actions_config.change_vibe)
477477
action_params["number_of_vibes"] = (

packages/mettagrid/python/src/mettagrid/config/mettagrid_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def Attack(self, location: Literal["1", "2", "3", "4", "5", "6", "7", "8", "9"])
216216
return Action(name=f"attack_{location}")
217217

218218

219-
<<<<<<< HEAD
220219
class VibeTransfer(Config):
221220
"""Configuration for resource transfers triggered by a specific vibe.
222221
@@ -254,8 +253,6 @@ def _actions(self) -> list[Action]:
254253
return []
255254

256255

257-
=======
258-
>>>>>>> d8324275fe (Remove Transfer action, keep only vibe-triggered Attack on Move)
259256
class ActionsConfig(Config):
260257
"""
261258
Actions configuration.
@@ -266,11 +263,12 @@ class ActionsConfig(Config):
266263
noop: NoopActionConfig = Field(default_factory=lambda: NoopActionConfig())
267264
move: MoveActionConfig = Field(default_factory=lambda: MoveActionConfig())
268265
attack: AttackActionConfig = Field(default_factory=lambda: AttackActionConfig(enabled=False))
266+
transfer: TransferActionConfig = Field(default_factory=lambda: TransferActionConfig(enabled=False))
269267
change_vibe: ChangeVibeActionConfig = Field(default_factory=lambda: ChangeVibeActionConfig())
270268

271269
def actions(self) -> list[Action]:
272270
return sum(
273-
[action.actions() for action in [self.noop, self.move, self.attack, self.change_vibe]],
271+
[action.actions() for action in [self.noop, self.move, self.attack, self.transfer, self.change_vibe]],
274272
[],
275273
)
276274

packages/mettagrid/python/src/mettagrid/mettagrid_c.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class AgentConfig(GridObjectConfig):
9797
initial_inventory: dict[int, int] = {},
9898
inventory_regen_amounts: dict[int, int] = {},
9999
diversity_tracked_resources: list[int] = [],
100+
initial_vibe: int = 0,
100101
damage_config: DamageConfig = ...,
101102
) -> None: ...
102103
type_id: int
@@ -111,6 +112,7 @@ class AgentConfig(GridObjectConfig):
111112
initial_inventory: dict[int, int]
112113
inventory_regen_amounts: dict[int, int]
113114
diversity_tracked_resources: list[int]
115+
initial_vibe: int
114116
damage_config: DamageConfig
115117

116118
class ActionConfig:

packages/mettagrid/tests/test_transfer_action.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525
from mettagrid.simulator import Simulation
2626
from mettagrid.test_support.map_builders import ObjectNameMapBuilder
2727

28+
# Skip all tests in this module if C++ doesn't support transfer action
29+
try:
30+
from mettagrid.mettagrid_c import TransferActionConfig as _ # noqa: F401
31+
32+
HAS_TRANSFER = True
33+
except ImportError:
34+
HAS_TRANSFER = False
35+
36+
pytestmark = pytest.mark.skipif(not HAS_TRANSFER, reason="Transfer action not available in C++ bindings")
37+
2838
# Use vibes from the global VIBES list
2939
# Default (vibe id 0): "default"
3040
# Charger (vibe id 1): "charger" - use for energy transfer tests

packages/mettagrid/tests/test_vibe_triggered_actions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
AttackActionConfig,
77
ChangeVibeActionConfig,
88
GameConfig,
9+
InventoryConfig,
910
MettaGridConfig,
1011
MoveActionConfig,
1112
NoopActionConfig,
@@ -64,12 +65,12 @@ def test_attack_triggers_on_move_with_matching_vibe(self):
6465
AgentConfig(
6566
team_id=0,
6667
freeze_duration=5,
67-
initial_inventory={"energy": 10, "heart": 5},
68+
inventory=InventoryConfig(initial={"energy": 10, "heart": 5}),
6869
),
6970
AgentConfig(
7071
team_id=1,
7172
freeze_duration=5,
72-
initial_inventory={"energy": 10, "heart": 5},
73+
inventory=InventoryConfig(initial={"energy": 10, "heart": 5}),
7374
),
7475
],
7576
)
@@ -128,12 +129,12 @@ def test_no_attack_without_matching_vibe(self):
128129
AgentConfig(
129130
team_id=0,
130131
freeze_duration=5,
131-
initial_inventory={"energy": 10, "heart": 5},
132+
inventory=InventoryConfig(initial={"energy": 10, "heart": 5}),
132133
),
133134
AgentConfig(
134135
team_id=1,
135136
freeze_duration=5,
136-
initial_inventory={"energy": 10, "heart": 5},
137+
inventory=InventoryConfig(initial={"energy": 10, "heart": 5}),
137138
),
138139
],
139140
)
@@ -179,7 +180,7 @@ def test_movement_works_normally_into_empty_space(self):
179180
),
180181
objects={"wall": WallConfig()},
181182
agent=AgentConfig(
182-
initial_inventory={"energy": 10},
183+
inventory=InventoryConfig(initial={"energy": 10}),
183184
),
184185
)
185186

0 commit comments

Comments
 (0)