Skip to content

Conversation

@relh
Copy link
Contributor

@relh relh commented Dec 23, 2025

perf: reduce training overhead in machina_1

Summary

  • Make ActionProbs optionally skip full_log_probs emission when no losses need it (default stays on).
  • Remove redundant TensorDict reshaping in ActionProbs; rely on ComponentPolicy’s reshape step.
  • Reduce rollout/training allocations by caching metadata tensors, indices, and buffers.
  • Avoid CPU↔GPU sync hotspots in Experience bookkeeping and re-use sequential minibatch indices.
  • Set machina_1 recipe to disable deterministic torch and gate full-log-prob emission based on enabled losses.

Motivation / Background

Training throughput was dominated by repeated small allocations and avoidable data movement:

  • Per-step creation of batch/bptt metadata tensors.
  • Per-rollout generation of environment indices and sampling weights.
  • Frequent CPU↔device sync in Experience.store() when reading t_in_row.
  • Unconditional creation/storage of full_log_probs even when no loss consumes them.
    This PR focuses on removing those sources of overhead while keeping behavior stable by default.

Detailed Changes (by file)

agent/src/metta/agent/components/actor.py

  • New config flag: ActionProbsConfig.emit_full_log_probs: bool = True.
    • Why: Many runs don’t need the full action distribution; skipping it avoids extra softmax/logits storage.
  • ActionProbs behavior: full_log_probs is now only added to the TensorDict when emit_full_log_probs=True.
  • Training path simplification: removed the extra TD reshape at the end of forward_training().
    • Why: ComponentPolicy already reshapes based on td["batch"] and td["bptt"]; this avoids an extra reshape and potential extra allocation.

metta/rl/training/core.py

  • Preallocate last_action once to experience.total_agents on init.
    • Why: avoid dynamic growth and per‑rollout reallocation.
  • Env index caching: use the prebuilt range tensor directly without per‑rollout device moves.
  • Row metadata: row_id / t_in_row are used directly (already on the correct device).
  • Rollout metadata caching: inline creation of batch and bptt tensors with a per‑device cache.
    • Why: removes repeated torch.full/ones allocations during rollout.
  • Optimizer step: zero_grad(set_to_none=True) for lower memory traffic.

metta/rl/training/experience.py

  • No duplicate-key helper: inline check to reduce indirection and avoid extra list creation.
  • Row tracking: switch t_in_row/row_slot_ids to int64 and add t_in_row_cpu mirror.
    • Why: use CPU mirror for .item() reads without device syncs; int64 matches torch indexing defaults.
  • Sequential sampling cache: precompute per‑minibatch sequential index tensors.
  • Priority weights reuse: preallocate ones tensor for sequential sampling.
  • Range tensor: _range_tensor now int64 to align with indexing operations.

metta/rl/utils.py

  • Global cache for batch/bptt tensors.
    • Used by ensure_sequence_metadata and prepare_policy_forward_td.
    • Why: avoid per‑call allocations for common shapes.

recipes/experiment/machina_1.py

  • Auto‑gate full_log_probs: when supervisor, sliced_scripted_cloner, eer_kickstarter, and eer_cloner are all disabled, set emit_full_log_probs=False.
  • Explicitly disable deterministic torch: tt.system.torch_deterministic = False.
    • Why: speed; deterministic settings can slow kernels and limit algorithm choices.

Behavior Changes / Compatibility Notes

  • full_log_probs may be absent from TensorDicts when emit_full_log_probs=False.
    • Losses that require it must be enabled with emit_full_log_probs=True. The machina_1 recipe now does this gating automatically.
  • training_env_ids, row_id, and t_in_row are now int64.
    • This is aligned with torch indexing; any downstream code assuming int32 should still work but will see dtype changes.
  • last_action is now fixed-length at total_agents.
    • This assumes the agent count is stable (consistent with current training setup).
  • machina_1 runs are explicitly non‑deterministic now.

Performance Impact (expected)

  • Fewer allocations per rollout step (batch/bptt, indices, sampling weights).
  • Reduced CPU↔device synchronization in Experience.store.
  • Lower memory/compute by skipping full_log_probs when unused.
  • Lower gradient‑zeroing overhead with set_to_none=True.

Prior numbers (for context)

  • Previous perf runs (pre‑revert): e2–e5 avg ~243.6 ksps (3× 10‑min runs)
  • main baseline (same setup): e2–e5 avg ~199.3 ksps (3× 10‑min runs)

How to Use / Manual Notes

  • To force full log probs for a run:
    • Ensure action_probs_config.emit_full_log_probs = True in the policy architecture config.
  • To skip full log probs in machina_1:
    • Ensure none of the teacher/supervised losses are enabled (the recipe now auto‑disables emission).
  • If you enable any of:
    • supervisor, sliced_scripted_cloner, eer_kickstarter, eer_cloner
    • you must keep emit_full_log_probs=True.

Testing

  • Not rerun after reverting stats/monitor cadence changes.
  • Perf numbers listed above are from earlier runs.

Risks / Follow‑ups

  • If any custom loss consumes full_log_probs but is not listed in machina_1 gating, it will need to be added.
  • If any downstream code strictly expects int32 env IDs, it should be updated (though int64 is standard for indexing).

References

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@relh relh force-pushed the richard-perf1223-fixes branch from da67e82 to 1da8ffb Compare December 24, 2025 00:09
Copy link
Contributor Author

relh commented Dec 24, 2025

This stack of pull requests is managed by Graphite. Learn more about stacking.

@datadog-official
Copy link

datadog-official bot commented Dec 24, 2025

✅ Tests

🎉 All green!

❄️ No new flaky tests detected
🧪 All tests passed

This comment will be updated automatically if new data arrives.
🔗 Commit SHA: 63f25b7 | Docs | Was this helpful? Give us feedback!

@relh relh force-pushed the richard-perf1223-fixes branch from 99e6dd7 to bae2839 Compare December 27, 2025 03:53
@relh relh force-pushed the richard-perf1223-fixes branch from d3ea41d to 13cd08d Compare December 28, 2025 16:21
@relh relh changed the title fix training pipeline performance issues Perf: reduce training overhead in machina_1 Dec 28, 2025
@relh relh changed the title Perf: reduce training overhead in machina_1 perf: reduce training overhead in machina_1 Dec 28, 2025
@relh relh enabled auto-merge December 28, 2025 20:58
@relh relh added the review wanted: stamp This PR needs a review from any available team member label Dec 28, 2025

# ComponentPolicy reshapes the TD after training forward based on td["batch"] and td["bptt"]
# The reshaping happens in ComponentPolicy.forward() after forward_training()
if "batch" in td.keys() and "bptt" in td.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this part intended as an addition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! we dont need the .item() and reshapes, the .item() calls cause synchronization between the GPUs and actor heads get called A LOT

@relh relh changed the title perf: reduce training overhead in machina_1 perf: training SPS goes from 190k -> 250k SPS in machina_1 Dec 29, 2025
@relh relh added this pull request to the merge queue Dec 29, 2025
Merged via the queue into main with commit 2252ceb Dec 29, 2025
30 checks passed
@relh relh deleted the richard-perf1223-fixes branch December 29, 2025 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

review wanted: stamp This PR needs a review from any available team member

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants