Skip to content

Commit 0fdcc0b

Browse files
committed
cp
1 parent e8e37f2 commit 0fdcc0b

File tree

14 files changed

+99
-91
lines changed

14 files changed

+99
-91
lines changed

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ MettaGrid::MettaGrid(const GameConfig& game_config, const py::list map, unsigned
6868

6969
_grid = std::make_unique<Grid>(height, width);
7070
_obs_encoder = std::make_unique<ObservationEncoder>(
71-
game_config.protocol_details_obs, resource_names, game_config.feature_ids, game_config.token_value_max);
71+
game_config.protocol_details_obs, resource_names, game_config.feature_ids, game_config.token_value_base);
7272

7373
// Initialize ObservationFeature namespace with feature IDs
7474
ObservationFeature::Initialize(game_config.feature_ids);

packages/mettagrid/cpp/include/mettagrid/config/mettagrid_config.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct GameConfig {
5353
std::shared_ptr<ClipperConfig> clipper = nullptr;
5454

5555
// Observation encoding settings
56-
unsigned int token_value_max = 255; // Maximum value per inventory token (base for encoding)
56+
unsigned int token_value_base = 256; // Base for multi-token inventory encoding (value per token: 0 to base-1)
5757
};
5858

5959
namespace py = pybind11;
@@ -127,7 +127,7 @@ inline void bind_game_config(py::module& m) {
127127
py::arg("clipper") = std::shared_ptr<ClipperConfig>(nullptr),
128128

129129
// Observation encoding
130-
py::arg("token_value_max") = 255)
130+
py::arg("token_value_base") = 256)
131131
.def_readwrite("num_agents", &GameConfig::num_agents)
132132
.def_readwrite("max_steps", &GameConfig::max_steps)
133133
.def_readwrite("episode_truncates", &GameConfig::episode_truncates)
@@ -158,7 +158,7 @@ inline void bind_game_config(py::module& m) {
158158
.def_readwrite("clipper", &GameConfig::clipper)
159159

160160
// Observation encoding
161-
.def_readwrite("token_value_max", &GameConfig::token_value_max);
161+
.def_readwrite("token_value_base", &GameConfig::token_value_base);
162162
}
163163

164164
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CONFIG_METTAGRID_CONFIG_HPP_

packages/mettagrid/cpp/include/mettagrid/objects/chest.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_OBJECTS_CHEST_HPP_
22
#define PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_OBJECTS_CHEST_HPP_
33

4+
#include <algorithm>
45
#include <set>
56
#include <unordered_map>
67
#include <vector>
@@ -125,7 +126,8 @@ class Chest : public GridObject, public Usable, public HasInventory {
125126
throw std::runtime_error("Observation encoder not set for chest");
126127
}
127128
std::vector<PartialObservationToken> features;
128-
features.reserve(1 + this->inventory.get().size() + this->tag_ids.size() + 3);
129+
features.reserve(1 + this->inventory.get().size() * this->obs_encoder->get_num_inventory_tokens() +
130+
this->tag_ids.size() + (this->vibe != 0 ? 1 : 0));
129131

130132
if (this->vibe != 0) features.push_back({ObservationFeature::Vibe, static_cast<ObservationType>(this->vibe)});
131133

packages/mettagrid/cpp/include/mettagrid/objects/inventory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_OBJECTS_INVENTORY_HPP_
22
#define PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_OBJECTS_INVENTORY_HPP_
33

4+
#include <limits>
45
#include <string>
56
#include <unordered_map>
67
#include <vector>

