From b3068c99bbb327264ae87b2aae71af2372ace6ec Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Wed, 11 Jan 2023 15:34:33 -0600 Subject: [PATCH 01/58] Added sac codebase. Works independently. --- scripts/sac_mujoco/config/group/group1.yaml | 4 + scripts/sac_mujoco/config/group/group2.yaml | 4 + .../config/hydra/launcher/local.yaml | 11 + .../config/hydra/launcher/slurm.yaml | 14 + .../sac_mujoco/config/hydra/output/local.yaml | 8 + .../sac_mujoco/config/hydra/output/slurm.yaml | 8 + scripts/sac_mujoco/config/sac.yaml | 40 ++ scripts/sac_mujoco/sac.py | 485 ++++++++++++++++++ scripts/sac_mujoco/sac_loss.py | 311 +++++++++++ 9 files changed, 885 insertions(+) create mode 100644 scripts/sac_mujoco/config/group/group1.yaml create mode 100644 scripts/sac_mujoco/config/group/group2.yaml create mode 100644 scripts/sac_mujoco/config/hydra/launcher/local.yaml create mode 100644 scripts/sac_mujoco/config/hydra/launcher/slurm.yaml create mode 100644 scripts/sac_mujoco/config/hydra/output/local.yaml create mode 100644 scripts/sac_mujoco/config/hydra/output/slurm.yaml create mode 100644 scripts/sac_mujoco/config/sac.yaml create mode 100644 scripts/sac_mujoco/sac.py create mode 100644 scripts/sac_mujoco/sac_loss.py diff --git a/scripts/sac_mujoco/config/group/group1.yaml b/scripts/sac_mujoco/config/group/group1.yaml new file mode 100644 index 000000000..6730093ca --- /dev/null +++ b/scripts/sac_mujoco/config/group/group1.yaml @@ -0,0 +1,4 @@ +# @package _group_ + grp1a: 11 + grp1b: aaa + gra1c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file diff --git a/scripts/sac_mujoco/config/group/group2.yaml b/scripts/sac_mujoco/config/group/group2.yaml new file mode 100644 index 000000000..b2f47d6dd --- /dev/null +++ b/scripts/sac_mujoco/config/group/group2.yaml @@ -0,0 +1,4 @@ +# @package _group_ + grp2a: 22 + grp2b: bbb + gra2c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file diff --git a/scripts/sac_mujoco/config/hydra/launcher/local.yaml b/scripts/sac_mujoco/config/hydra/launcher/local.yaml new file mode 100644 index 000000000..30a563c70 --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/launcher/local.yaml @@ -0,0 +1,11 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: 12 + gpus_per_node: 1 + tasks_per_node: 1 + timeout_min: 4320 + mem_gb: 64 + name: ${hydra.job.name} + _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher + submitit_folder: ${hydra.sweep.dir}/.submitit/%j diff --git a/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml b/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml new file mode 100644 index 000000000..e9b58047c --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml @@ -0,0 +1,14 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: 16 + gpus_per_node: 1 + tasks_per_node: 1 + timeout_min: 4320 + mem_gb: 64 + name: ${hydra.job.name} + # partition: devlab + # array_parallelism: 256 + _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + submitit_folder: ${hydra.sweep.dir}/.submitit/%j + partition: dgx diff --git a/scripts/sac_mujoco/config/hydra/output/local.yaml b/scripts/sac_mujoco/config/hydra/output/local.yaml new file mode 100644 index 000000000..aee5a513f --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/output/local.yaml @@ -0,0 +1,8 @@ +# @package _global_ +hydra: + run: + dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} + sweep: + dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} \ No newline at end of file diff --git a/scripts/sac_mujoco/config/hydra/output/slurm.yaml b/scripts/sac_mujoco/config/hydra/output/slurm.yaml new file mode 100644 index 000000000..8b7afb41a --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/output/slurm.yaml @@ -0,0 +1,8 @@ +# @package _global_ +hydra: + run: + dir: /scratch/cluster/rutavms/robohive/outputs_sac_robohive/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} + sweep: + dir: /scratch/cluster/rutavms/robohive/outputs_sac_robohive/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml new file mode 100644 index 000000000..e367c09ba --- /dev/null +++ b/scripts/sac_mujoco/config/sac.yaml @@ -0,0 +1,40 @@ +default: + - override hydra/output: local + - override hydra/launcher: local + +# Logger +exp_name: sac +record_interval: 1 +device: "cuda:0" + +# Environment +task: visual_franka_slide_random-v3 +frame_skip: 2 +from_pixels: true +reward_scaling: 5.0 +init_env_steps: 1000 +seed: 42 + +# Collector +env_per_collector: 1 +max_frames_per_traj: -1 +total_frames: 1000000 +init_random_frames: 25000 +frames_per_batch: 1000 + +# Replay Buffer +prb: 0 +buffer_size: 100000 +buffer_scratch_dir: /tmp/ + +# Optimization +gamma: 0.99 +batch_size: 256 +lr: 3.0e-4 +weight_decay: 0.0 +target_update_polyak: 0.995 +utd_ratio: 1 + +hydra: + job: + name: sac_${task}_${seed} diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py new file mode 100644 index 000000000..3cd058c93 --- /dev/null +++ b/scripts/sac_mujoco/sac.py @@ -0,0 +1,485 @@ +# Make all the necessary imports for training + + +import os +import gc +import argparse +import yaml +from typing import Optional + +import numpy as np +import torch +import torch.cuda +import tqdm + +import hydra +from omegaconf import DictConfig, OmegaConf, open_dict +import wandb +#from torchrl.objectives import SACLoss +from sac_loss import SACLoss + +from torch import nn, optim +from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors.collectors import RandomPolicy +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, +) +from torchrl.envs import EnvCreator +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import MLP, NormalParamWrapper, ProbabilisticActor, SafeModule +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + +from torchrl.objectives import SoftUpdate +from torchrl.trainers import Recorder + +from rlhive.rl_envs import RoboHiveEnv +from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform + +os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior + +def make_env(): + """ + Create a base env + """ + env_args = (args.task,) + env_library = GymEnv + + env_kwargs = { + "device": device, + "frame_skip": args.frame_skip, + "from_pixels": args.from_pixels, + "pixels_only": args.from_pixels, + } + env = env_library(*env_args, **env_kwargs) + + env_name = args.task + base_env = RoboHiveEnv(env_name, device=device) + env = TransformedEnv(base_env, R3MTransform('resnet50', in_keys=["pixels"], download=True)) + assert env.device == device + + return env + + +def make_transformed_env( + env, + stats=None, +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform(RewardScaling(loc=0.0, scale=5.0)) + selected_keys = list(env.observation_spec.keys()) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) + return env + + +def parallel_env_constructor( + stats, + num_worker=1, + **env_kwargs, +): + if num_worker == 1: + env_creator = EnvCreator( + lambda: make_transformed_env(make_env(), stats, **env_kwargs) + ) + return env_creator + + parallel_env = ParallelEnv( + num_workers=num_worker, + create_env_fn=EnvCreator(lambda: make_env()), + create_env_kwargs=None, + pin_memory=False, + ) + env = make_transformed_env(parallel_env, stats, **env_kwargs) + return env + + +def get_stats_random_rollout(proof_environment, key: Optional[str] = None): + print("computing state stats") + n = 0 + td_stats = [] + while n < args.init_env_steps: + _td_stats = proof_environment.rollout(max_steps=args.init_env_steps) + n += _td_stats.numel() + _td_stats_select = _td_stats.to_tensordict().select(key).cpu() + if not len(list(_td_stats_select.keys())): + raise RuntimeError( + f"key {key} not found in tensordict with keys {list(_td_stats.keys())}" + ) + td_stats.append(_td_stats_select) + del _td_stats, _td_stats_select + td_stats = torch.cat(td_stats, 0) + + m = td_stats.get(key).mean(dim=0) + s = td_stats.get(key).std(dim=0) + m[s == 0] = 0.0 + s[s == 0] = 1.0 + + print( + f"stats computed for {td_stats.numel()} steps. Got: \n" + f"loc = {m}, \n" + f"scale: {s}" + ) + if not torch.isfinite(m).all(): + raise RuntimeError("non-finite values found in mean") + if not torch.isfinite(s).all(): + raise RuntimeError("non-finite values found in sd") + stats = {"loc": m, "scale": s} + return stats + + +def get_env_stats(): + """ + Gets the stats of an environment + """ + proof_env = make_transformed_env(make_env(), None) + proof_env.set_seed(args.seed) + stats = get_stats_random_rollout( + proof_env, + key="observation_vector", + ) + # make sure proof_env is closed + proof_env.close() + return stats + + +def make_recorder( + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + device: torch.device + ): + _base_env = RoboHiveEnv(task, device=device) # TODO: Move this to make_env() function + test_env = make_transformed_env(_base_env) + recorder_obj = Recorder( + record_frames=1000, + frame_skip=frame_skip, + policy_exploration=actor_model_explore, + recorder=test_env, + exploration_mode="mean", + record_interval=record_interval, + ) + return recorder_obj + + +def make_replay_buffer( + prb: bool, + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3 + ): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + return replay_buffer + + + +@hydra.main(config_name="sac.yaml", config_path="config") +def main(args: DictConfig): + device = ( + torch.device("cuda:0") + if torch.cuda.is_available() + and torch.cuda.device_count() > 0 + and args.device == "cuda:0" + else torch.device("cpu") + ) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Create Environment + base_env = RoboHiveEnv(args.task, device=args.device) # TODO: Move this to make_env() function + train_env = make_transformed_env(base_env) + + # Create Agent + + # Define Actor Network + in_keys = ["observation_vector"] + action_spec = train_env.action_spec + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": 2 * action_spec.shape[-1], + "activation_class": nn.ReLU, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": False, + } + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ReLU, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # add forward pass for initialization with proof env + _base_env = RoboHiveEnv(args.task, device=args.device) # TODO: move this to make_env + proof_env = make_transformed_env(_base_env) + # init nets + with torch.no_grad(), set_exploration_mode("random"): + td = proof_env.reset() + td = td.to(device) + #print(td[in_keys[0]].shape) + for net in model: + net(td) + del td + proof_env.close() + + actor_model_explore = model[0] + + # Create SAC loss + loss_module = SACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + gamma=args.gamma, + loss_function="smooth_l1", + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, args.target_update_polyak) + + # Make Off-Policy Collector + + collector = MultiaSyncDataCollector( + create_env_fn=[train_env], + policy=actor_model_explore, + total_frames=args.total_frames, + max_frames_per_traj=args.frames_per_batch, + frames_per_batch=args.env_per_collector * args.frames_per_batch, + init_random_frames=args.init_random_frames, + reset_at_each_iter=False, + postproc=None, + split_trajs=True, + devices=[device], # device for execution + passing_devices=[device], # device where data will be stored and passed + seed=None, + pin_memory=False, + update_at_each_batch=False, + exploration_mode="random", + ) + collector.set_seed(args.seed) + + # Make Replay Buffer + replay_buffer = make_replay_buffer( + prb=args.prb, + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) + + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + device=device + ) + + # Optimizers + params = list(loss_module.parameters()) + list([loss_module.log_alpha]) + optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + + rewards = [] + rewards_eval = [] + + # Main loop + target_net_updater.init_() + + collected_frames = 0 + episodes = 0 + pbar = tqdm.tqdm(total=args.total_frames) + r0 = None + loss = None + + with wandb.init(project="SAC_TorchRL", name=args.exp_name, config=args): + for i, tensordict in enumerate(collector): + + # update weights of the inference policy + collector.update_policy_weights_() + + if r0 is None: + r0 = tensordict["reward"].sum(-1).mean().item() + pbar.update(tensordict.numel()) + + # extend the replay buffer with the new data + if "mask" in tensordict.keys(): + # if multi-step, a mask is present to help filter padded values + current_frames = tensordict["mask"].sum() + tensordict = tensordict[tensordict.get("mask").squeeze(-1)] + else: + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + collected_frames += current_frames + episodes += args.env_per_collector + replay_buffer.extend(tensordict.cpu()) + + # optimization steps + if collected_frames >= args.init_random_frames: + ( + total_losses, + actor_losses, + q_losses, + alpha_losses, + alphas, + entropies, + ) = ([], [], [], [], [], []) + for _ in range( + args.env_per_collector * args.frames_per_batch * args.utd_ratio + ): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample(args.batch_size).clone() + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + loss = actor_loss + q_loss + alpha_loss + optimizer_actor.zero_grad() + loss.backward() + optimizer_actor.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if args.prb: + replay_buffer.update_priority(sampled_tensordict) + + total_losses.append(loss.item()) + actor_losses.append(actor_loss.item()) + q_losses.append(q_loss.item()) + alpha_losses.append(alpha_loss.item()) + alphas.append(loss_td["alpha"].item()) + entropies.append(loss_td["entropy"].item()) + + rewards.append( + (i, tensordict["reward"].sum().item() / args.env_per_collector) + ) + wandb.log( + { + "train_reward": rewards[-1][1], + "collected_frames": collected_frames, + "episodes": episodes, + } + ) + if loss is not None: + wandb.log( + { + "total_loss": np.mean(total_losses), + "actor_loss": np.mean(actor_losses), + "q_loss": np.mean(q_losses), + "alpha_loss": np.mean(alpha_losses), + "alpha": np.mean(alphas), + "entropy": np.mean(entropies), + } + ) + td_record = recorder(None) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) + ) + wandb.log({"test_reward": rewards_eval[-1][1]}) + if len(rewards_eval): + pbar.set_description( + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + ) + del tensordict + gc.collect() + + collector.shutdown() + +if __name__ == "__main__": + main() diff --git a/scripts/sac_mujoco/sac_loss.py b/scripts/sac_mujoco/sac_loss.py new file mode 100644 index 000000000..cebe7f2e9 --- /dev/null +++ b/scripts/sac_mujoco/sac_loss.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss From f37ac892f9985268eb687aa1daa6b3497d0f5f85 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Wed, 11 Jan 2023 16:25:16 -0600 Subject: [PATCH 02/58] Added small test codebase. --- scripts/sac_mujoco/test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 scripts/sac_mujoco/test.py diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py new file mode 100644 index 000000000..01b0c812e --- /dev/null +++ b/scripts/sac_mujoco/test.py @@ -0,0 +1,33 @@ +import torch +from rlhive.rl_envs import RoboHiveEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose +from torchrl.envs import TransformedEnv, R3MTransform +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, +) + +def make_transformed_env( + env, + stats=None, +): + env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) + return env + +from torchrl.envs import default_info_dict_reader +reader = default_info_dict_reader(["solved"]) +base_env = RoboHiveEnv("visual_franka_slide_random-v3", device=torch.device('cuda:0')) +env = make_transformed_env(base_env) +env = env.set_info_dict_reader(info_dict_reader=reader) +with torch.no_grad(), set_exploration_mode("random"): + td = env.reset() + td = env.rand_step() + print(td) + #print(td['observation_vector'].shape) + #print(td['r3m_vec'].shape) + #print(env) + #print(td) + From 6c03e9c82ba29bca6446eba3d29c4083f2517b3c Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Thu, 12 Jan 2023 22:27:21 -0600 Subject: [PATCH 03/58] test.py updated with another bug --- scripts/sac_mujoco/test.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py index 01b0c812e..9b52574a5 100644 --- a/scripts/sac_mujoco/test.py +++ b/scripts/sac_mujoco/test.py @@ -17,17 +17,11 @@ def make_transformed_env( env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) return env -from torchrl.envs import default_info_dict_reader -reader = default_info_dict_reader(["solved"]) base_env = RoboHiveEnv("visual_franka_slide_random-v3", device=torch.device('cuda:0')) +env = base_env env = make_transformed_env(base_env) -env = env.set_info_dict_reader(info_dict_reader=reader) +print(env) with torch.no_grad(), set_exploration_mode("random"): td = env.reset() td = env.rand_step() print(td) - #print(td['observation_vector'].shape) - #print(td['r3m_vec'].shape) - #print(env) - #print(td) - From 50ae2e04feb912b5a03374c6511123aa048233c0 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Fri, 13 Jan 2023 12:59:59 -0600 Subject: [PATCH 04/58] small change with updated torchrl --- scripts/sac_mujoco/sac.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index 3cd058c93..d4266d4da 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -81,7 +81,8 @@ def make_transformed_env( """ env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 env.append_transform(RewardScaling(loc=0.0, scale=5.0)) - selected_keys = list(env.observation_spec.keys()) + #selected_keys = list(env.observation_spec.keys()) + selected_keys = ["r3m_vec", "observation"] out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) From f2d9b4305d7dc0c7bc735cf296ec159abc8b3bc8 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Fri, 13 Jan 2023 20:52:14 -0600 Subject: [PATCH 05/58] working sac codebase. cleanup --- scripts/sac_mujoco/config/sac.yaml | 4 +- scripts/sac_mujoco/sac.py | 152 +++++++++-------------------- 2 files changed, 50 insertions(+), 106 deletions(-) diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml index e367c09ba..1f8e904a0 100644 --- a/scripts/sac_mujoco/config/sac.yaml +++ b/scripts/sac_mujoco/config/sac.yaml @@ -9,11 +9,11 @@ device: "cuda:0" # Environment task: visual_franka_slide_random-v3 -frame_skip: 2 -from_pixels: true +frame_skip: 1 reward_scaling: 5.0 init_env_steps: 1000 seed: 42 +eval_traj: 25 # Collector env_per_collector: 1 diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index d4266d4da..ebdbc959d 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -45,47 +45,37 @@ from torchrl.trainers import Recorder from rlhive.rl_envs import RoboHiveEnv -from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform +from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior -def make_env(): - """ - Create a base env - """ - env_args = (args.task,) - env_library = GymEnv - - env_kwargs = { - "device": device, - "frame_skip": args.frame_skip, - "from_pixels": args.from_pixels, - "pixels_only": args.from_pixels, - } - env = env_library(*env_args, **env_kwargs) - - env_name = args.task - base_env = RoboHiveEnv(env_name, device=device) - env = TransformedEnv(base_env, R3MTransform('resnet50', in_keys=["pixels"], download=True)) - assert env.device == device +def make_env( + task, + reward_scaling, + device + ): + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) return env def make_transformed_env( env, + reward_scaling=5.0, stats=None, ): """ Apply transforms to the env (such as reward scaling and state normalization) """ - env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 - env.append_transform(RewardScaling(loc=0.0, scale=5.0)) - #selected_keys = list(env.observation_spec.keys()) + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) selected_keys = ["r3m_vec", "observation"] out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -97,93 +87,24 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env - -def parallel_env_constructor( - stats, - num_worker=1, - **env_kwargs, -): - if num_worker == 1: - env_creator = EnvCreator( - lambda: make_transformed_env(make_env(), stats, **env_kwargs) - ) - return env_creator - - parallel_env = ParallelEnv( - num_workers=num_worker, - create_env_fn=EnvCreator(lambda: make_env()), - create_env_kwargs=None, - pin_memory=False, - ) - env = make_transformed_env(parallel_env, stats, **env_kwargs) - return env - - -def get_stats_random_rollout(proof_environment, key: Optional[str] = None): - print("computing state stats") - n = 0 - td_stats = [] - while n < args.init_env_steps: - _td_stats = proof_environment.rollout(max_steps=args.init_env_steps) - n += _td_stats.numel() - _td_stats_select = _td_stats.to_tensordict().select(key).cpu() - if not len(list(_td_stats_select.keys())): - raise RuntimeError( - f"key {key} not found in tensordict with keys {list(_td_stats.keys())}" - ) - td_stats.append(_td_stats_select) - del _td_stats, _td_stats_select - td_stats = torch.cat(td_stats, 0) - - m = td_stats.get(key).mean(dim=0) - s = td_stats.get(key).std(dim=0) - m[s == 0] = 0.0 - s[s == 0] = 1.0 - - print( - f"stats computed for {td_stats.numel()} steps. Got: \n" - f"loc = {m}, \n" - f"scale: {s}" - ) - if not torch.isfinite(m).all(): - raise RuntimeError("non-finite values found in mean") - if not torch.isfinite(s).all(): - raise RuntimeError("non-finite values found in sd") - stats = {"loc": m, "scale": s} - return stats - - -def get_env_stats(): - """ - Gets the stats of an environment - """ - proof_env = make_transformed_env(make_env(), None) - proof_env.set_seed(args.seed) - stats = get_stats_random_rollout( - proof_env, - key="observation_vector", - ) - # make sure proof_env is closed - proof_env.close() - return stats - - def make_recorder( task: str, frame_skip: int, record_interval: int, actor_model_explore: object, - device: torch.device + eval_traj: int, + env_configs: dict, ): - _base_env = RoboHiveEnv(task, device=device) # TODO: Move this to make_env() function - test_env = make_transformed_env(_base_env) + test_env = make_env(task=task, **env_configs) recorder_obj = Recorder( - record_frames=1000, + record_frames=eval_traj*test_env.horizon, frame_skip=frame_skip, policy_exploration=actor_model_explore, recorder=test_env, exploration_mode="mean", record_interval=record_interval, + log_keys=["reward", "solved"], + out_keys={"reward": "r_evaluation", "solved" : "success"} ) return recorder_obj @@ -220,6 +141,20 @@ def make_replay_buffer( return replay_buffer +def evaluate_success( + env_success_fn, + td_record: dict, + eval_traj: int + ): + td_record["success"] = td_record["success"].reshape((eval_traj, -1)) + paths = [] + for traj, solved_traj in zip(range(eval_traj), td_record["success"]): + path = {"env_infos": {"solved": solved_traj.data.cpu().numpy()}} + paths.append(path) + success_percentage = env_success_fn(paths) + return success_percentage + + @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): @@ -234,8 +169,11 @@ def main(args: DictConfig): np.random.seed(args.seed) # Create Environment - base_env = RoboHiveEnv(args.task, device=args.device) # TODO: Move this to make_env() function - train_env = make_transformed_env(base_env) + env_configs = { + "reward_scaling": args.reward_scaling, + "device": args.device, + } + train_env = make_env(task=args.task, **env_configs) # Create Agent @@ -299,13 +237,11 @@ def main(args: DictConfig): model = nn.ModuleList([actor, qvalue]).to(device) # add forward pass for initialization with proof env - _base_env = RoboHiveEnv(args.task, device=args.device) # TODO: move this to make_env - proof_env = make_transformed_env(_base_env) + proof_env = make_env(task=args.task, **env_configs) # init nets with torch.no_grad(), set_exploration_mode("random"): td = proof_env.reset() td = td.to(device) - #print(td[in_keys[0]].shape) for net in model: net(td) del td @@ -354,13 +290,15 @@ def main(args: DictConfig): device=device, ) + # Trajectory recorder for evaluation recorder = make_recorder( task=args.task, frame_skip=args.frame_skip, record_interval=args.record_interval, actor_model_explore=actor_model_explore, - device=device + eval_traj=args.eval_traj, + env_configs=env_configs, ) # Optimizers @@ -464,6 +402,11 @@ def main(args: DictConfig): } ) td_record = recorder(None) + success_percentage = evaluate_success( + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj + ) if td_record is not None: rewards_eval.append( ( @@ -473,6 +416,7 @@ def main(args: DictConfig): ) ) wandb.log({"test_reward": rewards_eval[-1][1]}) + wandb.log({"success": success_percentage}) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" From b5766827581fcc599bf4951008ae720e00b1798e Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Sun, 15 Jan 2023 14:01:43 -0600 Subject: [PATCH 06/58] added installation script. sac configs correct --- scripts/installation.sh | 14 ++++++++++++++ scripts/sac_mujoco/config/sac.yaml | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 scripts/installation.sh diff --git a/scripts/installation.sh b/scripts/installation.sh new file mode 100644 index 000000000..aa89d5275 --- /dev/null +++ b/scripts/installation.sh @@ -0,0 +1,14 @@ +export MJENV_LIB_PATH="mj_envs" + +python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + +here=$(pwd) +git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch add_all_xmls --recursive https://github.com/vmoens/mj_envs.git $MJENV_LIB_PATH +cd $MJENV_LIB_PATH +python3 -mpip install . # one can also install it locally with the -e flag +cd $here + +python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) +python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) +pip install wandb +pip install hydra-submitit-launcher --upgrade diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml index 1f8e904a0..b45185234 100644 --- a/scripts/sac_mujoco/config/sac.yaml +++ b/scripts/sac_mujoco/config/sac.yaml @@ -3,7 +3,7 @@ default: - override hydra/launcher: local # Logger -exp_name: sac +exp_name: ${task}_sac_r3m record_interval: 1 device: "cuda:0" From 2f07d0c4e74c76e3381fb3393bf90b76b63168fa Mon Sep 17 00:00:00 2001 From: ShahRutav <43668417+ShahRutav@users.noreply.github.com> Date: Sun, 15 Jan 2023 14:37:52 -0600 Subject: [PATCH 07/58] Added a new running instruction for SAC+R3M --- scripts/README.md | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 scripts/README.md diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 000000000..f801eefe5 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,39 @@ +## Installation +``` +git clone --branch=sac_dev https://github.com/facebookresearch/rlhive.git +conda create -n rlhive -y python=3.8 +conda activate rlhive +bash rlhive/scripts/installation.sh +cd rlhive +pip install -e . +``` + +## Testing installation +``` +python -c "import mj_envs" +MUJOCO_GL=egl sim_backend=MUJOCO python -c """ +from rlhive.rl_envs import RoboHiveEnv +env_name = 'visual_franka_slide_random-v3' +base_env = RoboHiveEnv(env_name,) +print(base_env.rollout(3)) + +# check that the env specs are ok +from torchrl.envs.utils import check_env_specs +check_env_specs(base_env) +""" +``` + +## Launching experiments +[NOTE] Set ulimit for your shell (default 1024): `ulimit -n 4096` +Set your slurm configs especially `partition` and `hydra.run.dir` +Slurm files are located at `sac_mujoco/config/hydra/launcher/slurm.yaml` and `sac_mujoco/config/hydra/output/slurm.yaml` +``` +cd sac_mujoco +sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m hydra/launcher=slurm hydra/output=slurm +``` + +To run a small experiment for testing, run the following command: +``` +cd sac_mujoco +sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm +``` From e6067c461690922857fe399b4244135e92a511e4 Mon Sep 17 00:00:00 2001 From: ShahRutav <43668417+ShahRutav@users.noreply.github.com> Date: Sun, 15 Jan 2023 14:38:53 -0600 Subject: [PATCH 08/58] Fixed readme --- scripts/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/README.md b/scripts/README.md index f801eefe5..78123b066 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -28,12 +28,12 @@ check_env_specs(base_env) Set your slurm configs especially `partition` and `hydra.run.dir` Slurm files are located at `sac_mujoco/config/hydra/launcher/slurm.yaml` and `sac_mujoco/config/hydra/output/slurm.yaml` ``` -cd sac_mujoco +cd scripts/sac_mujoco sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m hydra/launcher=slurm hydra/output=slurm ``` To run a small experiment for testing, run the following command: ``` -cd sac_mujoco +cd scripts/sac_mujoco sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm ``` From c6084e893f52faf019f8089831332f4e6e5ac590 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Sun, 15 Jan 2023 14:47:46 -0600 Subject: [PATCH 09/58] Added redq codebase from torchrl --- scripts/redq/config.yaml | 36 +++++++ scripts/redq/redq.py | 216 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 scripts/redq/config.yaml create mode 100644 scripts/redq/redq.py diff --git a/scripts/redq/config.yaml b/scripts/redq/config.yaml new file mode 100644 index 000000000..e595c3db4 --- /dev/null +++ b/scripts/redq/config.yaml @@ -0,0 +1,36 @@ +env_name: HalfCheetah-v4 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 0 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +frames_per_batch: 1024 +optim_steps_per_batch: 1024 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +ou_exploration: 1 +multi_step: 1 +init_random_frames: 25000 +activation: elu +gSDE: 0 +from_pixels: 0 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +collector_devices: [cpu,cpu] +env_per_collector: 1 +num_workers: 2 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 10000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +norm_stats: 1 diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py new file mode 100644 index 000000000..16beb8557 --- /dev/null +++ b/scripts/redq/redq.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import uuid +from datetime import datetime + +import hydra +import torch.cuda +from hydra.core.config_store import ConfigStore +from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import OrnsteinUhlenbeckProcessWrapper +from torchrl.record import VideoRecorder +from torchrl.trainers.helpers.collectors import ( + make_collector_offpolicy, + OffPolicyCollectorConfig, +) +from torchrl.trainers.helpers.envs import ( + correct_for_frame_skip, + EnvConfig, + initialize_observation_norm_transforms, + parallel_env_constructor, + retrieve_observation_norms_state_dict, + transformed_env_constructor, +) +from torchrl.trainers.helpers.logger import LoggerConfig +from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss +from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig +from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig +from torchrl.trainers.loggers.utils import generate_exp_name, get_logger + +config_fields = [ + (config_field.name, config_field.type, config_field) + for config_cls in ( + TrainerConfig, + OffPolicyCollectorConfig, + EnvConfig, + LossConfig, + REDQModelConfig, + LoggerConfig, + ReplayArgsConfig, + ) + for config_field in dataclasses.fields(config_cls) +] + +Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) +cs = ConfigStore.instance() +cs.store(name="config", node=Config) + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + + cfg = correct_for_frame_skip(cfg) + + if not isinstance(cfg.reward_scaling, float): + cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0) + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "REDQ", + cfg.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + + exp_name = generate_exp_name("REDQ", cfg.exp_name) + logger = get_logger( + logger_type=cfg.logger, logger_name="redq_logging", experiment_name=exp_name + ) + video_tag = exp_name if cfg.record_video else "" + + key, init_env_steps, stats = None, None, None + if not cfg.vecnorm and cfg.norm_stats: + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} + elif cfg.from_pixels: + stats = {"loc": 0.5, "scale": 0.5} + + proof_env = transformed_env_constructor( + cfg=cfg, + use_env_creator=False, + stats=stats, + )() + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + + model = make_redq_model( + proof_env, + cfg=cfg, + device=device, + ) + loss_module, target_net_updater = make_redq_loss(model, cfg) + + actor_model_explore = model[0] + if cfg.ou_exploration: + if cfg.gSDE: + raise RuntimeError("gSDE and ou_exploration are incompatible") + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore, + annealing_num_steps=cfg.annealing_frames, + sigma=cfg.ou_sigma, + theta=cfg.ou_theta, + ).to(device) + if device == torch.device("cpu"): + # mostly for debugging + actor_model_explore.share_memory() + + if cfg.gSDE: + with torch.no_grad(), set_exploration_mode("random"): + # get dimensions to build the parallel env + proof_td = actor_model_explore(proof_env.reset().to(device)) + action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:] + del proof_td + else: + action_dim_gsde, state_dim_gsde = None, None + + proof_env.close() + create_env_fn = parallel_env_constructor( + cfg=cfg, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model_explore, + cfg=cfg, + # make_env_kwargs=[ + # {"device": device} if device >= 0 else {} + # for device in args.env_rendering_devices + # ], + ) + + replay_buffer = make_replay_buffer(device, cfg) + + recorder = transformed_env_constructor( + cfg, + video_tag=video_tag, + norm_obs_only=True, + obs_norm_state_dict=obs_norm_state_dict, + logger=logger, + use_env_creator=False, + )() + + # remove video recorder from recorder to have matching state_dict keys + if cfg.record_video: + recorder_rm = TransformedEnv(recorder.base_env) + for transform in recorder.transform: + if not isinstance(transform, VideoRecorder): + recorder_rm.append_transform(transform.clone()) + else: + recorder_rm = recorder + + if isinstance(create_env_fn, ParallelEnv): + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + create_env_fn.close() + elif isinstance(create_env_fn, EnvCreator): + recorder_rm.load_state_dict(create_env_fn().state_dict()) + else: + recorder_rm.load_state_dict(create_env_fn.state_dict()) + + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + t.loc.fill_(0.0) + + trainer = make_trainer( + collector, + loss_module, + recorder, + target_net_updater, + actor_model_explore, + replay_buffer, + logger, + cfg, + ) + + final_seed = collector.set_seed(cfg.seed) + print(f"init seed: {cfg.seed}, final seed: {final_seed}") + + trainer.train() + return (logger.log_dir, trainer._log_dict) + + +if __name__ == "__main__": + main() From 1f02c30c2d37d182cf180942c0f1945837f13dfa Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Sun, 15 Jan 2023 23:27:21 -0600 Subject: [PATCH 10/58] updated redq script with robohive env --- scripts/redq/config.yaml | 8 ++- scripts/redq/redq.py | 117 +++++++++++++++++++++++++++------------ 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/scripts/redq/config.yaml b/scripts/redq/config.yaml index e595c3db4..37e0ce8d6 100644 --- a/scripts/redq/config.yaml +++ b/scripts/redq/config.yaml @@ -1,4 +1,5 @@ -env_name: HalfCheetah-v4 +# Environment +env_name: visual_franka_slide_random-v3 env_task: "" env_library: gym async_collection: 1 @@ -6,6 +7,7 @@ record_video: 0 normalize_rewards_online: 1 normalize_rewards_online_scale: 5 frame_skip: 1 +reward_scaling: 5.0 frames_per_batch: 1024 optim_steps_per_batch: 1024 batch_size: 256 @@ -17,9 +19,11 @@ multi_step: 1 init_random_frames: 25000 activation: elu gSDE: 0 +# Internal assumption by make_redq, hard codes key values if from_pixels is True from_pixels: 0 #collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] -collector_devices: [cpu,cpu] +#collector_devices: [cpu,cpu] +collector_devices: [cuda:0] env_per_collector: 1 num_workers: 2 lr_scheduler: "" diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py index 16beb8557..709aac78b 100644 --- a/scripts/redq/redq.py +++ b/scripts/redq/redq.py @@ -25,7 +25,7 @@ initialize_observation_norm_transforms, parallel_env_constructor, retrieve_observation_norms_state_dict, - transformed_env_constructor, + #transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss @@ -33,6 +33,54 @@ from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig from torchrl.trainers.loggers.utils import generate_exp_name, get_logger +from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose + +from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, +) +from rlhive.rl_envs import RoboHiveEnv +def make_env( + task, + reward_scaling, + device + ): + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) + + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + stats=None, +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + selected_keys = ["r3m_vec", "observation"] + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) + return env config_fields = [ (config_field.name, config_field.type, config_field) @@ -65,27 +113,12 @@ @hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - - cfg = correct_for_frame_skip(cfg) - - if not isinstance(cfg.reward_scaling, float): - cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0) - device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0") ) - exp_name = "_".join( - [ - "REDQ", - cfg.exp_name, - str(uuid.uuid4())[:8], - datetime.now().strftime("%y_%m_%d-%H_%M_%S"), - ] - ) - exp_name = generate_exp_name("REDQ", cfg.exp_name) logger = get_logger( logger_type=cfg.logger, logger_name="redq_logging", experiment_name=exp_name @@ -96,26 +129,28 @@ def main(cfg: "DictConfig"): # noqa: F821 if not cfg.vecnorm and cfg.norm_stats: if not hasattr(cfg, "init_env_steps"): raise AttributeError("init_env_steps missing from arguments.") - key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector") + key = ("next", "observation_vector") init_env_steps = cfg.init_env_steps stats = {"loc": None, "scale": None} elif cfg.from_pixels: stats = {"loc": 0.5, "scale": 0.5} - proof_env = transformed_env_constructor( - cfg=cfg, - use_env_creator=False, - stats=stats, - )() + proof_env = make_env( + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + ) initialize_observation_norm_transforms( proof_environment=proof_env, num_iter=init_env_steps, key=key ) _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + print(proof_env) model = make_redq_model( proof_env, cfg=cfg, device=device, + in_keys=["observation_vector"], ) loss_module, target_net_updater = make_redq_loss(model, cfg) @@ -143,12 +178,17 @@ def main(cfg: "DictConfig"): # noqa: F821 action_dim_gsde, state_dim_gsde = None, None proof_env.close() - create_env_fn = parallel_env_constructor( - cfg=cfg, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde, - ) + #create_env_fn = parallel_env_constructor( + # cfg=cfg, + # obs_norm_state_dict=obs_norm_state_dict, + # action_dim_gsde=action_dim_gsde, + # state_dim_gsde=state_dim_gsde, + #) + create_env_fn = make_env( ## Pass EnvBase instead of the create_env_fn + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + ) collector = make_collector_offpolicy( make_env=create_env_fn, @@ -162,14 +202,19 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer = make_replay_buffer(device, cfg) - recorder = transformed_env_constructor( - cfg, - video_tag=video_tag, - norm_obs_only=True, - obs_norm_state_dict=obs_norm_state_dict, - logger=logger, - use_env_creator=False, - )() + #recorder = transformed_env_constructor( + # cfg, + # video_tag=video_tag, + # norm_obs_only=True, + # obs_norm_state_dict=obs_norm_state_dict, + # logger=logger, + # use_env_creator=False, + #)() + recorder = make_env( + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + ) # remove video recorder from recorder to have matching state_dict keys if cfg.record_video: From fab90848fcd4fb9f55275e24f812d6d09ed35b94 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Mon, 16 Jan 2023 14:05:43 -0600 Subject: [PATCH 11/58] Added RRLTransform --- scripts/sac_mujoco/config/sac.yaml | 3 +- scripts/sac_mujoco/rrl.py | 324 +++++++++++++++++++++++++++++ scripts/sac_mujoco/sac.py | 20 +- scripts/sac_mujoco/test.py | 54 ++++- 4 files changed, 391 insertions(+), 10 deletions(-) create mode 100644 scripts/sac_mujoco/rrl.py diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml index b45185234..bb914f9ec 100644 --- a/scripts/sac_mujoco/config/sac.yaml +++ b/scripts/sac_mujoco/config/sac.yaml @@ -3,7 +3,8 @@ default: - override hydra/launcher: local # Logger -exp_name: ${task}_sac_r3m +exp_name: ${task}_sac_${visual_transform} +visual_transform: r3m record_interval: 1 device: "cuda:0" diff --git a/scripts/sac_mujoco/rrl.py b/scripts/sac_mujoco/rrl.py new file mode 100644 index 000000000..b9b501485 --- /dev/null +++ b/scripts/sac_mujoco/rrl.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Union + +import torch +from tensordict import TensorDict +from torch.hub import load_state_dict_from_url +from torch.nn import Identity + +from torchrl.data.tensor_specs import ( + CompositeSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs.transforms.transforms import ( + CatTensors, + Compose, + FlattenObservation, + ObservationNorm, + Resize, + ToTensorImage, + Transform, + UnsqueezeTransform, +) + +try: + from torchvision import models + + _has_tv = True +except ImportError: + _has_tv = False + + +class _RRLNet(Transform): + + inplace = False + + def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): + if not _has_tv: + raise ImportError( + "Tried to instantiate RRL without torchvision. Make sure you have " + "torchvision installed in your environment." + ) + if model_name == "resnet18": + self.model_name = "rrl_18" + self.outdim = 512 + convnet = models.resnet18(pretrained=True) + elif model_name == "resnet34": + self.model_name = "rrl_34" + self.outdim = 512 + convnet = models.resnet34(pretrained=True) + elif model_name == "resnet50": + self.model_name = "rrl_50" + self.outdim = 2048 + convnet = models.resnet50(pretrained=True) + else: + raise NotImplementedError( + f"model {model_name} is currently not supported by RRL" + ) + convnet.fc = Identity() + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.convnet = convnet + self.del_keys = del_keys + + def _call(self, tensordict): + tensordict_view = tensordict.view(-1) + super()._call(tensordict_view) + if self.del_keys: + tensordict.exclude(*self.in_keys, inplace=True) + return tensordict + + @torch.no_grad() + def _apply_transform(self, obs: torch.Tensor) -> None: + shape = None + if obs.ndimension() > 4: + shape = obs.shape[:-3] + obs = obs.flatten(0, -4) + out = self.convnet(obs) + if shape is not None: + out = out.view(*shape, *out.shape[1:]) + return out + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if not isinstance(observation_spec, CompositeSpec): + raise ValueError("_RRLNet can only infer CompositeSpec") + + keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] + device = observation_spec[keys[0]].device + dim = observation_spec[keys[0]].shape[:-3] + + observation_spec = CompositeSpec(observation_spec) + if self.del_keys: + for in_key in keys: + del observation_spec[in_key] + + for out_key in self.out_keys: + observation_spec[out_key] = UnboundedContinuousTensorSpec( + shape=torch.Size([*dim, self.outdim]), device=device + ) + + return observation_spec + + #@staticmethod + #def _load_weights(model_name, r3m_instance, dir_prefix): + # if model_name not in ("r3m_50", "r3m_34", "r3m_18"): + # raise ValueError( + # "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'" + # ) + # # url = "https://download.pytorch.org/models/rl/r3m/" + model_name + # url = "https://pytorch.s3.amazonaws.com/models/rl/r3m/" + model_name + ".pt" + # d = load_state_dict_from_url( + # url, + # progress=True, + # map_location=next(r3m_instance.parameters()).device, + # model_dir=dir_prefix, + # ) + # td = TensorDict(d["r3m"], []).unflatten_keys(".") + # td_flatten = td["module"]["convnet"].flatten_keys(".") + # state_dict = td_flatten.to_dict() + # r3m_instance.convnet.load_state_dict(state_dict) + + #def load_weights(self, dir_prefix=None): + # self._load_weights(self.model_name, self, dir_prefix) + + +def _init_first(fun): + def new_fun(self, *args, **kwargs): + if not self.initialized: + self._init() + return fun(self, *args, **kwargs) + + return new_fun + + +class RRLTransform(Compose): + """RRL Transform class. + + RRL provides pre-trained ResNet weights aimed at facilitating visual + embedding for robotic tasks. The models are trained using Ego4d. + + See the paper: + Shah, Rutav, and Vikash Kumar. "RRl: Resnet as representation for reinforcement learning." + arXiv preprint arXiv:2107.03380 (2021). + The RRLTransform is created in a lazy manner: the object will be initialized + only when an attribute (a spec or the forward method) will be queried. + The reason for this is that the :obj:`_init()` method requires some attributes of + the parent environment (if any) to be accessed: by making the class lazy we + can ensure that the following code snippet works as expected: + + Examples: + >>> transform = RRLTransform("resnet50", in_keys=["pixels"]) + >>> env.append_transform(transform) + >>> # the forward method will first call _init which will look at env.observation_spec + >>> env.reset() + + Args: + model_name (str): one of resnet50, resnet34 or resnet18 + in_keys (list of str): list of input keys. If left empty, the + "pixels" key is assumed. + out_keys (list of str, optional): list of output keys. If left empty, + "rrl_vec" is assumed. + size (int, optional): Size of the image to feed to resnet. + Defaults to 244. + stack_images (bool, optional): if False, the images given in the :obj:`in_keys` + argument will be treaded separetely and each will be given a single, + separated entry in the output tensordict. Defaults to :obj:`True`. + download (bool, optional): if True, the weights will be downloaded using + the torch.hub download API (i.e. weights will be cached for future use). + Defaults to False. + download_path (str, optional): path where to download the models. + Default is None (cache path determined by torch.hub utils). + tensor_pixels_keys (list of str, optional): Optionally, one can keep the + original images (as collected from the env) in the output tensordict. + If no value is provided, this won't be collected. + """ + + @classmethod + def __new__(cls, *args, **kwargs): + cls.initialized = False + cls._device = None + cls._dtype = None + return super().__new__(cls) + + def __init__( + self, + model_name: str, + in_keys: List[str], + out_keys: List[str] = None, + size: int = 244, + stack_images: bool = True, + download: bool = False, + download_path: Optional[str] = None, + tensor_pixels_keys: List[str] = None, + ): + super().__init__() + self.in_keys = in_keys if in_keys is not None else ["pixels"] + self.download = download + self.download_path = download_path + self.model_name = model_name + self.out_keys = out_keys + self.size = size + self.stack_images = stack_images + self.tensor_pixels_keys = tensor_pixels_keys + self._init() + + def _init(self): + """Initializer for RRL.""" + self.initialized = True + in_keys = self.in_keys + model_name = self.model_name + out_keys = self.out_keys + size = self.size + stack_images = self.stack_images + tensor_pixels_keys = self.tensor_pixels_keys + + # ToTensor + transforms = [] + if tensor_pixels_keys: + for i in range(len(in_keys)): + transforms.append( + CatTensors( + in_keys=[in_keys[i]], + out_key=tensor_pixels_keys[i], + del_keys=False, + ) + ) + + totensor = ToTensorImage( + unsqueeze=False, + in_keys=in_keys, + ) + transforms.append(totensor) + + # Normalize + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + normalize = ObservationNorm( + in_keys=in_keys, + loc=torch.tensor(mean).view(3, 1, 1), + scale=torch.tensor(std).view(3, 1, 1), + standard_normal=True, + ) + transforms.append(normalize) + + # Resize: note that resize is a no-op if the tensor has the desired size already + resize = Resize(size, size, in_keys=in_keys) + transforms.append(resize) + + # RRL + if out_keys is None: + if stack_images: + out_keys = ["rrl_vec"] + else: + out_keys = [f"rrl_vec_{i}" for i in range(len(in_keys))] + self.out_keys = out_keys + elif stack_images and len(out_keys) != 1: + raise ValueError( + f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}" + ) + elif not stack_images and len(out_keys) != len(in_keys): + raise ValueError( + "out_key must be of length equal to in_keys if stack_images is False." + ) + + if stack_images and len(in_keys) > 1: + + unsqueeze = UnsqueezeTransform( + in_keys=in_keys, + out_keys=in_keys, + unsqueeze_dim=-4, + ) + transforms.append(unsqueeze) + + cattensors = CatTensors( + in_keys, + out_keys[0], + dim=-4, + ) + network = _RRLNet( + in_keys=out_keys, + out_keys=out_keys, + model_name=model_name, + del_keys=False, + ) + flatten = FlattenObservation(-2, -1, out_keys) + transforms = [*transforms, cattensors, network, flatten] + + else: + network = _RRLNet( + in_keys=in_keys, + out_keys=out_keys, + model_name=model_name, + del_keys=True, + ) + transforms = [*transforms, network] + + for transform in transforms: + self.append(transform) + #if self.download: + # self[-1].load_weights(dir_prefix=self.download_path) + + if self._device is not None: + self.to(self._device) + if self._dtype is not None: + self.to(self._dtype) + + def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + if isinstance(dest, torch.dtype): + self._dtype = dest + else: + self._device = dest + return super().to(dest) + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index ebdbc959d..05a5da9b5 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -46,16 +46,20 @@ from rlhive.rl_envs import RoboHiveEnv from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform +from rrl import RRLTransform os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior def make_env( task, + visual_transform, reward_scaling, device ): + assert visual_transform in ('rrl', 'r3m') base_env = RoboHiveEnv(task, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) + print(env) return env @@ -63,15 +67,24 @@ def make_env( def make_transformed_env( env, reward_scaling=5.0, + visual_transform='r3m', stats=None, ): """ Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + if visual_transform == 'rrl': + vec_keys = ["rrl_vec"] + selected_keys = ["observation", "rrl_vec"] + env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == 'r3m': + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) - selected_keys = ["r3m_vec", "observation"] out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) @@ -171,6 +184,7 @@ def main(args: DictConfig): # Create Environment env_configs = { "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, "device": args.device, } train_env = make_env(task=args.task, **env_configs) diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py index 9b52574a5..d2af1146b 100644 --- a/scripts/sac_mujoco/test.py +++ b/scripts/sac_mujoco/test.py @@ -2,25 +2,67 @@ from rlhive.rl_envs import RoboHiveEnv from torchrl.envs.utils import set_exploration_mode from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose -from torchrl.envs import TransformedEnv, R3MTransform +from torchrl.envs import TransformedEnv, R3MTransform, SelectTransform from torchrl.envs import ( CatTensors, DoubleToFloat, EnvCreator, ObservationNorm, ) +from rrl import RRLTransform + +def make_env( + task, + visual_transform, + reward_scaling, + device + ): + assert visual_transform in ('rrl', 'r3m') + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) + print(env) + #exit() + + return env + def make_transformed_env( env, + reward_scaling=5.0, + visual_transform='r3m', stats=None, ): - env = TransformedEnv(env, Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + if visual_transform == 'rrl': + vec_keys = ["rrl_vec"] + selected_keys = ["observation", "rrl_vec"] + env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == 'r3m': + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env -base_env = RoboHiveEnv("visual_franka_slide_random-v3", device=torch.device('cuda:0')) -env = base_env -env = make_transformed_env(base_env) -print(env) +env = make_env(task="visual_franka_slide_random-v3", reward_scaling=5.0, device=torch.device('cuda:0'), visual_transform='rrl') with torch.no_grad(), set_exploration_mode("random"): td = env.reset() td = env.rand_step() From 76e601a9ff26e5c482230f064ef2977a9d6eecad Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Mon, 16 Jan 2023 14:22:23 -0600 Subject: [PATCH 12/58] moved rrl_transform inside helpers --- .../rrl.py => rlhive/sim_algos/helpers/rrl_transform.py | 0 scripts/sac_mujoco/sac.py | 2 +- scripts/sac_mujoco/test.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename scripts/sac_mujoco/rrl.py => rlhive/sim_algos/helpers/rrl_transform.py (100%) diff --git a/scripts/sac_mujoco/rrl.py b/rlhive/sim_algos/helpers/rrl_transform.py similarity index 100% rename from scripts/sac_mujoco/rrl.py rename to rlhive/sim_algos/helpers/rrl_transform.py diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index 05a5da9b5..56bb61f99 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -46,7 +46,7 @@ from rlhive.rl_envs import RoboHiveEnv from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform -from rrl import RRLTransform +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py index d2af1146b..fec7a06d1 100644 --- a/scripts/sac_mujoco/test.py +++ b/scripts/sac_mujoco/test.py @@ -9,7 +9,7 @@ EnvCreator, ObservationNorm, ) -from rrl import RRLTransform +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform def make_env( task, From 850c3d977bfbd1f47296b6d0847126e944a03390 Mon Sep 17 00:00:00 2001 From: ShahRutav <43668417+ShahRutav@users.noreply.github.com> Date: Tue, 17 Jan 2023 15:09:05 -0600 Subject: [PATCH 13/58] Updated README with parameter sweep --- scripts/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/scripts/README.md b/scripts/README.md index 78123b066..f2a8a9153 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -37,3 +37,35 @@ To run a small experiment for testing, run the following command: cd scripts/sac_mujoco sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm ``` + +## Parameter Sweep +1. R3M and RRL experiments: `visual_transform=r3m,rrl` +2. Multiple seeds: `seed=42,43,44` +3. List of environments: + ``` +task=visual_franka_slide_random-v3,\ + visual_franka_slide_close-v3,\ + visual_franka_slide_open-v3,\ + visual_franka_micro_random-v3,\ + visual_franka_micro_close-v3,\ + visual_franka_micro_open-v3,\ + visual_kitchen_knob1_off-v3,\ + visual_kitchen_knob1_on-v3,\ + visual_kitchen_knob2_off-v3,\ + visual_kitchen_knob2_on-v3,\ + visual_kitchen_knob3_off-v3,\ + visual_kitchen_knob3_on-v3,\ + visual_kitchen_knob4_off-v3,\ + visual_kitchen_knob4_on-v3,\ + visual_kitchen_light_off-v3,\ + visual_kitchen_light_on-v3,\ + visual_kitchen_sdoor_close-v3,\ + visual_kitchen_sdoor_open-v3,\ + visual_kitchen_ldoor_close-v3,\ + visual_kitchen_ldoor_open-v3,\ + visual_kitchen_rdoor_close-v3,\ + visual_kitchen_rdoor_open-v3,\ + visual_kitchen_micro_close-v3,\ + visual_kitchen_micro_open-v3,\ + visual_kitchen_close-v3 + ``` From 2a942abdac1cda9e1647b03d054fecab1ec8f318 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Mon, 23 Jan 2023 18:17:37 -0600 Subject: [PATCH 14/58] updated redq with action, state, and obs norms --- scripts/redq/redq.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py index 709aac78b..9800bd3cb 100644 --- a/scripts/redq/redq.py +++ b/scripts/redq/redq.py @@ -47,11 +47,22 @@ def make_env( task, reward_scaling, - device + device, + obs_norm_state_dict=None, + action_dim_gsde=None, + state_dim_gsde=None, ): base_env = RoboHiveEnv(task, device=device) env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) + if not obs_norm_state_dict is None: + obs_norm = ObservationNorm(**obs_norm_state_dict, in_keys=["observation_vector"]) + env.append_transform(obs_norm) + + if not action_dim_gsde is None: + env.append_transform( + gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde) + ) return env @@ -178,16 +189,13 @@ def main(cfg: "DictConfig"): # noqa: F821 action_dim_gsde, state_dim_gsde = None, None proof_env.close() - #create_env_fn = parallel_env_constructor( - # cfg=cfg, - # obs_norm_state_dict=obs_norm_state_dict, - # action_dim_gsde=action_dim_gsde, - # state_dim_gsde=state_dim_gsde, - #) create_env_fn = make_env( ## Pass EnvBase instead of the create_env_fn task=cfg.env_name, reward_scaling=cfg.reward_scaling, device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde ) collector = make_collector_offpolicy( @@ -214,6 +222,9 @@ def main(cfg: "DictConfig"): # noqa: F821 task=cfg.env_name, reward_scaling=cfg.reward_scaling, device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde ) # remove video recorder from recorder to have matching state_dict keys From 5823199bb7989724ffd3f89ed9fd2b2c71a53409 Mon Sep 17 00:00:00 2001 From: ShahRutav Date: Wed, 25 Jan 2023 16:02:25 -0600 Subject: [PATCH 15/58] updated the code with torchrl sacloss and rrl transform --- scripts/sac_mujoco/sac.py | 15 +- scripts/sac_mujoco/sac_loss.py | 311 --------------------------------- 2 files changed, 8 insertions(+), 318 deletions(-) delete mode 100644 scripts/sac_mujoco/sac_loss.py diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index 56bb61f99..885c6d5ff 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -1,4 +1,7 @@ -# Make all the necessary imports for training +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import os @@ -15,8 +18,7 @@ import hydra from omegaconf import DictConfig, OmegaConf, open_dict import wandb -#from torchrl.objectives import SACLoss -from sac_loss import SACLoss +from torchrl.objectives import SACLoss from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector @@ -46,7 +48,6 @@ from rlhive.rl_envs import RoboHiveEnv from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform -from rlhive.sim_algos.helpers.rrl_transform import RRLTransform os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior @@ -75,9 +76,9 @@ def make_transformed_env( """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) if visual_transform == 'rrl': - vec_keys = ["rrl_vec"] - selected_keys = ["observation", "rrl_vec"] - env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download="IMAGENET1K_V1"), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 elif visual_transform == 'r3m': vec_keys = ["r3m_vec"] selected_keys = ["observation", "r3m_vec"] diff --git a/scripts/sac_mujoco/sac_loss.py b/scripts/sac_mujoco/sac_loss.py deleted file mode 100644 index cebe7f2e9..000000000 --- a/scripts/sac_mujoco/sac_loss.py +++ /dev/null @@ -1,311 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math -from numbers import Number -from typing import Union - -import numpy as np -import torch - -from tensordict.nn import TensorDictSequential -from tensordict.tensordict import TensorDict, TensorDictBase -from torch import Tensor - -from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import SafeModule -from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import ( - distance_loss, - next_state_value as get_next_state_value, -) - -try: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - - -class SACLoss(LossModule): - """SAC Loss module. - Args: - actor_network (SafeModule): the actor to be trained - qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. - num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. - gamma (Number, optional): gamma decay factor. Default is 0.99. - priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is - `"td_error"`. - loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", - "l1", Default is "smooth_l1". - alpha_init (float, optional): initial entropy multiplier. - Default is 1.0. - min_alpha (float, optional): min value of alpha. - Default is 0.1. - max_alpha (float, optional): max value of alpha. - Default is 10.0. - fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. - target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". - delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used - for data collection. Default is :obj:`False`. - gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. - Default is False - """ - - delay_actor: bool = False - - def __init__( - self, - actor_network: SafeModule, - qvalue_network: SafeModule, - num_qvalue_nets: int = 2, - gamma: Number = 0.99, - priotity_key: str = "td_error", - loss_function: str = "smooth_l1", - alpha_init: float = 1.0, - min_alpha: float = 0.1, - max_alpha: float = 10.0, - fixed_alpha: bool = False, - target_entropy: Union[str, Number] = "auto", - delay_qvalue: bool = True, - gSDE: bool = False, - ): - if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" - ) - - super().__init__() - self.convert_to_functional( - actor_network, - "actor_network", - create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist_params"], - ) - - # let's make sure that actor_network has `return_log_prob` set to True - self.actor_network.return_log_prob = True - - self.delay_qvalue = delay_qvalue - self.convert_to_functional( - qvalue_network, - "qvalue_network", - num_qvalue_nets, - create_target_params=self.delay_qvalue, - compare_against=list(actor_network.parameters()), - ) - self.num_qvalue_nets = num_qvalue_nets - self.register_buffer("gamma", torch.tensor(gamma)) - self.priority_key = priotity_key - self.loss_function = loss_function - - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") - - self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) - self.register_buffer( - "min_log_alpha", torch.tensor(min_alpha, device=device).log() - ) - self.register_buffer( - "max_log_alpha", torch.tensor(max_alpha, device=device).log() - ) - self.fixed_alpha = fixed_alpha - if fixed_alpha: - self.register_buffer( - "log_alpha", torch.tensor(math.log(alpha_init), device=device) - ) - else: - self.register_parameter( - "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), - ) - - if target_entropy == "auto": - if actor_network.spec["action"] is None: - raise RuntimeError( - "Cannot infer the dimensionality of the action. Consider providing " - "the target entropy explicitely or provide the spec of the " - "action tensor in the actor network." - ) - target_entropy = -float(np.prod(actor_network.spec["action"].shape)) - self.register_buffer( - "target_entropy", torch.tensor(target_entropy, device=device) - ) - self.gSDE = gSDE - - @property - def alpha(self): - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) - with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha - - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - obs_keys = self.actor_network.in_keys - tensordict_select = tensordict.select( - "reward", "done", "next", *obs_keys, "action" - ) - - actor_params = torch.stack( - [self.actor_network_params, self.target_actor_network_params], 0 - ) - - tensordict_actor_grad = tensordict_select.select( - *obs_keys - ) # to avoid overwriting keys - next_td_actor = step_mdp(tensordict_select).select( - *self.actor_network.in_keys - ) # next_observation -> - tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - tensordict_actor = tensordict_actor.contiguous() - - with set_exploration_mode("random"): - if self.gSDE: - tensordict_actor.set( - "_eps_gSDE", - torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), - ) - # vmap doesn't support sampling, so we take it out from the vmap - td_params = vmap(self.actor_network.get_dist_params)( - tensordict_actor, - actor_params, - ) - if isinstance(self.actor_network, TensorDictSequential): - sample_key = self.actor_network[-1].out_keys[0] - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - else: - sample_key = self.actor_network.out_keys[0] - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - tensordict_actor[sample_key] = tensordict_actor_dist.rsample() - tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( - tensordict_actor[sample_key] - ) - - # repeat tensordict_actor to match the qvalue size - _actor_loss_td = ( - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) - ) # for actor loss - _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, - ) # for qvalue loss - _next_val_td = ( - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) - ) # for next value estimation - tensordict_qval = torch.cat( - [ - _actor_loss_td, - _next_val_td, - _qval_td, - ], - 0, - ) - - # cat params - q_params_detach = self.qvalue_network_params.detach() - qvalue_params = torch.cat( - [ - q_params_detach, - self.target_qvalue_network_params, - self.qvalue_network_params, - ], - 0, - ) - tensordict_qval = vmap(self.qvalue_network)( - tensordict_qval, - qvalue_params, - ) - - state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) - ( - state_action_value_actor, - next_state_action_value_qvalue, - state_action_value_qvalue, - ) = state_action_value.split( - [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], - dim=0, - ) - sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) - ( - action_log_prob_actor, - next_action_log_prob_qvalue, - ) = sample_log_prob.unbind(0) - - loss_actor = -( - state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor - ).mean() - - next_state_value = ( - next_state_action_value_qvalue.min(0)[0] - - self.alpha * next_action_log_prob_qvalue - ) - - target_value = get_next_state_value( - tensordict, - gamma=self.gamma, - pred_next_val=next_state_value, - ) - pred_val = state_action_value_qvalue - td_error = (pred_val - target_value).pow(2) - loss_qval = ( - distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, - ) - .mean(-1) - .sum() - * 0.5 - ) - - tensordict.set("td_error", td_error.detach().max(0)[0]) - - loss_alpha = self._loss_alpha(sample_log_prob) - if not loss_qval.shape == loss_actor.shape: - raise RuntimeError( - f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" - ) - td_out = TensorDict( - { - "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qval.mean(), - "loss_alpha": loss_alpha.mean(), - "alpha": self.alpha.detach(), - "entropy": -sample_log_prob.mean().detach(), - "state_action_value_actor": state_action_value_actor.mean().detach(), - "action_log_prob_actor": action_log_prob_actor.mean().detach(), - "next.state_value": next_state_value.mean().detach(), - "target_value": target_value.mean().detach(), - }, - [], - ) - - return td_out - - def _loss_alpha(self, log_pi: Tensor) -> Tensor: - if torch.is_grad_enabled() and not log_pi.requires_grad: - raise RuntimeError( - "expected log_pi to require gradient for the alpha loss)" - ) - if self.target_entropy is not None: - # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) - else: - # placeholder - alpha_loss = torch.zeros_like(log_pi) - return alpha_loss From e68f917a93eba43ce2ffb85b46116961bcb5dea9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:32:01 +0000 Subject: [PATCH 16/58] init --- examples/config/sac.yaml | 42 ++ examples/sac.py | 471 ++++++++++++++++++ rlhive/envs.py | 36 +- rlhive/rl_envs.py | 10 +- rlhive/sim_algos/helpers/rrl_transform.py | 8 +- scripts/redq/redq.py | 79 +-- scripts/sac_mujoco/config/group/group1.yaml | 2 +- scripts/sac_mujoco/config/group/group2.yaml | 2 +- .../sac_mujoco/config/hydra/output/local.yaml | 2 +- scripts/sac_mujoco/sac.py | 145 +++--- scripts/sac_mujoco/test.py | 60 ++- 11 files changed, 710 insertions(+), 147 deletions(-) create mode 100644 examples/config/sac.yaml create mode 100644 examples/sac.py diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml new file mode 100644 index 000000000..aaaf96bf9 --- /dev/null +++ b/examples/config/sac.yaml @@ -0,0 +1,42 @@ +default: + - override hydra/output: local + - override hydra/launcher: local + +# Logger +exp_name: ${task}_sac_${visual_transform} +visual_transform: r3m +record_interval: 1 +device: "cpu" + +# Environment +task: visual_franka_slide_random-v3 +frame_skip: 1 +reward_scaling: 5.0 +init_env_steps: 1000 +seed: 42 +eval_traj: 25 +num_envs: 8 + +# Collector +env_per_collector: 1 +max_frames_per_traj: -1 +total_frames: 1000000 +init_random_frames: 25000 +frames_per_batch: 10 + +# Replay Buffer +prb: 0 +buffer_size: 100000 +buffer_scratch_dir: /tmp/ + +# Optimization +gamma: 0.99 +batch_size: 256 +lr: 3.0e-4 +weight_decay: 0.0 +target_update_polyak: 0.995 +utd_ratio: 1 + +hydra: + job: + name: sac_${task}_${seed} diff --git a/examples/sac.py b/examples/sac.py new file mode 100644 index 000000000..8c287b981 --- /dev/null +++ b/examples/sac.py @@ -0,0 +1,471 @@ +# TODO +# Simplify +# logger +# check SAC loss vs torchrl's +# Make all the necessary imports for training +import os + +os.environ["sim_backend"] = "MUJOCO" + +import gc +import os +from copy import deepcopy + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from omegaconf import DictConfig +from rlhive.rl_envs import RoboHiveEnv + +# from torchrl.objectives import SACLoss +from tensordict import TensorDict + +from torch import nn, optim +from torchrl.data import TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage + +# from torchrl.envs import SerialEnv as ParallelEnv, R3MTransform, SelectTransform, TransformedEnv +from torchrl.envs import ( + CatTensors, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) +from torchrl.envs.transforms import Compose, FlattenObservation, RewardScaling +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import MLP, NormalParamWrapper, SafeModule +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + +from torchrl.objectives import SACLoss, SoftUpdate +from torchrl.record.loggers import WandbLogger +from torchrl.trainers import Recorder + + +# =========================================================================================== +# Env constructor +# --------------- +# - Use the RoboHiveEnv class to wrap robohive envs in torchrl's GymWrapper +# - Add transforms immediately after that: +# - SelectTransform: selects the relevant kesy from our output +# - R3MTransform +# - FlattenObservation: The images delivered by robohive have a singleton dim to start with, we need to flatten that +# - RewardScaling +# +# One can also possibly use ObservationNorm. +# +# TIPS: +# - For faster execution, you should follow this abstract scheme, where we reduce the data +# to be passed from worker to worker to a minimum, we apply R3M to a batch and append the +# rest of the transforms afterward: +# +# >>> env = TransformedEnv( +# ... ParallelEnv(N, lambda: TransformedEnv(RoboHiveEnv(...), SelectTransform(...))), +# ... Compose( +# ... R3MTransform(...), +# ... FlattenObservation(...), +# ... *other_transforms, +# ... )) +# + + +def make_env(num_envs, task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") + if num_envs > 1: + base_env = ParallelEnv(num_envs, lambda: RoboHiveEnv(task, device=device)) + else: + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) + + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + visual_transform="r3m", +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv( + env, + SelectTransform("solved", "pixels", "observation", "rwd_dense", "rwd_sparse"), + ) + if visual_transform == "r3m": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + return env + + +# =========================================================================================== +# Making a recorder +# ----------------- +# +# A `Recorder` is a dedicated torchrl class that will run the policy in the test env +# once every X steps (eg X=1M). +# + + +def make_recorder( + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, +): + test_env = make_env(num_envs=1, task=task, **env_configs) + recorder_obj = Recorder( + record_frames=eval_traj * test_env.horizon, + frame_skip=frame_skip, + policy_exploration=actor_model_explore, + recorder=test_env, + exploration_mode="mean", + record_interval=record_interval, + log_keys=["reward", "solved"], + out_keys={"reward": "r_evaluation", "solved": "success"}, + ) + return recorder_obj + + +# =========================================================================================== +# Relplay buffers +# --------------- +# +# TorchRL also provides prioritized RBs if needed. +# + + +def make_replay_buffer( + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + return replay_buffer + + +# =========================================================================================== +# Dataloader +# ---------- +# +# This is a simplified version of the dataloder +# + + +@torch.no_grad() +@set_exploration_mode("random") +def dataloader( + total_frames, fpb, train_env, actor, actor_collection, device_collection +): + params = TensorDict( + {k: v for k, v in actor.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + params_collection = TensorDict( + {k: v for k, v in actor_collection.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + _prev = None + + params_collection.update_(params) + collected_frames = 0 + while collected_frames < total_frames: + batch = TensorDict( + {}, batch_size=[fpb, *train_env.batch_size], device=device_collection + ) + for t in range(fpb): + if _prev is None: + _prev = train_env.reset() + _reset = _prev["_reset"] = _prev["done"].clone().squeeze(-1) + if _reset.any(): + _prev = train_env.reset(_prev) + _new = train_env.step(actor_collection(_prev)) + batch[t] = _new + _prev = step_mdp(_new, exclude_done=False) + collected_frames += batch.numel() + yield batch + + +@hydra.main(config_name="sac.yaml", config_path="config") +def main(args: DictConfig): + # customize device at will + device = "cpu" + device_collection = "cpu" + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Create Environment + env_configs = { + "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, + "device": args.device, + } + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs) + + # Create Agent + # Define Actor Network + in_keys = ["observation_vector"] + action_spec = train_env.action_spec + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": 2 * action_spec.shape[-1], + "activation_class": nn.ReLU, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": False, + } + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ReLU, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # add forward pass for initialization with proof env + proof_env = make_env(num_envs=1, task=args.task, **env_configs) + # init nets + with torch.no_grad(), set_exploration_mode("random"): + td = proof_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + proof_env.close() + + actor_collection = deepcopy(actor).to(device_collection) + + actor_model_explore = model[0] + + # Create SAC loss + loss_module = SACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + gamma=args.gamma, + loss_function="smooth_l1", + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, args.target_update_polyak) + + # Make Replay Buffer + replay_buffer = make_replay_buffer( + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) + + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + ) + + # Optimizers + params = list(loss_module.parameters()) + list([loss_module.log_alpha]) + optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + + rewards = [] + rewards_eval = [] + + # Main loop + target_net_updater.init_() + + collected_frames = 0 + episodes = 0 + optim_steps = 0 + pbar = tqdm.tqdm(total=args.total_frames) + r0 = None + loss = None + + total_frames = args.total_frames + frames_per_batch = args.frames_per_batch + + logger = WandbLogger( + exp_name=args.task, + project="SAC_TorchRL", + name=args.exp_name, + config=args, + entity="RLHive", + mode="offline", + ) + + for i, batch in enumerate( + dataloader( + total_frames, + frames_per_batch, + train_env, + actor, + actor_collection, + device_collection, + ) + ): + if r0 is None: + r0 = batch["reward"].sum(-1).mean().item() + pbar.update(batch.numel()) + + # extend the replay buffer with the new data + batch = batch.view(-1) + current_frames = batch.numel() + collected_frames += current_frames + episodes += args.env_per_collector + replay_buffer.extend(batch.cpu()) + + # optimization steps + if collected_frames >= args.init_random_frames: + ( + total_losses, + actor_losses, + q_losses, + alpha_losses, + alphas, + entropies, + ) = ([], [], [], [], [], []) + for _ in range( + args.env_per_collector * args.frames_per_batch * args.utd_ratio + ): + optim_steps += 1 + # sample from replay buffer + sampled_tensordict = replay_buffer.sample(args.batch_size).clone() + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + loss = actor_loss + q_loss + alpha_loss + optimizer_actor.zero_grad() + loss.backward() + optimizer_actor.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if args.prb: + replay_buffer.update_priority(sampled_tensordict) + + total_losses.append(loss.item()) + actor_losses.append(actor_loss.item()) + q_losses.append(q_loss.item()) + alpha_losses.append(alpha_loss.item()) + alphas.append(loss_td["alpha"].item()) + entropies.append(loss_td["entropy"].item()) + + rewards.append((i, batch["reward"].sum().item() / args.env_per_collector)) + logger.log_scalar("train_reward", rewards[-1][1], step=collected_frames) + logger.log_scalar("optim_steps", optim_steps, step=collected_frames) + logger.log_scalar("episodes", episodes, step=collected_frames) + + if loss is not None: + logger.log_scalar( + "total_loss", np.mean(total_losses), step=collected_frames + ) + logger.log_scalar( + "actor_loss", np.mean(actor_losses), step=collected_frames + ) + logger.log_scalar("q_loss", np.mean(q_losses), step=collected_frames) + logger.log_scalar( + "alpha_loss", np.mean(alpha_losses), step=collected_frames + ) + logger.log_scalar("alpha", np.mean(alphas), step=collected_frames) + logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) + td_record = recorder(None) + # success_percentage = evaluate_success( + # env_success_fn=train_env.evaluate_success, + # td_record=td_record, + # eval_traj=args.eval_traj, + # ) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) + ) + logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + if len(rewards_eval): + pbar.set_description( + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + ) + del batch + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/rlhive/envs.py b/rlhive/envs.py index 54d0a2822..a8c0b2059 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -6,11 +6,13 @@ # Custom env reg for RoboHive usage in TorchRL # Pixel rendering will be queried by torchrl, so we don't include those keys in visual_obs_keys_wt import os +import warnings from pathlib import Path -import mj_envs.envs.env_variants.register_env_variant import mj_envs.envs.multi_task.substeps1 +from mj_envs.envs.env_variants import register_env_variant + visual_obs_keys_wt = mj_envs.envs.multi_task.substeps1.visual_obs_keys_wt @@ -70,7 +72,7 @@ def register_kitchen_envs(): "kitchen_rdoor_open-v3", "kitchen_micro_close-v3", "kitchen_micro_open-v3", - "kitchen_close-v3", + # "kitchen_close-v3", ] visual_obs_keys_wt = { @@ -80,10 +82,17 @@ def register_kitchen_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) @set_directory(CURR_DIR) @@ -106,7 +115,14 @@ def register_franka_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index c922922f4..c9e7d3c48 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -173,7 +173,7 @@ def read_obs(self, observation): def read_info(self, info, tensordict_out): out = {} for key, value in info.items(): - if key in ("obs_dict",): + if key in ("obs_dict", "done", "reward"): continue if isinstance(value, dict): value = make_tensordict(value, batch_size=[]) @@ -181,14 +181,6 @@ def read_info(self, info, tensordict_out): tensordict_out.update(out) return tensordict_out - def _step(self, td): - td = super()._step(td) - return td - - def _reset(self, td=None, **kwargs): - td = super()._reset(td, **kwargs) - return td - def to(self, *args, **kwargs): out = super().to(*args, **kwargs) try: diff --git a/rlhive/sim_algos/helpers/rrl_transform.py b/rlhive/sim_algos/helpers/rrl_transform.py index b9b501485..b9b46f872 100644 --- a/rlhive/sim_algos/helpers/rrl_transform.py +++ b/rlhive/sim_algos/helpers/rrl_transform.py @@ -104,8 +104,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec - #@staticmethod - #def _load_weights(model_name, r3m_instance, dir_prefix): + # @staticmethod + # def _load_weights(model_name, r3m_instance, dir_prefix): # if model_name not in ("r3m_50", "r3m_34", "r3m_18"): # raise ValueError( # "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'" @@ -123,7 +123,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec # state_dict = td_flatten.to_dict() # r3m_instance.convnet.load_state_dict(state_dict) - #def load_weights(self, dir_prefix=None): + # def load_weights(self, dir_prefix=None): # self._load_weights(self.model_name, self, dir_prefix) @@ -300,7 +300,7 @@ def _init(self): for transform in transforms: self.append(transform) - #if self.download: + # if self.download: # self[-1].load_weights(dir_prefix=self.download_path) if self._device is not None: diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py index 9800bd3cb..df9fc8783 100644 --- a/scripts/redq/redq.py +++ b/scripts/redq/redq.py @@ -10,8 +10,25 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore +from rlhive.rl_envs import RoboHiveEnv from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) + +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) from torchrl.envs.utils import set_exploration_mode from torchrl.modules import OrnsteinUhlenbeckProcessWrapper from torchrl.record import VideoRecorder @@ -25,7 +42,7 @@ initialize_observation_norm_transforms, parallel_env_constructor, retrieve_observation_norms_state_dict, - #transformed_env_constructor, + # transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss @@ -33,30 +50,23 @@ from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig from torchrl.trainers.loggers.utils import generate_exp_name, get_logger -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose -from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform -from torchrl.envs import ( - CatTensors, - DoubleToFloat, - EnvCreator, - ObservationNorm, - ParallelEnv, -) -from rlhive.rl_envs import RoboHiveEnv + def make_env( - task, - reward_scaling, - device, - obs_norm_state_dict=None, - action_dim_gsde=None, - state_dim_gsde=None, - ): + task, + reward_scaling, + device, + obs_norm_state_dict=None, + action_dim_gsde=None, + state_dim_gsde=None, +): base_env = RoboHiveEnv(task, device=device) env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) if not obs_norm_state_dict is None: - obs_norm = ObservationNorm(**obs_norm_state_dict, in_keys=["observation_vector"]) + obs_norm = ObservationNorm( + **obs_norm_state_dict, in_keys=["observation_vector"] + ) env.append_transform(obs_norm) if not action_dim_gsde is None: @@ -75,13 +85,17 @@ def make_transformed_env( Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=["r3m_vec"]), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) selected_keys = ["r3m_vec", "observation"] out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -93,6 +107,7 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env + config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( @@ -190,13 +205,13 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = make_env( ## Pass EnvBase instead of the create_env_fn - task=cfg.env_name, - reward_scaling=cfg.reward_scaling, - device=device, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde - ) + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) collector = make_collector_offpolicy( make_env=create_env_fn, @@ -210,21 +225,21 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer = make_replay_buffer(device, cfg) - #recorder = transformed_env_constructor( + # recorder = transformed_env_constructor( # cfg, # video_tag=video_tag, # norm_obs_only=True, # obs_norm_state_dict=obs_norm_state_dict, # logger=logger, # use_env_creator=False, - #)() + # )() recorder = make_env( task=cfg.env_name, reward_scaling=cfg.reward_scaling, device=device, obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde + state_dim_gsde=state_dim_gsde, ) # remove video recorder from recorder to have matching state_dict keys diff --git a/scripts/sac_mujoco/config/group/group1.yaml b/scripts/sac_mujoco/config/group/group1.yaml index 6730093ca..886d4fc95 100644 --- a/scripts/sac_mujoco/config/group/group1.yaml +++ b/scripts/sac_mujoco/config/group/group1.yaml @@ -1,4 +1,4 @@ # @package _group_ grp1a: 11 grp1b: aaa - gra1c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file + gra1c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/group/group2.yaml b/scripts/sac_mujoco/config/group/group2.yaml index b2f47d6dd..8b95ac771 100644 --- a/scripts/sac_mujoco/config/group/group2.yaml +++ b/scripts/sac_mujoco/config/group/group2.yaml @@ -1,4 +1,4 @@ # @package _group_ grp2a: 22 grp2b: bbb - gra2c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file + gra2c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/hydra/output/local.yaml b/scripts/sac_mujoco/config/hydra/output/local.yaml index aee5a513f..d3c95076e 100644 --- a/scripts/sac_mujoco/config/hydra/output/local.yaml +++ b/scripts/sac_mujoco/config/hydra/output/local.yaml @@ -5,4 +5,4 @@ hydra: subdir: ${hydra.job.num}_${hydra.job.override_dirname} sweep: dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} - subdir: ${hydra.job.num}_${hydra.job.override_dirname} \ No newline at end of file + subdir: ${hydra.job.num}_${hydra.job.override_dirname} diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index 56bb61f99..cc1a9167a 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -1,21 +1,24 @@ # Make all the necessary imports for training -import os -import gc import argparse -import yaml +import gc +import os from typing import Optional +import hydra + import numpy as np import torch import torch.cuda import tqdm - -import hydra -from omegaconf import DictConfig, OmegaConf, open_dict import wandb -#from torchrl.objectives import SACLoss +import yaml +from omegaconf import DictConfig, OmegaConf, open_dict +from rlhive.rl_envs import RoboHiveEnv +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform + +# from torchrl.objectives import SACLoss from sac_loss import SACLoss from torch import nn, optim @@ -31,10 +34,15 @@ ObservationNorm, ParallelEnv, ) -from torchrl.envs import EnvCreator from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) +from torchrl.envs import ParallelEnv, R3MTransform, SelectTransform, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import MLP, NormalParamWrapper, ProbabilisticActor, SafeModule from torchrl.modules.distributions import TanhNormal @@ -44,21 +52,15 @@ from torchrl.objectives import SoftUpdate from torchrl.trainers import Recorder -from rlhive.rl_envs import RoboHiveEnv -from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform -from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +os.environ["WANDB_MODE"] = "offline" ## offline sync. TODO: Remove this behavior -os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior -def make_env( - task, - visual_transform, - reward_scaling, - device - ): - assert visual_transform in ('rrl', 'r3m') +def make_env(task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") base_env = RoboHiveEnv(task, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) print(env) return env @@ -67,28 +69,37 @@ def make_env( def make_transformed_env( env, reward_scaling=5.0, - visual_transform='r3m', + visual_transform="r3m", stats=None, ): """ Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - if visual_transform == 'rrl': + if visual_transform == "rrl": vec_keys = ["rrl_vec"] selected_keys = ["observation", "rrl_vec"] - env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 - elif visual_transform == 'r3m': + env.append_transform( + Compose( + RRLTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "r3m": vec_keys = ["r3m_vec"] selected_keys = ["observation", "r3m_vec"] - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 else: raise NotImplementedError env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -100,35 +111,36 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env + def make_recorder( - task: str, - frame_skip: int, - record_interval: int, - actor_model_explore: object, - eval_traj: int, - env_configs: dict, - ): + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, +): test_env = make_env(task=task, **env_configs) recorder_obj = Recorder( - record_frames=eval_traj*test_env.horizon, + record_frames=eval_traj * test_env.horizon, frame_skip=frame_skip, policy_exploration=actor_model_explore, recorder=test_env, exploration_mode="mean", record_interval=record_interval, log_keys=["reward", "solved"], - out_keys={"reward": "r_evaluation", "solved" : "success"} + out_keys={"reward": "r_evaluation", "solved": "success"}, ) return recorder_obj def make_replay_buffer( - prb: bool, - buffer_size: int, - buffer_scratch_dir: str, - device: torch.device, - make_replay_buffer: int = 3 - ): + prb: bool, + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, @@ -154,11 +166,7 @@ def make_replay_buffer( return replay_buffer -def evaluate_success( - env_success_fn, - td_record: dict, - eval_traj: int - ): +def evaluate_success(env_success_fn, td_record: dict, eval_traj: int): td_record["success"] = td_record["success"].reshape((eval_traj, -1)) paths = [] for traj, solved_traj in zip(range(eval_traj), td_record["success"]): @@ -168,7 +176,6 @@ def evaluate_success( return success_percentage - @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): device = ( @@ -183,10 +190,10 @@ def main(args: DictConfig): # Create Environment env_configs = { - "reward_scaling": args.reward_scaling, - "visual_transform": args.visual_transform, - "device": args.device, - } + "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, + "device": args.device, + } train_env = make_env(task=args.task, **env_configs) # Create Agent @@ -298,22 +305,21 @@ def main(args: DictConfig): # Make Replay Buffer replay_buffer = make_replay_buffer( - prb=args.prb, - buffer_size=args.buffer_size, - buffer_scratch_dir=args.buffer_scratch_dir, - device=device, - ) - + prb=args.prb, + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) # Trajectory recorder for evaluation recorder = make_recorder( - task=args.task, - frame_skip=args.frame_skip, - record_interval=args.record_interval, - actor_model_explore=actor_model_explore, - eval_traj=args.eval_traj, - env_configs=env_configs, - ) + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + ) # Optimizers params = list(loss_module.parameters()) + list([loss_module.log_alpha]) @@ -417,10 +423,10 @@ def main(args: DictConfig): ) td_record = recorder(None) success_percentage = evaluate_success( - env_success_fn=train_env.evaluate_success, - td_record=td_record, - eval_traj=args.eval_traj - ) + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj, + ) if td_record is not None: rewards_eval.append( ( @@ -440,5 +446,6 @@ def main(args: DictConfig): collector.shutdown() + if __name__ == "__main__": main() diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py index fec7a06d1..61e120811 100644 --- a/scripts/sac_mujoco/test.py +++ b/scripts/sac_mujoco/test.py @@ -1,27 +1,32 @@ import torch from rlhive.rl_envs import RoboHiveEnv -from torchrl.envs.utils import set_exploration_mode -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose -from torchrl.envs import TransformedEnv, R3MTransform, SelectTransform +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) from torchrl.envs import ( CatTensors, DoubleToFloat, EnvCreator, ObservationNorm, + R3MTransform, + SelectTransform, + TransformedEnv, ) -from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +from torchrl.envs.utils import set_exploration_mode + -def make_env( - task, - visual_transform, - reward_scaling, - device - ): - assert visual_transform in ('rrl', 'r3m') +def make_env(task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") base_env = RoboHiveEnv(task, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) print(env) - #exit() + # exit() return env @@ -29,28 +34,37 @@ def make_env( def make_transformed_env( env, reward_scaling=5.0, - visual_transform='r3m', + visual_transform="r3m", stats=None, ): """ Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - if visual_transform == 'rrl': + if visual_transform == "rrl": vec_keys = ["rrl_vec"] selected_keys = ["observation", "rrl_vec"] - env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 - elif visual_transform == 'r3m': + env.append_transform( + Compose( + RRLTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "r3m": vec_keys = ["r3m_vec"] selected_keys = ["observation", "r3m_vec"] - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 else: raise NotImplementedError env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -62,7 +76,13 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env -env = make_env(task="visual_franka_slide_random-v3", reward_scaling=5.0, device=torch.device('cuda:0'), visual_transform='rrl') + +env = make_env( + task="visual_franka_slide_random-v3", + reward_scaling=5.0, + device=torch.device("cuda:0"), + visual_transform="rrl", +) with torch.no_grad(), set_exploration_mode("random"): td = env.reset() td = env.rand_step() From 721394cc29dd2352d3e01b9bc2cc2a96faa34d70 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:32:21 +0000 Subject: [PATCH 17/58] amend --- examples/sac.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 8c287b981..14c733712 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -1,8 +1,3 @@ -# TODO -# Simplify -# logger -# check SAC loss vs torchrl's -# Make all the necessary imports for training import os os.environ["sim_backend"] = "MUJOCO" From 47dbc8a0fa03dd4e0fb4f85d124322125d9c4852 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:41:24 +0000 Subject: [PATCH 18/58] amend --- examples/install/install_rlhive.sh | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100755 examples/install/install_rlhive.sh diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh new file mode 100755 index 000000000..9615903d9 --- /dev/null +++ b/examples/install/install_rlhive.sh @@ -0,0 +1,20 @@ +#!/bin/zsh + +here=$(pwd) +module_path=$HOME/modules/ + +module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 + +conda create -n rlhive -y python=3.8 +conda activate rlhive + +python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + +cd $module_path +git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.4dev --recursive https://github.com/vikashplus/mj_envs.git mj_envs +python3 -mpip install . # one can also install it locally with the -e flag +cd $here + +python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) +python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) +python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) From e120d7b0b19213226d4f40fad7be0732dc6efde7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:01:26 +0000 Subject: [PATCH 19/58] amend --- examples/install/install_rlhive.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 9615903d9..174852af8 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -1,17 +1,21 @@ #!/bin/zsh +set -e + here=$(pwd) module_path=$HOME/modules/ -module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 +conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 +mkdir $module_path cd $module_path git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.4dev --recursive https://github.com/vikashplus/mj_envs.git mj_envs +cd mj_envs python3 -mpip install . # one can also install it locally with the -e flag cd $here From 582020c00c40fe97a29d8f132aecac772ce72ca4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:12:53 +0000 Subject: [PATCH 20/58] amend --- examples/install/install_rlhive.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 174852af8..feec54cf7 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -2,12 +2,17 @@ set -e +conda_path=$(conda info | grep -i 'base environment' | awk '{ print $4 }') +source $conda_path/etc/profile.d/conda.sh + here=$(pwd) module_path=$HOME/modules/ conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 + + conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 From ad060d87e4a8bd8c90b6d63fd27d9e91bf86208e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:37:11 +0000 Subject: [PATCH 21/58] amend --- examples/install/install_rlhive.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index feec54cf7..e0da7d2e7 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -8,6 +8,9 @@ source $conda_path/etc/profile.d/conda.sh here=$(pwd) module_path=$HOME/modules/ +module purge +module load cuda/11.6 + conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 From e1225d591de177fbb41b398f7984e599d4289d45 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:49:16 +0000 Subject: [PATCH 22/58] amend --- examples/config/sac.yaml | 2 +- examples/install/install_rlhive.sh | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index aaaf96bf9..c90aecda2 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -6,7 +6,7 @@ default: exp_name: ${task}_sac_${visual_transform} visual_transform: r3m record_interval: 1 -device: "cpu" +device: "cuda:0" # Environment task: visual_franka_slide_random-v3 diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index e0da7d2e7..f602ad73c 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -1,5 +1,7 @@ #!/bin/zsh +# Instructions to install a fresh anaconda environment with RLHive + set -e conda_path=$(conda info | grep -i 'base environment' | awk '{ print $4 }') @@ -30,3 +32,5 @@ cd $here python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) + +pip install wandb tqdm hydra-core \ No newline at end of file From 79d1eae688d36f06f6158cc7cd5768fd5b9dca49 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 17:02:58 +0000 Subject: [PATCH 23/58] amend --- examples/config/sac.yaml | 1 + examples/install/install_rlhive.sh | 11 ++++++++--- examples/sac.py | 6 +++--- setup.py | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index c90aecda2..dacd35ebe 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -7,6 +7,7 @@ exp_name: ${task}_sac_${visual_transform} visual_transform: r3m record_interval: 1 device: "cuda:0" +device_collection: "cuda:1" # Environment task: visual_franka_slide_random-v3 diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index f602ad73c..6b2b0e70a 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -17,7 +17,6 @@ conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 - conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 @@ -31,6 +30,12 @@ cd $here python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) -python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) -pip install wandb tqdm hydra-core \ No newline at end of file +# this +# python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) +# or this +cd ../.. +pip install -e . +cd $here + +pip install wandb tqdm hydra-core diff --git a/examples/sac.py b/examples/sac.py index 14c733712..8cdd20904 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -212,8 +212,8 @@ def dataloader( @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): # customize device at will - device = "cpu" - device_collection = "cpu" + device = args.device + device_collection = args.device_collection torch.manual_seed(args.seed) np.random.seed(args.seed) @@ -223,7 +223,7 @@ def main(args: DictConfig): "visual_transform": args.visual_transform, "device": args.device, } - train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs) + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to(device_collection) # Create Agent # Define Actor Network diff --git a/setup.py b/setup.py index 8186a081c..86b104881 100644 --- a/setup.py +++ b/setup.py @@ -160,7 +160,7 @@ def _main(): # f"torchrl @ file://{rl_path}", "torchrl", "gym==0.13", - "mj_envs", + # "mj_envs", # f"mj_envs @ file://{mj_env_path}", "numpy", "packaging", From 5d87afc65053816878725fbc380b273566c30153 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:23:07 +0000 Subject: [PATCH 24/58] amend --- examples/config/sac.yaml | 3 ++- examples/sac.py | 29 +++++++++++++++-------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index dacd35ebe..9daa16eb3 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -16,6 +16,7 @@ reward_scaling: 5.0 init_env_steps: 1000 seed: 42 eval_traj: 25 +eval_interval: 1000 num_envs: 8 # Collector @@ -23,7 +24,7 @@ env_per_collector: 1 max_frames_per_traj: -1 total_frames: 1000000 init_random_frames: 25000 -frames_per_batch: 10 +frames_per_batch: 1000 # Replay Buffer prb: 0 diff --git a/examples/sac.py b/examples/sac.py index 8cdd20904..d7b0966e3 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -439,21 +439,22 @@ def main(args: DictConfig): ) logger.log_scalar("alpha", np.mean(alphas), step=collected_frames) logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) - td_record = recorder(None) - # success_percentage = evaluate_success( - # env_success_fn=train_env.evaluate_success, - # td_record=td_record, - # eval_traj=args.eval_traj, - # ) - if td_record is not None: - rewards_eval.append( - ( - i, - td_record["total_r_evaluation"] - / 1, # divide by number of eval worker + if i % args.eval_interval == 0: + td_record = recorder(None) + # success_percentage = evaluate_success( + # env_success_fn=train_env.evaluate_success, + # td_record=td_record, + # eval_traj=args.eval_traj, + # ) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) ) - ) - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" From 2af22a481d75124e36e0ce2b157c0935347d4edc Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:25:58 +0000 Subject: [PATCH 25/58] amend --- examples/sac.py | 3 +- examples/sac_loss.py | 311 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 examples/sac_loss.py diff --git a/examples/sac.py b/examples/sac.py index d7b0966e3..1a8fe2099 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -38,7 +38,8 @@ from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator -from torchrl.objectives import SACLoss, SoftUpdate +from sac_loss import SACLoss +from torchrl.objectives import SoftUpdate from torchrl.record.loggers import WandbLogger from torchrl.trainers import Recorder diff --git a/examples/sac_loss.py b/examples/sac_loss.py new file mode 100644 index 000000000..cebe7f2e9 --- /dev/null +++ b/examples/sac_loss.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss From 65bd6ef7ebf0ab458d39adb1ac5d5a4fbb5fbba9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:52:38 +0000 Subject: [PATCH 26/58] amend --- examples/sac.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 1a8fe2099..06d822d5b 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -1,5 +1,7 @@ import os +from torchrl.record import VideoRecorder + os.environ["sim_backend"] = "MUJOCO" import gc @@ -15,6 +17,8 @@ from omegaconf import DictConfig from rlhive.rl_envs import RoboHiveEnv +from sac_loss import SACLoss + # from torchrl.objectives import SACLoss from tensordict import TensorDict @@ -37,8 +41,6 @@ from torchrl.modules.distributions import TanhNormal from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator - -from sac_loss import SACLoss from torchrl.objectives import SoftUpdate from torchrl.record.loggers import WandbLogger from torchrl.trainers import Recorder @@ -129,8 +131,12 @@ def make_recorder( actor_model_explore: object, eval_traj: int, env_configs: dict, + wandb_logger: WandbLogger, ): test_env = make_env(num_envs=1, task=task, **env_configs) + test_env.insert_transform( + 0, VideoRecorder(wandb_logger, "test", in_keys=["pixels"]) + ) recorder_obj = Recorder( record_frames=eval_traj * test_env.horizon, frame_skip=frame_skip, @@ -224,7 +230,9 @@ def main(args: DictConfig): "visual_transform": args.visual_transform, "device": args.device, } - train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to(device_collection) + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to( + device_collection + ) # Create Agent # Define Actor Network @@ -448,6 +456,7 @@ def main(args: DictConfig): # eval_traj=args.eval_traj, # ) if td_record is not None: + print("recorded", td_record) rewards_eval.append( ( i, @@ -455,7 +464,13 @@ def main(args: DictConfig): / 1, # divide by number of eval worker ) ) - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + logger.log_scalar( + "test_reward", rewards_eval[-1][1], step=collected_frames + ) + logger.log_scalar( + "success", td_record["success"].any(), step=collected_frames + ) + if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" From bbb1d72f0792767527329fea7d9ea7ed3881f662 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:53:23 +0000 Subject: [PATCH 27/58] amend --- examples/sac.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 06d822d5b..1647956f4 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -467,13 +467,14 @@ def main(args: DictConfig): logger.log_scalar( "test_reward", rewards_eval[-1][1], step=collected_frames ) + solved = td_record["success"].any() logger.log_scalar( - "success", td_record["success"].any(), step=collected_frames + "success", solved, step=collected_frames ) if len(rewards_eval): pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}, solved: {solved}" ) del batch gc.collect() From ab22dec71f7d809a3cd8edf9bc6b398876216d77 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:54:32 +0000 Subject: [PATCH 28/58] amend --- examples/sac.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 1647956f4..3776d813e 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -328,16 +328,6 @@ def main(args: DictConfig): device=device, ) - # Trajectory recorder for evaluation - recorder = make_recorder( - task=args.task, - frame_skip=args.frame_skip, - record_interval=args.record_interval, - actor_model_explore=actor_model_explore, - eval_traj=args.eval_traj, - env_configs=env_configs, - ) - # Optimizers params = list(loss_module.parameters()) + list([loss_module.log_alpha]) optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) @@ -367,6 +357,17 @@ def main(args: DictConfig): mode="offline", ) + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + wandb_logger=logger, + ) + for i, batch in enumerate( dataloader( total_frames, From 31689087bbcc772328e0038d7f948e483c1abdac Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:55:55 +0000 Subject: [PATCH 29/58] amend --- examples/install/install_rlhive.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 6b2b0e70a..ba5f78931 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -38,4 +38,4 @@ cd ../.. pip install -e . cd $here -pip install wandb tqdm hydra-core +pip install wandb tqdm hydra-core moviepy From c85a24d493e4c73a59450cd4bf4284201ade2b87 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 19:00:34 +0000 Subject: [PATCH 30/58] amend --- examples/sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac.py b/examples/sac.py index 3776d813e..3447936eb 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -468,7 +468,7 @@ def main(args: DictConfig): logger.log_scalar( "test_reward", rewards_eval[-1][1], step=collected_frames ) - solved = td_record["success"].any() + solved = float(td_record["success"].any()) logger.log_scalar( "success", solved, step=collected_frames ) From bbcd73d65f286f8cfff9e2624cf652630538f6f0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 20:55:46 +0000 Subject: [PATCH 31/58] amend --- examples/sac.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 3447936eb..b41ec1ae4 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -329,8 +329,8 @@ def main(args: DictConfig): ) # Optimizers - params = list(loss_module.parameters()) + list([loss_module.log_alpha]) - optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + params = list(loss_module.parameters()) + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) rewards = [] rewards_eval = [] @@ -407,15 +407,18 @@ def main(args: DictConfig): sampled_tensordict = replay_buffer.sample(args.batch_size).clone() loss_td = loss_module(sampled_tensordict) + print(f'value: {loss_td["state_action_value_actor"].mean():4.4f}') + print(f'log_prob: {loss_td["action_log_prob_actor"].mean():4.4f}') + print(f'next.state_value: {loss_td["state_value"].mean():4.4f}') actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] loss = actor_loss + q_loss + alpha_loss - optimizer_actor.zero_grad() + optimizer.zero_grad() loss.backward() - optimizer_actor.step() + optimizer.step() # update qnet_target params target_net_updater.step() From dc68e2e09ebfa2b334cbaa9360a3280b5d5a51cc Mon Sep 17 00:00:00 2001 From: rutavms Date: Sat, 28 Jan 2023 02:38:34 -0600 Subject: [PATCH 32/58] rl_env updated for state based experiments --- rlhive/rl_envs.py | 2 +- scripts/sac_mujoco/config/sac.yaml | 1 + scripts/sac_mujoco/sac.py | 32 ++++++++++++++++++------------ 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index 9092584f9..08dec587f 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -164,7 +164,7 @@ def read_obs(self, observation): pixel_list.append(pix) elif key in self._env.obs_keys: obsvec.append( - observations[key] + observations[key].flatten() if observations[key].ndim == 0 else observations[key] ) # ravel helps with images if obsvec: obsvec = np.concatenate(obsvec, 0) diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml index bb914f9ec..39a201706 100644 --- a/scripts/sac_mujoco/config/sac.yaml +++ b/scripts/sac_mujoco/config/sac.yaml @@ -1,6 +1,7 @@ default: - override hydra/output: local - override hydra/launcher: local +from_pixels: True # Logger exp_name: ${task}_sac_${visual_transform} diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index 885c6d5ff..a9daa23c8 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -55,18 +55,19 @@ def make_env( task, visual_transform, reward_scaling, + from_pixels, device ): assert visual_transform in ('rrl', 'r3m') - base_env = RoboHiveEnv(task, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) - print(env) + base_env = RoboHiveEnv(task, from_pixels=from_pixels, device=device) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform, from_pixels=from_pixels) return env def make_transformed_env( env, + from_pixels, reward_scaling=5.0, visual_transform='r3m', stats=None, @@ -74,17 +75,21 @@ def make_transformed_env( """ Apply transforms to the env (such as reward scaling and state normalization) """ - env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - if visual_transform == 'rrl': - vec_keys = ["r3m_vec"] - selected_keys = ["observation", "r3m_vec"] - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download="IMAGENET1K_V1"), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 - elif visual_transform == 'r3m': - vec_keys = ["r3m_vec"] - selected_keys = ["observation", "r3m_vec"] - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + if from_pixels: + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + if visual_transform == 'rrl': + vec_keys = ["rrl_vec"] + selected_keys = ["observation", "rrl_vec"] + env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == 'r3m': + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError else: - raise NotImplementedError + env = TransformedEnv(env, SelectTransform("solved", "observation")) + selected_keys = ["observation"] env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) @@ -187,6 +192,7 @@ def main(args: DictConfig): "reward_scaling": args.reward_scaling, "visual_transform": args.visual_transform, "device": args.device, + "from_pixels": args.from_pixels, } train_env = make_env(task=args.task, **env_configs) From faa46de5bde036cfe7a66dd44ab51efc49eee801 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 28 Jan 2023 20:57:46 +0000 Subject: [PATCH 33/58] amend --- examples/config/sac.yaml | 2 + examples/sac.py | 12 +-- examples/sac_loss.py | 187 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 183 insertions(+), 18 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index 9daa16eb3..83f638245 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -8,6 +8,8 @@ visual_transform: r3m record_interval: 1 device: "cuda:0" device_collection: "cuda:1" +wandb_entity: "RLHive" +wandb_mode: "offline" # Environment task: visual_franka_slide_random-v3 diff --git a/examples/sac.py b/examples/sac.py index b41ec1ae4..a256e8635 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -197,9 +197,9 @@ def dataloader( ).unflatten_keys(".") _prev = None - params_collection.update_(params) collected_frames = 0 while collected_frames < total_frames: + params_collection.update_(params) batch = TensorDict( {}, batch_size=[fpb, *train_env.batch_size], device=device_collection ) @@ -250,7 +250,7 @@ def main(args: DictConfig): dist_kwargs = { "min": action_spec.space.minimum, "max": action_spec.space.maximum, - "tanh_loc": False, + "tanh_loc": True, } actor_net = NormalParamWrapper( actor_net, @@ -353,8 +353,8 @@ def main(args: DictConfig): project="SAC_TorchRL", name=args.exp_name, config=args, - entity="RLHive", - mode="offline", + entity=args.wandb_entity, + mode=args.wandb_mode, ) # Trajectory recorder for evaluation @@ -386,7 +386,7 @@ def main(args: DictConfig): batch = batch.view(-1) current_frames = batch.numel() collected_frames += current_frames - episodes += args.env_per_collector + episodes += batch["done"].sum() replay_buffer.extend(batch.cpu()) # optimization steps @@ -434,7 +434,7 @@ def main(args: DictConfig): alphas.append(loss_td["alpha"].item()) entropies.append(loss_td["entropy"].item()) - rewards.append((i, batch["reward"].sum().item() / args.env_per_collector)) + rewards.append((i, batch["reward"].mean().item())) logger.log_scalar("train_reward", rewards[-1][1], step=collected_frames) logger.log_scalar("optim_steps", optim_steps, step=collected_frames) logger.log_scalar("episodes", episodes, step=collected_frames) diff --git a/examples/sac_loss.py b/examples/sac_loss.py index cebe7f2e9..07d5ab1c4 100644 --- a/examples/sac_loss.py +++ b/examples/sac_loss.py @@ -58,6 +58,7 @@ class SACLoss(LossModule): """ delay_actor: bool = False + _explicit: bool = True def __init__( self, @@ -148,6 +149,26 @@ def alpha(self): return alpha def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._explicit: + # slow but explicit version + return self._forward_explicit(tensordict) + else: + return self._forward_vectorized(tensordict) + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + def _forward_vectorized(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_select = tensordict.select( "reward", "done", "next", *obs_keys, "action" @@ -187,7 +208,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_actor_dist = self.actor_network.build_dist_from_params( td_params ) - tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor[sample_key] = self._rsample(tensordict_actor_dist) tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( tensordict_actor[sample_key] ) @@ -246,6 +267,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: next_action_log_prob_qvalue, ) = sample_log_prob.unbind(0) + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized loss_actor = -( state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor ).mean() @@ -297,15 +319,156 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return td_out - def _loss_alpha(self, log_pi: Tensor) -> Tensor: - if torch.is_grad_enabled() and not log_pi.requires_grad: - raise RuntimeError( - "expected log_pi to require gradient for the alpha loss)" + def _forward_explicit(self, tensordict: TensorDictBase) -> TensorDictBase: + loss_actor, sample_log_prob = self._loss_actor_explicit(tensordict.clone(False)) + loss_qval, td_error = self._loss_qval_explicit(tensordict.clone(False)) + tensordict.set("td_error", td_error.detach().max(0)[0]) + loss_alpha = self._loss_alpha(sample_log_prob) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + # "state_action_value_actor": state_action_value_actor.mean().detach(), + # "action_log_prob_actor": action_log_prob_actor.mean().detach(), + # "next.state_value": next_state_value.mean().detach(), + # "target_value": target_value.mean().detach(), + }, + [], + ) + return td_out + + def _rsample(self, dist, ): + # separated only for the purpose of making the sampling + # deterministic to compare methods + return dist.rsample() + + + def _sample_reparam(self, tensordict, params): + """Given a policy param batch and input data in a tensordict, writes a reparam sample and log-prob key.""" + with set_exploration_mode("random"): + if self.gSDE: + raise NotImplementedError + # vmap doesn't support sampling, so we take it out from the vmap + td_params = self.actor_network.get_dist_params(tensordict, params,) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict[sample_key] = self._rsample(tensordict_actor_dist) + tensordict["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict[sample_key] ) - if self.target_entropy is not None: - # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) - else: - # placeholder - alpha_loss = torch.zeros_like(log_pi) - return alpha_loss + return tensordict + + def _loss_actor_explicit(self, tensordict): + tensordict_actor = tensordict.clone(False) + actor_params = self.actor_network_params + tensordict_actor = self._sample_reparam(tensordict_actor, actor_params) + action_log_prob_actor = tensordict_actor["sample_log_prob"] + + tensordict_qval = ( + tensordict_actor + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor.batch_size) + ) # for actor loss + qvalue_params = self.qvalue_network_params.detach() + tensordict_qval = vmap(self.qvalue_network)(tensordict_qval, qvalue_params,) + state_action_value_actor = tensordict_qval.get("state_action_value").squeeze(-1) + state_action_value_actor = state_action_value_actor.min(0)[0] + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = (self.alpha * action_log_prob_actor - state_action_value_actor).mean() + + return loss_actor, action_log_prob_actor + + def _loss_qval_explicit(self, tensordict): + next_tensordict = step_mdp(tensordict) + next_tensordict = self._sample_reparam(next_tensordict, self.target_actor_network_params) + next_action_log_prob_qvalue = next_tensordict["sample_log_prob"] + next_state_action_value_qvalue = vmap(self.qvalue_network, (None, 0))( + next_tensordict, + self.target_qvalue_network_params, + )["state_action_value"].squeeze(-1) + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + pred_val = vmap(self.qvalue_network, (None, 0))( + tensordict, + self.qvalue_network_params, + )["state_action_value"].squeeze(-1) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + + # 1/2 * E[Q(s,a) - (r + gamma * (Q(s,a)-alpha log pi(s, a))) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + td_error = (pred_val - target_value).pow(2) + return loss_qval, td_error + +if __name__ == "__main__": + # Tests the vectorized version of SAC-v2 against plain implementation + from torchrl.modules import ProbabilisticActor, ValueOperator + from torchrl.data import BoundedTensorSpec + from torch import nn + from tensordict.nn import TensorDictModule + from torchrl.modules.distributions import TanhNormal + + torch.manual_seed(0) + + action_spec = BoundedTensorSpec(-1, 1, shape=(3,)) + class Splitter(nn.Linear): + def forward(self, x): + loc, scale = super().forward(x).chunk(2, -1) + return loc, scale.exp() + actor_module = TensorDictModule(Splitter(6, 6), in_keys=["obs"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=TanhNormal, + default_interaction_mode="random", + return_log_prob=False, + ) + class QVal(nn.Linear): + def forward(self, s: Tensor, a: Tensor) -> Tensor: + return super().forward(torch.cat([s, a], -1)) + + qvalue = ValueOperator(QVal(9, 1), in_keys=["obs", "action"]) + _rsample_old = SACLoss._rsample + def _rsample_new(self, dist): + return torch.ones_like(_rsample_old(self, dist)) + SACLoss._rsample = _rsample_new + loss = SACLoss(actor, qvalue) + + for batch in ((), (2, 3)): + td_input = TensorDict({"obs": torch.rand(*batch, 6), "action": torch.rand(*batch, 3).clamp(-1, 1), "next": {"obs": torch.rand(*batch, 6)}, "reward": torch.rand(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch) + loss._explicit = True + loss0 = loss(td_input) + loss._explicit = False + loss1 = loss(td_input) + print("a", loss0["loss_actor"]-loss1["loss_actor"]) + print("q", loss0["loss_qvalue"]-loss1["loss_qvalue"]) From e8959128ab23cf92a00d2a8e30a17954ec24b1a2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 17:00:31 +0000 Subject: [PATCH 34/58] init --- rlhive/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rlhive/__init__.py b/rlhive/__init__.py index d79196148..0dfd05c78 100644 --- a/rlhive/__init__.py +++ b/rlhive/__init__.py @@ -3,8 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .envs import register_franka_envs, register_kitchen_envs +# from .envs import register_franka_envs, register_kitchen_envs +# +# register_franka_envs() +# register_kitchen_envs() -register_franka_envs() -register_kitchen_envs() from .rl_envs import RoboHiveEnv From 3da5e5c245fd2567ebeacd9645f1ea22b7980291 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 18:17:20 +0000 Subject: [PATCH 35/58] amend --- rlhive/__init__.py | 8 +- rlhive/envs.py | 435 +++++---------------------------------------- rlhive/rl_envs.py | 17 +- 3 files changed, 52 insertions(+), 408 deletions(-) diff --git a/rlhive/__init__.py b/rlhive/__init__.py index 0dfd05c78..88aff9f05 100644 --- a/rlhive/__init__.py +++ b/rlhive/__init__.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# from .envs import register_franka_envs, register_kitchen_envs -# -# register_franka_envs() -# register_kitchen_envs() +from .envs import register_franka_envs, register_kitchen_envs + +register_franka_envs() +register_kitchen_envs() from .rl_envs import RoboHiveEnv diff --git a/rlhive/envs.py b/rlhive/envs.py index 00fac8324..54d0a2822 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -8,9 +8,8 @@ import os from pathlib import Path +import mj_envs.envs.env_variants.register_env_variant import mj_envs.envs.multi_task.substeps1 -from gym.envs.registration import register -from mj_envs.envs.multi_task.common.franka_kitchen_v1 import KitchenFrankaFixed visual_obs_keys_wt = mj_envs.envs.multi_task.substeps1.visual_obs_keys_wt @@ -52,11 +51,27 @@ def new_fun(*args, **kwargs): def register_kitchen_envs(): print("RLHive:> Registering Kitchen Envs") - # ======================================================== - - # V3 environments - # In this version of the environment, the observations consist of the - # distance between end effector and all relevent objects in the scene + env_list = [ + "kitchen_knob1_off-v3", + "kitchen_knob1_on-v3", + "kitchen_knob2_off-v3", + "kitchen_knob2_on-v3", + "kitchen_knob3_off-v3", + "kitchen_knob3_on-v3", + "kitchen_knob4_off-v3", + "kitchen_knob4_on-v3", + "kitchen_light_off-v3", + "kitchen_light_on-v3", + "kitchen_sdoor_close-v3", + "kitchen_sdoor_open-v3", + "kitchen_ldoor_close-v3", + "kitchen_ldoor_open-v3", + "kitchen_rdoor_close-v3", + "kitchen_rdoor_open-v3", + "kitchen_micro_close-v3", + "kitchen_micro_open-v3", + "kitchen_close-v3", + ] visual_obs_keys_wt = { "robot_jnt": 1.0, @@ -64,289 +79,25 @@ def register_kitchen_envs(): "rgb:right_cam:224x224:2d": 1.0, "rgb:left_cam:224x224:2d": 1.0, } - obs_keys_wt = visual_obs_keys_wt - for site in KitchenFrankaFixed.OBJ_INTERACTION_SITES: - obs_keys_wt[site + "_err"] = 1.0 - - # Kitchen - register( - id="visual_kitchen_close-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_goal": {}, - "obj_init": { - "knob1_joint": -1.57, - "knob2_joint": -1.57, - "knob3_joint": -1.57, - "knob4_joint": -1.57, - "lightswitch_joint": -0.7, - "slidedoor_joint": 0.44, - "micro0joint": -1.25, - "rightdoorhinge": 1.57, - "leftdoorhinge": -1.25, - }, - "obs_keys_wt": obs_keys_wt, - }, - ) - - # Microwave door - register( - id="visual_kitchen_micro_open-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"micro0joint": 0}, - "obj_goal": {"micro0joint": -1.25}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "microhandle_site", - }, - ) - register( - id="visual_kitchen_micro_close-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"micro0joint": -1.25}, - "obj_goal": {"micro0joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "microhandle_site", - }, - ) - - # Right hinge cabinet - register( - id="visual_kitchen_rdoor_open-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"rightdoorhinge": 0}, - "obj_goal": {"rightdoorhinge": 1.57}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "rightdoor_site", - }, - ) - register( - id="visual_kitchen_rdoor_close-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"rightdoorhinge": 1.57}, - "obj_goal": {"rightdoorhinge": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "rightdoor_site", - }, - ) - - # Left hinge cabinet - register( - id="visual_kitchen_ldoor_open-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"leftdoorhinge": 0}, - "obj_goal": {"leftdoorhinge": -1.25}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "leftdoor_site", - }, - ) - register( - id="visual_kitchen_ldoor_close-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"leftdoorhinge": -1.25}, - "obj_goal": {"leftdoorhinge": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "leftdoor_site", - }, - ) - - # Slide cabinet - register( - id="visual_kitchen_sdoor_open-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"slidedoor_joint": 0}, - "obj_goal": {"slidedoor_joint": 0.44}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "slide_site", - }, - ) - register( - id="visual_kitchen_sdoor_close-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"slidedoor_joint": 0.44}, - "obj_goal": {"slidedoor_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "slide_site", - }, - ) - - # Lights - register( - id="visual_kitchen_light_on-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"lightswitch_joint": 0}, - "obj_goal": {"lightswitch_joint": -0.7}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "light_site", - }, - ) - register( - id="visual_kitchen_light_off-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"lightswitch_joint": -0.7}, - "obj_goal": {"lightswitch_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "light_site", - }, - ) - - # Knob4 - register( - id="visual_kitchen_knob4_on-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob4_joint": 0}, - "obj_goal": {"knob4_joint": -1.57}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob4_site", - }, - ) - register( - id="visual_kitchen_knob4_off-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob4_joint": -1.57}, - "obj_goal": {"knob4_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob4_site", - }, - ) - - # Knob3 - register( - id="visual_kitchen_knob3_on-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob3_joint": 0}, - "obj_goal": {"knob3_joint": -1.57}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob3_site", - }, - ) - register( - id="visual_kitchen_knob3_off-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob3_joint": -1.57}, - "obj_goal": {"knob3_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob3_site", - }, - ) - - # Knob2 - register( - id="visual_kitchen_knob2_on-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob2_joint": 0}, - "obj_goal": {"knob2_joint": -1.57}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob2_site", - }, - ) - register( - id="visual_kitchen_knob2_off-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob2_joint": -1.57}, - "obj_goal": {"knob2_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob2_site", - }, - ) - - # Knob1 - register( - id="visual_kitchen_knob1_on-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob1_joint": 0}, - "obj_goal": {"knob1_joint": -1.57}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob1_site", - }, - ) - register( - id="visual_kitchen_knob1_off-v3", - entry_point=ENTRY_POINT, - max_episode_steps=50, - kwargs={ - "model_path": MODEL_PATH, - "config_path": CONFIG_PATH, - "obj_init": {"knob1_joint": -1.57}, - "obj_goal": {"knob1_joint": 0}, - "obs_keys_wt": obs_keys_wt, - "interact_site": "knob1_site", - }, - ) + for env in env_list: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name + ) @set_directory(CURR_DIR) def register_franka_envs(): + print("RLHive:> Registering Franka Envs") + env_list = [ + "franka_slide_random-v3", + "franka_slide_close-v3", + "franka_slide_open-v3", + "franka_micro_random-v3", + "franka_micro_close-v3", + "franka_micro_open-v3", + ] + # Franka Appliance ====================================================================== visual_obs_keys_wt = { "robot_jnt": 1.0, @@ -354,110 +105,8 @@ def register_franka_envs(): "rgb:right_cam:224x224:2d": 1.0, "rgb:left_cam:224x224:2d": 1.0, } - - # MICROWAVE - # obs_keys_wt = { - # "robot_jnt": 1.0, - # "end_effector": 1.0, - # } - register( - id="visual_franka_micro_open-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=75, - kwargs={ - "model_path": CURR_DIR + "/../common/microwave/franka_microwave.xml", - "config_path": CURR_DIR + "/../common/microwave/franka_microwave.config", - "obj_init": {"micro0joint": 0}, - "obj_goal": {"micro0joint": -1.25}, - "obj_interaction_site": ("microhandle_site",), - "obj_jnt_names": ("micro0joint",), - "interact_site": "microhandle_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) - register( - id="visual_franka_micro_close-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=50, - kwargs={ - "model_path": CURR_DIR + "/../common/microwave/franka_microwave.xml", - "config_path": CURR_DIR + "/../common/microwave/franka_microwave.config", - "obj_init": {"micro0joint": -1.25}, - "obj_goal": {"micro0joint": 0}, - "obj_interaction_site": ("microhandle_site",), - "obj_jnt_names": ("micro0joint",), - "interact_site": "microhandle_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) - register( - id="visual_franka_micro_random-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=50, - kwargs={ - "model_path": CURR_DIR + "/../common/microwave/franka_microwave.xml", - "config_path": CURR_DIR + "/../common/microwave/franka_microwave.config", - "obj_init": {"micro0joint": (-1.25, 0)}, - "obj_goal": {"micro0joint": (-1.25, 0)}, - "obj_interaction_site": ("microhandle_site",), - "obj_jnt_names": ("micro0joint",), - "obj_body_randomize": ("microwave",), - "interact_site": "microhandle_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) - - # SLIDE-CABINET - # obs_keys_wt = { - # "robot_jnt": 1.0, - # "end_effector": 1.0, - # } - register( - id="visual_franka_slide_open-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=50, - kwargs={ - "model_path": CURR_DIR + "/../common/slidecabinet/franka_slidecabinet.xml", - "config_path": CURR_DIR - + "/../common/slidecabinet/franka_slidecabinet.config", - "obj_init": {"slidedoor_joint": 0}, - "obj_goal": {"slidedoor_joint": 0.44}, - "obj_interaction_site": ("slide_site",), - "obj_jnt_names": ("slidedoor_joint",), - "interact_site": "slide_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) - register( - id="visual_franka_slide_close-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=50, - kwargs={ - "model_path": CURR_DIR + "/../common/slidecabinet/franka_slidecabinet.xml", - "config_path": CURR_DIR - + "/../common/slidecabinet/franka_slidecabinet.config", - "obj_init": {"slidedoor_joint": 0.44}, - "obj_goal": {"slidedoor_joint": 0}, - "obj_interaction_site": ("slide_site",), - "obj_jnt_names": ("slidedoor_joint",), - "interact_site": "slide_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) - register( - id="visual_franka_slide_random-v3", - entry_point="mj_envs.envs.multi_task.common.franka_appliance_v1:FrankaAppliance", - max_episode_steps=50, - kwargs={ - "model_path": CURR_DIR + "/../common/slidecabinet/franka_slidecabinet.xml", - "config_path": CURR_DIR - + "/../common/slidecabinet/franka_slidecabinet.config", - "obj_init": {"slidedoor_joint": (0, 0.44)}, - "obj_goal": {"slidedoor_joint": (0, 0.44)}, - "obj_interaction_site": ("slide_site",), - "obj_jnt_names": ("slidedoor_joint",), - "obj_body_randomize": ("slidecabinet",), - "interact_site": "slide_site", - "obs_keys_wt": visual_obs_keys_wt, - }, - ) + for env in env_list: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name + ) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index 08dec587f..37ce4547b 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -7,11 +7,7 @@ import numpy as np import torch from tensordict.tensordict import make_tensordict, TensorDictBase -from torchrl.data import ( - CompositeSpec, - BoundedTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, _has_gym, GymEnv from torchrl.envs.transforms import CatTensors, Compose, R3MTransform, TransformedEnv from torchrl.trainers.helpers.envs import LIBS @@ -25,7 +21,9 @@ def make_extra_spec(tensordict, obsspec: CompositeSpec): tensordict = tensordict.view(-1)[0] c = CompositeSpec() for key, value in tensordict.items(): - if obsspec is not None and (key in ("next", "action", "done", "reward") or key in obsspec.keys()): + if obsspec is not None and ( + key in ("next", "action", "done", "reward") or key in obsspec.keys() + ): continue if isinstance(value, TensorDictBase): spec = make_extra_spec(value, None) @@ -37,6 +35,7 @@ def make_extra_spec(tensordict, obsspec: CompositeSpec): return obsspec return c + class RoboHiveEnv(GymEnv): # info_keys = ["time", "rwd_dense", "rwd_sparse", "solved"] @@ -174,11 +173,7 @@ def read_obs(self, observation): out = {"observation": obsvec} return super().read_obs(out) - def read_info( - self, - info, - tensordict_out - ): + def read_info(self, info, tensordict_out): out = {} for key, value in info.items(): if key in ("obs_dict",): From eee0d4bdb81ca380dd840cafa9867d7858122438 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 18:17:52 +0000 Subject: [PATCH 36/58] amend --- .circleci/unittest/linux/scripts/run_test.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh index d4266d838..9dd49b427 100755 --- a/.circleci/unittest/linux/scripts/run_test.sh +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -30,10 +30,10 @@ export MUJOCO_GL=$PRIVATE_MUJOCO_GL export PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL export sim_backend=MUJOCO -#python ./third_party/mj_envs/mj_envs/tests/test_arms.py -#python ./third_party/mj_envs/mj_envs/tests/test_claws.py -#python ./third_party/mj_envs/mj_envs/tests/test_envs.py -#python ./third_party/mj_envs/mj_envs/tests/test_fm.py -#python ./third_party/mj_envs/mj_envs/tests/test_hand_manipulation_suite.py -#python ./third_party/mj_envs/mj_envs/tests/test_multitask.py +python ./third_party/mj_envs/mj_envs/tests/test_arms.py +python ./third_party/mj_envs/mj_envs/tests/test_claws.py +python ./third_party/mj_envs/mj_envs/tests/test_envs.py +python ./third_party/mj_envs/mj_envs/tests/test_fm.py +python ./third_party/mj_envs/mj_envs/tests/test_hand_manipulation_suite.py +python ./third_party/mj_envs/mj_envs/tests/test_multitask.py python test/test_envs.py From 2e5e1e6518cbc6e3b05295059a99885d6084b23c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 18:18:09 +0000 Subject: [PATCH 37/58] minor --- rlhive/rl_envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index 37ce4547b..ab070d7d0 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from copy import copy import numpy as np import torch From caa66e104122238cbd62c153eb1ceccd376502b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 23 Jan 2023 17:30:56 +0000 Subject: [PATCH 38/58] Some more info in GET_STARTED.md --- GET_STARTED.md | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/GET_STARTED.md b/GET_STARTED.md index b78115652..8936bfdbf 100644 --- a/GET_STARTED.md +++ b/GET_STARTED.md @@ -4,6 +4,8 @@ ## Installing dependencies The following code snippet installs the nightly versions of the libraries. For a faster installation, simply install `torchrl-nightly` and `tensordict-nightly`. +However, we recommand using the `git` version as they will be more likely up-to-date with the latest features, and as we are +actively working on fine-tuning torchrl for RoboHive usage, keeping the latest version of the library may be beneficial. ```shell module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 @@ -78,3 +80,86 @@ if __name__ == "__main__": print(data) ``` + +## Designing experiments and logging values + +TorchRL provides a series of wrappers around common loggers (tensorboard, mlflow, wandb etc). +We generally default to wandb. +Here are the details on how to set up your logger: wandb can work in one of two +modes: `online`, where you need an account and the machine you're running your experiment on must be +connected to the cloud, and `offline` where the logs are stored locally. +The latter is more general and easier to collect, hence we suggest you use this mode instead. +To configure and use your logger using TorchRL, procede as follows (notice that +using the plain wandb API is very similar to this, TorchRL's conveniance just relies in the +interchangeability with other loggers): + +```python +import argparse +import os + +from torchrl.trainers.loggers import WandbLogger +import torch + +parser = argparse.ArgumentParser() + +parser.add_argument("--total_frames", default=300, type=int) +parser.add_argument("--training_steps", default=3, type=int) +parser.add_argument("--wandb_exp_name", default="a2c") +parser.add_argument("--wandb_save_dir", default="./mylogs") +parser.add_argument("--wandb_project", default="rlhive") +parser.add_argument("--wandb_mode", default="offline", + choices=["online", "offline"]) + +if __name__ == "__main__": + args = parser.parse_args() + training_steps = args.training_steps + if args.wandb_mode == "offline": + # This will be integrated in torchrl + dest_dir = args.wandb_save_dir + os.makedirs(dest_dir, exist_ok=True) + logger = WandbLogger( + exp_name=args.wandb_exp_name, + save_dir=dest_dir, + project=args.wandb_project, + mode=args.wandb_mode, + ) + + # we collect 3 frames in each batch + collector = (torch.randn(3, 4, 0) for _ in range(args.total_frames // 3)) + total_frames = 0 + # main loop: collection of batches + for batch in collector: + for step in range(training_steps): + pass + total_frames += batch.shape[0] + # We log according to the frames, which we believe is the less subject to experiment + # hyperparameters + logger.log_scalar("loss_value", torch.randn([]).item(), + step=total_frames) + # one can log videos too! But custom steps do not work as expected :( + video = torch.randint(255, (10, 11, 3, 64, 64)) # 10 videos of 11 frames, 64x64 pixels + logger.log_video("demo", video) + +``` + + +This script will save your logs in `./mylogs`. Don't worry too much about `project` or `entity`, which can be [overwritten +at upload time](https://docs.wandb.ai/ref/cli/wandb-sync): + +Once we'll have collected these logs, we will upload them to a wandb account using `wandb sync path/to/log --entity someone --project something`. + +## What to log + +In general, experiments should log the following items: +- dense reward (train and test) +- sparse reward (train and test) +- success perc (train and test) +- video: after every 1M runs or so, a test run should be performed. A video recorder should be appended + to the test env to log the behaviour. +- number of training steps: since our "x"-axis will be the number of frames collected, keeping track of the + training steps will help us interpolate one with the other. +- For behavioural cloning we should log the number of epochs instead. + +## A more concrete example + +TODO From c935d24cd8b7662216f43e9e0f14f2800e3ddad6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 23 Jan 2023 17:31:41 +0000 Subject: [PATCH 39/58] Fix ref to wandb --- GET_STARTED.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GET_STARTED.md b/GET_STARTED.md index 8936bfdbf..a2492b551 100644 --- a/GET_STARTED.md +++ b/GET_STARTED.md @@ -97,7 +97,7 @@ interchangeability with other loggers): import argparse import os -from torchrl.trainers.loggers import WandbLogger +from torchrl.record.loggers import WandbLogger import torch parser = argparse.ArgumentParser() From 1af25a96b8475a26f4cfd63f69ef6b0ebc2b596e Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 24 Jan 2023 11:38:46 +0000 Subject: [PATCH 40/58] cleanup --- GET_STARTED.md | 2 ++ README.md | 49 +++---------------------------------------------- 2 files changed, 5 insertions(+), 46 deletions(-) diff --git a/GET_STARTED.md b/GET_STARTED.md index a2492b551..2390df9c3 100644 --- a/GET_STARTED.md +++ b/GET_STARTED.md @@ -28,6 +28,8 @@ python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or s ``` +For more complete instructions, check the installation pipeline in `.circleci/unittest/linux/script/install.sh` + You can run these two commands to check that installation was successful: ```shell diff --git a/README.md b/README.md index 7b70ccb17..6f8747717 100644 --- a/README.md +++ b/README.md @@ -111,19 +111,7 @@ torchrl examples: - [torchrl](https://github.com/pytorch/rl/tree/main/examples) - [torchrl_examples](https://github.com/compsciencelab/torchrl_examples) -*UNSTABLE*: One can train a model in state-only or with pixels. -When using pixels, the current API covers the R3M pipeline. Using a plain CNN is -an upcoming feature. -The entry point for all model trainings is `rlhive/sim_algos/run.py`. - -#### State-only - -As of now, one needs to specify the model and exploration method. -```bash -python run.py +model=sac +exploration=gaussian -``` - -## Exexution +## Execution RLHive is optimized for the `MUJOCO` backend. Make sure to set the `sim_backend` environment variable to `"MUJOCO"` before running the code: @@ -135,40 +123,9 @@ sim_backend=MUJOCO python script.py RLHive has two core dependencies: torchrl and RoboHive. RoboHive relies on mujoco and mujoco-py for physics simulation and rendering. As of now, RoboHive requires you to use the old mujoco bindings as well as the v0.13 of gym. -TorchRL provides [detailed instructions](https://github.com/facebookresearch/rl/pull/375) +TorchRL provides [detailed instructions](https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html#installing-mujoco). on how to setup an environment with the old mujoco bindings. -### Create conda env -```bash -$ conda create -n rlhive python=3.9 -$ conda activate rlhive -``` - -### Installing TorchRL -```bash -$ # cuda 11.6 -$ pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 -``` -For other cuda versions or nightly builds, check the [torch installation instructions](https://pytorch.org/get-started/locally/). -To install torchrl, run -``` -pip install torchrl-nightly -``` -or install directly from github: -``` -pip install git+https://github.com/pytorch/rl -``` - -### Installing RoboHive -First clone the repo and install it locally: -```bash -$ cd path/to/root -$ # follow the getting started instructions: https://github.com/vikashplus/mj_envs/tree/v0.3dev#getting-started -$ cd mj_envs -$ git checkout v0.3dev -$ git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.3dev --recursive https://github.com/vikashplus/mj_envs.git -$ cd mj_envs -$ pip install -e . -``` +See also the [Getting Started](GET_STARTED.md) markdown for more info on setting up your env. For more complete instructions, check the installation pipeline in `.circleci/unittest/linux/script/install.sh` From ad202067ca797b655295c7041abd2c855a93e03b Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:32:01 +0000 Subject: [PATCH 41/58] init --- examples/config/sac.yaml | 42 ++ examples/sac.py | 471 ++++++++++++++++++ rlhive/envs.py | 36 +- rlhive/rl_envs.py | 10 +- rlhive/sim_algos/helpers/rrl_transform.py | 8 +- scripts/redq/redq.py | 79 +-- scripts/sac_mujoco/config/group/group1.yaml | 2 +- scripts/sac_mujoco/config/group/group2.yaml | 2 +- .../sac_mujoco/config/hydra/output/local.yaml | 2 +- scripts/sac_mujoco/sac.py | 124 +++-- scripts/sac_mujoco/test.py | 60 ++- 11 files changed, 695 insertions(+), 141 deletions(-) create mode 100644 examples/config/sac.yaml create mode 100644 examples/sac.py diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml new file mode 100644 index 000000000..aaaf96bf9 --- /dev/null +++ b/examples/config/sac.yaml @@ -0,0 +1,42 @@ +default: + - override hydra/output: local + - override hydra/launcher: local + +# Logger +exp_name: ${task}_sac_${visual_transform} +visual_transform: r3m +record_interval: 1 +device: "cpu" + +# Environment +task: visual_franka_slide_random-v3 +frame_skip: 1 +reward_scaling: 5.0 +init_env_steps: 1000 +seed: 42 +eval_traj: 25 +num_envs: 8 + +# Collector +env_per_collector: 1 +max_frames_per_traj: -1 +total_frames: 1000000 +init_random_frames: 25000 +frames_per_batch: 10 + +# Replay Buffer +prb: 0 +buffer_size: 100000 +buffer_scratch_dir: /tmp/ + +# Optimization +gamma: 0.99 +batch_size: 256 +lr: 3.0e-4 +weight_decay: 0.0 +target_update_polyak: 0.995 +utd_ratio: 1 + +hydra: + job: + name: sac_${task}_${seed} diff --git a/examples/sac.py b/examples/sac.py new file mode 100644 index 000000000..8c287b981 --- /dev/null +++ b/examples/sac.py @@ -0,0 +1,471 @@ +# TODO +# Simplify +# logger +# check SAC loss vs torchrl's +# Make all the necessary imports for training +import os + +os.environ["sim_backend"] = "MUJOCO" + +import gc +import os +from copy import deepcopy + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from omegaconf import DictConfig +from rlhive.rl_envs import RoboHiveEnv + +# from torchrl.objectives import SACLoss +from tensordict import TensorDict + +from torch import nn, optim +from torchrl.data import TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage + +# from torchrl.envs import SerialEnv as ParallelEnv, R3MTransform, SelectTransform, TransformedEnv +from torchrl.envs import ( + CatTensors, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) +from torchrl.envs.transforms import Compose, FlattenObservation, RewardScaling +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import MLP, NormalParamWrapper, SafeModule +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + +from torchrl.objectives import SACLoss, SoftUpdate +from torchrl.record.loggers import WandbLogger +from torchrl.trainers import Recorder + + +# =========================================================================================== +# Env constructor +# --------------- +# - Use the RoboHiveEnv class to wrap robohive envs in torchrl's GymWrapper +# - Add transforms immediately after that: +# - SelectTransform: selects the relevant kesy from our output +# - R3MTransform +# - FlattenObservation: The images delivered by robohive have a singleton dim to start with, we need to flatten that +# - RewardScaling +# +# One can also possibly use ObservationNorm. +# +# TIPS: +# - For faster execution, you should follow this abstract scheme, where we reduce the data +# to be passed from worker to worker to a minimum, we apply R3M to a batch and append the +# rest of the transforms afterward: +# +# >>> env = TransformedEnv( +# ... ParallelEnv(N, lambda: TransformedEnv(RoboHiveEnv(...), SelectTransform(...))), +# ... Compose( +# ... R3MTransform(...), +# ... FlattenObservation(...), +# ... *other_transforms, +# ... )) +# + + +def make_env(num_envs, task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") + if num_envs > 1: + base_env = ParallelEnv(num_envs, lambda: RoboHiveEnv(task, device=device)) + else: + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) + + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + visual_transform="r3m", +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv( + env, + SelectTransform("solved", "pixels", "observation", "rwd_dense", "rwd_sparse"), + ) + if visual_transform == "r3m": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + return env + + +# =========================================================================================== +# Making a recorder +# ----------------- +# +# A `Recorder` is a dedicated torchrl class that will run the policy in the test env +# once every X steps (eg X=1M). +# + + +def make_recorder( + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, +): + test_env = make_env(num_envs=1, task=task, **env_configs) + recorder_obj = Recorder( + record_frames=eval_traj * test_env.horizon, + frame_skip=frame_skip, + policy_exploration=actor_model_explore, + recorder=test_env, + exploration_mode="mean", + record_interval=record_interval, + log_keys=["reward", "solved"], + out_keys={"reward": "r_evaluation", "solved": "success"}, + ) + return recorder_obj + + +# =========================================================================================== +# Relplay buffers +# --------------- +# +# TorchRL also provides prioritized RBs if needed. +# + + +def make_replay_buffer( + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + return replay_buffer + + +# =========================================================================================== +# Dataloader +# ---------- +# +# This is a simplified version of the dataloder +# + + +@torch.no_grad() +@set_exploration_mode("random") +def dataloader( + total_frames, fpb, train_env, actor, actor_collection, device_collection +): + params = TensorDict( + {k: v for k, v in actor.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + params_collection = TensorDict( + {k: v for k, v in actor_collection.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + _prev = None + + params_collection.update_(params) + collected_frames = 0 + while collected_frames < total_frames: + batch = TensorDict( + {}, batch_size=[fpb, *train_env.batch_size], device=device_collection + ) + for t in range(fpb): + if _prev is None: + _prev = train_env.reset() + _reset = _prev["_reset"] = _prev["done"].clone().squeeze(-1) + if _reset.any(): + _prev = train_env.reset(_prev) + _new = train_env.step(actor_collection(_prev)) + batch[t] = _new + _prev = step_mdp(_new, exclude_done=False) + collected_frames += batch.numel() + yield batch + + +@hydra.main(config_name="sac.yaml", config_path="config") +def main(args: DictConfig): + # customize device at will + device = "cpu" + device_collection = "cpu" + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Create Environment + env_configs = { + "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, + "device": args.device, + } + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs) + + # Create Agent + # Define Actor Network + in_keys = ["observation_vector"] + action_spec = train_env.action_spec + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": 2 * action_spec.shape[-1], + "activation_class": nn.ReLU, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": False, + } + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ReLU, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # add forward pass for initialization with proof env + proof_env = make_env(num_envs=1, task=args.task, **env_configs) + # init nets + with torch.no_grad(), set_exploration_mode("random"): + td = proof_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + proof_env.close() + + actor_collection = deepcopy(actor).to(device_collection) + + actor_model_explore = model[0] + + # Create SAC loss + loss_module = SACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + gamma=args.gamma, + loss_function="smooth_l1", + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, args.target_update_polyak) + + # Make Replay Buffer + replay_buffer = make_replay_buffer( + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) + + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + ) + + # Optimizers + params = list(loss_module.parameters()) + list([loss_module.log_alpha]) + optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + + rewards = [] + rewards_eval = [] + + # Main loop + target_net_updater.init_() + + collected_frames = 0 + episodes = 0 + optim_steps = 0 + pbar = tqdm.tqdm(total=args.total_frames) + r0 = None + loss = None + + total_frames = args.total_frames + frames_per_batch = args.frames_per_batch + + logger = WandbLogger( + exp_name=args.task, + project="SAC_TorchRL", + name=args.exp_name, + config=args, + entity="RLHive", + mode="offline", + ) + + for i, batch in enumerate( + dataloader( + total_frames, + frames_per_batch, + train_env, + actor, + actor_collection, + device_collection, + ) + ): + if r0 is None: + r0 = batch["reward"].sum(-1).mean().item() + pbar.update(batch.numel()) + + # extend the replay buffer with the new data + batch = batch.view(-1) + current_frames = batch.numel() + collected_frames += current_frames + episodes += args.env_per_collector + replay_buffer.extend(batch.cpu()) + + # optimization steps + if collected_frames >= args.init_random_frames: + ( + total_losses, + actor_losses, + q_losses, + alpha_losses, + alphas, + entropies, + ) = ([], [], [], [], [], []) + for _ in range( + args.env_per_collector * args.frames_per_batch * args.utd_ratio + ): + optim_steps += 1 + # sample from replay buffer + sampled_tensordict = replay_buffer.sample(args.batch_size).clone() + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + loss = actor_loss + q_loss + alpha_loss + optimizer_actor.zero_grad() + loss.backward() + optimizer_actor.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if args.prb: + replay_buffer.update_priority(sampled_tensordict) + + total_losses.append(loss.item()) + actor_losses.append(actor_loss.item()) + q_losses.append(q_loss.item()) + alpha_losses.append(alpha_loss.item()) + alphas.append(loss_td["alpha"].item()) + entropies.append(loss_td["entropy"].item()) + + rewards.append((i, batch["reward"].sum().item() / args.env_per_collector)) + logger.log_scalar("train_reward", rewards[-1][1], step=collected_frames) + logger.log_scalar("optim_steps", optim_steps, step=collected_frames) + logger.log_scalar("episodes", episodes, step=collected_frames) + + if loss is not None: + logger.log_scalar( + "total_loss", np.mean(total_losses), step=collected_frames + ) + logger.log_scalar( + "actor_loss", np.mean(actor_losses), step=collected_frames + ) + logger.log_scalar("q_loss", np.mean(q_losses), step=collected_frames) + logger.log_scalar( + "alpha_loss", np.mean(alpha_losses), step=collected_frames + ) + logger.log_scalar("alpha", np.mean(alphas), step=collected_frames) + logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) + td_record = recorder(None) + # success_percentage = evaluate_success( + # env_success_fn=train_env.evaluate_success, + # td_record=td_record, + # eval_traj=args.eval_traj, + # ) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) + ) + logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + if len(rewards_eval): + pbar.set_description( + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + ) + del batch + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/rlhive/envs.py b/rlhive/envs.py index 54d0a2822..a8c0b2059 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -6,11 +6,13 @@ # Custom env reg for RoboHive usage in TorchRL # Pixel rendering will be queried by torchrl, so we don't include those keys in visual_obs_keys_wt import os +import warnings from pathlib import Path -import mj_envs.envs.env_variants.register_env_variant import mj_envs.envs.multi_task.substeps1 +from mj_envs.envs.env_variants import register_env_variant + visual_obs_keys_wt = mj_envs.envs.multi_task.substeps1.visual_obs_keys_wt @@ -70,7 +72,7 @@ def register_kitchen_envs(): "kitchen_rdoor_open-v3", "kitchen_micro_close-v3", "kitchen_micro_open-v3", - "kitchen_close-v3", + # "kitchen_close-v3", ] visual_obs_keys_wt = { @@ -80,10 +82,17 @@ def register_kitchen_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) @set_directory(CURR_DIR) @@ -106,7 +115,14 @@ def register_franka_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index ab070d7d0..0f741d023 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -175,7 +175,7 @@ def read_obs(self, observation): def read_info(self, info, tensordict_out): out = {} for key, value in info.items(): - if key in ("obs_dict",): + if key in ("obs_dict", "done", "reward"): continue if isinstance(value, dict): value = make_tensordict(value, batch_size=[]) @@ -183,14 +183,6 @@ def read_info(self, info, tensordict_out): tensordict_out.update(out) return tensordict_out - def _step(self, td): - td = super()._step(td) - return td - - def _reset(self, td=None, **kwargs): - td = super()._reset(td, **kwargs) - return td - def to(self, *args, **kwargs): out = super().to(*args, **kwargs) try: diff --git a/rlhive/sim_algos/helpers/rrl_transform.py b/rlhive/sim_algos/helpers/rrl_transform.py index b9b501485..b9b46f872 100644 --- a/rlhive/sim_algos/helpers/rrl_transform.py +++ b/rlhive/sim_algos/helpers/rrl_transform.py @@ -104,8 +104,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec - #@staticmethod - #def _load_weights(model_name, r3m_instance, dir_prefix): + # @staticmethod + # def _load_weights(model_name, r3m_instance, dir_prefix): # if model_name not in ("r3m_50", "r3m_34", "r3m_18"): # raise ValueError( # "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'" @@ -123,7 +123,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec # state_dict = td_flatten.to_dict() # r3m_instance.convnet.load_state_dict(state_dict) - #def load_weights(self, dir_prefix=None): + # def load_weights(self, dir_prefix=None): # self._load_weights(self.model_name, self, dir_prefix) @@ -300,7 +300,7 @@ def _init(self): for transform in transforms: self.append(transform) - #if self.download: + # if self.download: # self[-1].load_weights(dir_prefix=self.download_path) if self._device is not None: diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py index 9800bd3cb..df9fc8783 100644 --- a/scripts/redq/redq.py +++ b/scripts/redq/redq.py @@ -10,8 +10,25 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore +from rlhive.rl_envs import RoboHiveEnv from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) + +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) from torchrl.envs.utils import set_exploration_mode from torchrl.modules import OrnsteinUhlenbeckProcessWrapper from torchrl.record import VideoRecorder @@ -25,7 +42,7 @@ initialize_observation_norm_transforms, parallel_env_constructor, retrieve_observation_norms_state_dict, - #transformed_env_constructor, + # transformed_env_constructor, ) from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss @@ -33,30 +50,23 @@ from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig from torchrl.trainers.loggers.utils import generate_exp_name, get_logger -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose -from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform -from torchrl.envs import ( - CatTensors, - DoubleToFloat, - EnvCreator, - ObservationNorm, - ParallelEnv, -) -from rlhive.rl_envs import RoboHiveEnv + def make_env( - task, - reward_scaling, - device, - obs_norm_state_dict=None, - action_dim_gsde=None, - state_dim_gsde=None, - ): + task, + reward_scaling, + device, + obs_norm_state_dict=None, + action_dim_gsde=None, + state_dim_gsde=None, +): base_env = RoboHiveEnv(task, device=device) env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) if not obs_norm_state_dict is None: - obs_norm = ObservationNorm(**obs_norm_state_dict, in_keys=["observation_vector"]) + obs_norm = ObservationNorm( + **obs_norm_state_dict, in_keys=["observation_vector"] + ) env.append_transform(obs_norm) if not action_dim_gsde is None: @@ -75,13 +85,17 @@ def make_transformed_env( Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=["r3m_vec"]))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=["r3m_vec"]), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) selected_keys = ["r3m_vec", "observation"] out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -93,6 +107,7 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env + config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( @@ -190,13 +205,13 @@ def main(cfg: "DictConfig"): # noqa: F821 proof_env.close() create_env_fn = make_env( ## Pass EnvBase instead of the create_env_fn - task=cfg.env_name, - reward_scaling=cfg.reward_scaling, - device=device, - obs_norm_state_dict=obs_norm_state_dict, - action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde - ) + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) collector = make_collector_offpolicy( make_env=create_env_fn, @@ -210,21 +225,21 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer = make_replay_buffer(device, cfg) - #recorder = transformed_env_constructor( + # recorder = transformed_env_constructor( # cfg, # video_tag=video_tag, # norm_obs_only=True, # obs_norm_state_dict=obs_norm_state_dict, # logger=logger, # use_env_creator=False, - #)() + # )() recorder = make_env( task=cfg.env_name, reward_scaling=cfg.reward_scaling, device=device, obs_norm_state_dict=obs_norm_state_dict, action_dim_gsde=action_dim_gsde, - state_dim_gsde=state_dim_gsde + state_dim_gsde=state_dim_gsde, ) # remove video recorder from recorder to have matching state_dict keys diff --git a/scripts/sac_mujoco/config/group/group1.yaml b/scripts/sac_mujoco/config/group/group1.yaml index 6730093ca..886d4fc95 100644 --- a/scripts/sac_mujoco/config/group/group1.yaml +++ b/scripts/sac_mujoco/config/group/group1.yaml @@ -1,4 +1,4 @@ # @package _group_ grp1a: 11 grp1b: aaa - gra1c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file + gra1c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/group/group2.yaml b/scripts/sac_mujoco/config/group/group2.yaml index b2f47d6dd..8b95ac771 100644 --- a/scripts/sac_mujoco/config/group/group2.yaml +++ b/scripts/sac_mujoco/config/group/group2.yaml @@ -1,4 +1,4 @@ # @package _group_ grp2a: 22 grp2b: bbb - gra2c: $group_seed{group.seed}_exp_seed{exp.seed} \ No newline at end of file + gra2c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/hydra/output/local.yaml b/scripts/sac_mujoco/config/hydra/output/local.yaml index aee5a513f..d3c95076e 100644 --- a/scripts/sac_mujoco/config/hydra/output/local.yaml +++ b/scripts/sac_mujoco/config/hydra/output/local.yaml @@ -5,4 +5,4 @@ hydra: subdir: ${hydra.job.num}_${hydra.job.override_dirname} sweep: dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} - subdir: ${hydra.job.num}_${hydra.job.override_dirname} \ No newline at end of file + subdir: ${hydra.job.num}_${hydra.job.override_dirname} diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index a9daa23c8..d7d2be539 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -4,21 +4,25 @@ # LICENSE file in the root directory of this source tree. -import os -import gc import argparse -import yaml +import gc +import os from typing import Optional +import hydra + import numpy as np import torch import torch.cuda import tqdm - -import hydra -from omegaconf import DictConfig, OmegaConf, open_dict import wandb -from torchrl.objectives import SACLoss +import yaml +from omegaconf import DictConfig, OmegaConf, open_dict +from rlhive.rl_envs import RoboHiveEnv +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform + +# from torchrl.objectives import SACLoss +from sac_loss import SACLoss from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector @@ -33,10 +37,15 @@ ObservationNorm, ParallelEnv, ) -from torchrl.envs import EnvCreator from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) +from torchrl.envs import ParallelEnv, R3MTransform, SelectTransform, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import MLP, NormalParamWrapper, ProbabilisticActor, SafeModule from torchrl.modules.distributions import TanhNormal @@ -46,21 +55,15 @@ from torchrl.objectives import SoftUpdate from torchrl.trainers import Recorder -from rlhive.rl_envs import RoboHiveEnv -from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform, SelectTransform - -os.environ['WANDB_MODE'] = 'offline' ## offline sync. TODO: Remove this behavior +os.environ["WANDB_MODE"] = "offline" ## offline sync. TODO: Remove this behavior -def make_env( - task, - visual_transform, - reward_scaling, - from_pixels, - device - ): - assert visual_transform in ('rrl', 'r3m') - base_env = RoboHiveEnv(task, from_pixels=from_pixels, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform, from_pixels=from_pixels) +def make_env(task, visual_transform, reward_scaling, device, from_pixels): + assert visual_transform in ("rrl", "r3m") + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform, from_pixels=from_pixels + ) + print(env) return env @@ -69,7 +72,7 @@ def make_transformed_env( env, from_pixels, reward_scaling=5.0, - visual_transform='r3m', + visual_transform="r3m", stats=None, ): """ @@ -94,7 +97,6 @@ def make_transformed_env( out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -106,35 +108,36 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env + def make_recorder( - task: str, - frame_skip: int, - record_interval: int, - actor_model_explore: object, - eval_traj: int, - env_configs: dict, - ): + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, +): test_env = make_env(task=task, **env_configs) recorder_obj = Recorder( - record_frames=eval_traj*test_env.horizon, + record_frames=eval_traj * test_env.horizon, frame_skip=frame_skip, policy_exploration=actor_model_explore, recorder=test_env, exploration_mode="mean", record_interval=record_interval, log_keys=["reward", "solved"], - out_keys={"reward": "r_evaluation", "solved" : "success"} + out_keys={"reward": "r_evaluation", "solved": "success"}, ) return recorder_obj def make_replay_buffer( - prb: bool, - buffer_size: int, - buffer_scratch_dir: str, - device: torch.device, - make_replay_buffer: int = 3 - ): + prb: bool, + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, @@ -160,11 +163,7 @@ def make_replay_buffer( return replay_buffer -def evaluate_success( - env_success_fn, - td_record: dict, - eval_traj: int - ): +def evaluate_success(env_success_fn, td_record: dict, eval_traj: int): td_record["success"] = td_record["success"].reshape((eval_traj, -1)) paths = [] for traj, solved_traj in zip(range(eval_traj), td_record["success"]): @@ -174,7 +173,6 @@ def evaluate_success( return success_percentage - @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): device = ( @@ -305,22 +303,21 @@ def main(args: DictConfig): # Make Replay Buffer replay_buffer = make_replay_buffer( - prb=args.prb, - buffer_size=args.buffer_size, - buffer_scratch_dir=args.buffer_scratch_dir, - device=device, - ) - + prb=args.prb, + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) # Trajectory recorder for evaluation recorder = make_recorder( - task=args.task, - frame_skip=args.frame_skip, - record_interval=args.record_interval, - actor_model_explore=actor_model_explore, - eval_traj=args.eval_traj, - env_configs=env_configs, - ) + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + ) # Optimizers params = list(loss_module.parameters()) + list([loss_module.log_alpha]) @@ -424,10 +421,10 @@ def main(args: DictConfig): ) td_record = recorder(None) success_percentage = evaluate_success( - env_success_fn=train_env.evaluate_success, - td_record=td_record, - eval_traj=args.eval_traj - ) + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj, + ) if td_record is not None: rewards_eval.append( ( @@ -447,5 +444,6 @@ def main(args: DictConfig): collector.shutdown() + if __name__ == "__main__": main() diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py index fec7a06d1..61e120811 100644 --- a/scripts/sac_mujoco/test.py +++ b/scripts/sac_mujoco/test.py @@ -1,27 +1,32 @@ import torch from rlhive.rl_envs import RoboHiveEnv -from torchrl.envs.utils import set_exploration_mode -from torchrl.envs.transforms import RewardScaling, TransformedEnv, FlattenObservation, Compose -from torchrl.envs import TransformedEnv, R3MTransform, SelectTransform +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) from torchrl.envs import ( CatTensors, DoubleToFloat, EnvCreator, ObservationNorm, + R3MTransform, + SelectTransform, + TransformedEnv, ) -from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +from torchrl.envs.utils import set_exploration_mode + -def make_env( - task, - visual_transform, - reward_scaling, - device - ): - assert visual_transform in ('rrl', 'r3m') +def make_env(task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") base_env = RoboHiveEnv(task, device=device) - env = make_transformed_env(env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) print(env) - #exit() + # exit() return env @@ -29,28 +34,37 @@ def make_env( def make_transformed_env( env, reward_scaling=5.0, - visual_transform='r3m', + visual_transform="r3m", stats=None, ): """ Apply transforms to the env (such as reward scaling and state normalization) """ env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) - if visual_transform == 'rrl': + if visual_transform == "rrl": vec_keys = ["rrl_vec"] selected_keys = ["observation", "rrl_vec"] - env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 - elif visual_transform == 'r3m': + env.append_transform( + Compose( + RRLTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "r3m": vec_keys = ["r3m_vec"] selected_keys = ["observation", "r3m_vec"] - env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 else: raise NotImplementedError env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) - # we normalize the states if stats is None: _stats = {"loc": 0.0, "scale": 1.0} @@ -62,7 +76,13 @@ def make_transformed_env( env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) return env -env = make_env(task="visual_franka_slide_random-v3", reward_scaling=5.0, device=torch.device('cuda:0'), visual_transform='rrl') + +env = make_env( + task="visual_franka_slide_random-v3", + reward_scaling=5.0, + device=torch.device("cuda:0"), + visual_transform="rrl", +) with torch.no_grad(), set_exploration_mode("random"): td = env.reset() td = env.rand_step() From 1bbddd4ad8abcd017187a8017941bc6e8d96afa9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:32:21 +0000 Subject: [PATCH 42/58] amend --- examples/sac.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 8c287b981..14c733712 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -1,8 +1,3 @@ -# TODO -# Simplify -# logger -# check SAC loss vs torchrl's -# Make all the necessary imports for training import os os.environ["sim_backend"] = "MUJOCO" From fea42b2236fcb18df76f5e08f8f806112ec771ae Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 15:41:24 +0000 Subject: [PATCH 43/58] amend --- examples/install/install_rlhive.sh | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100755 examples/install/install_rlhive.sh diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh new file mode 100755 index 000000000..9615903d9 --- /dev/null +++ b/examples/install/install_rlhive.sh @@ -0,0 +1,20 @@ +#!/bin/zsh + +here=$(pwd) +module_path=$HOME/modules/ + +module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 + +conda create -n rlhive -y python=3.8 +conda activate rlhive + +python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + +cd $module_path +git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.4dev --recursive https://github.com/vikashplus/mj_envs.git mj_envs +python3 -mpip install . # one can also install it locally with the -e flag +cd $here + +python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) +python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) +python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) From a43e2a488c174e0dc677cd433c9b0179db3cfb70 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:01:26 +0000 Subject: [PATCH 44/58] amend --- examples/install/install_rlhive.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 9615903d9..174852af8 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -1,17 +1,21 @@ #!/bin/zsh +set -e + here=$(pwd) module_path=$HOME/modules/ -module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 +conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 +mkdir $module_path cd $module_path git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.4dev --recursive https://github.com/vikashplus/mj_envs.git mj_envs +cd mj_envs python3 -mpip install . # one can also install it locally with the -e flag cd $here From 8cb852d3e527e03929057da6914149104c0e4e5c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:12:53 +0000 Subject: [PATCH 45/58] amend --- examples/install/install_rlhive.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 174852af8..feec54cf7 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -2,12 +2,17 @@ set -e +conda_path=$(conda info | grep -i 'base environment' | awk '{ print $4 }') +source $conda_path/etc/profile.d/conda.sh + here=$(pwd) module_path=$HOME/modules/ conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 + + conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 From deeb272bef3926a90c5cdb5f1b5665ef90b2b1f1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:37:11 +0000 Subject: [PATCH 46/58] amend --- examples/install/install_rlhive.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index feec54cf7..e0da7d2e7 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -8,6 +8,9 @@ source $conda_path/etc/profile.d/conda.sh here=$(pwd) module_path=$HOME/modules/ +module purge +module load cuda/11.6 + conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 From ff4895a3313f2964ea6009f42868cce858d422a6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 16:49:16 +0000 Subject: [PATCH 47/58] amend --- examples/config/sac.yaml | 2 +- examples/install/install_rlhive.sh | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index aaaf96bf9..c90aecda2 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -6,7 +6,7 @@ default: exp_name: ${task}_sac_${visual_transform} visual_transform: r3m record_interval: 1 -device: "cpu" +device: "cuda:0" # Environment task: visual_franka_slide_random-v3 diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index e0da7d2e7..f602ad73c 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -1,5 +1,7 @@ #!/bin/zsh +# Instructions to install a fresh anaconda environment with RLHive + set -e conda_path=$(conda info | grep -i 'base environment' | awk '{ print $4 }') @@ -30,3 +32,5 @@ cd $here python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) + +pip install wandb tqdm hydra-core \ No newline at end of file From 3224ec282db6bf5242c9295d05c71618dae849ed Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 17:02:58 +0000 Subject: [PATCH 48/58] amend --- examples/config/sac.yaml | 1 + examples/install/install_rlhive.sh | 11 ++++++++--- examples/sac.py | 6 +++--- setup.py | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index c90aecda2..dacd35ebe 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -7,6 +7,7 @@ exp_name: ${task}_sac_${visual_transform} visual_transform: r3m record_interval: 1 device: "cuda:0" +device_collection: "cuda:1" # Environment task: visual_franka_slide_random-v3 diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index f602ad73c..6b2b0e70a 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -17,7 +17,6 @@ conda env remove -n rlhive -y conda create -n rlhive -y python=3.8 - conda activate rlhive python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 @@ -31,6 +30,12 @@ cd $here python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) -python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) -pip install wandb tqdm hydra-core \ No newline at end of file +# this +# python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) +# or this +cd ../.. +pip install -e . +cd $here + +pip install wandb tqdm hydra-core diff --git a/examples/sac.py b/examples/sac.py index 14c733712..8cdd20904 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -212,8 +212,8 @@ def dataloader( @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): # customize device at will - device = "cpu" - device_collection = "cpu" + device = args.device + device_collection = args.device_collection torch.manual_seed(args.seed) np.random.seed(args.seed) @@ -223,7 +223,7 @@ def main(args: DictConfig): "visual_transform": args.visual_transform, "device": args.device, } - train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs) + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to(device_collection) # Create Agent # Define Actor Network diff --git a/setup.py b/setup.py index 8186a081c..86b104881 100644 --- a/setup.py +++ b/setup.py @@ -160,7 +160,7 @@ def _main(): # f"torchrl @ file://{rl_path}", "torchrl", "gym==0.13", - "mj_envs", + # "mj_envs", # f"mj_envs @ file://{mj_env_path}", "numpy", "packaging", From 4573419ebaff423be2f020b69b8ad719615826c1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:23:07 +0000 Subject: [PATCH 49/58] amend --- examples/config/sac.yaml | 3 ++- examples/sac.py | 29 +++++++++++++++-------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index dacd35ebe..9daa16eb3 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -16,6 +16,7 @@ reward_scaling: 5.0 init_env_steps: 1000 seed: 42 eval_traj: 25 +eval_interval: 1000 num_envs: 8 # Collector @@ -23,7 +24,7 @@ env_per_collector: 1 max_frames_per_traj: -1 total_frames: 1000000 init_random_frames: 25000 -frames_per_batch: 10 +frames_per_batch: 1000 # Replay Buffer prb: 0 diff --git a/examples/sac.py b/examples/sac.py index 8cdd20904..d7b0966e3 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -439,21 +439,22 @@ def main(args: DictConfig): ) logger.log_scalar("alpha", np.mean(alphas), step=collected_frames) logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) - td_record = recorder(None) - # success_percentage = evaluate_success( - # env_success_fn=train_env.evaluate_success, - # td_record=td_record, - # eval_traj=args.eval_traj, - # ) - if td_record is not None: - rewards_eval.append( - ( - i, - td_record["total_r_evaluation"] - / 1, # divide by number of eval worker + if i % args.eval_interval == 0: + td_record = recorder(None) + # success_percentage = evaluate_success( + # env_success_fn=train_env.evaluate_success, + # td_record=td_record, + # eval_traj=args.eval_traj, + # ) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) ) - ) - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" From 97180aeab88492c8d9be26db41d8d9cd5e48e2fe Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:25:58 +0000 Subject: [PATCH 50/58] amend --- examples/sac.py | 3 +- examples/sac_loss.py | 311 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 examples/sac_loss.py diff --git a/examples/sac.py b/examples/sac.py index d7b0966e3..1a8fe2099 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -38,7 +38,8 @@ from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator -from torchrl.objectives import SACLoss, SoftUpdate +from sac_loss import SACLoss +from torchrl.objectives import SoftUpdate from torchrl.record.loggers import WandbLogger from torchrl.trainers import Recorder diff --git a/examples/sac_loss.py b/examples/sac_loss.py new file mode 100644 index 000000000..cebe7f2e9 --- /dev/null +++ b/examples/sac_loss.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss From 7106f01cbbcdc44fd3a9a29b69000737edfcde3b Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:52:38 +0000 Subject: [PATCH 51/58] amend --- examples/sac.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 1a8fe2099..06d822d5b 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -1,5 +1,7 @@ import os +from torchrl.record import VideoRecorder + os.environ["sim_backend"] = "MUJOCO" import gc @@ -15,6 +17,8 @@ from omegaconf import DictConfig from rlhive.rl_envs import RoboHiveEnv +from sac_loss import SACLoss + # from torchrl.objectives import SACLoss from tensordict import TensorDict @@ -37,8 +41,6 @@ from torchrl.modules.distributions import TanhNormal from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator - -from sac_loss import SACLoss from torchrl.objectives import SoftUpdate from torchrl.record.loggers import WandbLogger from torchrl.trainers import Recorder @@ -129,8 +131,12 @@ def make_recorder( actor_model_explore: object, eval_traj: int, env_configs: dict, + wandb_logger: WandbLogger, ): test_env = make_env(num_envs=1, task=task, **env_configs) + test_env.insert_transform( + 0, VideoRecorder(wandb_logger, "test", in_keys=["pixels"]) + ) recorder_obj = Recorder( record_frames=eval_traj * test_env.horizon, frame_skip=frame_skip, @@ -224,7 +230,9 @@ def main(args: DictConfig): "visual_transform": args.visual_transform, "device": args.device, } - train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to(device_collection) + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to( + device_collection + ) # Create Agent # Define Actor Network @@ -448,6 +456,7 @@ def main(args: DictConfig): # eval_traj=args.eval_traj, # ) if td_record is not None: + print("recorded", td_record) rewards_eval.append( ( i, @@ -455,7 +464,13 @@ def main(args: DictConfig): / 1, # divide by number of eval worker ) ) - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + logger.log_scalar( + "test_reward", rewards_eval[-1][1], step=collected_frames + ) + logger.log_scalar( + "success", td_record["success"].any(), step=collected_frames + ) + if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" From f71a155086bcd7d8c9e032f05dad39b8bc905372 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:53:23 +0000 Subject: [PATCH 52/58] amend --- examples/sac.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 06d822d5b..1647956f4 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -467,13 +467,14 @@ def main(args: DictConfig): logger.log_scalar( "test_reward", rewards_eval[-1][1], step=collected_frames ) + solved = td_record["success"].any() logger.log_scalar( - "success", td_record["success"].any(), step=collected_frames + "success", solved, step=collected_frames ) if len(rewards_eval): pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}, solved: {solved}" ) del batch gc.collect() From a28404b55df4ddd3d8f8c4ce7c39b89f41dada17 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:54:32 +0000 Subject: [PATCH 53/58] amend --- examples/sac.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 1647956f4..3776d813e 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -328,16 +328,6 @@ def main(args: DictConfig): device=device, ) - # Trajectory recorder for evaluation - recorder = make_recorder( - task=args.task, - frame_skip=args.frame_skip, - record_interval=args.record_interval, - actor_model_explore=actor_model_explore, - eval_traj=args.eval_traj, - env_configs=env_configs, - ) - # Optimizers params = list(loss_module.parameters()) + list([loss_module.log_alpha]) optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) @@ -367,6 +357,17 @@ def main(args: DictConfig): mode="offline", ) + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + wandb_logger=logger, + ) + for i, batch in enumerate( dataloader( total_frames, From 1ac5466e57746ec9accd7cc3af8fb5f7ac33aef8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 18:55:55 +0000 Subject: [PATCH 54/58] amend --- examples/install/install_rlhive.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh index 6b2b0e70a..ba5f78931 100755 --- a/examples/install/install_rlhive.sh +++ b/examples/install/install_rlhive.sh @@ -38,4 +38,4 @@ cd ../.. pip install -e . cd $here -pip install wandb tqdm hydra-core +pip install wandb tqdm hydra-core moviepy From a7be171de48a42f010c87d2de2b965f9464fffdb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 19:00:34 +0000 Subject: [PATCH 55/58] amend --- examples/sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac.py b/examples/sac.py index 3776d813e..3447936eb 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -468,7 +468,7 @@ def main(args: DictConfig): logger.log_scalar( "test_reward", rewards_eval[-1][1], step=collected_frames ) - solved = td_record["success"].any() + solved = float(td_record["success"].any()) logger.log_scalar( "success", solved, step=collected_frames ) From 22d91cb919692149fefa0f06af85ca94d502f8c1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 27 Jan 2023 20:55:46 +0000 Subject: [PATCH 56/58] amend --- examples/sac.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/sac.py b/examples/sac.py index 3447936eb..b41ec1ae4 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -329,8 +329,8 @@ def main(args: DictConfig): ) # Optimizers - params = list(loss_module.parameters()) + list([loss_module.log_alpha]) - optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + params = list(loss_module.parameters()) + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) rewards = [] rewards_eval = [] @@ -407,15 +407,18 @@ def main(args: DictConfig): sampled_tensordict = replay_buffer.sample(args.batch_size).clone() loss_td = loss_module(sampled_tensordict) + print(f'value: {loss_td["state_action_value_actor"].mean():4.4f}') + print(f'log_prob: {loss_td["action_log_prob_actor"].mean():4.4f}') + print(f'next.state_value: {loss_td["state_value"].mean():4.4f}') actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] loss = actor_loss + q_loss + alpha_loss - optimizer_actor.zero_grad() + optimizer.zero_grad() loss.backward() - optimizer_actor.step() + optimizer.step() # update qnet_target params target_net_updater.step() From 1a6e527df417486f45ad0ab7007a25756d63624f Mon Sep 17 00:00:00 2001 From: rutavms Date: Sat, 28 Jan 2023 17:27:33 -0600 Subject: [PATCH 57/58] moving the sac_loss to local file --- scripts/sac_mujoco/sac.py | 2 +- scripts/sac_mujoco/sac_loss.py | 474 +++++++++++++++++++++++++++++++++ 2 files changed, 475 insertions(+), 1 deletion(-) create mode 100644 scripts/sac_mujoco/sac_loss.py diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py index d7d2be539..aa14fb800 100644 --- a/scripts/sac_mujoco/sac.py +++ b/scripts/sac_mujoco/sac.py @@ -21,7 +21,7 @@ from rlhive.rl_envs import RoboHiveEnv from rlhive.sim_algos.helpers.rrl_transform import RRLTransform -# from torchrl.objectives import SACLoss +#from torchrl.objectives import SACLoss from sac_loss import SACLoss from torch import nn, optim diff --git a/scripts/sac_mujoco/sac_loss.py b/scripts/sac_mujoco/sac_loss.py new file mode 100644 index 000000000..07d5ab1c4 --- /dev/null +++ b/scripts/sac_mujoco/sac_loss.py @@ -0,0 +1,474 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + _explicit: bool = True + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._explicit: + # slow but explicit version + return self._forward_explicit(tensordict) + else: + return self._forward_vectorized(tensordict) + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + def _forward_vectorized(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = self._rsample(tensordict_actor_dist) + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _forward_explicit(self, tensordict: TensorDictBase) -> TensorDictBase: + loss_actor, sample_log_prob = self._loss_actor_explicit(tensordict.clone(False)) + loss_qval, td_error = self._loss_qval_explicit(tensordict.clone(False)) + tensordict.set("td_error", td_error.detach().max(0)[0]) + loss_alpha = self._loss_alpha(sample_log_prob) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + # "state_action_value_actor": state_action_value_actor.mean().detach(), + # "action_log_prob_actor": action_log_prob_actor.mean().detach(), + # "next.state_value": next_state_value.mean().detach(), + # "target_value": target_value.mean().detach(), + }, + [], + ) + return td_out + + def _rsample(self, dist, ): + # separated only for the purpose of making the sampling + # deterministic to compare methods + return dist.rsample() + + + def _sample_reparam(self, tensordict, params): + """Given a policy param batch and input data in a tensordict, writes a reparam sample and log-prob key.""" + with set_exploration_mode("random"): + if self.gSDE: + raise NotImplementedError + # vmap doesn't support sampling, so we take it out from the vmap + td_params = self.actor_network.get_dist_params(tensordict, params,) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict[sample_key] = self._rsample(tensordict_actor_dist) + tensordict["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict[sample_key] + ) + return tensordict + + def _loss_actor_explicit(self, tensordict): + tensordict_actor = tensordict.clone(False) + actor_params = self.actor_network_params + tensordict_actor = self._sample_reparam(tensordict_actor, actor_params) + action_log_prob_actor = tensordict_actor["sample_log_prob"] + + tensordict_qval = ( + tensordict_actor + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor.batch_size) + ) # for actor loss + qvalue_params = self.qvalue_network_params.detach() + tensordict_qval = vmap(self.qvalue_network)(tensordict_qval, qvalue_params,) + state_action_value_actor = tensordict_qval.get("state_action_value").squeeze(-1) + state_action_value_actor = state_action_value_actor.min(0)[0] + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = (self.alpha * action_log_prob_actor - state_action_value_actor).mean() + + return loss_actor, action_log_prob_actor + + def _loss_qval_explicit(self, tensordict): + next_tensordict = step_mdp(tensordict) + next_tensordict = self._sample_reparam(next_tensordict, self.target_actor_network_params) + next_action_log_prob_qvalue = next_tensordict["sample_log_prob"] + next_state_action_value_qvalue = vmap(self.qvalue_network, (None, 0))( + next_tensordict, + self.target_qvalue_network_params, + )["state_action_value"].squeeze(-1) + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + pred_val = vmap(self.qvalue_network, (None, 0))( + tensordict, + self.qvalue_network_params, + )["state_action_value"].squeeze(-1) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + + # 1/2 * E[Q(s,a) - (r + gamma * (Q(s,a)-alpha log pi(s, a))) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + td_error = (pred_val - target_value).pow(2) + return loss_qval, td_error + +if __name__ == "__main__": + # Tests the vectorized version of SAC-v2 against plain implementation + from torchrl.modules import ProbabilisticActor, ValueOperator + from torchrl.data import BoundedTensorSpec + from torch import nn + from tensordict.nn import TensorDictModule + from torchrl.modules.distributions import TanhNormal + + torch.manual_seed(0) + + action_spec = BoundedTensorSpec(-1, 1, shape=(3,)) + class Splitter(nn.Linear): + def forward(self, x): + loc, scale = super().forward(x).chunk(2, -1) + return loc, scale.exp() + actor_module = TensorDictModule(Splitter(6, 6), in_keys=["obs"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=TanhNormal, + default_interaction_mode="random", + return_log_prob=False, + ) + class QVal(nn.Linear): + def forward(self, s: Tensor, a: Tensor) -> Tensor: + return super().forward(torch.cat([s, a], -1)) + + qvalue = ValueOperator(QVal(9, 1), in_keys=["obs", "action"]) + _rsample_old = SACLoss._rsample + def _rsample_new(self, dist): + return torch.ones_like(_rsample_old(self, dist)) + SACLoss._rsample = _rsample_new + loss = SACLoss(actor, qvalue) + + for batch in ((), (2, 3)): + td_input = TensorDict({"obs": torch.rand(*batch, 6), "action": torch.rand(*batch, 3).clamp(-1, 1), "next": {"obs": torch.rand(*batch, 6)}, "reward": torch.rand(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch) + loss._explicit = True + loss0 = loss(td_input) + loss._explicit = False + loss1 = loss(td_input) + print("a", loss0["loss_actor"]-loss1["loss_actor"]) + print("q", loss0["loss_qvalue"]-loss1["loss_qvalue"]) From c521fcde65e8b07cc321affa58f5b8f44a615991 Mon Sep 17 00:00:00 2001 From: rutavms Date: Tue, 31 Jan 2023 14:05:57 -0600 Subject: [PATCH 58/58] updated with rrl,r3m,flatten transforms, added visual hand envs --- examples/config/sac.yaml | 2 +- examples/sac.py | 119 +++++++++++++++++++++++++++------------ rlhive/__init__.py | 3 +- rlhive/envs.py | 34 +++++++++++ scripts/installation.sh | 2 +- 5 files changed, 121 insertions(+), 39 deletions(-) diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml index 83f638245..e4b3909c4 100644 --- a/examples/config/sac.yaml +++ b/examples/config/sac.yaml @@ -13,7 +13,7 @@ wandb_mode: "offline" # Environment task: visual_franka_slide_random-v3 -frame_skip: 1 +#frame_skip: 1 reward_scaling: 5.0 init_env_steps: 1000 seed: 42 diff --git a/examples/sac.py b/examples/sac.py index a256e8635..e417b820c 100644 --- a/examples/sac.py +++ b/examples/sac.py @@ -15,6 +15,7 @@ import torch.cuda import tqdm from omegaconf import DictConfig +from torchvision.models import ResNet50_Weights from rlhive.rl_envs import RoboHiveEnv from sac_loss import SACLoss @@ -35,7 +36,7 @@ SelectTransform, TransformedEnv, ) -from torchrl.envs.transforms import Compose, FlattenObservation, RewardScaling +from torchrl.envs.transforms import Compose, FlattenObservation, RewardScaling, Resize, ToTensorImage from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import MLP, NormalParamWrapper, SafeModule from torchrl.modules.distributions import TanhNormal @@ -45,7 +46,6 @@ from torchrl.record.loggers import WandbLogger from torchrl.trainers import Recorder - # =========================================================================================== # Env constructor # --------------- @@ -72,9 +72,20 @@ # ... )) # +def is_visual_env(task): + return task.startswith("visual_") + +def evaluate_success(env_success_fn, td_record: dict, eval_traj: int): + td_record["success"] = td_record["success"].reshape((eval_traj, -1)) + paths = [] + for traj, solved_traj in zip(range(eval_traj), td_record["success"]): + path = {"env_infos": {"solved": solved_traj.data.cpu().numpy()}} + paths.append(path) + success_percentage = env_success_fn(paths) + return success_percentage def make_env(num_envs, task, visual_transform, reward_scaling, device): - assert visual_transform in ("rrl", "r3m") + assert visual_transform in ("rrl", "r3m", "flatten", "state") if num_envs > 1: base_env = ParallelEnv(num_envs, lambda: RoboHiveEnv(task, device=device)) else: @@ -94,21 +105,46 @@ def make_transformed_env( """ Apply transforms to the env (such as reward scaling and state normalization) """ - env = TransformedEnv( - env, - SelectTransform("solved", "pixels", "observation", "rwd_dense", "rwd_sparse"), - ) - if visual_transform == "r3m": - vec_keys = ["r3m_vec"] - selected_keys = ["observation", "r3m_vec"] - env.append_transform( - Compose( - R3MTransform("resnet50", in_keys=["pixels"], download=True), - FlattenObservation(-2, -1, in_keys=vec_keys), + if visual_transform != "state": + env = TransformedEnv( + env, + SelectTransform("solved", "pixels", "observation", "rwd_dense", "rwd_sparse"), + ) + if visual_transform == "r3m": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "rrl": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=ResNet50_Weights.IMAGENET1K_V2), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "flatten": + vec_keys = ["pixels"] + out_keys = ["pixels"] + selected_keys = ["observation", "pixels"] + env.append_transform( + Compose( + ToTensorImage(), + Resize(64, 64, in_keys=vec_keys, out_keys=out_keys), ## TODO: Why is resize not working? + FlattenObservation(-4, -1, in_keys=out_keys), + ) ) - ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError else: - raise NotImplementedError + env = TransformedEnv(env, SelectTransform("solved", "observation", "rwd_dense", "rwd_sparse")) + selected_keys = ["observation"] + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) @@ -126,7 +162,7 @@ def make_transformed_env( def make_recorder( task: str, - frame_skip: int, + #frame_skip: int, record_interval: int, actor_model_explore: object, eval_traj: int, @@ -134,18 +170,20 @@ def make_recorder( wandb_logger: WandbLogger, ): test_env = make_env(num_envs=1, task=task, **env_configs) - test_env.insert_transform( - 0, VideoRecorder(wandb_logger, "test", in_keys=["pixels"]) - ) + if is_visual_env(task):## TODO(Rutav): Change this behavior. Record only when using visual env + test_env.insert_transform( + 0, VideoRecorder(wandb_logger, "test", in_keys=["pixels"]) + ) recorder_obj = Recorder( record_frames=eval_traj * test_env.horizon, - frame_skip=frame_skip, + #frame_skip=frame_skip, ## To maintain consistency and using default env frame_skip values + frame_skip=1, ## To maintain consistency and using default env frame_skip values policy_exploration=actor_model_explore, recorder=test_env, exploration_mode="mean", record_interval=record_interval, - log_keys=["reward", "solved"], - out_keys={"reward": "r_evaluation", "solved": "success"}, + log_keys=["reward", "solved", "rwd_dense", "rwd_sparse"], + out_keys={"reward": "r_evaluation", "solved": "success", "rwd_dense": "rwd_dense", "rwd_sparse": "rwd_sparse"}, ) return recorder_obj @@ -218,6 +256,7 @@ def dataloader( @hydra.main(config_name="sac.yaml", config_path="config") def main(args: DictConfig): + assert ((args.visual_transform == "state")^is_visual_env(args.task)), "Please use visual_transform=state if using state environment; else use visual_transform=r3m,rrl" # customize device at will device = args.device device_collection = args.device_collection @@ -334,6 +373,7 @@ def main(args: DictConfig): rewards = [] rewards_eval = [] + success_percentage_hist = [] # Main loop target_net_updater.init_() @@ -360,7 +400,7 @@ def main(args: DictConfig): # Trajectory recorder for evaluation recorder = make_recorder( task=args.task, - frame_skip=args.frame_skip, + #frame_skip=args.frame_skip, record_interval=args.record_interval, actor_model_explore=actor_model_explore, eval_traj=args.eval_traj, @@ -407,9 +447,10 @@ def main(args: DictConfig): sampled_tensordict = replay_buffer.sample(args.batch_size).clone() loss_td = loss_module(sampled_tensordict) - print(f'value: {loss_td["state_action_value_actor"].mean():4.4f}') - print(f'log_prob: {loss_td["action_log_prob_actor"].mean():4.4f}') - print(f'next.state_value: {loss_td["state_value"].mean():4.4f}') + ## Not returned in explicit forward loss + #print(f'value: {loss_td["state_action_value_actor"].mean():4.4f}') + #print(f'log_prob: {loss_td["action_log_prob_actor"].mean():4.4f}') + #print(f'next.state_value: {loss_td["state_value"].mean():4.4f}') actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] @@ -454,13 +495,7 @@ def main(args: DictConfig): logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) if i % args.eval_interval == 0: td_record = recorder(None) - # success_percentage = evaluate_success( - # env_success_fn=train_env.evaluate_success, - # td_record=td_record, - # eval_traj=args.eval_traj, - # ) if td_record is not None: - print("recorded", td_record) rewards_eval.append( ( i, @@ -471,14 +506,26 @@ def main(args: DictConfig): logger.log_scalar( "test_reward", rewards_eval[-1][1], step=collected_frames ) - solved = float(td_record["success"].any()) logger.log_scalar( - "success", solved, step=collected_frames + "reward_sparse", td_record["rwd_sparse"].sum()/args.eval_traj, step=collected_frames + ) + logger.log_scalar( + "reward_dense", td_record["rwd_dense"].sum()/args.eval_traj, step=collected_frames + ) + success_percentage = evaluate_success( + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj, + ) + success_percentage_hist.append(success_percentage) + #solved = float(td_record["success"].any()) + logger.log_scalar( + "success_rate", success_percentage, step=collected_frames ) if len(rewards_eval): pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}, solved: {solved}" + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}, Success: {success_percentage_hist[-1]}" ) del batch gc.collect() diff --git a/rlhive/__init__.py b/rlhive/__init__.py index 88aff9f05..56538b856 100644 --- a/rlhive/__init__.py +++ b/rlhive/__init__.py @@ -3,9 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .envs import register_franka_envs, register_kitchen_envs +from .envs import register_franka_envs, register_kitchen_envs, register_hand_envs register_franka_envs() register_kitchen_envs() +register_hand_envs() from .rl_envs import RoboHiveEnv diff --git a/rlhive/envs.py b/rlhive/envs.py index a8c0b2059..9606ebebc 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -126,3 +126,37 @@ def register_franka_envs(): warnings.warn( f"Could not register {new_env_name}, the following error was raised: {err}" ) + +@set_directory(CURR_DIR) +def register_hand_envs(): + print("RLHive:> Registering Franka Envs") + env_list = [ + "door-v1", + "hammer-v1", + "pen-v1", + "relocate-v1" + ] + + # Hand Manipulation Suite ====================================================================== + visual_obs_keys_wt = { + "hand_jnt": 1.0, + "rgb:vil_camera:224x224:2d": 1.0, + "rgb:fixed:224x224:2d": 1.0, + } + + for env in env_list: + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={'obs_keys': + ['hand_jnt', + "rgb:vil_camera:224x224:2d", + "rgb:fixed:224x224:2d"] + }, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) diff --git a/scripts/installation.sh b/scripts/installation.sh index aa89d5275..cd232abcc 100644 --- a/scripts/installation.sh +++ b/scripts/installation.sh @@ -10,5 +10,5 @@ cd $here python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) -pip install wandb +pip install wandb moviepy pip install hydra-submitit-launcher --upgrade