-
Notifications
You must be signed in to change notification settings - Fork 48
perf: training SPS goes from 190k -> 250k SPS in machina_1 #4526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
da67e82 to
1da8ffb
Compare
99e6dd7 to
bae2839
Compare
d3ea41d to
13cd08d
Compare
|
|
||
| # ComponentPolicy reshapes the TD after training forward based on td["batch"] and td["bptt"] | ||
| # The reshaping happens in ComponentPolicy.forward() after forward_training() | ||
| if "batch" in td.keys() and "bptt" in td.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this part intended as an addition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup! we dont need the .item() and reshapes, the .item() calls cause synchronization between the GPUs and actor heads get called A LOT

perf: reduce training overhead in machina_1
Summary
full_log_probsemission when no losses need it (default stays on).machina_1recipe to disable deterministic torch and gate full-log-prob emission based on enabled losses.Motivation / Background
Training throughput was dominated by repeated small allocations and avoidable data movement:
batch/bpttmetadata tensors.Experience.store()when readingt_in_row.full_log_probseven when no loss consumes them.This PR focuses on removing those sources of overhead while keeping behavior stable by default.
Detailed Changes (by file)
agent/src/metta/agent/components/actor.pyActionProbsConfig.emit_full_log_probs: bool = True.full_log_probsis now only added to the TensorDict whenemit_full_log_probs=True.forward_training().td["batch"]andtd["bptt"]; this avoids an extra reshape and potential extra allocation.metta/rl/training/core.pylast_actiononce toexperience.total_agentson init.row_id/t_in_roware used directly (already on the correct device).batchandbptttensors with a per‑device cache.torch.full/onesallocations during rollout.zero_grad(set_to_none=True)for lower memory traffic.metta/rl/training/experience.pyt_in_row/row_slot_idstoint64and addt_in_row_cpumirror..item()reads without device syncs;int64matches torch indexing defaults.onestensor for sequential sampling._range_tensornowint64to align with indexing operations.metta/rl/utils.pybatch/bptttensors.ensure_sequence_metadataandprepare_policy_forward_td.recipes/experiment/machina_1.pyfull_log_probs: whensupervisor,sliced_scripted_cloner,eer_kickstarter, andeer_clonerare all disabled, setemit_full_log_probs=False.tt.system.torch_deterministic = False.Behavior Changes / Compatibility Notes
full_log_probsmay be absent from TensorDicts whenemit_full_log_probs=False.emit_full_log_probs=True. Themachina_1recipe now does this gating automatically.training_env_ids,row_id, andt_in_roware nowint64.int32should still work but will see dtype changes.last_actionis now fixed-length attotal_agents.machina_1runs are explicitly non‑deterministic now.Performance Impact (expected)
batch/bptt, indices, sampling weights).Experience.store.full_log_probswhen unused.set_to_none=True.Prior numbers (for context)
How to Use / Manual Notes
action_probs_config.emit_full_log_probs = Truein the policy architecture config.machina_1:supervisor,sliced_scripted_cloner,eer_kickstarter,eer_cloneremit_full_log_probs=True.Testing
Risks / Follow‑ups
full_log_probsbut is not listed inmachina_1gating, it will need to be added.int32env IDs, it should be updated (thoughint64is standard for indexing).References