diff --git a/amago/agent.py b/amago/agent.py index a880595..01abf7a 100644 --- a/amago/agent.py +++ b/amago/agent.py @@ -4,6 +4,7 @@ import abc import itertools +from functools import partial from typing import Type, Optional, Tuple, Any, List, Iterable import torch @@ -200,6 +201,83 @@ def exp_filter( return weights +# compiled within the agent +def nstep_return( + r: torch.Tensor, + d: torch.Tensor, + q: torch.Tensor, + gamma: torch.Tensor, + n: int, + mask: torch.Tensor, +) -> torch.Tensor: + """Compute n-step TD targets + + Args: + r: Rewards, shape ``(B, L, 1, G, 1)``. Already scaled by reward_multiplier. + d: Dones (float 0 or 1), shape ``(B, L, 1, G, 1)``. + q: Bootstrap Q-values (ensemble-reduced), shape ``(B, L, 1, G, 1)``. + ``q[b, t]`` = Q(s_{t+1}, a'_{t+1}) — already aligned as next-state values. + gamma: Discount factors, shape ``(G, 1)``. + n: N-step horizon. n=1 recovers the standard 1-step Bellman target. + mask: Valid timestep mask (1=valid, 0=padding), shape ``(B, L, 1, 1, 1)``. + + Returns: + N-step TD targets, shape ``(B, L, 1, G, 1)``. + """ + B, L, _one, G, _one2 = r.shape + device = r.device + + gamma_pow = gamma.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, G, 1) + + r_masked = r * mask + d_masked = d * mask + + cumulative_reward = torch.zeros(B, L, 1, G, 1, device=device, dtype=r.dtype) + survival = torch.ones(B, L, 1, G, 1, device=device, dtype=r.dtype) + + last_survival = torch.ones(B, L, 1, G, 1, device=device, dtype=r.dtype) + last_q_idx = torch.zeros(B, L, 1, G, 1, device=device, dtype=torch.long) + reached_n = torch.zeros(B, L, 1, G, 1, device=device, dtype=torch.bool) + + for i in range(n): + if i == 0: + r_i = r_masked + d_i = d_masked + valid_i = mask + else: + r_i = F.pad(r_masked[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0) + d_i = F.pad(d_masked[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0) + valid_i = F.pad(mask[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0) + + step_valid = valid_i * (~reached_n).float() + cumulative_reward += (gamma_pow**i) * survival * r_i * step_valid + + new_survival = survival * (1.0 - d_i) + + active = step_valid.bool() + last_survival = torch.where(active, new_survival, last_survival) + last_q_idx = torch.where(active, torch.tensor(i, device=device), last_q_idx) + + done_here = (d_i > 0.5) & active + invalid_here = (valid_i < 0.5) & (~reached_n) + reached_n = reached_n | done_here | invalid_here + + survival = new_survival + + reached_n[:] = True + + k = last_q_idx + 1 + gamma_k = gamma_pow ** k.float() + + t_indices = torch.arange(L, device=device).view(1, L, 1, 1, 1).expand(B, L, 1, G, 1) + boot_indices = (t_indices + last_q_idx).clamp(max=L - 1) + + q_bootstrap = torch.gather(q, dim=1, index=boot_indices) + bootstrap = gamma_k * last_survival * q_bootstrap + + return cumulative_reward + bootstrap + + ##################### ## Built-in Agents ## ##################### @@ -495,6 +573,10 @@ def V(state, critic, action_dist, k) -> float: use_target_actor: If True, use a target actor to sample actions used in TD targets. Defaults to True. use_multigamma: If True, train on multiple discount horizons (:py:class:`~amago.agent.Multigammas`) in parallel. Defaults to True. + n_step: N-step horizon for TD targets. n=1 gives standard 1-step Bellman targets. + Higher values use :py:func:`~amago.agent.nstep_return` to sum discounted + rewards over ``n`` future steps before bootstrapping with Q(s_{t+n}). + Defaults to 1. actor_type: Actor MLP head for producing action distributions. Defaults to :py:class:`~amago.nets.actor_critic.Actor`. critic_type: Critic MLP head for producing Q-values. Defaults to :py:class:`~amago.nets.actor_critic.NCritics`. pass_obs_keys_to_actor: List of keys from the observation space to pass directly to the actor network's forward pass if needed for some reason (e.g., for masking actions). Defaults to None. @@ -523,6 +605,7 @@ def __init__( popart: bool = True, use_target_actor: bool = True, use_multigamma: bool = True, + n_step: int = 1, actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor, critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCritics, pass_obs_keys_to_actor: Optional[Iterable[str]] = None, @@ -546,6 +629,11 @@ def __init__( self.critic_loss_weight = critic_loss_weight self.tau = tau self.use_target_actor = use_target_actor + assert n_step >= 1, f"n_step must be >= 1, got {n_step}" + self.n_step = n_step + self._nstep_fn = torch.compile( + partial(nstep_return, n=n_step), mode="reduce-overhead" + ) multigammas = ( Multigammas().discrete if self.discrete else Multigammas().continuous ) @@ -663,23 +751,23 @@ def get_actions( dtype = torch.uint8 if (self.discrete or self.multibinary) else torch.float32 return actions.to(dtype=dtype), hidden_state - def _critic_ensemble_to_td_target(self, ensemble_td_target: torch.Tensor): - B, L, C, G, _ = ensemble_td_target.shape + def _reduce_critic_ensemble(self, q_ensemble: torch.Tensor) -> torch.Tensor: + B, L, C, G, _ = q_ensemble.shape # random subset of critic ensemble random_subset = torch.randint( low=0, high=C, size=(B, L, self.num_critics_td, G, 1), - device=ensemble_td_target.device, + device=q_ensemble.device, ) - td_target_rand = torch.take_along_dim(ensemble_td_target, random_subset, dim=2) + q_subset = torch.take_along_dim(q_ensemble, random_subset, dim=2) if self.online_coeff > 0: # clipped double q - td_target = td_target_rand.min(2, keepdims=True).values + q_reduced = q_subset.min(2, keepdims=True).values else: # without DPG updates the usual min creates strong underestimation. take mean instead - td_target = td_target_rand.mean(2, keepdims=True) - return td_target + q_reduced = q_subset.mean(2, keepdims=True) + return q_reduced def _compute_loss( self, @@ -809,10 +897,10 @@ def forward(self, batch: Batch, log_step: bool) -> torch.Tensor: # Q_target(s', a') q_targ_sp_ap_gp = self.popart(self.target_critics(*sp_ap_gp).mean(0), normalized=False) assert q_targ_sp_ap_gp.shape == (B, L - 1, C, G, 1) - # y = r + gamma * (1 - d) * Q_target(s', a') - ensemble_td_target = r + gamma * (1.0 - d) * q_targ_sp_ap_gp - assert ensemble_td_target.shape == (B, L - 1, C, G, 1) - td_target = self._critic_ensemble_to_td_target(ensemble_td_target) + q_reduced = self._reduce_critic_ensemble(q_targ_sp_ap_gp) + assert q_reduced.shape == (B, L - 1, 1, G, 1) + nstep_mask = state_mask.float().unsqueeze(-1).unsqueeze(-1) # (B, L-1, 1, 1, 1) + td_target = self._nstep_fn(r, d, q_reduced, gamma, mask=nstep_mask) assert td_target.shape == (B, L - 1, 1, G, 1) self.popart.update_stats( td_target, mask=critic_mask.all(2, keepdim=True) @@ -1061,6 +1149,7 @@ def __init__( popart: bool = True, use_target_actor: bool = True, use_multigamma: bool = True, + n_step: int = 1, actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor, critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCriticsTwoHot, pass_obs_keys_to_actor: Optional[Iterable[str]] = None, @@ -1087,6 +1176,7 @@ def __init__( critic_loss_weight=critic_loss_weight, use_target_actor=use_target_actor, use_multigamma=use_multigamma, + n_step=n_step, fbc_filter_func=fbc_filter_func, popart=popart, actor_type=actor_type, @@ -1168,10 +1258,10 @@ def forward(self, batch: Batch, log_step: bool): assert q_targ_sp_ap_gp.probs.shape == (K_c, B, L - 1, C, G, Bins) q_targ_sp_ap_gp = self.target_critics.bin_dist_to_raw_vals(q_targ_sp_ap_gp).mean(0) assert q_targ_sp_ap_gp.shape == (B, L - 1, C, G, 1) - # y = r + gamma * (1.0 - d) * Q(s', a') - ensemble_td_target = r + gamma * (1.0 - d) * q_targ_sp_ap_gp - assert ensemble_td_target.shape == (B, L - 1, C, G, 1) - td_target = self._critic_ensemble_to_td_target(ensemble_td_target) + q_reduced = self._reduce_critic_ensemble(q_targ_sp_ap_gp) + assert q_reduced.shape == (B, L - 1, 1, G, 1) + nstep_mask = state_mask.float().unsqueeze(-1).unsqueeze(-1) # (B, L-1, 1, 1, 1) + td_target = self._nstep_fn(r, d, q_reduced, gamma, mask=nstep_mask) assert td_target.shape == (B, L - 1, 1, G, 1) self.popart.update_stats( td_target, mask=critic_mask.all(2, keepdim=True)