Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions amago/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc
import itertools
from functools import partial
from typing import Type, Optional, Tuple, Any, List, Iterable

import torch
Expand Down Expand Up @@ -200,6 +201,83 @@ def exp_filter(
return weights


# compiled within the agent
def nstep_return(
r: torch.Tensor,
d: torch.Tensor,
q: torch.Tensor,
gamma: torch.Tensor,
n: int,
mask: torch.Tensor,
) -> torch.Tensor:
"""Compute n-step TD targets

Args:
r: Rewards, shape ``(B, L, 1, G, 1)``. Already scaled by reward_multiplier.
d: Dones (float 0 or 1), shape ``(B, L, 1, G, 1)``.
q: Bootstrap Q-values (ensemble-reduced), shape ``(B, L, 1, G, 1)``.
``q[b, t]`` = Q(s_{t+1}, a'_{t+1}) — already aligned as next-state values.
gamma: Discount factors, shape ``(G, 1)``.
n: N-step horizon. n=1 recovers the standard 1-step Bellman target.
mask: Valid timestep mask (1=valid, 0=padding), shape ``(B, L, 1, 1, 1)``.

Returns:
N-step TD targets, shape ``(B, L, 1, G, 1)``.
"""
B, L, _one, G, _one2 = r.shape
device = r.device

gamma_pow = gamma.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, G, 1)

r_masked = r * mask
d_masked = d * mask

cumulative_reward = torch.zeros(B, L, 1, G, 1, device=device, dtype=r.dtype)
survival = torch.ones(B, L, 1, G, 1, device=device, dtype=r.dtype)

last_survival = torch.ones(B, L, 1, G, 1, device=device, dtype=r.dtype)
last_q_idx = torch.zeros(B, L, 1, G, 1, device=device, dtype=torch.long)
reached_n = torch.zeros(B, L, 1, G, 1, device=device, dtype=torch.bool)

