Skip to content

Commit bae2839

Browse files
committed
Reduce hot-loop syncs in rollout and action sampling
1 parent 3f1583e commit bae2839

File tree

5 files changed

+22
-8
lines changed

5 files changed

+22
-8
lines changed

agent/src/metta/agent/util/distribution_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ def sample_actions(action_logits: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tenso
3131
# Sample actions from categorical distribution (replacement=True is implicit when num_samples=1)
3232
actions = torch.multinomial(action_probs, num_samples=1).view(-1) # [batch_size]
3333

34-
# Extract log-probabilities for sampled actions using advanced indexing
35-
batch_indices = torch.arange(actions.shape[0], device=actions.device)
36-
act_log_prob = full_log_probs[batch_indices, actions] # [batch_size]
34+
# Extract log-probabilities for sampled actions without arange churn
35+
act_log_prob = full_log_probs.gather(1, actions.view(-1, 1)).squeeze(1) # [batch_size]
3736

3837
# Compute policy entropy: H(π) = -∑π(a|s)log π(a|s)
3938
entropy = -torch.sum(action_probs * full_log_probs, dim=-1) # [batch_size]
@@ -65,9 +64,8 @@ def evaluate_actions(action_logits: Tensor, actions: Tensor) -> Tuple[Tensor, Te
6564
action_log_probs = F.log_softmax(action_logits, dim=-1) # [batch_size, num_actions]
6665
action_probs = torch.exp(action_log_probs) # [batch_size, num_actions]
6766

68-
# Extract log-probabilities for the provided actions using advanced indexing
69-
batch_indices = torch.arange(actions.shape[0], device=actions.device)
70-
log_probs = action_log_probs[batch_indices, actions] # [batch_size]
67+
# Extract log-probabilities for the provided actions without arange churn
68+
log_probs = action_log_probs.gather(1, actions.view(-1, 1)).squeeze(1) # [batch_size]
7169

7270
# Compute policy entropy: H(π) = -∑π(a|s)log π(a|s)
7371
entropy = -torch.sum(action_probs * action_log_probs, dim=-1) # [batch_size]

metta/rl/loss/loss.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class Loss:
145145
loss_tracker: dict[str, list[float]] | None = None
146146
_zero_tensor: Tensor | None = None
147147
_context: ComponentContext | None = None
148+
_metric_mb_idx: int | None = field(default=None, init=False, repr=False)
148149

149150
_state_attrs: set[str] = field(default_factory=set, init=False, repr=False)
150151

@@ -205,7 +206,11 @@ def train(
205206
ctx = self._ensure_context(context)
206207
if not self._loss_gate_allows("train", ctx):
207208
return self._zero(), shared_loss_data, False
208-
return self.run_train(shared_loss_data, ctx, mb_idx)
209+
self._metric_mb_idx = mb_idx
210+
try:
211+
return self.run_train(shared_loss_data, ctx, mb_idx)
212+
finally:
213+
self._metric_mb_idx = None
209214

210215
def run_train(
211216
self,
@@ -248,6 +253,10 @@ def stats(self) -> dict[str, float]:
248253

249254
def track_metric(self, key: str, value: Tensor | float) -> None:
250255
"""Track a scalar metric."""
256+
interval = getattr(self.trainer_cfg, "loss_metric_interval", 1)
257+
if interval > 1 and self._metric_mb_idx is not None:
258+
if self._metric_mb_idx % interval != 0:
259+
return
251260
_track_metric(self.loss_tracker, key, value)
252261

253262
def metric_mean(self, key: str) -> float:

metta/rl/trainer_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class TrainerConfig(Config):
8989

9090
# Debug/perf toggles.
9191
synchronize_after_optimizer_step: bool = False
92+
loss_metric_interval: int = Field(default=1, ge=1)
9293
update_epochs: int = Field(default=1, gt=0)
9394
scale_batches_by_world_size: bool = False
9495

metta/rl/training/experience.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151

5252
# Row-aligned tracking (per-agent row slot id and position within row)
5353
self.t_in_row = torch.zeros(total_agents, device=self.device, dtype=torch.int32)
54+
# Keep a CPU mirror to avoid GPU syncs for scalar reads.
55+
self._t_in_row_cpu = torch.zeros(total_agents, device="cpu", dtype=torch.int32)
5456
self.row_slot_ids = torch.arange(total_agents, device=self.device, dtype=torch.int32) % self.segments
5557
self.free_idx = total_agents % self.segments
5658

@@ -114,7 +116,7 @@ def store(self, data_td: TensorDict, env_id: slice) -> None:
114116
assert isinstance(env_id, slice), (
115117
f"TypeError: env_id expected to be a slice for segmented storage. Got {type(env_id).__name__} instead."
116118
)
117-
t_in_row_val = self.t_in_row[env_id.start].item()
119+
t_in_row_val = int(self._t_in_row_cpu[env_id.start].item())
118120
row_ids = self.row_slot_ids[env_id]
119121

120122
# Scheduler updates these keys based on the active losses for the epoch.
@@ -124,6 +126,7 @@ def store(self, data_td: TensorDict, env_id: slice) -> None:
124126
raise ValueError("No store keys set. set_store_keys() was likely used incorrectly.")
125127

126128
self.t_in_row[env_id] += 1
129+
self._t_in_row_cpu[env_id] += 1
127130

128131
if t_in_row_val + 1 >= self.bptt_horizon:
129132
self._reset_completed_episodes(env_id)
@@ -133,6 +136,7 @@ def _reset_completed_episodes(self, env_id) -> None:
133136
num_full = env_id.stop - env_id.start
134137
self.row_slot_ids[env_id] = (self.free_idx + self._range_tensor[:num_full]) % self.segments
135138
self.t_in_row[env_id] = 0
139+
self._t_in_row_cpu[env_id] = 0
136140
self.free_idx = (self.free_idx + num_full) % self.segments
137141
self.full_rows += num_full
138142

@@ -142,6 +146,7 @@ def reset_for_rollout(self) -> None:
142146
self.free_idx = self.total_agents % self.segments
143147
self.row_slot_ids = self._range_tensor % self.segments
144148
self.t_in_row.zero_()
149+
self._t_in_row_cpu.zero_()
145150

146151
def update(self, indices: Tensor, data_td: TensorDict) -> None:
147152
"""Update buffer with new data for given indices."""

recipes/experiment/machina_1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def train(
4040
teacher=teacher,
4141
)
4242
tt.policy_architecture = policy_architecture or ViTDefaultConfig()
43+
tt.trainer.loss_metric_interval = 8
4344

4445
# Explicitly keep full vibe/action definitions so saved checkpoints remain compatible.
4546
env_cfg = tt.training_env.curriculum.task_generator.env

0 commit comments

Comments
 (0)