Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 27 additions & 47 deletions joyrl/algos/TD3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,42 @@
import torch.nn as nn
import torch.nn.functional as F
from common.memories import ReplayBufferQue
class Actor(nn.Module):

def __init__(self, n_states, n_actions, hidden_dim = 256):
super(Actor, self).__init__()
from common.models import MLP, Critic

self.l1 = nn.Linear(n_states, hidden_dim)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, n_actions)

def forward(self, state):

x = F.relu(self.l1(state))
x = F.relu(self.l2(x))
x = torch.tanh(self.l3(x))
return x

class Critic(nn.Module):
def __init__(self, n_states, n_actions, hidden_dim = 256):
super(Critic, self).__init__()

self.l1 = nn.Linear(n_states + n_actions, 256)
self.l2 = nn.Linear(hidden_dim, hidden_dim)
self.l3 = nn.Linear(hidden_dim, 1)

def forward(self, state, action):
sa = torch.cat([state, action], 1)
q = F.relu(self.l1(sa))
q = F.relu(self.l2(q))
q = self.l3(q)
return q


class Agent(object):
def __init__(self,cfg):
self.gamma = cfg.gamma
self.actor_lr = cfg.actor_lr
self.critic_lr = cfg.critic_lr
self.policy_noise = cfg.policy_noise
self.noise_clip = cfg.noise_clip
self.expl_noise = cfg.expl_noise
self.policy_freq = cfg.policy_freq
self.policy_noise = cfg.policy_noise # noise added to target policy during critic update
self.noise_clip = cfg.noise_clip # range to clip target policy noise
self.expl_noise = cfg.expl_noise # std of Gaussian exploration noise
self.policy_freq = cfg.policy_freq # policy update frequency
self.batch_size = cfg.batch_size
self.tau = cfg.tau
self.sample_count = 0
self.policy_freq = cfg.policy_freq
self.explore_steps = cfg.explore_steps
self.explore_steps = cfg.explore_steps # exploration steps before training
self.device = torch.device(cfg.device)
self.n_actions = cfg.n_actions
self.action_space = cfg.action_space
self.actor_input_dim = cfg.n_states
self.actor_output_dim = cfg.n_actions
self.critic_input_dim = cfg.n_states + cfg.n_actions
self.critic_output_dim = 1
self.action_scale = torch.tensor((self.action_space.high - self.action_space.low)/2, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
self.action_bias = torch.tensor((self.action_space.high + self.action_space.low)/2, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
self.actor = Actor(cfg.n_states, cfg.n_actions, hidden_dim = cfg.actor_hidden_dim).to(self.device)
self.actor_target = Actor(cfg.n_states, cfg.n_actions, hidden_dim = cfg.actor_hidden_dim).to(self.device)
self.actor = MLP(self.actor_input_dim, self.actor_output_dim, hidden_dim = cfg.actor_hidden_dim).to(self.device)
self.actor_target = MLP(self.actor_input_dim, self.actor_output_dim, hidden_dim = cfg.actor_hidden_dim).to(self.device)
self.actor_target.load_state_dict(self.actor.state_dict())

self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr = self.actor_lr)

self.critic_1 = Critic(cfg.n_states, cfg.n_actions, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_2 = Critic(cfg.n_states, cfg.n_actions, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_1_target = Critic(cfg.n_states, cfg.n_actions, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_2_target = Critic(cfg.n_states, cfg.n_actions, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_1 = Critic(self.critic_input_dim, self.critic_output_dim, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_2 = Critic(self.critic_input_dim, self.critic_output_dim, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_1_target = Critic(self.critic_input_dim, self.critic_output_dim, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_2_target = Critic(self.critic_input_dim, self.critic_output_dim, hidden_dim = cfg.critic_hidden_dim).to(self.device)
self.critic_1_target.load_state_dict(self.critic_1.state_dict())
self.critic_2_target.load_state_dict(self.critic_2.state_dict())

Expand All @@ -88,7 +63,7 @@ def sample_action(self, state):
return self.action_space.sample()
else:
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
action = self.actor(state)
action = torch.tanh(self.actor(state))
action = self.action_scale * action + self.action_bias
action = action.detach().cpu().numpy()[0]
action_noise = np.random.normal(0, self.action_scale.cpu().numpy()[0] * self.expl_noise, size=self.n_actions)
Expand All @@ -98,7 +73,7 @@ def sample_action(self, state):
@torch.no_grad()
def predict_action(self, state):
state = torch.tensor(state, device=self.device, dtype=torch.float32).unsqueeze(dim=0)
action = self.actor(state)
action = torch.tanh(self.actor(state))
action = self.action_scale * action + self.action_bias
return action.detach().cpu().numpy()[0]

Expand All @@ -108,6 +83,7 @@ def update(self):
if len(self.memory) < self.explore_steps:
return
state, action, reward, next_state, done = self.memory.sample(self.batch_size)
# convert to tensor
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32)
action = torch.tensor(np.array(action), device=self.device, dtype=torch.float32)
next_state = torch.tensor(np.array(next_state), device=self.device, dtype=torch.float32)
Expand All @@ -116,10 +92,13 @@ def update(self):
# update critic
noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_action = (self.actor_target(next_state) + noise).clamp(-self.action_scale+self.action_bias, self.action_scale+self.action_bias)
target_q1, target_q2 = self.critic_1_target(next_state, next_action).detach(), self.critic_2_target(next_state, next_action).detach()
target_q = torch.min(target_q1, target_q2)
next_sa = torch.cat([next_state, next_action], 1) # shape:[train_batch_size,n_states+n_actions]
target_q1, target_q2 = self.critic_1_target(next_sa).detach(), self.critic_2_target(next_sa).detach()
target_q = torch.min(target_q1, target_q2) # shape:[train_batch_size,n_actions]
target_q = reward + self.gamma * target_q * (1 - done)
current_q1, current_q2 = self.critic_1(state, action), self.critic_2(state, action)
sa = torch.cat([state, action], 1)
current_q1, current_q2 = self.critic_1(sa), self.critic_2(sa)
# compute critic loss
critic_1_loss = F.mse_loss(current_q1, target_q)
critic_2_loss = F.mse_loss(current_q2, target_q)
self.critic_1_optimizer.zero_grad()
Expand All @@ -130,7 +109,8 @@ def update(self):
self.critic_2_optimizer.step()
# Delayed policy updates
if self.sample_count % self.policy_freq == 0:
actor_loss = -self.critic_1(state, self.actor(state)).mean()
# compute actor loss
actor_loss = -self.critic_1(torch.cat([state, torch.tanh(self.actor(state))], 1)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
Expand Down