for i in range(n):
if i == 0:
r_i = r_masked
d_i = d_masked
valid_i = mask
else:
r_i = F.pad(r_masked[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0)
d_i = F.pad(d_masked[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0)
valid_i = F.pad(mask[:, i:, ...], (0, 0, 0, 0, 0, 0, 0, i), value=0.0)

step_valid = valid_i * (~reached_n).float()
cumulative_reward += (gamma_pow**i) * survival * r_i * step_valid

new_survival = survival * (1.0 - d_i)

active = step_valid.bool()
last_survival = torch.where(active, new_survival, last_survival)
last_q_idx = torch.where(active, torch.tensor(i, device=device), last_q_idx)

done_here = (d_i > 0.5) & active
invalid_here = (valid_i < 0.5) & (~reached_n)
reached_n = reached_n | done_here | invalid_here

survival = new_survival

reached_n[:] = True

k = last_q_idx + 1
gamma_k = gamma_pow ** k.float()

t_indices = torch.arange(L, device=device).view(1, L, 1, 1, 1).expand(B, L, 1, G, 1)
boot_indices = (t_indices + last_q_idx).clamp(max=L - 1)

q_bootstrap = torch.gather(q, dim=1, index=boot_indices)
bootstrap = gamma_k * last_survival * q_bootstrap

return cumulative_reward + bootstrap


#####################
## Built-in Agents ##
#####################
Expand Down Expand Up @@ -495,6 +573,10 @@ def V(state, critic, action_dist, k) -> float:
use_target_actor: If True, use a target actor to sample actions used in TD targets.
Defaults to True.
use_multigamma: If True, train on multiple discount horizons (:py:class:`~amago.agent.Multigammas`) in parallel. Defaults to True.
n_step: N-step horizon for TD targets. n=1 gives standard 1-step Bellman targets.
Higher values use :py:func:`~amago.agent.nstep_return` to sum discounted
rewards over ``n`` future steps before bootstrapping with Q(s_{t+n}).
Defaults to 1.
actor_type: Actor MLP head for producing action distributions. Defaults to :py:class:`~amago.nets.actor_critic.Actor`.
critic_type: Critic MLP head for producing Q-values. Defaults to :py:class:`~amago.nets.actor_critic.NCritics`.
pass_obs_keys_to_actor: List of keys from the observation space to pass directly to the actor network's forward pass if needed for some reason (e.g., for masking actions). Defaults to None.
Expand Down Expand Up @@ -523,6 +605,7 @@ def __init__(
popart: bool = True,
use_target_actor: bool = True,
use_multigamma: bool = True,
n_step: int = 1,
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCritics,
pass_obs_keys_to_actor: Optional[Iterable[str]] = None,
Expand All @@ -546,6 +629,11 @@ def __init__(
self.critic_loss_weight = critic_loss_weight
self.tau = tau
self.use_target_actor = use_target_actor
assert n_step >= 1, f"n_step must be >= 1, got {n_step}"
self.n_step = n_step
self._nstep_fn = torch.compile(
partial(nstep_return, n=n_step), mode="reduce-overhead"
)
multigammas = (
Multigammas().discrete if self.discrete else Multigammas().continuous
)
Expand Down Expand Up @@ -663,23 +751,23 @@ def get_actions(
dtype = torch.uint8 if (self.discrete or self.multibinary) else torch.float32
return actions.to(dtype=dtype), hidden_state

def _critic_ensemble_to_td_target(self, ensemble_td_target: torch.Tensor):
B, L, C, G, _ = ensemble_td_target.shape
def _reduce_critic_ensemble(self, q_ensemble: torch.Tensor) -> torch.Tensor:
B, L, C, G, _ = q_ensemble.shape
# random subset of critic ensemble
random_subset = torch.randint(
low=0,
high=C,
size=(B, L, self.num_critics_td, G, 1),
device=ensemble_td_target.device,
device=q_ensemble.device,
)
td_target_rand = torch.take_along_dim(ensemble_td_target, random_subset, dim=2)
q_subset = torch.take_along_dim(q_ensemble, random_subset, dim=2)
if self.online_coeff > 0:
# clipped double q
td_target = td_target_rand.min(2, keepdims=True).values
q_reduced = q_subset.min(2, keepdims=True).values
else:
# without DPG updates the usual min creates strong underestimation. take mean instead
td_target = td_target_rand.mean(2, keepdims=True)
return td_target
q_reduced = q_subset.mean(2, keepdims=True)
return q_reduced

def _compute_loss(
self,
Expand Down Expand Up @@ -809,10 +897,10 @@ def forward(self, batch: Batch, log_step: bool) -> torch.Tensor:
# Q_target(s', a')
q_targ_sp_ap_gp = self.popart(self.target_critics(*sp_ap_gp).mean(0), normalized=False)
assert q_targ_sp_ap_gp.shape == (B, L - 1, C, G, 1)
# y = r + gamma * (1 - d) * Q_target(s', a')
ensemble_td_target = r + gamma * (1.0 - d) * q_targ_sp_ap_gp
assert ensemble_td_target.shape == (B, L - 1, C, G, 1)
td_target = self._critic_ensemble_to_td_target(ensemble_td_target)
q_reduced = self._reduce_critic_ensemble(q_targ_sp_ap_gp)
assert q_reduced.shape == (B, L - 1, 1, G, 1)
nstep_mask = state_mask.float().unsqueeze(-1).unsqueeze(-1) # (B, L-1, 1, 1, 1)
td_target = self._nstep_fn(r, d, q_reduced, gamma, mask=nstep_mask)
assert td_target.shape == (B, L - 1, 1, G, 1)
self.popart.update_stats(
td_target, mask=critic_mask.all(2, keepdim=True)
Expand Down Expand Up @@ -1061,6 +1149,7 @@ def __init__(
popart: bool = True,
use_target_actor: bool = True,
use_multigamma: bool = True,
n_step: int = 1,
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCriticsTwoHot,
pass_obs_keys_to_actor: Optional[Iterable[str]] = None,
Expand All @@ -1087,6 +1176,7 @@ def __init__(
critic_loss_weight=critic_loss_weight,
use_target_actor=use_target_actor,
use_multigamma=use_multigamma,
n_step=n_step,
fbc_filter_func=fbc_filter_func,
popart=popart,
actor_type=actor_type,
Expand Down Expand Up @@ -1168,10 +1258,10 @@ def forward(self, batch: Batch, log_step: bool):
assert q_targ_sp_ap_gp.probs.shape == (K_c, B, L - 1, C, G, Bins)
q_targ_sp_ap_gp = self.target_critics.bin_dist_to_raw_vals(q_targ_sp_ap_gp).mean(0)
assert q_targ_sp_ap_gp.shape == (B, L - 1, C, G, 1)
# y = r + gamma * (1.0 - d) * Q(s', a')
ensemble_td_target = r + gamma * (1.0 - d) * q_targ_sp_ap_gp
assert ensemble_td_target.shape == (B, L - 1, C, G, 1)
td_target = self._critic_ensemble_to_td_target(ensemble_td_target)
q_reduced = self._reduce_critic_ensemble(q_targ_sp_ap_gp)
assert q_reduced.shape == (B, L - 1, 1, G, 1)
nstep_mask = state_mask.float().unsqueeze(-1).unsqueeze(-1) # (B, L-1, 1, 1, 1)
td_target = self._nstep_fn(r, d, q_reduced, gamma, mask=nstep_mask)
assert td_target.shape == (B, L - 1, 1, G, 1)
self.popart.update_stats(
td_target, mask=critic_mask.all(2, keepdim=True)
Expand Down