From 8d548c8514bb9c210ca9bb111dc0431488339fde Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 18 Jan 2023 10:41:07 +0000 Subject: [PATCH 1/4] init --- ppo_pong/conf.yaml | 5 ++-- ppo_pong/ppo.py | 67 +++++++++++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/ppo_pong/conf.yaml b/ppo_pong/conf.yaml index 7754734..0321f01 100644 --- a/ppo_pong/conf.yaml +++ b/ppo_pong/conf.yaml @@ -1,11 +1,12 @@ # Logger experiment_name: atari_pong agent_name: ppo_agent +entity: vmoens wandb_key: null log_dir: /tmp/atari_pong # Environment -env_name: PongNoFrameskip-v4 +env_name: ALE/Pong-v5 frame_skip: 4 # Collector @@ -25,4 +26,4 @@ gae_lamdda: 0.95 lr: 2.5e-4 num_ppo_epochs: 3 mini_batch_size: 256 # so 4 mini_batches - (8 * 128) / 256 -evaluation_frequency: 100 # In number of network updates +evaluation_frequency: 500 # In number of network updates diff --git a/ppo_pong/ppo.py b/ppo_pong/ppo.py index 286afe4..e10db28 100644 --- a/ppo_pong/ppo.py +++ b/ppo_pong/ppo.py @@ -1,6 +1,8 @@ +import gc import time import yaml import wandb +import tqdm import torch import argparse @@ -8,13 +10,11 @@ from torchrl.envs.libs.gym import GymEnv from torchrl.envs import TransformedEnv from torchrl.envs.vec_env import ParallelEnv -from torchrl.envs.transforms import ToTensorImage, GrayScale, CatFrames, NoopResetEnv, Resize -from .transforms.reward_sum import RewardSum -from .transforms.step_limit import StepLimit +from torchrl.envs.transforms import ToTensorImage, GrayScale, CatFrames, NoopResetEnv, Resize, ObservationNorm, RewardSum, StepCounter # Model imports from torchrl.envs import EnvCreator -from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.utils import set_exploration_mode, check_env_specs from torchrl.modules.models import ConvNet, MLP from torchrl.modules.distributions import OneHotCategorical from torchrl.modules import SafeModule, ProbabilisticActor, ValueOperator, ActorValueOperator @@ -30,6 +30,7 @@ from torch.optim import Adam from torch.optim.lr_scheduler import LinearLR from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler +import torch.distributions as dist def main(): @@ -37,14 +38,15 @@ def main(): args = get_args() device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + device_collection = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # 1. Define environment -------------------------------------------------------------------------------------------- # 1.1 Define env factory - def env_factory(): + def env_factory(device=device): """Creates an instance of the environment.""" - create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip)) + create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True)) # 1.2 Create env vector vec_env = ParallelEnv(create_env_fn=create_env_fn, num_workers=args.num_parallel_envs) @@ -55,14 +57,23 @@ def env_factory(): transformed_vec_env.append_transform(ToTensorImage()) # change shape from [h, w, 3] to [3, h, w] transformed_vec_env.append_transform(Resize(w=84, h=84)) # Resize image transformed_vec_env.append_transform(GrayScale()) # Convert to Grayscale + transformed_vec_env.append_transform(ObservationNorm(in_keys=["pixels"], standard_normal=True)) transformed_vec_env.append_transform(CatFrames(N=4)) # Stack last 4 frames transformed_vec_env.append_transform(RewardSum()) - transformed_vec_env.append_transform(StepLimit()) + transformed_vec_env.append_transform(StepCounter()) - return transformed_vec_env + norm_layer = transformed_vec_env.transform[3] + norm_layer.init_stats(num_iter=1000, reduce_dim=[0, 1, 3, 4], cat_dim=0, keep_dims=[3, 4]) + + return transformed_vec_env.to(device) # Sanity check test_env = env_factory() + train_env = env_factory() + check_env_specs(train_env) + + test_env.load_state_dict(train_env.state_dict()) + test_input = test_env.reset() assert "pixels" in test_input.keys() num_actions = test_env.specs["action_spec"].space.n @@ -80,7 +91,7 @@ def env_factory(): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], - ) + ).to(device) common_cnn_output = common_cnn(torch.ones_like(test_input["pixels"])) # Add MLP on top of shared MLP @@ -89,7 +100,7 @@ def env_factory(): activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=448, - num_cells=[256]) + num_cells=[256]).to(device) common_mlp_output = common_mlp(common_cnn_output) # Define shared net as TensorDictModule @@ -106,7 +117,7 @@ def env_factory(): in_features=common_mlp_output.shape[-1], out_features=num_actions, num_cells=[] - ) + ).to(device) # Define TensorDictModule policy_module = SafeModule( # TODO: The naming of SafeModule is confusing @@ -120,7 +131,7 @@ def env_factory(): policy_module, in_keys=["logits"], # TODO: Seems like only "logits" can be used as in_keys # out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=dist.Categorical, distribution_kwargs={}, return_log_prob=True, ) @@ -132,7 +143,7 @@ def env_factory(): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[] - ) + ).to(device) # Define TensorDictModule value_module = ValueOperator( @@ -154,14 +165,15 @@ def env_factory(): td = td.to(device) td = actor_critic(td) + print(actor_critic) # TODO: why wrap them together and then get separate operators ? # Get independent operators for actor and critic, to be able to call only one of them actor = actor_critic.get_policy_operator() critic = actor_critic.get_value_operator() # sanity check - actor(env_factory().reset()) - actor_critic(env_factory().reset()) + actor(test_env.reset()) + actor_critic(test_env.reset()) # # Ugly hack, otherwise I get errors # critic.out_keys = ['state_value', 'common_features'] @@ -170,11 +182,13 @@ def env_factory(): # 2. Define Collector ---------------------------------------------------------------------------------------------- collector = SyncDataCollector( - create_env_fn=env_factory, + create_env_fn=train_env, create_env_kwargs=None, policy=actor_critic, total_frames=args.total_frames, frames_per_batch=args.steps_per_env * args.num_parallel_envs, + device=device_collection, + passing_device=device_collection, ) # 3. Define Loss --------------------------------------------------------------------------------------------------- @@ -203,7 +217,7 @@ def env_factory(): mode = "online" wandb.login(key=str(args.wandb_key)) else: - mode = "disabled" + mode = "offline" # 5. Define training loop ------------------------------------------------------------------------------------------ @@ -216,17 +230,18 @@ def env_factory(): scheduler = LinearLR(optimizer, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1) evaluation_frequency = 100 # In number of network frames - with wandb.init(project=args.experiment_name, name=args.agent_name, config=args, mode=mode): + with wandb.init(project=args.experiment_name, name=args.agent_name, entity=args.entity, config=args, mode=mode): + pbar = tqdm.tqdm(total=collector.total_frames) start_time = time.time() for batch in collector: - + batch = batch.cpu() log_info = {} - # We don't use memory networks, so sequence dimension is not relevant batch_size = batch["mask"].sum().item() + pbar.update(batch_size) collected_frames += batch_size # add episode reward info @@ -239,7 +254,7 @@ def env_factory(): for epoch in range(args.num_ppo_epochs): # Compute advantage with the whole batch - batch = advantage_module(batch) + batch = advantage_module(batch.to(device)) batch_view = batch[batch["mask"].squeeze(-1)] @@ -248,7 +263,7 @@ def env_factory(): SubsetRandomSampler(range(batch_size)), args.mini_batch_size, drop_last=True): # select idxs to create mini_batch - mini_batch = batch_view[mini_batch_idxs].clone() + mini_batch = batch_view[mini_batch_idxs].clone().to(device) # Forward pass loss = loss_module(mini_batch) @@ -271,6 +286,7 @@ def env_factory(): "loss_objective": loss["loss_objective"].item(), "learning_rate": float(scheduler.get_last_lr()[0]), "collected_frames": collected_frames, + "reward": batch["reward"][batch["mask"]].mean().item(), }) if network_updates % args.evaluation_frequency == 0 and network_updates != 0: @@ -288,10 +304,10 @@ def env_factory(): # Print an informative message in the terminal fps = int(collected_frames * args.frame_skip / (time.time() - start_time)) print_msg = f"Update {network_updates}, num " \ - f"samples collected {collected_frames * args.frame_skip}, FPS {fps}\n " + f"samples collected {collected_frames * args.frame_skip}, FPS {fps} " for k, v in log_info.items(): - print_msg += f"{k}: {v} " - print(print_msg, flush=True) + print_msg += f"{k}: {v: 4.4f} " + pbar.set_description(print_msg) log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) wandb.log(log_info, step=network_updates) @@ -299,6 +315,7 @@ def env_factory(): # Update collector weights! collector.update_policy_weights_() + gc.collect() def get_args(): From ccd597e41dd92512d545eea087657a753e7e389a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 31 Jan 2023 18:02:41 +0000 Subject: [PATCH 2/4] init --- ppo_pong/conf.yaml | 10 +-- ppo_pong/ppo.py | 157 +++++++++++++++++++++++++++------------------ sac_mujoco/sac.py | 6 +- 3 files changed, 104 insertions(+), 69 deletions(-) diff --git a/ppo_pong/conf.yaml b/ppo_pong/conf.yaml index 0321f01..3e82c44 100644 --- a/ppo_pong/conf.yaml +++ b/ppo_pong/conf.yaml @@ -2,7 +2,7 @@ experiment_name: atari_pong agent_name: ppo_agent entity: vmoens -wandb_key: null +wandb_key: d0bee782a83f90cbc11177e36a092de77585cbb3 log_dir: /tmp/atari_pong # Environment @@ -11,19 +11,19 @@ frame_skip: 4 # Collector total_frames: 40_000_000 # without accounting for frame skip -num_parallel_envs: 8 +num_parallel_envs: 16 steps_per_env: 128 # between network updates # Loss gamma: 0.99 clip_epsilon: 0.1 -loss_critic_type: l2 -entropy_coef: 0.0001 +loss_critic_type: l1_smooth +entropy_coef: 0.001 critic_coef: 1.0 gae_lamdda: 0.95 # Training loop lr: 2.5e-4 -num_ppo_epochs: 3 +num_ppo_epochs: 10 mini_batch_size: 256 # so 4 mini_batches - (8 * 128) / 256 evaluation_frequency: 500 # In number of network updates diff --git a/ppo_pong/ppo.py b/ppo_pong/ppo.py index e10db28..81c2fb6 100644 --- a/ppo_pong/ppo.py +++ b/ppo_pong/ppo.py @@ -1,4 +1,5 @@ import gc +import os import time import yaml import wandb @@ -9,14 +10,16 @@ # Environment imports from torchrl.envs.libs.gym import GymEnv from torchrl.envs import TransformedEnv -from torchrl.envs.vec_env import ParallelEnv +if os.environ.get("PARALLEL", False): + from torchrl.envs.vec_env import ParallelEnv +else: + from torchrl.envs.vec_env import SerialEnv as ParallelEnv from torchrl.envs.transforms import ToTensorImage, GrayScale, CatFrames, NoopResetEnv, Resize, ObservationNorm, RewardSum, StepCounter # Model imports from torchrl.envs import EnvCreator -from torchrl.envs.utils import set_exploration_mode, check_env_specs +from torchrl.envs.utils import set_exploration_mode, check_env_specs, step_mdp from torchrl.modules.models import ConvNet, MLP -from torchrl.modules.distributions import OneHotCategorical from torchrl.modules import SafeModule, ProbabilisticActor, ValueOperator, ActorValueOperator # Collector imports @@ -31,7 +34,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler import torch.distributions as dist - +from tensordict import TensorDict +from copy import deepcopy def main(): @@ -43,13 +47,16 @@ def main(): # 1. Define environment -------------------------------------------------------------------------------------------- # 1.1 Define env factory - def env_factory(device=device): + def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs=1): """Creates an instance of the environment.""" - create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True)) # 1.2 Create env vector - vec_env = ParallelEnv(create_env_fn=create_env_fn, num_workers=args.num_parallel_envs) + if num_parallel_envs > 1: + create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True, device=device)) + vec_env = ParallelEnv(create_env_fn=create_env_fn, num_workers=num_parallel_envs) + else: + vec_env = GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True, device=device) # 1.3 Apply transformations to vec env - standard DeepMind Atari - Order of transforms is important! transformed_vec_env = TransformedEnv(vec_env) @@ -58,24 +65,33 @@ def env_factory(device=device): transformed_vec_env.append_transform(Resize(w=84, h=84)) # Resize image transformed_vec_env.append_transform(GrayScale()) # Convert to Grayscale transformed_vec_env.append_transform(ObservationNorm(in_keys=["pixels"], standard_normal=True)) - transformed_vec_env.append_transform(CatFrames(N=4)) # Stack last 4 frames + transformed_vec_env.append_transform(CatFrames(dim=-3, N=4)) # Stack last 4 frames transformed_vec_env.append_transform(RewardSum()) transformed_vec_env.append_transform(StepCounter()) + transformed_vec_env.set_seed(seed) - norm_layer = transformed_vec_env.transform[3] - norm_layer.init_stats(num_iter=1000, reduce_dim=[0, 1, 3, 4], cat_dim=0, keep_dims=[3, 4]) + if init_stats_steps: + norm_layer = transformed_vec_env.transform[3] + batch_dims = len(transformed_vec_env.batch_size) + norm_layer.init_stats( + num_iter=init_stats_steps, + reduce_dim=[*range(batch_dims), batch_dims, batch_dims+2, batch_dims+3], + cat_dim=batch_dims, + keep_dims=[batch_dims+2, batch_dims+3] + ) - return transformed_vec_env.to(device) + return transformed_vec_env + + torch.manual_seed(0) + test_env = env_factory(device, init_stats_steps=2, seed=0, num_parallel_envs=1) + train_env = env_factory(device_collection, init_stats_steps=1000, seed=1, num_parallel_envs=args.num_parallel_envs) - # Sanity check - test_env = env_factory() - train_env = env_factory() check_env_specs(train_env) - test_env.load_state_dict(train_env.state_dict()) + # Sanity check + test_env.transform.load_state_dict(train_env.transform.state_dict()) + - test_input = test_env.reset() - assert "pixels" in test_input.keys() num_actions = test_env.specs["action_spec"].space.n # 2. Define model -------------------------------------------------------------------------------------------------- @@ -87,17 +103,17 @@ def env_factory(device=device): # Define CNN common_cnn = ConvNet( - activation_class=torch.nn.ReLU, + activation_class=torch.nn.ELU, num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ).to(device) - common_cnn_output = common_cnn(torch.ones_like(test_input["pixels"])) + common_cnn_output = common_cnn(test_env.reset()["pixels"]) # Add MLP on top of shared MLP common_mlp = MLP( in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, + activation_class=torch.nn.ELU, activate_last_layer=True, out_features=448, num_cells=[256]).to(device) @@ -118,6 +134,7 @@ def env_factory(device=device): out_features=num_actions, num_cells=[] ).to(device) + policy_net[0].bias.data.fill_(0.0) # Define TensorDictModule policy_module = SafeModule( # TODO: The naming of SafeModule is confusing @@ -161,9 +178,10 @@ def env_factory(device=device): # 2.6 Initialize the model by running a forward pass with torch.no_grad(): - td = test_env.rollout(max_steps=1000) + td = test_env.rollout(max_steps=1000, break_when_any_done=False) td = td.to(device) td = actor_critic(td) + del td print(actor_critic) # TODO: why wrap them together and then get separate operators ? @@ -181,15 +199,11 @@ def env_factory(device=device): # 2. Define Collector ---------------------------------------------------------------------------------------------- - collector = SyncDataCollector( - create_env_fn=train_env, - create_env_kwargs=None, - policy=actor_critic, - total_frames=args.total_frames, - frames_per_batch=args.steps_per_env * args.num_parallel_envs, - device=device_collection, - passing_device=device_collection, - ) + train_env = train_env.to(device_collection) + if device_collection != device: + actor_collection = deepcopy(actor).to(device_collection).requires_grad_(False) + else: + actor_collection = actor # 3. Define Loss --------------------------------------------------------------------------------------------------- @@ -228,35 +242,56 @@ def env_factory(device=device): total_network_updates = (args.total_frames // batch_size) * args.num_ppo_epochs * num_mini_batches optimizer = Adam(params=actor_critic.parameters(), lr=args.lr) scheduler = LinearLR(optimizer, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1) - evaluation_frequency = 100 # In number of network frames - with wandb.init(project=args.experiment_name, name=args.agent_name, entity=args.entity, config=args, mode=mode): + @torch.no_grad() + @set_exploration_mode("random") + def dataloader(total_frames, fpb): + """This is a simplified dataloader.""" + if device_collection != device: + params = TensorDict({k: v for k, v in actor.named_parameters()}, batch_size=[]).unflatten_keys(".") + params_collection = TensorDict({k: v for k, v in actor_collection.named_parameters()}, batch_size=[]).unflatten_keys(".") + _prev = None + + collected_frames = 0 + while collected_frames < total_frames: + if device_collection != device: + params_collection.update_(params) + batch = TensorDict({}, batch_size=[fpb, *train_env.batch_size], device=device_collection) + for t in range(fpb): + if _prev is None: + _prev = train_env.reset() + _reset = _prev["_reset"] = _prev["done"].clone().squeeze(-1) + if _reset.any(): + _prev = train_env.reset(_prev) + _new = train_env.step(actor_collection(_prev)) + batch[t] = _new + _prev = step_mdp(_new, exclude_done=False) + collected_frames += batch.numel() + yield batch - pbar = tqdm.tqdm(total=collector.total_frames) + with wandb.init(project=args.experiment_name, name=args.agent_name, entity=args.entity, config=args, mode=mode): + total_frames = args.total_frames // args.frame_skip + fpb = args.steps_per_env + pbar = tqdm.tqdm(total=total_frames) start_time = time.time() - for batch in collector: + + for batch in dataloader(total_frames, fpb): batch = batch.cpu() log_info = {} # We don't use memory networks, so sequence dimension is not relevant - batch_size = batch["mask"].sum().item() + batch_size = batch.numel() pbar.update(batch_size) collected_frames += batch_size - # add episode reward info - # train_episode_reward = batch["episode_reward"][batch["done"]] - #if batch["episode_reward"][batch["done"]].numel() > 0: - # log_info.update({"train_episode_rewards": train_episode_reward.mean()}) - # episode_steps = batch["episode_steps"][batch["done"]] - # PPO epochs for epoch in range(args.num_ppo_epochs): # Compute advantage with the whole batch - batch = advantage_module(batch.to(device)) + batch = advantage_module(batch.to(device)).cpu() - batch_view = batch[batch["mask"].squeeze(-1)] + batch_view = batch.reshape(-1) # Create a random permutation in every epoch for mini_batch_idxs in BatchSampler( @@ -274,19 +309,23 @@ def env_factory(device=device): # Update networks optimizer.zero_grad() loss_sum.backward() - torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=0.5) + grad_norm = torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=10.0) optimizer.step() scheduler.step() network_updates += 1 log_info.update({ "loss": loss_sum.item(), - "loss_critic": loss["loss_critic"].item(), - "loss_entropy": loss["loss_entropy"].item(), - "loss_objective": loss["loss_objective"].item(), - "learning_rate": float(scheduler.get_last_lr()[0]), - "collected_frames": collected_frames, - "reward": batch["reward"][batch["mask"]].mean().item(), + "loss_cri": loss["loss_critic"].item(), + "loss_ent": loss["loss_entropy"].item(), + "loss_obj": loss["loss_objective"].item(), + "lr": float(scheduler.get_last_lr()[0]), + "gn": grad_norm, + "frames": collected_frames, + "reward": batch["reward"].mean().item(), + "traj_len": batch["step_count"].max().item(), + "pix avg": batch["pixels"].mean(), + "pix std": batch["pixels"].std(), }) if network_updates % args.evaluation_frequency == 0 and network_updates != 0: @@ -295,27 +334,23 @@ def env_factory(device=device): test_env.eval() test_td = test_env.rollout( policy=actor, - max_steps=100000, + max_steps=10000, auto_reset=True, auto_cast_to_device=True, + break_when_any_done=False, ).clone() log_info.update({"test_reward": test_td["reward"].squeeze(-1).sum(-1).mean()}) + del test_td # Print an informative message in the terminal fps = int(collected_frames * args.frame_skip / (time.time() - start_time)) print_msg = f"Update {network_updates}, num " \ f"samples collected {collected_frames * args.frame_skip}, FPS {fps} " for k, v in log_info.items(): - print_msg += f"{k}: {v: 4.4f} " - pbar.set_description(print_msg) - - log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) - wandb.log(log_info, step=network_updates) - del mini_batch - - # Update collector weights! - collector.update_policy_weights_() - gc.collect() + print_msg += f"{k}: {v: 4.2f} " + pbar.set_description(print_msg) + log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) + wandb.log(log_info, step=network_updates) def get_args(): diff --git a/sac_mujoco/sac.py b/sac_mujoco/sac.py index 2cf2e4d..02c1c85 100644 --- a/sac_mujoco/sac.py +++ b/sac_mujoco/sac.py @@ -243,7 +243,7 @@ def main(): ) actor = ProbabilisticActor( spec=action_spec, - dist_in_keys=["loc", "scale"], + in_keys=["loc", "scale"], module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, @@ -467,7 +467,7 @@ def get_args(): ) parser.add_argument( "--from_pixels", - action=argparse.BooleanOptionalAction, + action="store_true",#argparse.BooleanOptionalAction, default=False, help="Use pixel observations. Default: False", ) @@ -500,7 +500,7 @@ def get_args(): ) parser.add_argument( "--prb", - action=argparse.BooleanOptionalAction, + action="store_true",#argparse.BooleanOptionalAction, default=False, help="Use Prioritized Experience Replay Buffer. Default: False", ) From 4784bdb238891d9c7c9f07785c4bb46b896e688b Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 31 Jan 2023 21:03:25 +0000 Subject: [PATCH 3/4] amend --- ppo_pong/conf.yaml | 1 + ppo_pong/ppo.py | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ppo_pong/conf.yaml b/ppo_pong/conf.yaml index 3e82c44..e730219 100644 --- a/ppo_pong/conf.yaml +++ b/ppo_pong/conf.yaml @@ -21,6 +21,7 @@ loss_critic_type: l1_smooth entropy_coef: 0.001 critic_coef: 1.0 gae_lamdda: 0.95 +clip_grad: 0.5 # Training loop lr: 2.5e-4 diff --git a/ppo_pong/ppo.py b/ppo_pong/ppo.py index 81c2fb6..9dc35ac 100644 --- a/ppo_pong/ppo.py +++ b/ppo_pong/ppo.py @@ -147,7 +147,6 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], # TODO: Seems like only "logits" can be used as in_keys - # out_keys=["action"], distribution_class=dist.Categorical, distribution_kwargs={}, return_log_prob=True, @@ -193,10 +192,6 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= actor(test_env.reset()) actor_critic(test_env.reset()) - # # Ugly hack, otherwise I get errors - # critic.out_keys = ['state_value', 'common_features'] - # actor.out_keys = ['action', 'common_features', 'logits'] - # 2. Define Collector ---------------------------------------------------------------------------------------------- train_env = train_env.to(device_collection) @@ -309,7 +304,7 @@ def dataloader(total_frames, fpb): # Update networks optimizer.zero_grad() loss_sum.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=10.0) + grad_norm = torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=args.clip_grad) optimizer.step() scheduler.step() network_updates += 1 From b0837e688aa1cf94780124c1b5b23af6c0f969d4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 6 Feb 2023 10:31:04 +0000 Subject: [PATCH 4/4] amend --- ppo_pong/conf.yaml | 10 +++++----- ppo_pong/ppo.py | 47 +++++++++++++++++++++++++++++----------------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/ppo_pong/conf.yaml b/ppo_pong/conf.yaml index e730219..3a4d9c0 100644 --- a/ppo_pong/conf.yaml +++ b/ppo_pong/conf.yaml @@ -12,12 +12,12 @@ frame_skip: 4 # Collector total_frames: 40_000_000 # without accounting for frame skip num_parallel_envs: 16 -steps_per_env: 128 # between network updates +steps_per_env: 256 # between network updates # Loss gamma: 0.99 -clip_epsilon: 0.1 -loss_critic_type: l1_smooth +clip_epsilon: 0.2 +loss_critic_type: l2 entropy_coef: 0.001 critic_coef: 1.0 gae_lamdda: 0.95 @@ -25,6 +25,6 @@ clip_grad: 0.5 # Training loop lr: 2.5e-4 -num_ppo_epochs: 10 -mini_batch_size: 256 # so 4 mini_batches - (8 * 128) / 256 +num_ppo_epochs: 3 +mini_batch_size: 128 # so 4 mini_batches - (8 * 128) / 256 evaluation_frequency: 500 # In number of network updates diff --git a/ppo_pong/ppo.py b/ppo_pong/ppo.py index 9dc35ac..fed1109 100644 --- a/ppo_pong/ppo.py +++ b/ppo_pong/ppo.py @@ -23,7 +23,8 @@ from torchrl.modules import SafeModule, ProbabilisticActor, ValueOperator, ActorValueOperator # Collector imports -from torchrl.collectors.collectors import SyncDataCollector +from torchrl.collectors.collectors import SyncDataCollector, \ + MultiSyncDataCollector # Loss imports from torchrl.objectives import ClipPPOLoss @@ -103,7 +104,7 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= # Define CNN common_cnn = ConvNet( - activation_class=torch.nn.ELU, + activation_class=torch.nn.ReLU, num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], @@ -113,10 +114,10 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= # Add MLP on top of shared MLP common_mlp = MLP( in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ELU, + activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=448, - num_cells=[256]).to(device) + num_cells=[]).to(device) common_mlp_output = common_mlp(common_cnn_output) # Define shared net as TensorDictModule @@ -132,7 +133,7 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= policy_net = MLP( in_features=common_mlp_output.shape[-1], out_features=num_actions, - num_cells=[] + num_cells=[256] ).to(device) policy_net[0].bias.data.fill_(0.0) @@ -158,7 +159,7 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= value_net = MLP( in_features=common_mlp_output.shape[-1], out_features=1, - num_cells=[] + num_cells=[256] ).to(device) # Define TensorDictModule @@ -195,10 +196,10 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= # 2. Define Collector ---------------------------------------------------------------------------------------------- train_env = train_env.to(device_collection) - if device_collection != device: - actor_collection = deepcopy(actor).to(device_collection).requires_grad_(False) - else: - actor_collection = actor + # if device_collection != device: + # actor_collection = deepcopy(actor).to(device_collection).requires_grad_(False) + # else: + # actor_collection = actor # 3. Define Loss --------------------------------------------------------------------------------------------------- @@ -217,6 +218,7 @@ def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs= entropy_coef=args.entropy_coef, critic_coef=args.critic_coef, gamma=args.gamma, + normalize_advantage=False, ) # 4. Define logger ------------------------------------------------------------------------------------------------- @@ -264,14 +266,25 @@ def dataloader(total_frames, fpb): collected_frames += batch.numel() yield batch + fpb = args.steps_per_env + total_frames = args.total_frames // args.frame_skip + # collector = dataloader(total_frames, fpb) + collector = MultiSyncDataCollector( + [train_env], + actor, + frames_per_batch=fpb * args.num_parallel_envs, + total_frames=total_frames, + devices=device_collection, + passing_devices=device_collection, + split_trajs=False, + ) + with wandb.init(project=args.experiment_name, name=args.agent_name, entity=args.entity, config=args, mode=mode): - total_frames = args.total_frames // args.frame_skip - fpb = args.steps_per_env pbar = tqdm.tqdm(total=total_frames) start_time = time.time() - for batch in dataloader(total_frames, fpb): + for batch in collector: batch = batch.cpu() log_info = {} @@ -315,7 +328,7 @@ def dataloader(total_frames, fpb): "loss_ent": loss["loss_entropy"].item(), "loss_obj": loss["loss_objective"].item(), "lr": float(scheduler.get_last_lr()[0]), - "gn": grad_norm, + "grad_norm": grad_norm, "frames": collected_frames, "reward": batch["reward"].mean().item(), "traj_len": batch["step_count"].max().item(), @@ -325,14 +338,14 @@ def dataloader(total_frames, fpb): if network_updates % args.evaluation_frequency == 0 and network_updates != 0: # Run evaluation in test environment - with set_exploration_mode("random"): + with set_exploration_mode("random"), torch.no_grad(): test_env.eval() test_td = test_env.rollout( policy=actor, max_steps=10000, auto_reset=True, auto_cast_to_device=True, - break_when_any_done=False, + break_when_any_done=True, ).clone() log_info.update({"test_reward": test_td["reward"].squeeze(-1).sum(-1).mean()}) del test_td @@ -346,7 +359,7 @@ def dataloader(total_frames, fpb): pbar.set_description(print_msg) log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) wandb.log(log_info, step=network_updates) - + collector.update_policy_weights_() def get_args(): """Reads conf.yaml file in the same directory"""