packages/mettagrid/cpp/include/mettagrid/systems/observation_encoder.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class ObservationEncoder {
5454

5555
// Build inventory feature ID maps using multi-token encoding
5656
// inv:{resource} = base token (always emitted)
57-
// inv:{resource}:p1 = first power token (emitted if amount >= token_value_max)
58-
// inv:{resource}:p2 = second power token (emitted if amount >= token_value_max^2)
57+
// inv:{resource}:p1 = first power token (emitted if amount >= token_value_base)
58+
// inv:{resource}:p2 = second power token (emitted if amount >= token_value_base^2)
5959
// etc.
6060
_inventory_feature_ids.resize(resource_names.size());
6161
_inventory_power_feature_ids.resize(resource_names.size());
@@ -158,9 +158,9 @@ class ObservationEncoder {
158158
}
159159

160160
// Encode inventory amount using multi-token encoding with configurable base.
161-
// inv:{resource} = amount % token_value_max (always emitted)
162-
// inv:{resource}:p1 = (amount / token_value_max) % token_value_max (only emitted if amount >= token_value_max)
163-
// inv:{resource}:p2 = (amount / token_value_max^2) % token_value_max (only emitted if amount >= token_value_max^2)
161+
// inv:{resource} = amount % token_value_base (always emitted)
162+
// inv:{resource}:p1 = (amount / token_value_base) % token_value_base (only emitted if amount >= token_value_base)
163+
// inv:{resource}:p2 = (amount / token_value_base^2) % token_value_base (only emitted if amount >= token_value_base^2)
164164
// etc.
165165
void append_inventory_tokens(std::vector<PartialObservationToken>& features,
166166
InventoryItem item,
@@ -179,7 +179,7 @@ class ObservationEncoder {
179179
}
180180
}
181181

182-
unsigned int get_token_value_max() const {
182+
unsigned int get_token_value_base() const {
183183
return _token_value_base;
184184
}
185185

@@ -195,7 +195,8 @@ class ObservationEncoder {
195195
size_t _num_inventory_tokens;
196196
std::vector<ObservationType> _input_feature_ids;
197197
std::vector<ObservationType> _output_feature_ids;
198-
std::vector<ObservationType> _inventory_feature_ids; // Maps item index to base feature ID (amount % token_value_max)
198+
std::vector<ObservationType>
199+
_inventory_feature_ids; // Maps item index to base feature ID (amount % token_value_base)
199200
std::vector<std::vector<ObservationType>> _inventory_power_feature_ids; // Maps item index to power feature IDs
200201
};
201202

packages/mettagrid/cpp/src/mettagrid/objects/agent.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ std::vector<PartialObservationToken> Agent::obs_features() const {
209209
if (!this->obs_encoder) {
210210
throw std::runtime_error("Observation encoder not set for agent");
211211
}
212-
const size_t num_tokens = this->inventory.get().size() + this->tag_ids.size() + 5;
212+
const size_t num_tokens =
213+
this->inventory.get().size() * this->obs_encoder->get_num_inventory_tokens() + this->tag_ids.size() + 5;
213214

214215
std::vector<PartialObservationToken> features;
215216
features.reserve(num_tokens);

packages/mettagrid/docs/observations.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,36 +116,36 @@ depend on your game configuration (number of resources, whether protocol details
116116
| `cooldown_remaining` | Remaining cooldown time for objects | assembler, extractors | Value capped at 255 |
117117
| `clipped` | Whether an assembler is clipped or not | extractors | |
118118
| `remaining_uses` | Remaining uses for objects with use limits | extractors | Value capped at 255. Only emitted if `max_uses > 0` |
119-
| `inv:{resource_name}` | Base inventory amount (amount % token_value_max) | agents, chests | One feature per resource. See [Inventory Encoding](#inventory-encoding) below. |
119+
| `inv:{resource_name}` | Base inventory amount (amount % token_value_base) | agents, chests | One feature per resource. See [Inventory Encoding](#inventory-encoding) below. |
120120
| `inv:{resource_name}:p1` | Power 1 component ((amount / B) % B) | agents, chests | Only emitted if amount >= B. See [Inventory Encoding](#inventory-encoding). |
121121
| `inv:{resource_name}:p2` | Power 2 component ((amount / B²) % B) | agents, chests | Only emitted if amount >= B². See [Inventory Encoding](#inventory-encoding). |
122122
| `protocol_input:{resource_name}` | Required input resource amount for current protocol | assembler, extractors | One feature per resource |
123123
| `protocol_output:{resource_name}` | Output resource amount for current protocol | assembler, extractors | One feature per resource |
124124

125125
### Inventory Encoding
126126

127-
Inventory values are encoded using a multi-token scheme with a configurable base (`ObsConfig.token_value_max`, default
128-
255). This allows representing large amounts while keeping individual token values bounded. The number of tokens is
127+
Inventory values are encoded using a multi-token scheme with a configurable base (`ObsConfig.token_value_base`, default
128+
256). This allows representing large amounts while keeping individual token values bounded. The number of tokens is
129129
dynamically computed based on the maximum inventory value (uint16_t max = 65535).
130130

131131
- **`inv:{resource}`**: Base value = `amount % B` (always emitted if amount > 0)
132132
- **`inv:{resource}:p1`**: Power 1 = `(amount / B) % B` (only emitted if amount >= B)
133133
- **`inv:{resource}:p2`**: Power 2 = `(amount / B²) % B` (only emitted if amount >= B²)
134134
- etc.
135135

136-
Where B = `token_value_max` (default 255).
136+
Where B = `token_value_base` (default 256).
137137

138138
The full value is reconstructed as: `base + p1 * B + p2 * B² + ...`
139139

140-
**Examples with token_value_max=255:**
140+
**Examples with token_value_base=256 (default):**
141141

142-
| Amount | `inv:food` | `inv:food:p1` | `inv:food:p2` | Reconstruction |
143-
| ------ | ---------- | ------------- | ------------- | ------------------------- |
144-
| 42 | 42 | (not emitted) | (not emitted) | 42 |
145-
| 1234 | 214 | 4 | (not emitted) | 214 + 4 \* 255 = 1234 |
146-
| 65535 | 0 | 2 | 1 | 0 + 2 \* 255 + 1 \* 65025 |
142+
| Amount | `inv:food` | `inv:food:p1` | `inv:food:p2` | Reconstruction |
143+
| ------ | ---------- | ------------- | ------------- | ------------------------ |
144+
| 42 | 42 | (not emitted) | (not emitted) | 42 |
145+
| 1234 | 210 | 4 | (not emitted) | 210 + 4 \* 256 = 1234 |
146+
| 65535 | 255 | 255 | (not emitted) | 255 + 255 \* 256 = 65535 |
147147

148-
**Examples with token_value_max=100:**
148+
**Examples with token_value_base=100:**
149149

150150
| Amount | `inv:food` | `inv:food:p1` | `inv:food:p2` | Reconstruction |
151151
| ------ | ---------- | ------------- | ------------- | --------------------------- |

packages/mettagrid/docs/simulator_api.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,11 @@ obs = agent.observation
265265
for token in obs.tokens:
266266
if token.feature.name == "object_type":
267267
print(f"Object at ({token.col()}, {token.row()}): {token.value}")
268-
elif token.feature.name.startswith("inv:"):
269-
resource = token.feature.name[4:] # Remove "inv:" prefix
270-
print(f"Inventory {resource}: {token.value}")
268+
269+
# For inventory, use the agent.inventory property which handles the encoding
270+
inventory = agent.inventory
271+
for resource, amount in inventory.items():
272+
print(f"Inventory {resource}: {amount}")
271273
```
272274

273275
## Event Handling

packages/mettagrid/python/src/mettagrid/config/id_map.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515
from mettagrid.config.mettagrid_config import GameConfig
1616

1717

18-
def num_inventory_tokens_needed(max_inventory_value: int, token_value_max: int) -> int:
18+
def num_inventory_tokens_needed(max_inventory_value: int, token_value_base: int) -> int:
1919
"""Calculate how many tokens are needed to encode max_inventory_value with given base.
2020
2121
Args:
2222
max_inventory_value: Maximum inventory value to encode (e.g., 65535 for uint16_t)
23-
token_value_max: Maximum value per token (base for encoding)
23+
token_value_base: Base for encoding (value per token: 0 to base-1)
2424
2525
Returns:
2626
Number of tokens needed
2727
"""
2828
if max_inventory_value == 0:
2929
return 1
3030
# Need ceil(log_base(max_value + 1)) tokens
31-
return math.ceil(math.log(max_inventory_value + 1, token_value_max))
31+
return math.ceil(math.log(max_inventory_value + 1, token_value_base))
3232

3333

3434
class ObservationFeatureSpec(BaseModel):
@@ -137,14 +137,14 @@ def _compute_features(self) -> list[ObservationFeatureSpec]:
137137
feature_id += 1
138138

139139
# Inventory features using multi-token encoding with configurable base
140-
# inv:{resource} = amount % token_value_max (always emitted)
141-
# inv:{resource}:p1 = (amount / token_value_max) % token_value_max (emitted if amount >= token_value_max)
142-
# inv:{resource}:p2 = (amount / token_value_max^2) % token_value_max (emitted if amount >= token_value_max^2)
140+
# inv:{resource} = amount % token_value_base (always emitted)
141+
# inv:{resource}:p1 = (amount / token_value_base) % token_value_base (emitted if amount >= token_value_base)
142+
# inv:{resource}:p2 = (amount / token_value_base^2) % token_value_base (emitted if amount >= token_value_base^2)
143143
# etc.
144144
# Number of tokens is computed based on max uint16_t value (65535)
145-
token_value_max = self._config.obs.token_value_max
146-
num_inv_tokens = num_inventory_tokens_needed(65535, token_value_max)
147-
normalization = float(token_value_max)
145+
token_value_base = self._config.obs.token_value_base
146+
num_inv_tokens = num_inventory_tokens_needed(65535, token_value_base)
147+
normalization = float(token_value_base)
148148
for resource_name in self._config.resource_names:
149149
# Base token (always present)
150150
name = f"inv:{resource_name}"

packages/mettagrid/python/src/mettagrid/config/mettagrid_c_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,7 @@ def convert_to_cpp_game_config(mettagrid_config: dict | GameConfig):
348348
if name in resource_name_to_id
349349
}
350350
limit_defs.append(
351-
CppLimitDef(
352-
resources=resource_ids, base_limit=min(resource_limit.limit, 255), modifiers=modifier_ids
353-
)
351+
CppLimitDef(resources=resource_ids, base_limit=resource_limit.limit, modifiers=modifier_ids)
354352
)
355353

356354
inventory_config = CppInventoryConfig()
@@ -383,7 +381,7 @@ def convert_to_cpp_game_config(mettagrid_config: dict | GameConfig):
383381
game_cpp_params["obs_width"] = obs_config["width"]
384382
game_cpp_params["obs_height"] = obs_config["height"]
385383
game_cpp_params["num_observation_tokens"] = obs_config["num_tokens"]
386-
game_cpp_params["token_value_max"] = obs_config.get("token_value_max", 255)
384+
game_cpp_params["token_value_base"] = obs_config.get("token_value_base", 256)
387385
# Note: token_dim is not used by C++ GameConfig, it's only used in Python
388386

389387
# Convert observation features from Python to C++

0 commit comments

Comments
 (0)