diff --git a/ppo_pong/conf.yaml b/ppo_pong/conf.yaml index 7754734..3a4d9c0 100644 --- a/ppo_pong/conf.yaml +++ b/ppo_pong/conf.yaml @@ -1,28 +1,30 @@ # Logger experiment_name: atari_pong agent_name: ppo_agent -wandb_key: null +entity: vmoens +wandb_key: d0bee782a83f90cbc11177e36a092de77585cbb3 log_dir: /tmp/atari_pong # Environment -env_name: PongNoFrameskip-v4 +env_name: ALE/Pong-v5 frame_skip: 4 # Collector total_frames: 40_000_000 # without accounting for frame skip -num_parallel_envs: 8 -steps_per_env: 128 # between network updates +num_parallel_envs: 16 +steps_per_env: 256 # between network updates # Loss gamma: 0.99 -clip_epsilon: 0.1 +clip_epsilon: 0.2 loss_critic_type: l2 -entropy_coef: 0.0001 +entropy_coef: 0.001 critic_coef: 1.0 gae_lamdda: 0.95 +clip_grad: 0.5 # Training loop lr: 2.5e-4 num_ppo_epochs: 3 -mini_batch_size: 256 # so 4 mini_batches - (8 * 128) / 256 -evaluation_frequency: 100 # In number of network updates +mini_batch_size: 128 # so 4 mini_batches - (8 * 128) / 256 +evaluation_frequency: 500 # In number of network updates diff --git a/ppo_pong/ppo.py b/ppo_pong/ppo.py index 286afe4..fed1109 100644 --- a/ppo_pong/ppo.py +++ b/ppo_pong/ppo.py @@ -1,26 +1,30 @@ +import gc +import os import time import yaml import wandb +import tqdm import torch import argparse # Environment imports from torchrl.envs.libs.gym import GymEnv from torchrl.envs import TransformedEnv -from torchrl.envs.vec_env import ParallelEnv -from torchrl.envs.transforms import ToTensorImage, GrayScale, CatFrames, NoopResetEnv, Resize -from .transforms.reward_sum import RewardSum -from .transforms.step_limit import StepLimit +if os.environ.get("PARALLEL", False): + from torchrl.envs.vec_env import ParallelEnv +else: + from torchrl.envs.vec_env import SerialEnv as ParallelEnv +from torchrl.envs.transforms import ToTensorImage, GrayScale, CatFrames, NoopResetEnv, Resize, ObservationNorm, RewardSum, StepCounter # Model imports from torchrl.envs import EnvCreator -from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.utils import set_exploration_mode, check_env_specs, step_mdp from torchrl.modules.models import ConvNet, MLP -from torchrl.modules.distributions import OneHotCategorical from torchrl.modules import SafeModule, ProbabilisticActor, ValueOperator, ActorValueOperator # Collector imports -from torchrl.collectors.collectors import SyncDataCollector +from torchrl.collectors.collectors import SyncDataCollector, \ + MultiSyncDataCollector # Loss imports from torchrl.objectives import ClipPPOLoss @@ -30,24 +34,30 @@ from torch.optim import Adam from torch.optim.lr_scheduler import LinearLR from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler - +import torch.distributions as dist +from tensordict import TensorDict +from copy import deepcopy def main(): args = get_args() device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + device_collection = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # 1. Define environment -------------------------------------------------------------------------------------------- # 1.1 Define env factory - def env_factory(): + def env_factory(device=device, init_stats_steps=1000, seed=0, num_parallel_envs=1): """Creates an instance of the environment.""" - create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip)) # 1.2 Create env vector - vec_env = ParallelEnv(create_env_fn=create_env_fn, num_workers=args.num_parallel_envs) + if num_parallel_envs > 1: + create_env_fn = EnvCreator(lambda: GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True, device=device)) + vec_env = ParallelEnv(create_env_fn=create_env_fn, num_workers=num_parallel_envs) + else: + vec_env = GymEnv(env_name=args.env_name, frame_skip=args.frame_skip, categorical_action_encoding=True, device=device) # 1.3 Apply transformations to vec env - standard DeepMind Atari - Order of transforms is important! transformed_vec_env = TransformedEnv(vec_env) @@ -55,16 +65,34 @@ def env_factory(): transformed_vec_env.append_transform(ToTensorImage()) # change shape from [h, w, 3] to [3, h, w] transformed_vec_env.append_transform(Resize(w=84, h=84)) # Resize image transformed_vec_env.append_transform(GrayScale()) # Convert to Grayscale - transformed_vec_env.append_transform(CatFrames(N=4)) # Stack last 4 frames + transformed_vec_env.append_transform(ObservationNorm(in_keys=["pixels"], standard_normal=True)) + transformed_vec_env.append_transform(CatFrames(dim=-3, N=4)) # Stack last 4 frames transformed_vec_env.append_transform(RewardSum()) - transformed_vec_env.append_transform(StepLimit()) + transformed_vec_env.append_transform(StepCounter()) + transformed_vec_env.set_seed(seed) + + if init_stats_steps: + norm_layer = transformed_vec_env.transform[3] + batch_dims = len(transformed_vec_env.batch_size) + norm_layer.init_stats( + num_iter=init_stats_steps, + reduce_dim=[*range(batch_dims), batch_dims, batch_dims+2, batch_dims+3], + cat_dim=batch_dims, + keep_dims=[batch_dims+2, batch_dims+3] + ) return transformed_vec_env + torch.manual_seed(0) + test_env = env_factory(device, init_stats_steps=2, seed=0, num_parallel_envs=1) + train_env = env_factory(device_collection, init_stats_steps=1000, seed=1, num_parallel_envs=args.num_parallel_envs) + + check_env_specs(train_env) + # Sanity check - test_env = env_factory() - test_input = test_env.reset() - assert "pixels" in test_input.keys() + test_env.transform.load_state_dict(train_env.transform.state_dict()) + + num_actions = test_env.specs["action_spec"].space.n # 2. Define model -------------------------------------------------------------------------------------------------- @@ -80,8 +108,8 @@ def env_factory(): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones_like(test_input["pixels"])) + ).to(device) + common_cnn_output = common_cnn(test_env.reset()["pixels"]) # Add MLP on top of shared MLP common_mlp = MLP( @@ -89,7 +117,7 @@ def env_factory(): activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=448, - num_cells=[256]) + num_cells=[]).to(device) common_mlp_output = common_mlp(common_cnn_output) # Define shared net as TensorDictModule @@ -105,8 +133,9 @@ def env_factory(): policy_net = MLP( in_features=common_mlp_output.shape[-1], out_features=num_actions, - num_cells=[] - ) + num_cells=[256] + ).to(device) + policy_net[0].bias.data.fill_(0.0) # Define TensorDictModule policy_module = SafeModule( # TODO: The naming of SafeModule is confusing @@ -119,8 +148,7 @@ def env_factory(): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], # TODO: Seems like only "logits" can be used as in_keys - # out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=dist.Categorical, distribution_kwargs={}, return_log_prob=True, ) @@ -131,8 +159,8 @@ def env_factory(): value_net = MLP( in_features=common_mlp_output.shape[-1], out_features=1, - num_cells=[] - ) + num_cells=[256] + ).to(device) # Define TensorDictModule value_module = ValueOperator( @@ -150,32 +178,28 @@ def env_factory(): # 2.6 Initialize the model by running a forward pass with torch.no_grad(): - td = test_env.rollout(max_steps=1000) + td = test_env.rollout(max_steps=1000, break_when_any_done=False) td = td.to(device) td = actor_critic(td) + del td + print(actor_critic) # TODO: why wrap them together and then get separate operators ? # Get independent operators for actor and critic, to be able to call only one of them actor = actor_critic.get_policy_operator() critic = actor_critic.get_value_operator() # sanity check - actor(env_factory().reset()) - actor_critic(env_factory().reset()) - - # # Ugly hack, otherwise I get errors - # critic.out_keys = ['state_value', 'common_features'] - # actor.out_keys = ['action', 'common_features', 'logits'] + actor(test_env.reset()) + actor_critic(test_env.reset()) # 2. Define Collector ---------------------------------------------------------------------------------------------- - collector = SyncDataCollector( - create_env_fn=env_factory, - create_env_kwargs=None, - policy=actor_critic, - total_frames=args.total_frames, - frames_per_batch=args.steps_per_env * args.num_parallel_envs, - ) + train_env = train_env.to(device_collection) + # if device_collection != device: + # actor_collection = deepcopy(actor).to(device_collection).requires_grad_(False) + # else: + # actor_collection = actor # 3. Define Loss --------------------------------------------------------------------------------------------------- @@ -194,6 +218,7 @@ def env_factory(): entropy_coef=args.entropy_coef, critic_coef=args.critic_coef, gamma=args.gamma, + normalize_advantage=False, ) # 4. Define logger ------------------------------------------------------------------------------------------------- @@ -203,7 +228,7 @@ def env_factory(): mode = "online" wandb.login(key=str(args.wandb_key)) else: - mode = "disabled" + mode = "offline" # 5. Define training loop ------------------------------------------------------------------------------------------ @@ -214,41 +239,74 @@ def env_factory(): total_network_updates = (args.total_frames // batch_size) * args.num_ppo_epochs * num_mini_batches optimizer = Adam(params=actor_critic.parameters(), lr=args.lr) scheduler = LinearLR(optimizer, total_iters=total_network_updates, start_factor=1.0, end_factor=0.1) - evaluation_frequency = 100 # In number of network frames - with wandb.init(project=args.experiment_name, name=args.agent_name, config=args, mode=mode): + @torch.no_grad() + @set_exploration_mode("random") + def dataloader(total_frames, fpb): + """This is a simplified dataloader.""" + if device_collection != device: + params = TensorDict({k: v for k, v in actor.named_parameters()}, batch_size=[]).unflatten_keys(".") + params_collection = TensorDict({k: v for k, v in actor_collection.named_parameters()}, batch_size=[]).unflatten_keys(".") + _prev = None + + collected_frames = 0 + while collected_frames < total_frames: + if device_collection != device: + params_collection.update_(params) + batch = TensorDict({}, batch_size=[fpb, *train_env.batch_size], device=device_collection) + for t in range(fpb): + if _prev is None: + _prev = train_env.reset() + _reset = _prev["_reset"] = _prev["done"].clone().squeeze(-1) + if _reset.any(): + _prev = train_env.reset(_prev) + _new = train_env.step(actor_collection(_prev)) + batch[t] = _new + _prev = step_mdp(_new, exclude_done=False) + collected_frames += batch.numel() + yield batch + + fpb = args.steps_per_env + total_frames = args.total_frames // args.frame_skip + # collector = dataloader(total_frames, fpb) + collector = MultiSyncDataCollector( + [train_env], + actor, + frames_per_batch=fpb * args.num_parallel_envs, + total_frames=total_frames, + devices=device_collection, + passing_devices=device_collection, + split_trajs=False, + ) + with wandb.init(project=args.experiment_name, name=args.agent_name, entity=args.entity, config=args, mode=mode): + pbar = tqdm.tqdm(total=total_frames) start_time = time.time() - for batch in collector: + for batch in collector: + batch = batch.cpu() log_info = {} - # We don't use memory networks, so sequence dimension is not relevant - batch_size = batch["mask"].sum().item() + batch_size = batch.numel() + pbar.update(batch_size) collected_frames += batch_size - # add episode reward info - # train_episode_reward = batch["episode_reward"][batch["done"]] - #if batch["episode_reward"][batch["done"]].numel() > 0: - # log_info.update({"train_episode_rewards": train_episode_reward.mean()}) - # episode_steps = batch["episode_steps"][batch["done"]] - # PPO epochs for epoch in range(args.num_ppo_epochs): # Compute advantage with the whole batch - batch = advantage_module(batch) + batch = advantage_module(batch.to(device)).cpu() - batch_view = batch[batch["mask"].squeeze(-1)] + batch_view = batch.reshape(-1) # Create a random permutation in every epoch for mini_batch_idxs in BatchSampler( SubsetRandomSampler(range(batch_size)), args.mini_batch_size, drop_last=True): # select idxs to create mini_batch - mini_batch = batch_view[mini_batch_idxs].clone() + mini_batch = batch_view[mini_batch_idxs].clone().to(device) # Forward pass loss = loss_module(mini_batch) @@ -259,48 +317,50 @@ def env_factory(): # Update networks optimizer.zero_grad() loss_sum.backward() - torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=0.5) + grad_norm = torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), max_norm=args.clip_grad) optimizer.step() scheduler.step() network_updates += 1 log_info.update({ "loss": loss_sum.item(), - "loss_critic": loss["loss_critic"].item(), - "loss_entropy": loss["loss_entropy"].item(), - "loss_objective": loss["loss_objective"].item(), - "learning_rate": float(scheduler.get_last_lr()[0]), - "collected_frames": collected_frames, + "loss_cri": loss["loss_critic"].item(), + "loss_ent": loss["loss_entropy"].item(), + "loss_obj": loss["loss_objective"].item(), + "lr": float(scheduler.get_last_lr()[0]), + "grad_norm": grad_norm, + "frames": collected_frames, + "reward": batch["reward"].mean().item(), + "traj_len": batch["step_count"].max().item(), + "pix avg": batch["pixels"].mean(), + "pix std": batch["pixels"].std(), }) if network_updates % args.evaluation_frequency == 0 and network_updates != 0: # Run evaluation in test environment - with set_exploration_mode("random"): + with set_exploration_mode("random"), torch.no_grad(): test_env.eval() test_td = test_env.rollout( policy=actor, - max_steps=100000, + max_steps=10000, auto_reset=True, auto_cast_to_device=True, + break_when_any_done=True, ).clone() log_info.update({"test_reward": test_td["reward"].squeeze(-1).sum(-1).mean()}) + del test_td # Print an informative message in the terminal fps = int(collected_frames * args.frame_skip / (time.time() - start_time)) print_msg = f"Update {network_updates}, num " \ - f"samples collected {collected_frames * args.frame_skip}, FPS {fps}\n " + f"samples collected {collected_frames * args.frame_skip}, FPS {fps} " for k, v in log_info.items(): - print_msg += f"{k}: {v} " - print(print_msg, flush=True) - - log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) - wandb.log(log_info, step=network_updates) - del mini_batch - - # Update collector weights! + print_msg += f"{k}: {v: 4.2f} " + pbar.set_description(print_msg) + log_info.update({"collected_frames": int(collected_frames * args.frame_skip), "fps": fps}) + wandb.log(log_info, step=network_updates) collector.update_policy_weights_() - def get_args(): """Reads conf.yaml file in the same directory""" diff --git a/sac_mujoco/sac.py b/sac_mujoco/sac.py index 2cf2e4d..02c1c85 100644 --- a/sac_mujoco/sac.py +++ b/sac_mujoco/sac.py @@ -243,7 +243,7 @@ def main(): ) actor = ProbabilisticActor( spec=action_spec, - dist_in_keys=["loc", "scale"], + in_keys=["loc", "scale"], module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, @@ -467,7 +467,7 @@ def get_args(): ) parser.add_argument( "--from_pixels", - action=argparse.BooleanOptionalAction, + action="store_true",#argparse.BooleanOptionalAction, default=False, help="Use pixel observations. Default: False", ) @@ -500,7 +500,7 @@ def get_args(): ) parser.add_argument( "--prb", - action=argparse.BooleanOptionalAction, + action="store_true",#argparse.BooleanOptionalAction, default=False, help="Use Prioritized Experience Replay Buffer. Default: False", )