@@ -51,6 +51,8 @@ def __init__(
5151
5252 # Row-aligned tracking (per-agent row slot id and position within row)
5353 self .t_in_row = torch .zeros (total_agents , device = self .device , dtype = torch .int32 )
54+ # Keep a CPU mirror to avoid GPU syncs for scalar reads.
55+ self ._t_in_row_cpu = torch .zeros (total_agents , device = "cpu" , dtype = torch .int32 )
5456 self .row_slot_ids = torch .arange (total_agents , device = self .device , dtype = torch .int32 ) % self .segments
5557 self .free_idx = total_agents % self .segments
5658
@@ -114,7 +116,7 @@ def store(self, data_td: TensorDict, env_id: slice) -> None:
114116 assert isinstance (env_id , slice ), (
115117 f"TypeError: env_id expected to be a slice for segmented storage. Got { type (env_id ).__name__ } instead."
116118 )
117- t_in_row_val = self .t_in_row [env_id .start ].item ()
119+ t_in_row_val = int ( self ._t_in_row_cpu [env_id .start ].item () )
118120 row_ids = self .row_slot_ids [env_id ]
119121
120122 # Scheduler updates these keys based on the active losses for the epoch.
@@ -124,6 +126,7 @@ def store(self, data_td: TensorDict, env_id: slice) -> None:
124126 raise ValueError ("No store keys set. set_store_keys() was likely used incorrectly." )
125127
126128 self .t_in_row [env_id ] += 1
129+ self ._t_in_row_cpu [env_id ] += 1
127130
128131 if t_in_row_val + 1 >= self .bptt_horizon :
129132 self ._reset_completed_episodes (env_id )
@@ -133,6 +136,7 @@ def _reset_completed_episodes(self, env_id) -> None:
133136 num_full = env_id .stop - env_id .start
134137 self .row_slot_ids [env_id ] = (self .free_idx + self ._range_tensor [:num_full ]) % self .segments
135138 self .t_in_row [env_id ] = 0
139+ self ._t_in_row_cpu [env_id ] = 0
136140 self .free_idx = (self .free_idx + num_full ) % self .segments
137141 self .full_rows += num_full
138142
@@ -142,6 +146,7 @@ def reset_for_rollout(self) -> None:
142146 self .free_idx = self .total_agents % self .segments
143147 self .row_slot_ids = self ._range_tensor % self .segments
144148 self .t_in_row .zero_ ()
149+ self ._t_in_row_cpu .zero_ ()
145150
146151 def update (self , indices : Tensor , data_td : TensorDict ) -> None :
147152 """Update buffer with new data for given indices."""
0 commit comments