diff --git a/examples/config/sac.yaml b/examples/config/sac.yaml new file mode 100644 index 000000000..e4b3909c4 --- /dev/null +++ b/examples/config/sac.yaml @@ -0,0 +1,46 @@ +default: + - override hydra/output: local + - override hydra/launcher: local + +# Logger +exp_name: ${task}_sac_${visual_transform} +visual_transform: r3m +record_interval: 1 +device: "cuda:0" +device_collection: "cuda:1" +wandb_entity: "RLHive" +wandb_mode: "offline" + +# Environment +task: visual_franka_slide_random-v3 +#frame_skip: 1 +reward_scaling: 5.0 +init_env_steps: 1000 +seed: 42 +eval_traj: 25 +eval_interval: 1000 +num_envs: 8 + +# Collector +env_per_collector: 1 +max_frames_per_traj: -1 +total_frames: 1000000 +init_random_frames: 25000 +frames_per_batch: 1000 + +# Replay Buffer +prb: 0 +buffer_size: 100000 +buffer_scratch_dir: /tmp/ + +# Optimization +gamma: 0.99 +batch_size: 256 +lr: 3.0e-4 +weight_decay: 0.0 +target_update_polyak: 0.995 +utd_ratio: 1 + +hydra: + job: + name: sac_${task}_${seed} diff --git a/examples/install/install_rlhive.sh b/examples/install/install_rlhive.sh new file mode 100755 index 000000000..ba5f78931 --- /dev/null +++ b/examples/install/install_rlhive.sh @@ -0,0 +1,41 @@ +#!/bin/zsh + +# Instructions to install a fresh anaconda environment with RLHive + +set -e + +conda_path=$(conda info | grep -i 'base environment' | awk '{ print $4 }') +source $conda_path/etc/profile.d/conda.sh + +here=$(pwd) +module_path=$HOME/modules/ + +module purge +module load cuda/11.6 + +conda env remove -n rlhive -y + +conda create -n rlhive -y python=3.8 + +conda activate rlhive + +python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + +mkdir $module_path +cd $module_path +git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch v0.4dev --recursive https://github.com/vikashplus/mj_envs.git mj_envs +cd mj_envs +python3 -mpip install . # one can also install it locally with the -e flag +cd $here + +python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) +python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) + +# this +# python3 -mpip install git+https://github.com/facebookresearch/rlhive.git # or stable or nightly with pip install torchrl(-nightly) +# or this +cd ../.. +pip install -e . +cd $here + +pip install wandb tqdm hydra-core moviepy diff --git a/examples/sac.py b/examples/sac.py new file mode 100644 index 000000000..e417b820c --- /dev/null +++ b/examples/sac.py @@ -0,0 +1,535 @@ +import os + +from torchrl.record import VideoRecorder + +os.environ["sim_backend"] = "MUJOCO" + +import gc +import os +from copy import deepcopy + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from omegaconf import DictConfig +from torchvision.models import ResNet50_Weights +from rlhive.rl_envs import RoboHiveEnv + +from sac_loss import SACLoss + +# from torchrl.objectives import SACLoss +from tensordict import TensorDict + +from torch import nn, optim +from torchrl.data import TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage + +# from torchrl.envs import SerialEnv as ParallelEnv, R3MTransform, SelectTransform, TransformedEnv +from torchrl.envs import ( + CatTensors, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) +from torchrl.envs.transforms import Compose, FlattenObservation, RewardScaling, Resize, ToTensorImage +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import MLP, NormalParamWrapper, SafeModule +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator +from torchrl.objectives import SoftUpdate +from torchrl.record.loggers import WandbLogger +from torchrl.trainers import Recorder + +# =========================================================================================== +# Env constructor +# --------------- +# - Use the RoboHiveEnv class to wrap robohive envs in torchrl's GymWrapper +# - Add transforms immediately after that: +# - SelectTransform: selects the relevant kesy from our output +# - R3MTransform +# - FlattenObservation: The images delivered by robohive have a singleton dim to start with, we need to flatten that +# - RewardScaling +# +# One can also possibly use ObservationNorm. +# +# TIPS: +# - For faster execution, you should follow this abstract scheme, where we reduce the data +# to be passed from worker to worker to a minimum, we apply R3M to a batch and append the +# rest of the transforms afterward: +# +# >>> env = TransformedEnv( +# ... ParallelEnv(N, lambda: TransformedEnv(RoboHiveEnv(...), SelectTransform(...))), +# ... Compose( +# ... R3MTransform(...), +# ... FlattenObservation(...), +# ... *other_transforms, +# ... )) +# + +def is_visual_env(task): + return task.startswith("visual_") + +def evaluate_success(env_success_fn, td_record: dict, eval_traj: int): + td_record["success"] = td_record["success"].reshape((eval_traj, -1)) + paths = [] + for traj, solved_traj in zip(range(eval_traj), td_record["success"]): + path = {"env_infos": {"solved": solved_traj.data.cpu().numpy()}} + paths.append(path) + success_percentage = env_success_fn(paths) + return success_percentage + +def make_env(num_envs, task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m", "flatten", "state") + if num_envs > 1: + base_env = ParallelEnv(num_envs, lambda: RoboHiveEnv(task, device=device)) + else: + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) + + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + visual_transform="r3m", +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + if visual_transform != "state": + env = TransformedEnv( + env, + SelectTransform("solved", "pixels", "observation", "rwd_dense", "rwd_sparse"), + ) + if visual_transform == "r3m": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "rrl": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=ResNet50_Weights.IMAGENET1K_V2), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "flatten": + vec_keys = ["pixels"] + out_keys = ["pixels"] + selected_keys = ["observation", "pixels"] + env.append_transform( + Compose( + ToTensorImage(), + Resize(64, 64, in_keys=vec_keys, out_keys=out_keys), ## TODO: Why is resize not working? + FlattenObservation(-4, -1, in_keys=out_keys), + ) + ) + else: + raise NotImplementedError + else: + env = TransformedEnv(env, SelectTransform("solved", "observation", "rwd_dense", "rwd_sparse")) + selected_keys = ["observation"] + + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + return env + + +# =========================================================================================== +# Making a recorder +# ----------------- +# +# A `Recorder` is a dedicated torchrl class that will run the policy in the test env +# once every X steps (eg X=1M). +# + + +def make_recorder( + task: str, + #frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, + wandb_logger: WandbLogger, +): + test_env = make_env(num_envs=1, task=task, **env_configs) + if is_visual_env(task):## TODO(Rutav): Change this behavior. Record only when using visual env + test_env.insert_transform( + 0, VideoRecorder(wandb_logger, "test", in_keys=["pixels"]) + ) + recorder_obj = Recorder( + record_frames=eval_traj * test_env.horizon, + #frame_skip=frame_skip, ## To maintain consistency and using default env frame_skip values + frame_skip=1, ## To maintain consistency and using default env frame_skip values + policy_exploration=actor_model_explore, + recorder=test_env, + exploration_mode="mean", + record_interval=record_interval, + log_keys=["reward", "solved", "rwd_dense", "rwd_sparse"], + out_keys={"reward": "r_evaluation", "solved": "success", "rwd_dense": "rwd_dense", "rwd_sparse": "rwd_sparse"}, + ) + return recorder_obj + + +# =========================================================================================== +# Relplay buffers +# --------------- +# +# TorchRL also provides prioritized RBs if needed. +# + + +def make_replay_buffer( + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + return replay_buffer + + +# =========================================================================================== +# Dataloader +# ---------- +# +# This is a simplified version of the dataloder +# + + +@torch.no_grad() +@set_exploration_mode("random") +def dataloader( + total_frames, fpb, train_env, actor, actor_collection, device_collection +): + params = TensorDict( + {k: v for k, v in actor.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + params_collection = TensorDict( + {k: v for k, v in actor_collection.named_parameters()}, batch_size=[] + ).unflatten_keys(".") + _prev = None + + collected_frames = 0 + while collected_frames < total_frames: + params_collection.update_(params) + batch = TensorDict( + {}, batch_size=[fpb, *train_env.batch_size], device=device_collection + ) + for t in range(fpb): + if _prev is None: + _prev = train_env.reset() + _reset = _prev["_reset"] = _prev["done"].clone().squeeze(-1) + if _reset.any(): + _prev = train_env.reset(_prev) + _new = train_env.step(actor_collection(_prev)) + batch[t] = _new + _prev = step_mdp(_new, exclude_done=False) + collected_frames += batch.numel() + yield batch + + +@hydra.main(config_name="sac.yaml", config_path="config") +def main(args: DictConfig): + assert ((args.visual_transform == "state")^is_visual_env(args.task)), "Please use visual_transform=state if using state environment; else use visual_transform=r3m,rrl" + # customize device at will + device = args.device + device_collection = args.device_collection + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Create Environment + env_configs = { + "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, + "device": args.device, + } + train_env = make_env(num_envs=args.num_envs, task=args.task, **env_configs).to( + device_collection + ) + + # Create Agent + # Define Actor Network + in_keys = ["observation_vector"] + action_spec = train_env.action_spec + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": 2 * action_spec.shape[-1], + "activation_class": nn.ReLU, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": True, + } + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ReLU, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # add forward pass for initialization with proof env + proof_env = make_env(num_envs=1, task=args.task, **env_configs) + # init nets + with torch.no_grad(), set_exploration_mode("random"): + td = proof_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + proof_env.close() + + actor_collection = deepcopy(actor).to(device_collection) + + actor_model_explore = model[0] + + # Create SAC loss + loss_module = SACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + gamma=args.gamma, + loss_function="smooth_l1", + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, args.target_update_polyak) + + # Make Replay Buffer + replay_buffer = make_replay_buffer( + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) + + # Optimizers + params = list(loss_module.parameters()) + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + + rewards = [] + rewards_eval = [] + success_percentage_hist = [] + + # Main loop + target_net_updater.init_() + + collected_frames = 0 + episodes = 0 + optim_steps = 0 + pbar = tqdm.tqdm(total=args.total_frames) + r0 = None + loss = None + + total_frames = args.total_frames + frames_per_batch = args.frames_per_batch + + logger = WandbLogger( + exp_name=args.task, + project="SAC_TorchRL", + name=args.exp_name, + config=args, + entity=args.wandb_entity, + mode=args.wandb_mode, + ) + + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + #frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + wandb_logger=logger, + ) + + for i, batch in enumerate( + dataloader( + total_frames, + frames_per_batch, + train_env, + actor, + actor_collection, + device_collection, + ) + ): + if r0 is None: + r0 = batch["reward"].sum(-1).mean().item() + pbar.update(batch.numel()) + + # extend the replay buffer with the new data + batch = batch.view(-1) + current_frames = batch.numel() + collected_frames += current_frames + episodes += batch["done"].sum() + replay_buffer.extend(batch.cpu()) + + # optimization steps + if collected_frames >= args.init_random_frames: + ( + total_losses, + actor_losses, + q_losses, + alpha_losses, + alphas, + entropies, + ) = ([], [], [], [], [], []) + for _ in range( + args.env_per_collector * args.frames_per_batch * args.utd_ratio + ): + optim_steps += 1 + # sample from replay buffer + sampled_tensordict = replay_buffer.sample(args.batch_size).clone() + + loss_td = loss_module(sampled_tensordict) + ## Not returned in explicit forward loss + #print(f'value: {loss_td["state_action_value_actor"].mean():4.4f}') + #print(f'log_prob: {loss_td["action_log_prob_actor"].mean():4.4f}') + #print(f'next.state_value: {loss_td["state_value"].mean():4.4f}') + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + loss = actor_loss + q_loss + alpha_loss + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if args.prb: + replay_buffer.update_priority(sampled_tensordict) + + total_losses.append(loss.item()) + actor_losses.append(actor_loss.item()) + q_losses.append(q_loss.item()) + alpha_losses.append(alpha_loss.item()) + alphas.append(loss_td["alpha"].item()) + entropies.append(loss_td["entropy"].item()) + + rewards.append((i, batch["reward"].mean().item())) + logger.log_scalar("train_reward", rewards[-1][1], step=collected_frames) + logger.log_scalar("optim_steps", optim_steps, step=collected_frames) + logger.log_scalar("episodes", episodes, step=collected_frames) + + if loss is not None: + logger.log_scalar( + "total_loss", np.mean(total_losses), step=collected_frames + ) + logger.log_scalar( + "actor_loss", np.mean(actor_losses), step=collected_frames + ) + logger.log_scalar("q_loss", np.mean(q_losses), step=collected_frames) + logger.log_scalar( + "alpha_loss", np.mean(alpha_losses), step=collected_frames + ) + logger.log_scalar("alpha", np.mean(alphas), step=collected_frames) + logger.log_scalar("entropy", np.mean(entropies), step=collected_frames) + if i % args.eval_interval == 0: + td_record = recorder(None) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) + ) + logger.log_scalar( + "test_reward", rewards_eval[-1][1], step=collected_frames + ) + logger.log_scalar( + "reward_sparse", td_record["rwd_sparse"].sum()/args.eval_traj, step=collected_frames + ) + logger.log_scalar( + "reward_dense", td_record["rwd_dense"].sum()/args.eval_traj, step=collected_frames + ) + success_percentage = evaluate_success( + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj, + ) + success_percentage_hist.append(success_percentage) + #solved = float(td_record["success"].any()) + logger.log_scalar( + "success_rate", success_percentage, step=collected_frames + ) + + if len(rewards_eval): + pbar.set_description( + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}, Success: {success_percentage_hist[-1]}" + ) + del batch + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/examples/sac_loss.py b/examples/sac_loss.py new file mode 100644 index 000000000..07d5ab1c4 --- /dev/null +++ b/examples/sac_loss.py @@ -0,0 +1,474 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + _explicit: bool = True + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._explicit: + # slow but explicit version + return self._forward_explicit(tensordict) + else: + return self._forward_vectorized(tensordict) + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + def _forward_vectorized(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = self._rsample(tensordict_actor_dist) + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _forward_explicit(self, tensordict: TensorDictBase) -> TensorDictBase: + loss_actor, sample_log_prob = self._loss_actor_explicit(tensordict.clone(False)) + loss_qval, td_error = self._loss_qval_explicit(tensordict.clone(False)) + tensordict.set("td_error", td_error.detach().max(0)[0]) + loss_alpha = self._loss_alpha(sample_log_prob) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + # "state_action_value_actor": state_action_value_actor.mean().detach(), + # "action_log_prob_actor": action_log_prob_actor.mean().detach(), + # "next.state_value": next_state_value.mean().detach(), + # "target_value": target_value.mean().detach(), + }, + [], + ) + return td_out + + def _rsample(self, dist, ): + # separated only for the purpose of making the sampling + # deterministic to compare methods + return dist.rsample() + + + def _sample_reparam(self, tensordict, params): + """Given a policy param batch and input data in a tensordict, writes a reparam sample and log-prob key.""" + with set_exploration_mode("random"): + if self.gSDE: + raise NotImplementedError + # vmap doesn't support sampling, so we take it out from the vmap + td_params = self.actor_network.get_dist_params(tensordict, params,) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict[sample_key] = self._rsample(tensordict_actor_dist) + tensordict["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict[sample_key] + ) + return tensordict + + def _loss_actor_explicit(self, tensordict): + tensordict_actor = tensordict.clone(False) + actor_params = self.actor_network_params + tensordict_actor = self._sample_reparam(tensordict_actor, actor_params) + action_log_prob_actor = tensordict_actor["sample_log_prob"] + + tensordict_qval = ( + tensordict_actor + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor.batch_size) + ) # for actor loss + qvalue_params = self.qvalue_network_params.detach() + tensordict_qval = vmap(self.qvalue_network)(tensordict_qval, qvalue_params,) + state_action_value_actor = tensordict_qval.get("state_action_value").squeeze(-1) + state_action_value_actor = state_action_value_actor.min(0)[0] + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = (self.alpha * action_log_prob_actor - state_action_value_actor).mean() + + return loss_actor, action_log_prob_actor + + def _loss_qval_explicit(self, tensordict): + next_tensordict = step_mdp(tensordict) + next_tensordict = self._sample_reparam(next_tensordict, self.target_actor_network_params) + next_action_log_prob_qvalue = next_tensordict["sample_log_prob"] + next_state_action_value_qvalue = vmap(self.qvalue_network, (None, 0))( + next_tensordict, + self.target_qvalue_network_params, + )["state_action_value"].squeeze(-1) + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + pred_val = vmap(self.qvalue_network, (None, 0))( + tensordict, + self.qvalue_network_params, + )["state_action_value"].squeeze(-1) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + + # 1/2 * E[Q(s,a) - (r + gamma * (Q(s,a)-alpha log pi(s, a))) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + td_error = (pred_val - target_value).pow(2) + return loss_qval, td_error + +if __name__ == "__main__": + # Tests the vectorized version of SAC-v2 against plain implementation + from torchrl.modules import ProbabilisticActor, ValueOperator + from torchrl.data import BoundedTensorSpec + from torch import nn + from tensordict.nn import TensorDictModule + from torchrl.modules.distributions import TanhNormal + + torch.manual_seed(0) + + action_spec = BoundedTensorSpec(-1, 1, shape=(3,)) + class Splitter(nn.Linear): + def forward(self, x): + loc, scale = super().forward(x).chunk(2, -1) + return loc, scale.exp() + actor_module = TensorDictModule(Splitter(6, 6), in_keys=["obs"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=TanhNormal, + default_interaction_mode="random", + return_log_prob=False, + ) + class QVal(nn.Linear): + def forward(self, s: Tensor, a: Tensor) -> Tensor: + return super().forward(torch.cat([s, a], -1)) + + qvalue = ValueOperator(QVal(9, 1), in_keys=["obs", "action"]) + _rsample_old = SACLoss._rsample + def _rsample_new(self, dist): + return torch.ones_like(_rsample_old(self, dist)) + SACLoss._rsample = _rsample_new + loss = SACLoss(actor, qvalue) + + for batch in ((), (2, 3)): + td_input = TensorDict({"obs": torch.rand(*batch, 6), "action": torch.rand(*batch, 3).clamp(-1, 1), "next": {"obs": torch.rand(*batch, 6)}, "reward": torch.rand(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch) + loss._explicit = True + loss0 = loss(td_input) + loss._explicit = False + loss1 = loss(td_input) + print("a", loss0["loss_actor"]-loss1["loss_actor"]) + print("q", loss0["loss_qvalue"]-loss1["loss_qvalue"]) diff --git a/rlhive/__init__.py b/rlhive/__init__.py index 88aff9f05..56538b856 100644 --- a/rlhive/__init__.py +++ b/rlhive/__init__.py @@ -3,9 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .envs import register_franka_envs, register_kitchen_envs +from .envs import register_franka_envs, register_kitchen_envs, register_hand_envs register_franka_envs() register_kitchen_envs() +register_hand_envs() from .rl_envs import RoboHiveEnv diff --git a/rlhive/envs.py b/rlhive/envs.py index 54d0a2822..9606ebebc 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -6,11 +6,13 @@ # Custom env reg for RoboHive usage in TorchRL # Pixel rendering will be queried by torchrl, so we don't include those keys in visual_obs_keys_wt import os +import warnings from pathlib import Path -import mj_envs.envs.env_variants.register_env_variant import mj_envs.envs.multi_task.substeps1 +from mj_envs.envs.env_variants import register_env_variant + visual_obs_keys_wt = mj_envs.envs.multi_task.substeps1.visual_obs_keys_wt @@ -70,7 +72,7 @@ def register_kitchen_envs(): "kitchen_rdoor_open-v3", "kitchen_micro_close-v3", "kitchen_micro_open-v3", - "kitchen_close-v3", + # "kitchen_close-v3", ] visual_obs_keys_wt = { @@ -80,10 +82,17 @@ def register_kitchen_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) @set_directory(CURR_DIR) @@ -106,7 +115,48 @@ def register_franka_envs(): "rgb:left_cam:224x224:2d": 1.0, } for env in env_list: - new_env_name = "visual_" + env - mj_envs.envs.env_variants.register_env_variant( - env, variants={"obs_keys_wt": visual_obs_keys_wt}, variant_id=new_env_name - ) + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={"obs_keys_wt": visual_obs_keys_wt}, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) + +@set_directory(CURR_DIR) +def register_hand_envs(): + print("RLHive:> Registering Franka Envs") + env_list = [ + "door-v1", + "hammer-v1", + "pen-v1", + "relocate-v1" + ] + + # Hand Manipulation Suite ====================================================================== + visual_obs_keys_wt = { + "hand_jnt": 1.0, + "rgb:vil_camera:224x224:2d": 1.0, + "rgb:fixed:224x224:2d": 1.0, + } + + for env in env_list: + try: + new_env_name = "visual_" + env + mj_envs.envs.env_variants.register_env_variant( + env, + variants={'obs_keys': + ['hand_jnt', + "rgb:vil_camera:224x224:2d", + "rgb:fixed:224x224:2d"] + }, + variant_id=new_env_name, + ) + except AssertionError as err: + warnings.warn( + f"Could not register {new_env_name}, the following error was raised: {err}" + ) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index c922922f4..0f741d023 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -161,7 +161,9 @@ def read_obs(self, observation): pix = pix[None] pixel_list.append(pix) elif key in self._env.obs_keys: - obsvec.append(observations[key]) # ravel helps with images + obsvec.append( + observations[key].flatten() if observations[key].ndim == 0 else observations[key] + ) # ravel helps with images if obsvec: obsvec = np.concatenate(obsvec, 0) if self.from_pixels: @@ -173,7 +175,7 @@ def read_obs(self, observation): def read_info(self, info, tensordict_out): out = {} for key, value in info.items(): - if key in ("obs_dict",): + if key in ("obs_dict", "done", "reward"): continue if isinstance(value, dict): value = make_tensordict(value, batch_size=[]) @@ -181,14 +183,6 @@ def read_info(self, info, tensordict_out): tensordict_out.update(out) return tensordict_out - def _step(self, td): - td = super()._step(td) - return td - - def _reset(self, td=None, **kwargs): - td = super()._reset(td, **kwargs) - return td - def to(self, *args, **kwargs): out = super().to(*args, **kwargs) try: diff --git a/rlhive/sim_algos/helpers/rrl_transform.py b/rlhive/sim_algos/helpers/rrl_transform.py new file mode 100644 index 000000000..b9b46f872 --- /dev/null +++ b/rlhive/sim_algos/helpers/rrl_transform.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Union + +import torch +from tensordict import TensorDict +from torch.hub import load_state_dict_from_url +from torch.nn import Identity + +from torchrl.data.tensor_specs import ( + CompositeSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs.transforms.transforms import ( + CatTensors, + Compose, + FlattenObservation, + ObservationNorm, + Resize, + ToTensorImage, + Transform, + UnsqueezeTransform, +) + +try: + from torchvision import models + + _has_tv = True +except ImportError: + _has_tv = False + + +class _RRLNet(Transform): + + inplace = False + + def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): + if not _has_tv: + raise ImportError( + "Tried to instantiate RRL without torchvision. Make sure you have " + "torchvision installed in your environment." + ) + if model_name == "resnet18": + self.model_name = "rrl_18" + self.outdim = 512 + convnet = models.resnet18(pretrained=True) + elif model_name == "resnet34": + self.model_name = "rrl_34" + self.outdim = 512 + convnet = models.resnet34(pretrained=True) + elif model_name == "resnet50": + self.model_name = "rrl_50" + self.outdim = 2048 + convnet = models.resnet50(pretrained=True) + else: + raise NotImplementedError( + f"model {model_name} is currently not supported by RRL" + ) + convnet.fc = Identity() + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.convnet = convnet + self.del_keys = del_keys + + def _call(self, tensordict): + tensordict_view = tensordict.view(-1) + super()._call(tensordict_view) + if self.del_keys: + tensordict.exclude(*self.in_keys, inplace=True) + return tensordict + + @torch.no_grad() + def _apply_transform(self, obs: torch.Tensor) -> None: + shape = None + if obs.ndimension() > 4: + shape = obs.shape[:-3] + obs = obs.flatten(0, -4) + out = self.convnet(obs) + if shape is not None: + out = out.view(*shape, *out.shape[1:]) + return out + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if not isinstance(observation_spec, CompositeSpec): + raise ValueError("_RRLNet can only infer CompositeSpec") + + keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] + device = observation_spec[keys[0]].device + dim = observation_spec[keys[0]].shape[:-3] + + observation_spec = CompositeSpec(observation_spec) + if self.del_keys: + for in_key in keys: + del observation_spec[in_key] + + for out_key in self.out_keys: + observation_spec[out_key] = UnboundedContinuousTensorSpec( + shape=torch.Size([*dim, self.outdim]), device=device + ) + + return observation_spec + + # @staticmethod + # def _load_weights(model_name, r3m_instance, dir_prefix): + # if model_name not in ("r3m_50", "r3m_34", "r3m_18"): + # raise ValueError( + # "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'" + # ) + # # url = "https://download.pytorch.org/models/rl/r3m/" + model_name + # url = "https://pytorch.s3.amazonaws.com/models/rl/r3m/" + model_name + ".pt" + # d = load_state_dict_from_url( + # url, + # progress=True, + # map_location=next(r3m_instance.parameters()).device, + # model_dir=dir_prefix, + # ) + # td = TensorDict(d["r3m"], []).unflatten_keys(".") + # td_flatten = td["module"]["convnet"].flatten_keys(".") + # state_dict = td_flatten.to_dict() + # r3m_instance.convnet.load_state_dict(state_dict) + + # def load_weights(self, dir_prefix=None): + # self._load_weights(self.model_name, self, dir_prefix) + + +def _init_first(fun): + def new_fun(self, *args, **kwargs): + if not self.initialized: + self._init() + return fun(self, *args, **kwargs) + + return new_fun + + +class RRLTransform(Compose): + """RRL Transform class. + + RRL provides pre-trained ResNet weights aimed at facilitating visual + embedding for robotic tasks. The models are trained using Ego4d. + + See the paper: + Shah, Rutav, and Vikash Kumar. "RRl: Resnet as representation for reinforcement learning." + arXiv preprint arXiv:2107.03380 (2021). + The RRLTransform is created in a lazy manner: the object will be initialized + only when an attribute (a spec or the forward method) will be queried. + The reason for this is that the :obj:`_init()` method requires some attributes of + the parent environment (if any) to be accessed: by making the class lazy we + can ensure that the following code snippet works as expected: + + Examples: + >>> transform = RRLTransform("resnet50", in_keys=["pixels"]) + >>> env.append_transform(transform) + >>> # the forward method will first call _init which will look at env.observation_spec + >>> env.reset() + + Args: + model_name (str): one of resnet50, resnet34 or resnet18 + in_keys (list of str): list of input keys. If left empty, the + "pixels" key is assumed. + out_keys (list of str, optional): list of output keys. If left empty, + "rrl_vec" is assumed. + size (int, optional): Size of the image to feed to resnet. + Defaults to 244. + stack_images (bool, optional): if False, the images given in the :obj:`in_keys` + argument will be treaded separetely and each will be given a single, + separated entry in the output tensordict. Defaults to :obj:`True`. + download (bool, optional): if True, the weights will be downloaded using + the torch.hub download API (i.e. weights will be cached for future use). + Defaults to False. + download_path (str, optional): path where to download the models. + Default is None (cache path determined by torch.hub utils). + tensor_pixels_keys (list of str, optional): Optionally, one can keep the + original images (as collected from the env) in the output tensordict. + If no value is provided, this won't be collected. + """ + + @classmethod + def __new__(cls, *args, **kwargs): + cls.initialized = False + cls._device = None + cls._dtype = None + return super().__new__(cls) + + def __init__( + self, + model_name: str, + in_keys: List[str], + out_keys: List[str] = None, + size: int = 244, + stack_images: bool = True, + download: bool = False, + download_path: Optional[str] = None, + tensor_pixels_keys: List[str] = None, + ): + super().__init__() + self.in_keys = in_keys if in_keys is not None else ["pixels"] + self.download = download + self.download_path = download_path + self.model_name = model_name + self.out_keys = out_keys + self.size = size + self.stack_images = stack_images + self.tensor_pixels_keys = tensor_pixels_keys + self._init() + + def _init(self): + """Initializer for RRL.""" + self.initialized = True + in_keys = self.in_keys + model_name = self.model_name + out_keys = self.out_keys + size = self.size + stack_images = self.stack_images + tensor_pixels_keys = self.tensor_pixels_keys + + # ToTensor + transforms = [] + if tensor_pixels_keys: + for i in range(len(in_keys)): + transforms.append( + CatTensors( + in_keys=[in_keys[i]], + out_key=tensor_pixels_keys[i], + del_keys=False, + ) + ) + + totensor = ToTensorImage( + unsqueeze=False, + in_keys=in_keys, + ) + transforms.append(totensor) + + # Normalize + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + normalize = ObservationNorm( + in_keys=in_keys, + loc=torch.tensor(mean).view(3, 1, 1), + scale=torch.tensor(std).view(3, 1, 1), + standard_normal=True, + ) + transforms.append(normalize) + + # Resize: note that resize is a no-op if the tensor has the desired size already + resize = Resize(size, size, in_keys=in_keys) + transforms.append(resize) + + # RRL + if out_keys is None: + if stack_images: + out_keys = ["rrl_vec"] + else: + out_keys = [f"rrl_vec_{i}" for i in range(len(in_keys))] + self.out_keys = out_keys + elif stack_images and len(out_keys) != 1: + raise ValueError( + f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}" + ) + elif not stack_images and len(out_keys) != len(in_keys): + raise ValueError( + "out_key must be of length equal to in_keys if stack_images is False." + ) + + if stack_images and len(in_keys) > 1: + + unsqueeze = UnsqueezeTransform( + in_keys=in_keys, + out_keys=in_keys, + unsqueeze_dim=-4, + ) + transforms.append(unsqueeze) + + cattensors = CatTensors( + in_keys, + out_keys[0], + dim=-4, + ) + network = _RRLNet( + in_keys=out_keys, + out_keys=out_keys, + model_name=model_name, + del_keys=False, + ) + flatten = FlattenObservation(-2, -1, out_keys) + transforms = [*transforms, cattensors, network, flatten] + + else: + network = _RRLNet( + in_keys=in_keys, + out_keys=out_keys, + model_name=model_name, + del_keys=True, + ) + transforms = [*transforms, network] + + for transform in transforms: + self.append(transform) + # if self.download: + # self[-1].load_weights(dir_prefix=self.download_path) + + if self._device is not None: + self.to(self._device) + if self._dtype is not None: + self.to(self._dtype) + + def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): + if isinstance(dest, torch.dtype): + self._dtype = dest + else: + self._device = dest + return super().to(dest) + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 000000000..f2a8a9153 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,71 @@ +## Installation +``` +git clone --branch=sac_dev https://github.com/facebookresearch/rlhive.git +conda create -n rlhive -y python=3.8 +conda activate rlhive +bash rlhive/scripts/installation.sh +cd rlhive +pip install -e . +``` + +## Testing installation +``` +python -c "import mj_envs" +MUJOCO_GL=egl sim_backend=MUJOCO python -c """ +from rlhive.rl_envs import RoboHiveEnv +env_name = 'visual_franka_slide_random-v3' +base_env = RoboHiveEnv(env_name,) +print(base_env.rollout(3)) + +# check that the env specs are ok +from torchrl.envs.utils import check_env_specs +check_env_specs(base_env) +""" +``` + +## Launching experiments +[NOTE] Set ulimit for your shell (default 1024): `ulimit -n 4096` +Set your slurm configs especially `partition` and `hydra.run.dir` +Slurm files are located at `sac_mujoco/config/hydra/launcher/slurm.yaml` and `sac_mujoco/config/hydra/output/slurm.yaml` +``` +cd scripts/sac_mujoco +sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m hydra/launcher=slurm hydra/output=slurm +``` + +To run a small experiment for testing, run the following command: +``` +cd scripts/sac_mujoco +sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm +``` + +## Parameter Sweep +1. R3M and RRL experiments: `visual_transform=r3m,rrl` +2. Multiple seeds: `seed=42,43,44` +3. List of environments: + ``` +task=visual_franka_slide_random-v3,\ + visual_franka_slide_close-v3,\ + visual_franka_slide_open-v3,\ + visual_franka_micro_random-v3,\ + visual_franka_micro_close-v3,\ + visual_franka_micro_open-v3,\ + visual_kitchen_knob1_off-v3,\ + visual_kitchen_knob1_on-v3,\ + visual_kitchen_knob2_off-v3,\ + visual_kitchen_knob2_on-v3,\ + visual_kitchen_knob3_off-v3,\ + visual_kitchen_knob3_on-v3,\ + visual_kitchen_knob4_off-v3,\ + visual_kitchen_knob4_on-v3,\ + visual_kitchen_light_off-v3,\ + visual_kitchen_light_on-v3,\ + visual_kitchen_sdoor_close-v3,\ + visual_kitchen_sdoor_open-v3,\ + visual_kitchen_ldoor_close-v3,\ + visual_kitchen_ldoor_open-v3,\ + visual_kitchen_rdoor_close-v3,\ + visual_kitchen_rdoor_open-v3,\ + visual_kitchen_micro_close-v3,\ + visual_kitchen_micro_open-v3,\ + visual_kitchen_close-v3 + ``` diff --git a/scripts/installation.sh b/scripts/installation.sh new file mode 100644 index 000000000..cd232abcc --- /dev/null +++ b/scripts/installation.sh @@ -0,0 +1,14 @@ +export MJENV_LIB_PATH="mj_envs" + +python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 + +here=$(pwd) +git clone -c submodule.mj_envs/sims/neuromuscular_sim.update=none --branch add_all_xmls --recursive https://github.com/vmoens/mj_envs.git $MJENV_LIB_PATH +cd $MJENV_LIB_PATH +python3 -mpip install . # one can also install it locally with the -e flag +cd $here + +python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) +python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) +pip install wandb moviepy +pip install hydra-submitit-launcher --upgrade diff --git a/scripts/redq/config.yaml b/scripts/redq/config.yaml new file mode 100644 index 000000000..37e0ce8d6 --- /dev/null +++ b/scripts/redq/config.yaml @@ -0,0 +1,40 @@ +# Environment +env_name: visual_franka_slide_random-v3 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 0 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +reward_scaling: 5.0 +frames_per_batch: 1024 +optim_steps_per_batch: 1024 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +ou_exploration: 1 +multi_step: 1 +init_random_frames: 25000 +activation: elu +gSDE: 0 +# Internal assumption by make_redq, hard codes key values if from_pixels is True +from_pixels: 0 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +#collector_devices: [cpu,cpu] +collector_devices: [cuda:0] +env_per_collector: 1 +num_workers: 2 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 10000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +norm_stats: 1 diff --git a/scripts/redq/redq.py b/scripts/redq/redq.py new file mode 100644 index 000000000..df9fc8783 --- /dev/null +++ b/scripts/redq/redq.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import uuid +from datetime import datetime + +import hydra +import torch.cuda +from hydra.core.config_store import ConfigStore +from rlhive.rl_envs import RoboHiveEnv +from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) + +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, + R3MTransform, + SelectTransform, + TransformedEnv, +) +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import OrnsteinUhlenbeckProcessWrapper +from torchrl.record import VideoRecorder +from torchrl.trainers.helpers.collectors import ( + make_collector_offpolicy, + OffPolicyCollectorConfig, +) +from torchrl.trainers.helpers.envs import ( + correct_for_frame_skip, + EnvConfig, + initialize_observation_norm_transforms, + parallel_env_constructor, + retrieve_observation_norms_state_dict, + # transformed_env_constructor, +) +from torchrl.trainers.helpers.logger import LoggerConfig +from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss +from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig +from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig +from torchrl.trainers.loggers.utils import generate_exp_name, get_logger + + +def make_env( + task, + reward_scaling, + device, + obs_norm_state_dict=None, + action_dim_gsde=None, + state_dim_gsde=None, +): + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env(env=base_env, reward_scaling=reward_scaling) + + if not obs_norm_state_dict is None: + obs_norm = ObservationNorm( + **obs_norm_state_dict, in_keys=["observation_vector"] + ) + env.append_transform(obs_norm) + + if not action_dim_gsde is None: + env.append_transform( + gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde) + ) + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + stats=None, +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=["r3m_vec"]), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + selected_keys = ["r3m_vec", "observation"] + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) + return env + + +config_fields = [ + (config_field.name, config_field.type, config_field) + for config_cls in ( + TrainerConfig, + OffPolicyCollectorConfig, + EnvConfig, + LossConfig, + REDQModelConfig, + LoggerConfig, + ReplayArgsConfig, + ) + for config_field in dataclasses.fields(config_cls) +] + +Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) +cs = ConfigStore.instance() +cs.store(name="config", node=Config) + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = generate_exp_name("REDQ", cfg.exp_name) + logger = get_logger( + logger_type=cfg.logger, logger_name="redq_logging", experiment_name=exp_name + ) + video_tag = exp_name if cfg.record_video else "" + + key, init_env_steps, stats = None, None, None + if not cfg.vecnorm and cfg.norm_stats: + if not hasattr(cfg, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + key = ("next", "observation_vector") + init_env_steps = cfg.init_env_steps + stats = {"loc": None, "scale": None} + elif cfg.from_pixels: + stats = {"loc": 0.5, "scale": 0.5} + + proof_env = make_env( + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + ) + initialize_observation_norm_transforms( + proof_environment=proof_env, num_iter=init_env_steps, key=key + ) + _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0] + + print(proof_env) + model = make_redq_model( + proof_env, + cfg=cfg, + device=device, + in_keys=["observation_vector"], + ) + loss_module, target_net_updater = make_redq_loss(model, cfg) + + actor_model_explore = model[0] + if cfg.ou_exploration: + if cfg.gSDE: + raise RuntimeError("gSDE and ou_exploration are incompatible") + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore, + annealing_num_steps=cfg.annealing_frames, + sigma=cfg.ou_sigma, + theta=cfg.ou_theta, + ).to(device) + if device == torch.device("cpu"): + # mostly for debugging + actor_model_explore.share_memory() + + if cfg.gSDE: + with torch.no_grad(), set_exploration_mode("random"): + # get dimensions to build the parallel env + proof_td = actor_model_explore(proof_env.reset().to(device)) + action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:] + del proof_td + else: + action_dim_gsde, state_dim_gsde = None, None + + proof_env.close() + create_env_fn = make_env( ## Pass EnvBase instead of the create_env_fn + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model_explore, + cfg=cfg, + # make_env_kwargs=[ + # {"device": device} if device >= 0 else {} + # for device in args.env_rendering_devices + # ], + ) + + replay_buffer = make_replay_buffer(device, cfg) + + # recorder = transformed_env_constructor( + # cfg, + # video_tag=video_tag, + # norm_obs_only=True, + # obs_norm_state_dict=obs_norm_state_dict, + # logger=logger, + # use_env_creator=False, + # )() + recorder = make_env( + task=cfg.env_name, + reward_scaling=cfg.reward_scaling, + device=device, + obs_norm_state_dict=obs_norm_state_dict, + action_dim_gsde=action_dim_gsde, + state_dim_gsde=state_dim_gsde, + ) + + # remove video recorder from recorder to have matching state_dict keys + if cfg.record_video: + recorder_rm = TransformedEnv(recorder.base_env) + for transform in recorder.transform: + if not isinstance(transform, VideoRecorder): + recorder_rm.append_transform(transform.clone()) + else: + recorder_rm = recorder + + if isinstance(create_env_fn, ParallelEnv): + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + create_env_fn.close() + elif isinstance(create_env_fn, EnvCreator): + recorder_rm.load_state_dict(create_env_fn().state_dict()) + else: + recorder_rm.load_state_dict(create_env_fn.state_dict()) + + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + t.loc.fill_(0.0) + + trainer = make_trainer( + collector, + loss_module, + recorder, + target_net_updater, + actor_model_explore, + replay_buffer, + logger, + cfg, + ) + + final_seed = collector.set_seed(cfg.seed) + print(f"init seed: {cfg.seed}, final seed: {final_seed}") + + trainer.train() + return (logger.log_dir, trainer._log_dict) + + +if __name__ == "__main__": + main() diff --git a/scripts/sac_mujoco/config/group/group1.yaml b/scripts/sac_mujoco/config/group/group1.yaml new file mode 100644 index 000000000..886d4fc95 --- /dev/null +++ b/scripts/sac_mujoco/config/group/group1.yaml @@ -0,0 +1,4 @@ +# @package _group_ + grp1a: 11 + grp1b: aaa + gra1c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/group/group2.yaml b/scripts/sac_mujoco/config/group/group2.yaml new file mode 100644 index 000000000..8b95ac771 --- /dev/null +++ b/scripts/sac_mujoco/config/group/group2.yaml @@ -0,0 +1,4 @@ +# @package _group_ + grp2a: 22 + grp2b: bbb + gra2c: $group_seed{group.seed}_exp_seed{exp.seed} diff --git a/scripts/sac_mujoco/config/hydra/launcher/local.yaml b/scripts/sac_mujoco/config/hydra/launcher/local.yaml new file mode 100644 index 000000000..30a563c70 --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/launcher/local.yaml @@ -0,0 +1,11 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: 12 + gpus_per_node: 1 + tasks_per_node: 1 + timeout_min: 4320 + mem_gb: 64 + name: ${hydra.job.name} + _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher + submitit_folder: ${hydra.sweep.dir}/.submitit/%j diff --git a/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml b/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml new file mode 100644 index 000000000..e9b58047c --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/launcher/slurm.yaml @@ -0,0 +1,14 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: 16 + gpus_per_node: 1 + tasks_per_node: 1 + timeout_min: 4320 + mem_gb: 64 + name: ${hydra.job.name} + # partition: devlab + # array_parallelism: 256 + _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + submitit_folder: ${hydra.sweep.dir}/.submitit/%j + partition: dgx diff --git a/scripts/sac_mujoco/config/hydra/output/local.yaml b/scripts/sac_mujoco/config/hydra/output/local.yaml new file mode 100644 index 000000000..d3c95076e --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/output/local.yaml @@ -0,0 +1,8 @@ +# @package _global_ +hydra: + run: + dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} + sweep: + dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} diff --git a/scripts/sac_mujoco/config/hydra/output/slurm.yaml b/scripts/sac_mujoco/config/hydra/output/slurm.yaml new file mode 100644 index 000000000..8b7afb41a --- /dev/null +++ b/scripts/sac_mujoco/config/hydra/output/slurm.yaml @@ -0,0 +1,8 @@ +# @package _global_ +hydra: + run: + dir: /scratch/cluster/rutavms/robohive/outputs_sac_robohive/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} + sweep: + dir: /scratch/cluster/rutavms/robohive/outputs_sac_robohive/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num}_${hydra.job.override_dirname} diff --git a/scripts/sac_mujoco/config/sac.yaml b/scripts/sac_mujoco/config/sac.yaml new file mode 100644 index 000000000..39a201706 --- /dev/null +++ b/scripts/sac_mujoco/config/sac.yaml @@ -0,0 +1,42 @@ +default: + - override hydra/output: local + - override hydra/launcher: local +from_pixels: True + +# Logger +exp_name: ${task}_sac_${visual_transform} +visual_transform: r3m +record_interval: 1 +device: "cuda:0" + +# Environment +task: visual_franka_slide_random-v3 +frame_skip: 1 +reward_scaling: 5.0 +init_env_steps: 1000 +seed: 42 +eval_traj: 25 + +# Collector +env_per_collector: 1 +max_frames_per_traj: -1 +total_frames: 1000000 +init_random_frames: 25000 +frames_per_batch: 1000 + +# Replay Buffer +prb: 0 +buffer_size: 100000 +buffer_scratch_dir: /tmp/ + +# Optimization +gamma: 0.99 +batch_size: 256 +lr: 3.0e-4 +weight_decay: 0.0 +target_update_polyak: 0.995 +utd_ratio: 1 + +hydra: + job: + name: sac_${task}_${seed} diff --git a/scripts/sac_mujoco/sac.py b/scripts/sac_mujoco/sac.py new file mode 100644 index 000000000..aa14fb800 --- /dev/null +++ b/scripts/sac_mujoco/sac.py @@ -0,0 +1,449 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import gc +import os +from typing import Optional + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +import wandb +import yaml +from omegaconf import DictConfig, OmegaConf, open_dict +from rlhive.rl_envs import RoboHiveEnv +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform + +#from torchrl.objectives import SACLoss +from sac_loss import SACLoss + +from torch import nn, optim +from torchrl.collectors import MultiaSyncDataCollector +from torchrl.collectors.collectors import RandomPolicy +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer + +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + ParallelEnv, +) +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) +from torchrl.envs import ParallelEnv, R3MTransform, SelectTransform, TransformedEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import MLP, NormalParamWrapper, ProbabilisticActor, SafeModule +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + +from torchrl.objectives import SoftUpdate +from torchrl.trainers import Recorder + +os.environ["WANDB_MODE"] = "offline" ## offline sync. TODO: Remove this behavior + +def make_env(task, visual_transform, reward_scaling, device, from_pixels): + assert visual_transform in ("rrl", "r3m") + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform, from_pixels=from_pixels + ) + print(env) + + return env + + +def make_transformed_env( + env, + from_pixels, + reward_scaling=5.0, + visual_transform="r3m", + stats=None, +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + if from_pixels: + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + if visual_transform == 'rrl': + vec_keys = ["rrl_vec"] + selected_keys = ["observation", "rrl_vec"] + env.append_transform(Compose(RRLTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == 'r3m': + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform(Compose(R3MTransform('resnet50', in_keys=["pixels"], download=True), FlattenObservation(-2, -1, in_keys=vec_keys))) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError + else: + env = TransformedEnv(env, SelectTransform("solved", "observation")) + selected_keys = ["observation"] + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) + return env + + +def make_recorder( + task: str, + frame_skip: int, + record_interval: int, + actor_model_explore: object, + eval_traj: int, + env_configs: dict, +): + test_env = make_env(task=task, **env_configs) + recorder_obj = Recorder( + record_frames=eval_traj * test_env.horizon, + frame_skip=frame_skip, + policy_exploration=actor_model_explore, + recorder=test_env, + exploration_mode="mean", + record_interval=record_interval, + log_keys=["reward", "solved"], + out_keys={"reward": "r_evaluation", "solved": "success"}, + ) + return recorder_obj + + +def make_replay_buffer( + prb: bool, + buffer_size: int, + buffer_scratch_dir: str, + device: torch.device, + make_replay_buffer: int = 3, +): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=make_replay_buffer, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=buffer_scratch_dir, + device=device, + ), + ) + return replay_buffer + + +def evaluate_success(env_success_fn, td_record: dict, eval_traj: int): + td_record["success"] = td_record["success"].reshape((eval_traj, -1)) + paths = [] + for traj, solved_traj in zip(range(eval_traj), td_record["success"]): + path = {"env_infos": {"solved": solved_traj.data.cpu().numpy()}} + paths.append(path) + success_percentage = env_success_fn(paths) + return success_percentage + + +@hydra.main(config_name="sac.yaml", config_path="config") +def main(args: DictConfig): + device = ( + torch.device("cuda:0") + if torch.cuda.is_available() + and torch.cuda.device_count() > 0 + and args.device == "cuda:0" + else torch.device("cpu") + ) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Create Environment + env_configs = { + "reward_scaling": args.reward_scaling, + "visual_transform": args.visual_transform, + "device": args.device, + "from_pixels": args.from_pixels, + } + train_env = make_env(task=args.task, **env_configs) + + # Create Agent + + # Define Actor Network + in_keys = ["observation_vector"] + action_spec = train_env.action_spec + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": 2 * action_spec.shape[-1], + "activation_class": nn.ReLU, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": False, + } + actor_net = NormalParamWrapper( + actor_net, + scale_mapping=f"biased_softplus_{1.0}", + scale_lb=0.1, + ) + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ReLU, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # add forward pass for initialization with proof env + proof_env = make_env(task=args.task, **env_configs) + # init nets + with torch.no_grad(), set_exploration_mode("random"): + td = proof_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + proof_env.close() + + actor_model_explore = model[0] + + # Create SAC loss + loss_module = SACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + gamma=args.gamma, + loss_function="smooth_l1", + ) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, args.target_update_polyak) + + # Make Off-Policy Collector + + collector = MultiaSyncDataCollector( + create_env_fn=[train_env], + policy=actor_model_explore, + total_frames=args.total_frames, + max_frames_per_traj=args.frames_per_batch, + frames_per_batch=args.env_per_collector * args.frames_per_batch, + init_random_frames=args.init_random_frames, + reset_at_each_iter=False, + postproc=None, + split_trajs=True, + devices=[device], # device for execution + passing_devices=[device], # device where data will be stored and passed + seed=None, + pin_memory=False, + update_at_each_batch=False, + exploration_mode="random", + ) + collector.set_seed(args.seed) + + # Make Replay Buffer + replay_buffer = make_replay_buffer( + prb=args.prb, + buffer_size=args.buffer_size, + buffer_scratch_dir=args.buffer_scratch_dir, + device=device, + ) + + # Trajectory recorder for evaluation + recorder = make_recorder( + task=args.task, + frame_skip=args.frame_skip, + record_interval=args.record_interval, + actor_model_explore=actor_model_explore, + eval_traj=args.eval_traj, + env_configs=env_configs, + ) + + # Optimizers + params = list(loss_module.parameters()) + list([loss_module.log_alpha]) + optimizer_actor = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + + rewards = [] + rewards_eval = [] + + # Main loop + target_net_updater.init_() + + collected_frames = 0 + episodes = 0 + pbar = tqdm.tqdm(total=args.total_frames) + r0 = None + loss = None + + with wandb.init(project="SAC_TorchRL", name=args.exp_name, config=args): + for i, tensordict in enumerate(collector): + + # update weights of the inference policy + collector.update_policy_weights_() + + if r0 is None: + r0 = tensordict["reward"].sum(-1).mean().item() + pbar.update(tensordict.numel()) + + # extend the replay buffer with the new data + if "mask" in tensordict.keys(): + # if multi-step, a mask is present to help filter padded values + current_frames = tensordict["mask"].sum() + tensordict = tensordict[tensordict.get("mask").squeeze(-1)] + else: + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + collected_frames += current_frames + episodes += args.env_per_collector + replay_buffer.extend(tensordict.cpu()) + + # optimization steps + if collected_frames >= args.init_random_frames: + ( + total_losses, + actor_losses, + q_losses, + alpha_losses, + alphas, + entropies, + ) = ([], [], [], [], [], []) + for _ in range( + args.env_per_collector * args.frames_per_batch * args.utd_ratio + ): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample(args.batch_size).clone() + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + alpha_loss = loss_td["loss_alpha"] + + loss = actor_loss + q_loss + alpha_loss + optimizer_actor.zero_grad() + loss.backward() + optimizer_actor.step() + + # update qnet_target params + target_net_updater.step() + + # update priority + if args.prb: + replay_buffer.update_priority(sampled_tensordict) + + total_losses.append(loss.item()) + actor_losses.append(actor_loss.item()) + q_losses.append(q_loss.item()) + alpha_losses.append(alpha_loss.item()) + alphas.append(loss_td["alpha"].item()) + entropies.append(loss_td["entropy"].item()) + + rewards.append( + (i, tensordict["reward"].sum().item() / args.env_per_collector) + ) + wandb.log( + { + "train_reward": rewards[-1][1], + "collected_frames": collected_frames, + "episodes": episodes, + } + ) + if loss is not None: + wandb.log( + { + "total_loss": np.mean(total_losses), + "actor_loss": np.mean(actor_losses), + "q_loss": np.mean(q_losses), + "alpha_loss": np.mean(alpha_losses), + "alpha": np.mean(alphas), + "entropy": np.mean(entropies), + } + ) + td_record = recorder(None) + success_percentage = evaluate_success( + env_success_fn=train_env.evaluate_success, + td_record=td_record, + eval_traj=args.eval_traj, + ) + if td_record is not None: + rewards_eval.append( + ( + i, + td_record["total_r_evaluation"] + / 1, # divide by number of eval worker + ) + ) + wandb.log({"test_reward": rewards_eval[-1][1]}) + wandb.log({"success": success_percentage}) + if len(rewards_eval): + pbar.set_description( + f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f}), test reward: {rewards_eval[-1][1]: 4.4f}" + ) + del tensordict + gc.collect() + + collector.shutdown() + + +if __name__ == "__main__": + main() diff --git a/scripts/sac_mujoco/sac_loss.py b/scripts/sac_mujoco/sac_loss.py new file mode 100644 index 000000000..07d5ab1c4 --- /dev/null +++ b/scripts/sac_mujoco/sac_loss.py @@ -0,0 +1,474 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Union + +import numpy as np +import torch + +from tensordict.nn import TensorDictSequential +from tensordict.tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from torchrl.envs.utils import set_exploration_mode, step_mdp +from torchrl.modules import SafeModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + distance_loss, + next_state_value as get_next_state_value, +) + +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + + +class SACLoss(LossModule): + """SAC Loss module. + Args: + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is 0.1. + max_alpha (float, optional): max value of alpha. + Default is 10.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used + for data collection. Default is :obj:`False`. + gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. + Default is False + """ + + delay_actor: bool = False + _explicit: bool = True + + def __init__( + self, + actor_network: SafeModule, + qvalue_network: SafeModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + delay_qvalue: bool = True, + gSDE: bool = False, + ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) + + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], + ) + + # let's make sure that actor_network has `return_log_prob` set to True + self.actor_network.return_log_prob = True + + self.delay_qvalue = delay_qvalue + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=list(actor_network.parameters()), + ) + self.num_qvalue_nets = num_qvalue_nets + self.register_buffer("gamma", torch.tensor(gamma)) + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + if actor_network.spec["action"] is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + self.gSDE = gSDE + + @property + def alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._explicit: + # slow but explicit version + return self._forward_explicit(tensordict) + else: + return self._forward_vectorized(tensordict) + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + def _forward_vectorized(self, tensordict: TensorDictBase) -> TensorDictBase: + obs_keys = self.actor_network.in_keys + tensordict_select = tensordict.select( + "reward", "done", "next", *obs_keys, "action" + ) + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_mdp(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + tensordict_actor = tensordict_actor.contiguous() + + with set_exploration_mode("random"): + if self.gSDE: + tensordict_actor.set( + "_eps_gSDE", + torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), + ) + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( + tensordict_actor, + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = self._rsample(tensordict_actor_dist) + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] + ) + + # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[1].batch_size) + ) # for next value estimation + tensordict_qval = torch.cat( + [ + _actor_loss_td, + _next_val_td, + _qval_td, + ], + 0, + ) + + # cat params + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [ + q_params_detach, + self.target_qvalue_network_params, + self.qvalue_network_params, + ], + 0, + ) + tensordict_qval = vmap(self.qvalue_network)( + tensordict_qval, + qvalue_params, + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.num_qvalue_nets, self.num_qvalue_nets], + dim=0, + ) + sample_log_prob = tensordict_actor.get("sample_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = sample_log_prob.unbind(0) + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = -( + state_action_value_actor.min(0)[0] - self.alpha * action_log_prob_actor + ).mean() + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = (pred_val - target_value).pow(2) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(sample_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + "state_action_value_actor": state_action_value_actor.mean().detach(), + "action_log_prob_actor": action_log_prob_actor.mean().detach(), + "next.state_value": next_state_value.mean().detach(), + "target_value": target_value.mean().detach(), + }, + [], + ) + + return td_out + + def _forward_explicit(self, tensordict: TensorDictBase) -> TensorDictBase: + loss_actor, sample_log_prob = self._loss_actor_explicit(tensordict.clone(False)) + loss_qval, td_error = self._loss_qval_explicit(tensordict.clone(False)) + tensordict.set("td_error", td_error.detach().max(0)[0]) + loss_alpha = self._loss_alpha(sample_log_prob) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha.detach(), + "entropy": -sample_log_prob.mean().detach(), + # "state_action_value_actor": state_action_value_actor.mean().detach(), + # "action_log_prob_actor": action_log_prob_actor.mean().detach(), + # "next.state_value": next_state_value.mean().detach(), + # "target_value": target_value.mean().detach(), + }, + [], + ) + return td_out + + def _rsample(self, dist, ): + # separated only for the purpose of making the sampling + # deterministic to compare methods + return dist.rsample() + + + def _sample_reparam(self, tensordict, params): + """Given a policy param batch and input data in a tensordict, writes a reparam sample and log-prob key.""" + with set_exploration_mode("random"): + if self.gSDE: + raise NotImplementedError + # vmap doesn't support sampling, so we take it out from the vmap + td_params = self.actor_network.get_dist_params(tensordict, params,) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.out_keys[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict[sample_key] = self._rsample(tensordict_actor_dist) + tensordict["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict[sample_key] + ) + return tensordict + + def _loss_actor_explicit(self, tensordict): + tensordict_actor = tensordict.clone(False) + actor_params = self.actor_network_params + tensordict_actor = self._sample_reparam(tensordict_actor, actor_params) + action_log_prob_actor = tensordict_actor["sample_log_prob"] + + tensordict_qval = ( + tensordict_actor + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor.batch_size) + ) # for actor loss + qvalue_params = self.qvalue_network_params.detach() + tensordict_qval = vmap(self.qvalue_network)(tensordict_qval, qvalue_params,) + state_action_value_actor = tensordict_qval.get("state_action_value").squeeze(-1) + state_action_value_actor = state_action_value_actor.min(0)[0] + + # E[alpha * log_pi(a) - Q(s, a)] where a is reparameterized + loss_actor = (self.alpha * action_log_prob_actor - state_action_value_actor).mean() + + return loss_actor, action_log_prob_actor + + def _loss_qval_explicit(self, tensordict): + next_tensordict = step_mdp(tensordict) + next_tensordict = self._sample_reparam(next_tensordict, self.target_actor_network_params) + next_action_log_prob_qvalue = next_tensordict["sample_log_prob"] + next_state_action_value_qvalue = vmap(self.qvalue_network, (None, 0))( + next_tensordict, + self.target_qvalue_network_params, + )["state_action_value"].squeeze(-1) + + next_state_value = ( + next_state_action_value_qvalue.min(0)[0] + - self.alpha * next_action_log_prob_qvalue + ) + + pred_val = vmap(self.qvalue_network, (None, 0))( + tensordict, + self.qvalue_network_params, + )["state_action_value"].squeeze(-1) + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + + # 1/2 * E[Q(s,a) - (r + gamma * (Q(s,a)-alpha log pi(s, a))) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum() + * 0.5 + ) + td_error = (pred_val - target_value).pow(2) + return loss_qval, td_error + +if __name__ == "__main__": + # Tests the vectorized version of SAC-v2 against plain implementation + from torchrl.modules import ProbabilisticActor, ValueOperator + from torchrl.data import BoundedTensorSpec + from torch import nn + from tensordict.nn import TensorDictModule + from torchrl.modules.distributions import TanhNormal + + torch.manual_seed(0) + + action_spec = BoundedTensorSpec(-1, 1, shape=(3,)) + class Splitter(nn.Linear): + def forward(self, x): + loc, scale = super().forward(x).chunk(2, -1) + return loc, scale.exp() + actor_module = TensorDictModule(Splitter(6, 6), in_keys=["obs"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=TanhNormal, + default_interaction_mode="random", + return_log_prob=False, + ) + class QVal(nn.Linear): + def forward(self, s: Tensor, a: Tensor) -> Tensor: + return super().forward(torch.cat([s, a], -1)) + + qvalue = ValueOperator(QVal(9, 1), in_keys=["obs", "action"]) + _rsample_old = SACLoss._rsample + def _rsample_new(self, dist): + return torch.ones_like(_rsample_old(self, dist)) + SACLoss._rsample = _rsample_new + loss = SACLoss(actor, qvalue) + + for batch in ((), (2, 3)): + td_input = TensorDict({"obs": torch.rand(*batch, 6), "action": torch.rand(*batch, 3).clamp(-1, 1), "next": {"obs": torch.rand(*batch, 6)}, "reward": torch.rand(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch) + loss._explicit = True + loss0 = loss(td_input) + loss._explicit = False + loss1 = loss(td_input) + print("a", loss0["loss_actor"]-loss1["loss_actor"]) + print("q", loss0["loss_qvalue"]-loss1["loss_qvalue"]) diff --git a/scripts/sac_mujoco/test.py b/scripts/sac_mujoco/test.py new file mode 100644 index 000000000..61e120811 --- /dev/null +++ b/scripts/sac_mujoco/test.py @@ -0,0 +1,89 @@ +import torch +from rlhive.rl_envs import RoboHiveEnv +from rlhive.sim_algos.helpers.rrl_transform import RRLTransform +from torchrl.envs.transforms import ( + Compose, + FlattenObservation, + RewardScaling, + TransformedEnv, +) +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvCreator, + ObservationNorm, + R3MTransform, + SelectTransform, + TransformedEnv, +) +from torchrl.envs.utils import set_exploration_mode + + +def make_env(task, visual_transform, reward_scaling, device): + assert visual_transform in ("rrl", "r3m") + base_env = RoboHiveEnv(task, device=device) + env = make_transformed_env( + env=base_env, reward_scaling=reward_scaling, visual_transform=visual_transform + ) + print(env) + # exit() + + return env + + +def make_transformed_env( + env, + reward_scaling=5.0, + visual_transform="r3m", + stats=None, +): + """ + Apply transforms to the env (such as reward scaling and state normalization) + """ + env = TransformedEnv(env, SelectTransform("solved", "pixels", "observation")) + if visual_transform == "rrl": + vec_keys = ["rrl_vec"] + selected_keys = ["observation", "rrl_vec"] + env.append_transform( + Compose( + RRLTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + elif visual_transform == "r3m": + vec_keys = ["r3m_vec"] + selected_keys = ["observation", "r3m_vec"] + env.append_transform( + Compose( + R3MTransform("resnet50", in_keys=["pixels"], download=True), + FlattenObservation(-2, -1, in_keys=vec_keys), + ) + ) # Necessary to Compose R3MTransform with FlattenObservation; Track bug: https://github.com/pytorch/rl/issues/802 + else: + raise NotImplementedError + env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) + out_key = "observation_vector" + env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) + + # we normalize the states + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + env.append_transform( + ObservationNorm(**_stats, in_keys=[out_key], standard_normal=True) + ) + env.append_transform(DoubleToFloat(in_keys=[out_key], in_keys_inv=[])) + return env + + +env = make_env( + task="visual_franka_slide_random-v3", + reward_scaling=5.0, + device=torch.device("cuda:0"), + visual_transform="rrl", +) +with torch.no_grad(), set_exploration_mode("random"): + td = env.reset() + td = env.rand_step() + print(td) diff --git a/setup.py b/setup.py index 8186a081c..86b104881 100644 --- a/setup.py +++ b/setup.py @@ -160,7 +160,7 @@ def _main(): # f"torchrl @ file://{rl_path}", "torchrl", "gym==0.13", - "mj_envs", + # "mj_envs", # f"mj_envs @ file://{mj_env_path}", "numpy", "packaging",