Skip to content

Commit 10c987e

Browse files
committed
cp
1 parent 4a5f8a1 commit 10c987e

File tree

13 files changed

+56
-150
lines changed

13 files changed

+56
-150
lines changed

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 24 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@
1414

1515
#include "actions/action_handler.hpp"
1616
#include "actions/attack.hpp"
17-
#include "actions/build.hpp"
18-
#include "actions/build_config.hpp"
1917
#include "actions/change_vibe.hpp"
2018
#include "actions/move.hpp"
2119
#include "actions/move_config.hpp"
2220
#include "actions/noop.hpp"
23-
#include "actions/transfer.hpp"
2421
#include "config/observation_features.hpp"
2522
#include "core/grid.hpp"
2623
#include "core/types.hpp"
@@ -85,15 +82,10 @@ MettaGrid::MettaGrid(const GameConfig& game_config, const py::list map, unsigned
8582

8683
_action_success.resize(num_agents);
8784

88-
init_action_handlers();
85+
init_action_handlers(_game_config);
8986

9087
_init_grid(_game_config, map);
9188

92-
// Set runtime context for Build handler (needs obs_encoder and agents count)
93-
if (_build_handler) {
94-
_build_handler->set_runtime_context(&current_step, _obs_encoder.get(), num_agents);
95-
}
96-
9789
// Pre-compute goal_obs tokens for each agent
9890
if (_global_obs_config.goal_obs) {
9991
_agent_goal_obs_tokens.resize(_agents.size());
@@ -259,11 +251,11 @@ void MettaGrid::_init_buffers(unsigned int num_agents) {
259251
_compute_observations(executed_actions);
260252
}
261253

262-
void MettaGrid::init_action_handlers() {
254+
void MettaGrid::init_action_handlers(const GameConfig& game_config) {
263255
_max_action_priority = 0;
264256

265257
// Noop
266-
auto noop = std::make_unique<Noop>(*_game_config.actions.at("noop"));
258+
auto noop = std::make_unique<Noop>(*game_config.actions.at("noop"));
267259
noop->init(_grid.get(), &_rng);
268260
if (noop->priority > _max_action_priority) _max_action_priority = noop->priority;
269261
for (const auto& action : noop->actions()) {
@@ -272,72 +264,39 @@ void MettaGrid::init_action_handlers() {
272264
_action_handler_impl.push_back(std::move(noop));
273265

274266
// Move
275-
auto move_config = std::static_pointer_cast<const MoveActionConfig>(_game_config.actions.at("move"));
276-
auto move = std::make_unique<Move>(*move_config, &_game_config);
267+
auto move_config = std::static_pointer_cast<const MoveActionConfig>(game_config.actions.at("move"));
268+
auto move = std::make_unique<Move>(*move_config, &game_config);
277269
move->init(_grid.get(), &_rng);
278270
if (move->priority > _max_action_priority) _max_action_priority = move->priority;
279271
for (const auto& action : move->actions()) {
280272
_action_handlers.push_back(action);
281273
}
282-
// Capture the raw pointer to pass to other handlers
283-
Move* move_ptr = move.get();
284274
_action_handler_impl.push_back(std::move(move));
285275

286276
// Attack
287-
auto attack_config = std::static_pointer_cast<const AttackActionConfig>(_game_config.actions.at("attack"));
288-
auto attack = std::make_unique<Attack>(*attack_config, &_game_config);
289-
attack->init(_grid.get(), &_rng);
290-
if (attack->priority > _max_action_priority) _max_action_priority = attack->priority;
291-
for (const auto& action : attack->actions()) {
292-
_action_handlers.push_back(action);
293-
}
294-
295-
// Transfer
296-
auto transfer_config = std::static_pointer_cast<const TransferActionConfig>(_game_config.actions.at("transfer"));
297-
auto transfer = std::make_unique<Transfer>(*transfer_config, &_game_config);
298-
transfer->init(_grid.get(), &_rng);
299-
if (transfer->priority > _max_action_priority) _max_action_priority = transfer->priority;
300-
for (const auto& action : transfer->actions()) {
301-
_action_handlers.push_back(action);
302-
}
303-
304-
// Build (creates objects at previous location after successful move)
305-
_build_handler = nullptr;
306-
std::unique_ptr<Build> build;
307-
if (_game_config.actions.find("build") != _game_config.actions.end()) {
308-
auto build_config = std::static_pointer_cast<const BuildActionConfig>(_game_config.actions.at("build"));
309-
build = std::make_unique<Build>(*build_config, &_game_config, _stats.get());
310-
build->init(_grid.get(), &_rng);
311-
if (build->priority > _max_action_priority) _max_action_priority = build->priority;
312-
// Build doesn't create standalone actions - it's triggered by move
313-
_build_handler = build.get();
314-
}
315-
316-
// Register vibe-triggered action handlers with Move
317-
std::unordered_map<std::string, ActionHandler*> handlers;
318-
handlers["attack"] = attack.get();
319-
handlers["transfer"] = transfer.get();
320-
if (build) {
321-
handlers["build"] = build.get();
322-
}
323-
move_ptr->set_action_handlers(handlers);
324-
325-
_action_handler_impl.push_back(std::move(attack));
326-
_action_handler_impl.push_back(std::move(transfer));
327-
if (build) {
328-
_action_handler_impl.push_back(std::move(build));
277+
if (game_config.actions.find("attack") != game_config.actions.end()) {
278+
auto attack_config = std::static_pointer_cast<const AttackActionConfig>(game_config.actions.at("attack"));
279+
auto attack = std::make_unique<Attack>(*attack_config, &game_config);
280+
attack->init(_grid.get(), &_rng);
281+
if (attack->priority > _max_action_priority) _max_action_priority = attack->priority;
282+
for (const auto& action : attack->actions()) {
283+
_action_handlers.push_back(action);
284+
}
285+
_action_handler_impl.push_back(std::move(attack));
329286
}
330287

331288
// ChangeVibe
332-
auto change_vibe_config =
333-
std::static_pointer_cast<const ChangeVibeActionConfig>(_game_config.actions.at("change_vibe"));
334-
auto change_vibe = std::make_unique<ChangeVibe>(*change_vibe_config, &_game_config);
335-
change_vibe->init(_grid.get(), &_rng);
336-
if (change_vibe->priority > _max_action_priority) _max_action_priority = change_vibe->priority;
337-
for (const auto& action : change_vibe->actions()) {
338-
_action_handlers.push_back(action);
289+
if (game_config.actions.find("change_vibe") != game_config.actions.end()) {
290+
auto change_vibe_config =
291+
std::static_pointer_cast<const ChangeVibeActionConfig>(game_config.actions.at("change_vibe"));
292+
auto change_vibe = std::make_unique<ChangeVibe>(*change_vibe_config, &game_config);
293+
change_vibe->init(_grid.get(), &_rng);
294+
if (change_vibe->priority > _max_action_priority) _max_action_priority = change_vibe->priority;
295+
for (const auto& action : change_vibe->actions()) {
296+
_action_handlers.push_back(action);
297+
}
298+
_action_handler_impl.push_back(std::move(change_vibe));
339299
}
340-
_action_handler_impl.push_back(std::move(change_vibe));
341300
}
342301

343302
void MettaGrid::add_agent(Agent* agent) {
@@ -1091,10 +1050,6 @@ PYBIND11_MODULE(mettagrid_c, m) {
10911050
bind_chest_config(m);
10921051
bind_action_config(m);
10931052
bind_attack_action_config(m);
1094-
bind_vibe_transfer_effect(m);
1095-
bind_transfer_action_config(m);
1096-
bind_vibe_build_effect(m);
1097-
bind_build_action_config(m);
10981053
bind_change_vibe_action_config(m);
10991054
bind_move_action_config(m);
11001055
bind_global_obs_config(m);

packages/mettagrid/cpp/include/mettagrid/core/grid.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CORE_GRID_HPP_
33

44
#include <algorithm>
5+
#include <cstdlib>
56
#include <memory>
67
#include <unordered_map>
78
#include <vector>

packages/mettagrid/cpp/include/mettagrid/objects/agent.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class Agent : public GridObject, public HasInventory, public Usable {
6565
// Implementation of Usable interface
6666
bool onUse(Agent& actor, ActionArg arg) override;
6767

68-
std::vector<PartialObservationToken> obs_features() const override;
68+
std::vector<PartialObservationToken> obs_features(unsigned int observer_agent_id = UINT_MAX) const override;
6969

7070
// Set observation encoder for inventory feature ID lookup
7171
void set_obs_encoder(const ObservationEncoder* encoder) {

packages/mettagrid/cpp/include/mettagrid/objects/assembler.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class Assembler : public GridObject, public Usable {
266266
allow_partial_usage(cfg.allow_partial_usage),
267267
chest_search_distance(cfg.chest_search_distance),
268268
clipper_ptr(nullptr) {
269-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe);
269+
GridObject::init(cfg, GridLocation(r, c));
270270
}
271271
virtual ~Assembler() = default;
272272

@@ -499,7 +499,8 @@ class Assembler : public GridObject, public Usable {
499499
return true;
500500
}
501501

502-
virtual std::vector<PartialObservationToken> obs_features() const override {
502+
virtual std::vector<PartialObservationToken> obs_features(
503+
unsigned int /*observer_agent_id*/ = UINT_MAX) const override {
503504
std::vector<PartialObservationToken> features;
504505

505506
unsigned int remaining = std::min(cooldown_remaining(), 255u);

packages/mettagrid/cpp/include/mettagrid/objects/chest.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Chest : public GridObject, public Usable, public HasInventory {
7979
vibe_transfers(cfg.vibe_transfers),
8080
stats_tracker(stats_tracker),
8181
grid(nullptr) {
82-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe);
82+
GridObject::init(cfg, GridLocation(r, c));
8383
// Set initial inventory for all configured resources (ignore limits for initial setup)
8484
for (const auto& [resource, amount] : cfg.initial_inventory) {
8585
if (amount > 0) {
@@ -123,7 +123,8 @@ class Chest : public GridObject, public Usable, public HasInventory {
123123
return false;
124124
}
125125

126-
virtual std::vector<PartialObservationToken> obs_features() const override {
126+
virtual std::vector<PartialObservationToken> obs_features(
127+
unsigned int /*observer_agent_id*/ = UINT_MAX) const override {
127128
if (!this->obs_encoder) {
128129
throw std::runtime_error("Observation encoder not set for chest");
129130
}

packages/mettagrid/cpp/include/mettagrid/objects/wall.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ struct WallConfig : public GridObjectConfig {
2020
class Wall : public GridObject {
2121
public:
2222
Wall(GridCoord r, GridCoord c, const WallConfig& cfg) {
23-
GridObject::init(cfg.type_id, cfg.type_name, GridLocation(r, c), cfg.tag_ids, cfg.initial_vibe);
23+
GridObject::init(cfg, GridLocation(r, c));
2424
}
2525

26-
std::vector<PartialObservationToken> obs_features() const override {
26+
std::vector<PartialObservationToken> obs_features(unsigned int /*observer_agent_id*/ = UINT_MAX) const override {
2727
std::vector<PartialObservationToken> features;
2828
features.reserve(1 + tag_ids.size() + (this->vibe != 0 ? 1 : 0));
2929

packages/mettagrid/cpp/include/mettagrid/systems/observation_encoder.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,11 @@ class ObservationEncoder {
111111

112112
// Returns the number of tokens that were available to write. This will be the number of tokens actually
113113
// written if there was enough space -- or a greater number if there was not enough space.
114-
size_t encode_tokens(const GridObject* obj, ObservationTokens tokens, ObservationType location) {
115-
return append_tokens_if_room_available(tokens, obj->obs_features(), location);
114+
size_t encode_tokens(const GridObject* obj,
115+
ObservationTokens tokens,
116+
ObservationType location,
117+
unsigned int observer_agent_id = UINT_MAX) {
118+
return append_tokens_if_room_available(tokens, obj->obs_features(observer_agent_id), location);
116119
}
117120

118121
size_t get_resource_count() const {

packages/mettagrid/cpp/src/mettagrid/objects/agent.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Agent::Agent(GridCoord r,
4242
}
4343
}
4444
populate_initial_inventory(config.initial_inventory);
45-
GridObject::init(config.type_id, config.type_name, GridLocation(r, c), config.tag_ids, config.initial_vibe);
45+
GridObject::init(config, GridLocation(r, c));
4646
}
4747

4848
void Agent::init(RewardType* reward_ptr) {
@@ -205,7 +205,7 @@ bool Agent::onUse(Agent& actor, ActionArg arg) {
205205
return any_transfer_occurred;
206206
}
207207

208-
std::vector<PartialObservationToken> Agent::obs_features() const {
208+
std::vector<PartialObservationToken> Agent::obs_features(unsigned int /*observer_agent_id*/) const {
209209
if (!this->obs_encoder) {
210210
throw std::runtime_error("Observation encoder not set for agent");
211211
}

packages/mettagrid/python/src/mettagrid/config/mettagrid_c_config.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from mettagrid.mettagrid_c import AssemblerConfig as CppAssemblerConfig
1717
from mettagrid.mettagrid_c import AttackActionConfig as CppAttackActionConfig
1818
from mettagrid.mettagrid_c import AttackOutcome as CppAttackOutcome
19-
from mettagrid.mettagrid_c import BuildActionConfig as CppBuildActionConfig
2019
from mettagrid.mettagrid_c import ChangeVibeActionConfig as CppChangeVibeActionConfig
2120
from mettagrid.mettagrid_c import ChestConfig as CppChestConfig
2221
from mettagrid.mettagrid_c import ClipperConfig as CppClipperConfig
@@ -28,9 +27,6 @@
2827
from mettagrid.mettagrid_c import LimitDef as CppLimitDef
2928
from mettagrid.mettagrid_c import MoveActionConfig as CppMoveActionConfig
3029
from mettagrid.mettagrid_c import Protocol as CppProtocol
31-
from mettagrid.mettagrid_c import TransferActionConfig as CppTransferActionConfig
32-
from mettagrid.mettagrid_c import VibeBuildEffect as CppVibeBuildEffect
33-
from mettagrid.mettagrid_c import VibeTransferEffect as CppVibeTransferEffect
3430
from mettagrid.mettagrid_c import WallConfig as CppWallConfig
3531

3632

@@ -513,26 +509,6 @@ def process_action_config(action_name: str, action_config):
513509
action_params["vibe_bonus"] = {vibe_name_to_id[vibe]: bonus for vibe, bonus in attack_cfg.vibe_bonus.items()}
514510
actions_cpp_params["attack"] = CppAttackActionConfig(**action_params)
515511

516-
# Process transfer - vibes are derived from vibe_transfers keys in C++
517-
transfer_cfg = actions_config.transfer
518-
vibe_transfers_cpp = {}
519-
seen_vibes: set[str] = set()
520-
for vt in transfer_cfg.vibe_transfers:
521-
if vt.vibe not in vibe_name_to_id:
522-
raise ValueError(f"Unknown vibe name '{vt.vibe}' in transfer.vibe_transfers")
523-
if vt.vibe in seen_vibes:
524-
raise ValueError(f"Duplicate vibe name '{vt.vibe}' in transfer.vibe_transfers")
525-
seen_vibes.add(vt.vibe)
526-
vibe_id = vibe_name_to_id[vt.vibe]
527-
target_deltas = {resource_name_to_id[k]: v for k, v in vt.target.items()}
528-
actor_deltas = {resource_name_to_id[k]: v for k, v in vt.actor.items()}
529-
vibe_transfers_cpp[vibe_id] = CppVibeTransferEffect(target_deltas, actor_deltas)
530-
actions_cpp_params["transfer"] = CppTransferActionConfig(
531-
required_resources={resource_name_to_id[k]: int(v) for k, v in transfer_cfg.required_resources.items()},
532-
vibe_transfers=vibe_transfers_cpp,
533-
enabled=transfer_cfg.enabled,
534-
)
535-
536512
# Process change_vibe - always add to map
537513
action_params = process_action_config("change_vibe", actions_config.change_vibe)
538514
# Use vibes length if explicitly set, otherwise fall back to number_of_vibes
@@ -546,39 +522,6 @@ def process_action_config(action_name: str, action_config):
546522
action_params["number_of_vibes"] = 0
547523
actions_cpp_params["change_vibe"] = CppChangeVibeActionConfig(**action_params)
548524

549-
# Process build - collect build configs from all objects that have a build field
550-
vibe_builds: dict[int, CppVibeBuildEffect] = {}
551-
build_vibes: list[int] = []
552-
for object_type, object_config in game_config.objects.items():
553-
if isinstance(object_config, WallConfig) and object_config.build is not None:
554-
build_cfg = object_config.build
555-
# Validate vibe name exists
556-
if build_cfg.vibe not in vibe_name_to_id:
557-
raise ValueError(f"Unknown vibe name '{build_cfg.vibe}' in build config for '{object_type}'")
558-
vibe_id = vibe_name_to_id[build_cfg.vibe]
559-
# Validate and convert resource costs
560-
cost_cpp: dict[int, int] = {}
561-
for resource_name, amount in build_cfg.cost.items():
562-
if resource_name not in resource_name_to_id:
563-
raise ValueError(
564-
f"Unknown resource '{resource_name}' in build config cost for '{object_type}'. "
565-
f"Available resources: {list(resource_name_to_id.keys())}"
566-
)
567-
cost_cpp[resource_name_to_id[resource_name]] = int(amount)
568-
# Create VibeBuildEffect
569-
vibe_builds[vibe_id] = CppVibeBuildEffect(cost_cpp, object_type)
570-
build_vibes.append(vibe_id)
571-
572-
# Only add build action if there are any build configs
573-
if vibe_builds:
574-
actions_cpp_params["build"] = CppBuildActionConfig(
575-
required_resources={},
576-
consumed_resources={},
577-
vibe_builds=vibe_builds,
578-
enabled=True,
579-
vibes=build_vibes,
580-
)
581-
582525
game_cpp_params["actions"] = actions_cpp_params
583526
game_cpp_params["objects"] = objects_cpp_params
584527

packages/mettagrid/python/src/mettagrid/simulator/simulator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,15 @@ def __init__(
6565
self._timer = Stopwatch(log_level=logger.getEffectiveLevel())
6666
self._timer.start()
6767

68-
game_config_dict = self._config.game.model_dump()
69-
7068
with self._timer("sim.init.make_map"):
7169
map_grid = self._make_map().grid.tolist()
7270

7371
# Create C++ config
7472
try:
75-
c_cfg = mettagrid_c_config.convert_to_cpp_game_config(game_config_dict)
73+
c_cfg = mettagrid_c_config.convert_to_cpp_game_config(self._config.game)
7674
except Exception as e:
7775
logger.error(f"Error creating C++ config: {e}")
78-
logger.error(f"Game config: {game_config_dict}")
76+
logger.error(f"Game config: {self._config.game.model_dump()}")
7977
raise e
8078

8179
# Create C++ environment

0 commit comments

Comments
 (0)