diff --git a/.gitignore b/.gitignore index 36aa1d25..42370a6a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,13 @@ dist/ .DS_Store .idea/ .ruff_cache/ +taming/ +lightning_logs/ +logs/ +wandb/ +data/ +*.pth +*.gif +*.pth +*.npy +*.npz \ No newline at end of file diff --git a/data_gen_scripts/4060_1510.gif b/data_gen_scripts/4060_1510.gif new file mode 100644 index 00000000..6c0dd212 Binary files /dev/null and b/data_gen_scripts/4060_1510.gif differ diff --git a/data_gen_scripts/examples.gif b/data_gen_scripts/examples.gif new file mode 100644 index 00000000..0ece5924 Binary files /dev/null and b/data_gen_scripts/examples.gif differ diff --git a/data_gen_scripts/generate_locomaze.py b/data_gen_scripts/generate_locomaze.py index 2f42d15b..6ad3a3a5 100644 --- a/data_gen_scripts/generate_locomaze.py +++ b/data_gen_scripts/generate_locomaze.py @@ -1,3 +1,17 @@ +import sys +import os +sys.path.insert(0, "/home/hyeons/workspace/ogbench") # Path to local ogbench +sys.path.append("../impls") +os.environ["MUJOCO_GL"] = "egl" +# python generate_locomaze.py --env_name=visual-pointmaze-medium-v0 --save_path=data/visual-pointmaze-stitch-navigate-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5 +# python generate_locomaze.py --env_name=visual-pointmaze-giant-v0 --save_path=data/visual-pointmaze-giant-stitch-v0.npz --dataset_type=stitch --num_episodes=5000 --max_episode_steps=201 --noise=0.5 +# python generate_locomaze.py --env_name=visual-pointmaze-medium-v0 --save_path=data/visual-pointmaze-medium-navigate-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --noise=0.5 +# python generate_locomaze.py --env_name=visual-pointmaze-large-v0 --save_path=data/visual-pointmaze-medium-large-v0.npz --dataset_type=navigate --num_episodes=1000 --max_episode_steps=1001 --noise=0.5 +# python generate_locomaze.py --env_name=visual-pointmaze-giant-v0 --save_path=data/visual-pointmaze-giant-large-v0.npz --dataset_type=navigate --num_episodes=500 --max_episode_steps=2001 --noise=0.5 + + +# for test +# python generate_locomaze.py --env_name=visual-pointmaze-giant-v0 --save_path=./data/0331.npz --dataset_type=navigate --num_episodes=1 --max_episode_steps=2001 --noise=0.5 import glob import json from collections import defaultdict @@ -190,7 +204,9 @@ def actor_fn(ob, temperature): train_path = FLAGS.save_path val_path = FLAGS.save_path.replace('.npz', '-val.npz') - + # Check if the directory exists, if not, create it. + os.makedirs('/'.join(train_path.split('/')[:-1]), exist_ok=True) + print(f'Saving to {train_path} and {val_path}') # Split the dataset into training and validation sets. train_dataset = {} val_dataset = {} diff --git a/data_gen_scripts/medium_maze.gif b/data_gen_scripts/medium_maze.gif new file mode 100644 index 00000000..66bc08db Binary files /dev/null and b/data_gen_scripts/medium_maze.gif differ diff --git a/data_gen_scripts/vis.ipynb b/data_gen_scripts/vis.ipynb new file mode 100644 index 00000000..77ced6d0 --- /dev/null +++ b/data_gen_scripts/vis.ipynb @@ -0,0 +1,78 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import imageio\n", + "import IPython.display as display\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "def sample_gif(f_name):\n", + " data = np.load('./data/'+ f_name + '.npz') # load만 26초 정도 걸림\n", + " print('Loaded data from','./data/'+ f_name + '.npz')\n", + " obs = data['observations'][:1000]\n", + " with imageio.get_writer(f_name + '.gif', mode='I') as writer:\n", + " for i in range(0, len(obs), 3): # Skip every 2 frames to speed up by 3x\n", + " writer.append_data(obs[i])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "f_name = '0331'\n", + "sample_gif(f_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import ogbench\n", + "env = ogbench.make('visual-pointmaze-giant-stitch-v0')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "og_game", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ogbench/locomaze/__init__.py b/ogbench/locomaze/__init__.py index 0fe978a8..9a85fdff 100644 --- a/ogbench/locomaze/__init__.py +++ b/ogbench/locomaze/__init__.py @@ -19,6 +19,17 @@ maze_type='medium', ), ) +register( + id='visual-pointmaze-medium-v0', + entry_point='ogbench.locomaze.maze:make_maze_env', + max_episode_steps=1000, + kwargs=dict( + loco_env_type='point', + maze_env_type='maze', + maze_type='medium', + **visual_dict, + ), +) register( id='pointmaze-large-v0', entry_point='ogbench.locomaze.maze:make_maze_env', @@ -29,6 +40,17 @@ maze_type='large', ), ) +register( + id='visual-pointmaze-large-v0', + entry_point='ogbench.locomaze.maze:make_maze_env', + max_episode_steps=1000, + kwargs=dict( + loco_env_type='point', + maze_env_type='maze', + maze_type='large', + **visual_dict, + ), +) register( id='pointmaze-giant-v0', entry_point='ogbench.locomaze.maze:make_maze_env', @@ -39,6 +61,17 @@ maze_type='giant', ), ) +register( + id='visual-pointmaze-giant-v0', + entry_point='ogbench.locomaze.maze:make_maze_env', + max_episode_steps=1000, + kwargs=dict( + loco_env_type='point', + maze_env_type='maze', + maze_type='giant', + **visual_dict, + ), +) register( id='pointmaze-teleport-v0', entry_point='ogbench.locomaze.maze:make_maze_env', @@ -49,7 +82,17 @@ maze_type='teleport', ), ) - +register( + id='visual-pointmaze-teleport-v0', + entry_point='ogbench.locomaze.maze:make_maze_env', + max_episode_steps=1000, + kwargs=dict( + loco_env_type='point', + maze_env_type='maze', + maze_type='teleport', + **visual_dict, + ), +) register( id='antmaze-medium-v0', entry_point='ogbench.locomaze.maze:make_maze_env', diff --git a/ogbench/locomaze/assets/point.xml b/ogbench/locomaze/assets/point.xml index 9e5ccd68..387c0601 100644 --- a/ogbench/locomaze/assets/point.xml +++ b/ogbench/locomaze/assets/point.xml @@ -27,6 +27,10 @@ + + + + diff --git a/ogbench/locomaze/maze.py b/ogbench/locomaze/maze.py index c92f65af..8b7c3869 100644 --- a/ogbench/locomaze/maze.py +++ b/ogbench/locomaze/maze.py @@ -327,18 +327,18 @@ def set_tasks(self): raise ValueError(f'Unknown maze type: {self._maze_type}') # More diverse task generation based on the maze map - visitable_positions = [] - for i in range(self.maze_map.shape[0]): - for j in range(self.maze_map.shape[1]): - if self.maze_map[i, j] == 0: - visitable_positions.append((i, j)) - # Combinations - tasks = [] - for i in range(len(visitable_positions)): - for j in range(i + 1, len(visitable_positions)): - tasks.append([visitable_positions[i], visitable_positions[j]]) - tasks.append([visitable_positions[j], visitable_positions[i]]) - print(f"The number of tasks is {len(tasks)}. The Task ID should be in [1, {len(tasks)}].") + # visitable_positions = [] + # for i in range(self.maze_map.shape[0]): + # for j in range(self.maze_map.shape[1]): + # if self.maze_map[i, j] == 0: + # visitable_positions.append((i, j)) + # # Combinations + # tasks = [] + # for i in range(len(visitable_positions)): + # for j in range(i + 1, len(visitable_positions)): + # tasks.append([visitable_positions[i], visitable_positions[j]]) + # tasks.append([visitable_positions[j], visitable_positions[i]]) + # print(f"The number of tasks is {len(tasks)}. The Task ID should be in [1, {len(tasks)}].") self.task_infos = [] for i, task in enumerate(tasks): diff --git a/ogbench/manipspace/envs/manipspace_env.py b/ogbench/manipspace/envs/manipspace_env.py index 919e02c6..8ce769e7 100644 --- a/ogbench/manipspace/envs/manipspace_env.py +++ b/ogbench/manipspace/envs/manipspace_env.py @@ -453,5 +453,5 @@ def render( ): if camera is None: camera = 'front' if self._ob_type == 'states' else 'front_pixels' - + return super().render(camera=camera, *args, **kwargs) diff --git a/ogbench/pretrain/dataset/vog_maze.py b/ogbench/pretrain/dataset/vog_maze.py new file mode 100644 index 00000000..e9ed98a3 --- /dev/null +++ b/ogbench/pretrain/dataset/vog_maze.py @@ -0,0 +1,62 @@ +import torch +import numpy as np + +class VOGMaze2dOfflineRLDataset(torch.utils.data.Dataset): + ''' + Offline RL dataset for 2D maze environments from OG-Bench. + Large + Mean of obs: 139.58089505083132 + std of obs: 71.31185013523307 + Mean of pos: [16.702621 10.974173] + std of pos: [10.050303 6.8203936] + Giant + Mean of obs: 141.01873851323037 + std of obs: 73.4250522212486 + Mean of pos: [24.888689 17.158426] + std of pos: [14.732276 11.651127] + ''' + + def __init__(self, dataset_url='/home/hyeons/workspace/ogbench/ogbench/dataset/visual-pointmaze-medium-navigate-v0.npz' , split: str = "training"): + + super().__init__() + self.dataset_url = dataset_url + self.split = split + dataset = self.get_dataset(self.dataset_url) + self.observations = dataset["observations"] + self.pos = dataset["qpos"] + self.actions = dataset["actions"] + + # Normalizations + if 'large' in dataset_url: + self.observations = (self.observations - 139.58089505083132) / 71.31185013523307 + elif 'giant' in dataset_url: + self.observations = (self.observations - 141.01873851323037) / 73.4250522212486 + + def __getitem__(self, idx): + observation = torch.from_numpy(self.observations[idx]).float().permute(2, 0, 1) + pos = torch.from_numpy(self.pos[idx]).float() + action = torch.from_numpy(self.actions[idx]).float() + return observation, pos, action + + def __len__(self): + return len(self.observations) + + def get_dataset(self, path): + if self.split == "validation": + path = path.replace(".npz", "-val.npz") + dataset = np.load(path, allow_pickle=True, mmap_mode='r') # 메모리 매핑 적용 + return dataset + +if __name__ == '__main__': + dataset = VOGMaze2dOfflineRLDataset(dataset_url='/home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-large-navigate-v0.npz', split='training') + print('large Mean of obs:', np.mean(dataset.observations)) + print('large std of obs:', np.std(dataset.observations)) + print("large Mean of pos:", np.mean(dataset.pos, axis=0)) + print("large std of pos:", np.std(dataset.pos, axis=0)) + + dataset = VOGMaze2dOfflineRLDataset(dataset_url='/home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-giant-navigate-v0.npz', split='training') + print('Giant Mean of obs:', np.mean(dataset.observations)) + print('Giant std of obs:', np.std(dataset.observations)) + print("Giant Mean of pos:", np.mean(dataset.pos, axis=0)) + print("Giant std of pos:", np.std(dataset.pos, axis=0)) + diff --git a/ogbench/pretrain/dataset/vog_maze_emb.py b/ogbench/pretrain/dataset/vog_maze_emb.py new file mode 100644 index 00000000..108572ed --- /dev/null +++ b/ogbench/pretrain/dataset/vog_maze_emb.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import os + +class VOGEmbeddingDataset(torch.utils.data.Dataset): + ''' + Medium + - pos mean: [10.273524 9.648321] + - pos std: [5.627576 4.897987] + - Action mean: [-0.00524961 -0.00168911] + - Action std: [0.70124096 0.6971626] + Large + emb mean: [ 0.56942457 -0.7748614 0.03422518 -0.05451093 0.00234696 -0.4766496 + -0.53628725 -1.1162051 ] + emb std: [2.0956497 2.2737527 2.3882532 2.6977062 2.1805387 2.6994274 2.4300833 + 2.137858 ] + pos mean: [16.702621 10.974173] + pos std: [10.050303 6.8203936] + actions mean: [-0.01116096 0.00125011] + actions std: [0.7068106 0.6878459] + Giant + emb mean: [ 0.67540884 -0.6614879 -0.30717567 0.09488879 -0.27652553 -0.8268671 + -1.1487181 -0.662139 ] + emb std: [2.1224945 2.1889937 2.3098729 2.5563455 2.3711634 2.4826827 2.3423505 + 2.6387706] + pos mean: [24.888689 17.158426] + pos std: [14.732276 11.651127] + actions mean: [-0.00714872 -0.00213099] + actions std: [0.70283055 0.69673675] + ''' + + def __init__(self, dataset_url='/home/hyeons/workspace/HierarchicalDiffusionForcing/data/embedded_data' , split: str = "training"): + + super().__init__() + self.dataset_url = dataset_url + self.split = split + self.emb, self.pos, self.actions = self.get_dataset() + + # Normalizations + if 'medium' in dataset_url: + self.emb_mean = np.array([0.53332597, -0.57663816, -0.15480594, -0.10989726, 0.13822828, -0.7565398 , -0.67368555, -0.5261524]) + self.emb_std = np.array([2.230295 , 1.8695153, 2.5765393, 2.5024776, 2.409886 , 2.3264396, 2.2680814, 2.1177504]) + self.pos_mean = np.array([10.273524, 9.648321]) + self.pos_std = np.array([5.627576, 4.897987]) + self.actions_mean = np.array([-0.00524961, -0.00168911]) + self.actions_std = np.array([0.70124096, 0.6971626]) + elif 'large' in dataset_url: + self.emb_mean = np.array([0.56942457, -0.7748614, 0.03422518, -0.05451093, 0.00234696, -0.4766496, -0.53628725, -1.1162051]) + self.emb_std = np.array([2.0956497, 2.2737527, 2.3882532, 2.6977062, 2.1805387, 2.6994274, 2.4300833, 2.137858]) + self.pos_mean = np.array([16.702621, 10.974173]) + self.pos_std = np.array([10.050303, 6.8203936]) + self.actions_mean = np.array([-0.01116096, 0.00125011]) + self.actions_std = np.array([0.7068106, 0.6878459]) + elif 'giant' in dataset_url: + self.emb_mean = np.array([ 0.67540884, -0.6614879, -0.30717567, 0.09488879, -0.27652553, -0.8268671, -1.1487181, -0.662139]) + self.emb_std = np.array([2.1224945, 2.1889937, 2.3098729, 2.5563455, 2.3711634, 2.4826827, 2.3423505, 2.6387706]) + self.pos_mean = np.array([24.888689, 17.158426]) + self.pos_std = np.array([14.732276, 11.651127]) + self.actions_mean = np.array([-0.00714872, -0.00213099]) + self.actions_std = np.array([0.70283055, 0.69673675]) + + # self.emb = (self.emb - self.emb_mean) / self.emb_std + self.pos = (self.pos - self.pos_mean) / self.pos_std + # self.actions = (self.actions - self.actions_mean) / self.actions_std + + + def __getitem__(self, idx): + emb = torch.from_numpy(self.emb[idx]).float() + pos = torch.from_numpy(self.pos[idx]).float() + action = torch.from_numpy(self.actions[idx]).float() + return emb, pos, action + + def __len__(self): + return len(self.emb) + + def get_dataset(self): + path = os.path.join(self.dataset_url, self.split) + emb = np.load(os.path.join(path, 'latent.npy')) + pos = np.load(os.path.join(path, 'positions.npy')) + act = np.load(os.path.join(path, 'actions.npy')) + return emb, pos, act +if __name__ == '__main__': + dataset = VOGEmbeddingDataset(dataset_url='/home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/large', split='training') + print('large Mean of emb:', np.mean(dataset.emb, axis=0)) + print('large std of emb:', np.std(dataset.emb, axis=0)) + print("large Mean of pos:", np.mean(dataset.pos, axis=0)) + print("large std of pos:", np.std(dataset.pos, axis=0)) + print("large Mean of actions:", np.mean(dataset.actions, axis=0)) + print("Medium std of actions:", np.std(dataset.actions, axis=0)) + + dataset = VOGEmbeddingDataset(dataset_url='/home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/giant', split='training') + print('Giant Mean of emb:', np.mean(dataset.emb, axis=0)) + print('Giant std of emb:', np.std(dataset.emb, axis=0)) + print("Giant Mean of pos:", np.mean(dataset.pos, axis=0)) + print("Giant std of pos:", np.std(dataset.pos, axis=0)) + print("Giant Mean of actions:", np.mean(dataset.actions, axis=0)) + print("Giant std of actions:", np.std(dataset.actions, axis=0)) diff --git a/ogbench/pretrain/embed.py b/ogbench/pretrain/embed.py new file mode 100644 index 00000000..3e3912b7 --- /dev/null +++ b/ogbench/pretrain/embed.py @@ -0,0 +1,66 @@ +''' +This file is used to embed the data into the latent space using the pre-trained VAE model. +command +CUDA_VISIBLE_DEVICES=2 python embed.py +''' + +import os +import numpy as np +import torch +from torch.utils.data import DataLoader +from dataset.vog_maze import VOGMaze2dOfflineRLDataset +from models.bvae import BetaVAE +from tqdm import tqdm +import argparse + +# configs +parser = argparse.ArgumentParser(description='Beta VAE Training Configuration') +parser.add_argument('--dataset_url', type=str, default='None', help='dataset url') +parser.add_argument('--load', type=str, default=None, help='Path to load model') +args = parser.parse_args() +config = vars(args) +# Load Model +path_to_model = config['load'] +model = BetaVAE().cuda() +model.load_state_dict(torch.load(path_to_model)) +model.eval() +print("Model loaded successfully") + +maze_type = 'large' if 'large' in path_to_model else 'giant' +# Load Dataset +path_to_dataset = config['dataset_url'] +splits = ['training', 'validation'] +for split in splits: + dataset = VOGMaze2dOfflineRLDataset(path_to_dataset, split) + dataloader = DataLoader(dataset, batch_size=4096, shuffle=False) + print("Dataset loaded successfully") + print(dataset.observations.shape) + # Embed data and save in chunks + path_to_save = './embedded_data/' + maze_type + split + os.makedirs(path_to_save, exist_ok=True) + + embeddings = [] + positions = [] + actions = [] + chunk_size = 10000 # Adjust the chunk size as needed + for i, (obs, pos, act) in enumerate(tqdm(dataloader, desc="Embedding data")): + with torch.no_grad(): + obs = obs.cuda() + mean, log_var = model.encode(obs) + z = model.reparameterize(mean, log_var) + embeddings.append(z.cpu().numpy()) + positions.append(pos.cpu().numpy()) + actions.append(act.cpu().numpy()) + + + if embeddings: + embeddings = np.concatenate(embeddings, axis=0) + print(embeddings.shape) + positions = np.concatenate(positions, axis=0) + actions = np.concatenate(actions, axis=0) + with open(os.path.join(path_to_save, f'latent_{split}.npy'), 'ab') as f: + np.save(f, embeddings) + with open(os.path.join(path_to_save, f'positions_{split}.npy'), 'ab') as f: + np.save(f, positions) + with open(os.path.join(path_to_save, f'actions_{split}.npy'), 'ab') as f: + np.save(f, actions) \ No newline at end of file diff --git a/ogbench/pretrain/embed.sh b/ogbench/pretrain/embed.sh new file mode 100644 index 00000000..2dd46f1c --- /dev/null +++ b/ogbench/pretrain/embed.sh @@ -0,0 +1,32 @@ +# SESSION_NAME=0 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=0 +# python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-large-navigate-v0.npz +# " C-m + +# SESSION_NAME=1 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=1 +# python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-gaint-navigate-v0.npz +# " C-m + +# SESSION_NAME=6 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=6 +# python embed.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-large-navigate-v0.npz --load /home/hyeons/workspace/ogbench/ogbench/pretrain/large_vae.pth +# " C-m + +SESSION_NAME=5 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=5 +python embed.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-giant-navigate-v0.npz --load /home/hyeons/workspace/ogbench/ogbench/pretrain/giant_vae.pth +" C-m + diff --git a/ogbench/pretrain/initial_observation2.png b/ogbench/pretrain/initial_observation2.png new file mode 100644 index 00000000..320c3449 Binary files /dev/null and b/ogbench/pretrain/initial_observation2.png differ diff --git a/ogbench/pretrain/models/bvae.py b/ogbench/pretrain/models/bvae.py new file mode 100644 index 00000000..c0003803 --- /dev/null +++ b/ogbench/pretrain/models/bvae.py @@ -0,0 +1,177 @@ +''' +Reference +https://github.com/AntixK/PyTorch-VAE/tree/master +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Callable, Union, Any, TypeVar, Tuple +Tensor = TypeVar('torch.tensor') + + +class BetaVAE(nn.Module): + + num_iter = 0 # Global static variable to keep track of iterations + + def __init__(self, + in_channels= 3, + latent_dim= 8, + kld_weight= 1e-6, + ) -> None: + super(BetaVAE, self).__init__() + + self.latent_dim = latent_dim + self.kld_weight = kld_weight + modules = [] + hidden_dims = [32, 64, 128, 256, 512] + + # Build Encoder + for h_dim in hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels=h_dim, + kernel_size= 3, stride= 2, padding = 1), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU(), + nn.Conv2d(h_dim, out_channels=h_dim, + kernel_size= 3, stride=1, padding=1), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU()) + ) + in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + self.fc_mu = nn.Sequential( + nn.Linear(hidden_dims[-1]*4, hidden_dims[-1]), + nn.BatchNorm1d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Linear(hidden_dims[-1], latent_dim) + ) + self.fc_var = nn.Sequential( + nn.Linear(hidden_dims[-1]*4, hidden_dims[-1]), + nn.BatchNorm1d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Linear(hidden_dims[-1], latent_dim) + ) + + + # Build Decoder + modules = [] + + self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) + + hidden_dims.reverse() + + for i in range(len(hidden_dims) - 1): + modules.append( + nn.Sequential( + nn.ConvTranspose2d(hidden_dims[i], + hidden_dims[i + 1], + kernel_size=3, + stride = 2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[i + 1]), + nn.LeakyReLU()) + ) + + + + self.decoder = nn.Sequential(*modules) + + self.final_layer = nn.Sequential( + nn.ConvTranspose2d(hidden_dims[-1], + hidden_dims[-1], + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + nn.BatchNorm2d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Conv2d(hidden_dims[-1], out_channels= 3, + kernel_size= 3, padding= 1), + nn.Tanh()) + + def encode(self, input: Tensor) -> List[Tensor]: + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [N x C x H x W] + :return: (Tensor) List of latent codes + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + + # Split the result into mu and var components + # of the latent Gaussian distribution + mu = self.fc_mu(result) + log_var = self.fc_var(result) + + return [mu, log_var] + + def decode(self, z: Tensor) -> Tensor: + result = self.decoder_input(z) + result = result.view(-1, 512, 2, 2) + result = self.decoder(result) + result = self.final_layer(result) + return result + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """ + Will a single z be enough ti compute the expectation + for the loss?? + :param mu: (Tensor) Mean of the latent Gaussian + :param logvar: (Tensor) Standard deviation of the latent Gaussian + :return: + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def forward(self, input: Tensor, **kwargs) -> Tensor: + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + return [self.decode(z), input, mu, log_var] + + def loss_function(self, + *args, + **kwargs) -> dict: + self.num_iter += 1 + recons = args[0] + input = args[1] + mu = args[2] + log_var = args[3] + kld_weight = self.kld_weight + + recons_loss =F.mse_loss(recons, input) + + kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) + loss = recons_loss + kld_weight* kld_loss.abs() + return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss, 'loss - recons_loss': loss - recons_loss} + + def sample(self, + num_samples:int, + current_device: int, **kwargs) -> Tensor: + """ + Samples from the latent space and return the corresponding + image space map. + :param num_samples: (Int) Number of samples + :param current_device: (Int) Device to run the model + :return: (Tensor) + """ + z = torch.randn(num_samples, + self.latent_dim) + + z = z.to(current_device) + + samples = self.decode(z) + return samples + + def generate(self, x: Tensor, **kwargs) -> Tensor: + """ + Given an input image x, returns the reconstructed image + :param x: (Tensor) [B x C x H x W] + :return: (Tensor) [B x C x H x W] + """ + + return self.forward(x)[0] \ No newline at end of file diff --git a/ogbench/pretrain/models/mlp.py b/ogbench/pretrain/models/mlp.py new file mode 100644 index 00000000..03fe5ed4 --- /dev/null +++ b/ogbench/pretrain/models/mlp.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim=1024, num_layers=4): + super(MLP, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + self.layers = nn.ModuleList() + self.layers.append(nn.Linear(input_dim, hidden_dim)) + for _ in range(num_layers - 2): + self.layers.append(nn.Linear(hidden_dim, hidden_dim)) + self.layers.append(nn.Linear(hidden_dim, output_dim)) + + def forward(self, x): + for layer in self.layers[:-1]: + x = torch.relu(layer(x)) + x = self.layers[-1](x) + return x + + diff --git a/ogbench/pretrain/run.sh b/ogbench/pretrain/run.sh new file mode 100644 index 00000000..3bdcc14b --- /dev/null +++ b/ogbench/pretrain/run.sh @@ -0,0 +1,61 @@ +# SESSION_NAME=vmalv +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=0 +# python train_id.py --lr 1e-4 --weight_decay 1e-5 --hidden_size 4096 --num_layers 5 +# " C-m + +# SESSION_NAME=vmalv1 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=1 +# python train_id.py --lr 1e-4 --weight_decay 1e-4 --hidden_size 4096 --num_layers 5 +# " C-m +# #!/bin/bash + +# Hyperparameter grids +lrs=(5e-4 1e-3 1e-4) +weight_decays=(0) +hidden_sizes=(128 256 512 1024) # 256 1024 2048 4096 +num_layers=(4) +frame_stack=(2) + +# GPU configuration +gpus=(0 1 2 3 4 5 6 7) +experiments_per_gpu=4 + +# Total experiments +total_experiments=32 +experiment=0 + +# Kill any existing tmux server +tmux kill-server + +for lr in "${lrs[@]}"; do + for weight_decay in "${weight_decays[@]}"; do + for hidden_size in "${hidden_sizes[@]}"; do + for layer in "${num_layers[@]}"; do + for stack in "${frame_stack[@]}"; do + if [ $experiment -ge $total_experiments ]; then + break 4 + fi + + gpu=${gpus[$((experiment / experiments_per_gpu))]} + SESSION_NAME="exp_${experiment}" + + tmux new-session -d -s $SESSION_NAME + tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m + tmux send-keys -t $SESSION_NAME " + export CUDA_VISIBLE_DEVICES=${gpu} + python train_id.py --lr ${lr} --weight_decay ${weight_decay} --hidden_dim ${hidden_size} --num_layers ${layer} --frame_stack ${stack} + " C-m + + echo "Launched ${SESSION_NAME} on GPU ${gpu} with lr=${lr}, weight_decay=${weight_decay}, hidden_size=${hidden_size}, num_layers=${layer}" + experiment=$((experiment + 1)) + done + done + done + done +done \ No newline at end of file diff --git a/ogbench/pretrain/temp_points.png b/ogbench/pretrain/temp_points.png new file mode 100644 index 00000000..884a63db Binary files /dev/null and b/ogbench/pretrain/temp_points.png differ diff --git a/ogbench/pretrain/train_e2s_id.py b/ogbench/pretrain/train_e2s_id.py new file mode 100644 index 00000000..64b999ae --- /dev/null +++ b/ogbench/pretrain/train_e2s_id.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from dataset.vog_maze_emb import VOGEmbeddingDataset +from models.mlp import MLP +from tqdm import tqdm +import argparse +import numpy as np +from collections import deque + +# configs +parser = argparse.ArgumentParser(description='Training MLP on VOG embeddings') +parser.add_argument('--mode', type=str, choices=['invd', 'e2s'], required=True, help='Training mode: invd or e2s') +parser.add_argument('--num_epochs', type=int, default=60) +parser.add_argument('--batch_size', type=int, default=128) +parser.add_argument('--lr', type=float, default=1e-4) +parser.add_argument('--weight_decay', type=float, default=1e-5) +parser.add_argument('--seed', type=int, default=42) +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') +parser.add_argument('--name', type=str, default='mlp_model') +parser.add_argument('--load', type=str, default=None) +parser.add_argument('--group_name', type=str, default='default_group') +parser.add_argument('--log', type=bool, default=True) +parser.add_argument('--hidden_dim', type=int, default=1024) +parser.add_argument('--num_layers', type=int, default=3) +parser.add_argument('--frame_stack', type=int, default=3) +parser.add_argument('--dataset_url', type=str, default='None') +args = parser.parse_args() +config = vars(args) + +# Set seed +torch.manual_seed(config["seed"]) +if config["device"] == "cuda": + torch.cuda.manual_seed(config["seed"]) + +# Initialize wandb +if config["log"]: + import wandb + wandb.init(project=f'{config["mode"]}_project', entity='Hierarchical-Diffusion-Forcing', config=config) + +# Load raw datasets +train_raw = VOGEmbeddingDataset(config['dataset_url'], split='training') +val_raw = VOGEmbeddingDataset(config['dataset_url'], split='validation') + +# Preprocessing +if config["mode"] == 'invd': + def preprocess_inverse(dataset): + processed = [] + emb_q = deque(maxlen=config["frame_stack"]) + actions = deque(maxlen=2) + for i in range(len(dataset)): + emb, _, action = dataset[i] + emb_q.append(emb) + actions.append(action) + if i % 1001 == 0: + if i == 0: + emb_q.append(emb) + continue + else: + emb_q.append(emb) + if len(emb_q) == config["frame_stack"]: + processed.append((np.concatenate(list(emb_q)), actions[0])) + return processed + + train_data = preprocess_inverse(train_raw) + val_data = preprocess_inverse(val_raw) + + input_dim = 8 * config["frame_stack"] + output_dim = 2 +else: # e2s + train_data = [(emb, pos) for emb, pos, _ in train_raw] + val_data = [(emb, pos) for emb, pos, _ in val_raw] + input_dim = 8 + output_dim = 2 + +train_loader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True, pin_memory=True) +val_loader = DataLoader(val_data, batch_size=config["batch_size"], shuffle=False, pin_memory=True) + +# Initialize model +model = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dim=config['hidden_dim'], num_layers=config['num_layers']) +model.to(config["device"]) + +# Load weights +if config["load"]: + model.load_state_dict(torch.load(config["load"])) + +# Optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + +# Training loop +print("Training started") +for epoch in tqdm(range(config["num_epochs"])): + model.train() + train_loss_list = [] + for batch in train_loader: + inputs, targets = batch + inputs = inputs.to(config["device"]) + targets = targets.to(config["device"]) + + optimizer.zero_grad() + preds = model(inputs) + loss = nn.functional.mse_loss(preds, targets) + loss.backward() + optimizer.step() + train_loss_list.append(loss.item()) + + # Validation + if epoch % 5 == 0: + model.eval() + val_loss_list = [] + with torch.no_grad(): + for batch in val_loader: + inputs, targets = batch + inputs = inputs.to(config["device"]) + targets = targets.to(config["device"]) + preds = model(inputs) + loss = nn.functional.mse_loss(preds, targets) + val_loss_list.append(loss.item()) + + avg_train_loss = sum(train_loss_list) / len(train_loss_list) + avg_val_loss = sum(val_loss_list) / len(val_loss_list) + print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") + + if config["log"]: + wandb.log({"Train Loss": avg_train_loss, "Val Loss": avg_val_loss}) + +# Save final model +maze_type = 'large' if 'large' in config['dataset_url'] else 'giant' +import datetime +date = datetime.datetime.today().strftime('%m%d') +torch.save(model.state_dict(), f"{maze_type}_{config['mode']}_{date}.pth") \ No newline at end of file diff --git a/ogbench/pretrain/train_e2s_id.sh b/ogbench/pretrain/train_e2s_id.sh new file mode 100644 index 00000000..cf79f8f5 --- /dev/null +++ b/ogbench/pretrain/train_e2s_id.sh @@ -0,0 +1,16 @@ +# SESSION_NAME=3 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=0 +# python train_e2s_id.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/large --mode invd +# " C-m + +SESSION_NAME=2 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=0 +python train_e2s_id.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/large --mode e2s +" C-m + diff --git a/ogbench/pretrain/train_e2s_n_id.sh b/ogbench/pretrain/train_e2s_n_id.sh new file mode 100644 index 00000000..ba033f60 --- /dev/null +++ b/ogbench/pretrain/train_e2s_n_id.sh @@ -0,0 +1,33 @@ +SESSION_NAME=0 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=0 +python train_id.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/large +" C-m + +SESSION_NAME=1 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=1 +python train_id.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/giant +" C-m + +SESSION_NAME=2 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=2 +python train_emb2state.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/large +" C-m + +SESSION_NAME=3 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=3 +python train_emb2state.py --dataset_url /home/hyeons/workspace/ogbench/ogbench/pretrain/embedded_data/giant +" C-m + + diff --git a/ogbench/pretrain/train_emb2state.py b/ogbench/pretrain/train_emb2state.py new file mode 100644 index 00000000..72adb1d4 --- /dev/null +++ b/ogbench/pretrain/train_emb2state.py @@ -0,0 +1,96 @@ + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from dataset.vog_maze_emb import VOGEmbeddingDataset +from models.mlp import MLP +from tqdm import tqdm +import argparse + +# configs +parser = argparse.ArgumentParser(description='Training 3 layer MLP to predict the state from the embedding') + +parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs') +parser.add_argument('--batch_size', type=int, default=128, help='Batch size') +parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--seed', type=int, default=42, help='Random seed') +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use') +parser.add_argument('--name', type=str, default='beta_vae', help='Model name') +parser.add_argument('--load', type=str, default=None, help='Path to load model') +parser.add_argument('--group_name', type=str, default='None', help='gropu name') +parser.add_argument('--log', type=bool, default=True, help='log') +parser.add_argument('--hidden_dim', type=int, default=1024, help='log') +parser.add_argument('--num_layers', type=int, default=3, help='log') +parser.add_argument('--dataset_url', type=str, default='None', help='dataset url') + +args = parser.parse_args() +config = vars(args) + + +# set seed +torch.manual_seed(config["seed"]) +if config["device"] == "cuda": + torch.cuda.manual_seed(config["seed"]) + +#initialize wandb + +if config["log"]: + import wandb + wandb.init(project='hs_e2s', entity='Hierarchical-Diffusion-Forcing', config=config) + +# Load the dataset +train_dataset = VOGEmbeddingDataset(config['dataset_url'], split='training') +validation_dataset = VOGEmbeddingDataset(config['dataset_url'], split='validation') + +# Create the dataloaders +train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) +validation_loader = DataLoader(validation_dataset, batch_size=config["batch_size"], shuffle=False) + +print("Dataset loaded") + +# Initialize the model +model = MLP(input_dim=8, output_dim=2, hidden_dim=config['hidden_dim'], num_layers=config['num_layers']) +model.to(config["device"]) + +# Load the model if specified +if config["load"]: + model.load_state_dict(torch.load(config["load"])) + +# Initialize the optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) + +# Train +print("Training started") +for epoch in tqdm(range(config["num_epochs"])): + model.train() + train_loss_list = [] + for i, (emb, pos, _) in enumerate(train_loader): + emb, pos = emb.to(config["device"]), pos.to(config["device"]) + optimizer.zero_grad() + pred = model(emb) + loss = nn.functional.mse_loss(pred, pos) + loss.backward() + optimizer.step() + train_loss_list.append(loss.mean().item()) + + # Validation + if epoch % 5 == 0: + model.eval() + with torch.no_grad(): + loss_list = [] + for i, (emb, pos, _) in enumerate(validation_loader): + emb, pos = emb.to(config["device"]), pos.to(config["device"]) + pred = model(emb) + loss = nn.functional.mse_loss(pred, pos) + loss_list.append(loss.mean().item()) + print(f"Epoch: {epoch}, Train Loss: {sum(train_loss_list)/len(train_loss_list)}, Validation Loss: {sum(loss_list)/len(loss_list)}") + + if config["log"]: + wandb.log({"Training Loss": sum(train_loss_list)/len(train_loss_list)}) + wandb.log({"Validation Loss": sum(loss_list)/len(loss_list)}) + + # Save the model with time + # if epoch % 10 == 0: + # torch.save(model.state_dict(), f"e2s_loss{sum(loss_list)/len(loss_list)}_{epoch}.pth") +maze_type = 'large' if 'large' in config['dataset_url'] else 'giant' +torch.save(model.state_dict(), f"{maze_type}_e2s.pth") diff --git a/ogbench/pretrain/train_id.py b/ogbench/pretrain/train_id.py new file mode 100644 index 00000000..cddd3a59 --- /dev/null +++ b/ogbench/pretrain/train_id.py @@ -0,0 +1,135 @@ + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from dataset.vog_maze_emb import VOGEmbeddingDataset +from models.mlp import MLP +from tqdm import tqdm +import argparse +import numpy as np +from collections import deque + +# configs +parser = argparse.ArgumentParser(description='Training 3 layer MLP to predict the state from the embedding') + +parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs') +parser.add_argument('--batch_size', type=int, default=128, help='Batch size') +parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') +parser.add_argument('--seed', type=int, default=42, help='Random seed') +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use') +parser.add_argument('--name', type=str, default='beta_vae', help='Model name') +parser.add_argument('--load', type=str, default=None, help='Path to load model') +parser.add_argument('--group_name', type=str, default='after_sanity_check', help='gropu name') +parser.add_argument('--log', type=bool, default=True, help='log') +parser.add_argument('--hidden_dim', type=int, default=1024, help='log') +parser.add_argument('--num_layers', type=int, default=3) +parser.add_argument('--frame_stack', type=int, default=3, help='') +parser.add_argument('--dataset_url', type=str, default='None', help='dataset url') + +args = parser.parse_args() +config = vars(args) + +# set seed +torch.manual_seed(config["seed"]) +if config["device"] == "cuda": + torch.cuda.manual_seed(config["seed"]) + +#initialize wandb + +if config["log"]: + import wandb + wandb.init(project='inverse_dynamics', entity='Hierarchical-Diffusion-Forcing', config=config) + +# Load the dataset +train_dataset = VOGEmbeddingDataset(dataset_url=config['dataset_url'], split='training') +validation_dataset = VOGEmbeddingDataset(dataset_url= config['dataset_url'],split='validation') + +print("Preprocessing dataset") +train_inverse = [] +emb_q = deque(maxlen=config["frame_stack"]) +actions = deque(maxlen=2) +for i in range(len(train_dataset)): + emb, pos, action = train_dataset[i] + emb_q.append(emb) + actions.append(action) + if i % 1001 == 0: + if i == 0: + emb_q.append(emb) + continue + else: + emb_q.append(emb) + train_inverse.append((np.concatenate(list(emb_q)), actions[0])) # emb[-2]에서 emb[-1]로 가는 action이 target +train_dataset = train_inverse + +validation_inverse = [] +emb_q = deque(maxlen=config['frame_stack']) +actions = deque(maxlen=2) +for i in range(len(validation_dataset)): + emb, pos, action = validation_dataset[i] + emb_q.append(emb) + actions.append(action) + if i % 1001 == 0: + if i == 0: + emb_q.append(emb) + continue + else: + emb_q.append(emb) + validation_inverse.append((np.concatenate(list(emb_q)), actions[0])) +validation_dataset = validation_inverse + +print("Dataset preprocessed") + +train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, pin_memory=True) +validation_loader = DataLoader(validation_dataset, batch_size=config["batch_size"], shuffle=False, pin_memory=True) + +print("Dataset loaded") + +# Initialize the model +model = MLP(input_dim=8*config['frame_stack'], output_dim=2, hidden_dim=config['hidden_dim'], num_layers=config['num_layers']) +model.to(config["device"]) + +# Load the model if specified +if config["load"]: + model.load_state_dict(torch.load(config["load"])) + +# Initialize the optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + +# Train +print("Training started") +for epoch in tqdm(range(config["num_epochs"])): + model.train() + train_loss_list = [] + for i, (emb, action) in enumerate(train_loader): + emb, action = emb.to(config["device"]), action.to(config["device"]) + optimizer.zero_grad() + pred = model(emb) + loss = nn.functional.mse_loss(pred, action) + train_loss_list.append(loss.mean().item()) + loss.backward() + optimizer.step() + + # Validation + if epoch % 10 == 0: + model.eval() + with torch.no_grad(): + loss_list = [] + for i, (emb, action) in enumerate(validation_loader): + emb, action = emb.to(config["device"]), action.to(config["device"]) + pred = model(emb) + loss = nn.functional.mse_loss(pred, action) + loss_list.append(loss.mean().item()) + print(f"Epoch: {epoch}, Train Loss: {sum(train_loss_list)/len(train_loss_list)}, Validation Loss: {sum(loss_list)/len(loss_list)}") + + if config["log"]: + wandb.log({"Training Loss": sum(train_loss_list)/len(train_loss_list)}) + wandb.log({"Validation Loss": sum(loss_list)/len(loss_list)}) + + # Save the model with time + # if epoch % 10 == 0: + # torch.save(model.state_dict(), f"e2s_loss{sum(loss_list)/len(loss_list)}_{epoch}.pth") +maze_type = 'large' if 'large' in config['dataset_url'] else 'giant' +torch.save(model.state_dict(), f"{maze_type}_invd.pth") + + diff --git a/ogbench/pretrain/train_vae.py b/ogbench/pretrain/train_vae.py new file mode 100644 index 00000000..9ddd7aae --- /dev/null +++ b/ogbench/pretrain/train_vae.py @@ -0,0 +1,105 @@ +import torch +from torch.utils.data import DataLoader +from ogbench.pretrain.dataset.vog_maze import VOGMaze2dOfflineRLDataset +from ogbench.pretrain.models.bvae import BetaVAE +from tqdm import tqdm +import argparse + +# configs +parser = argparse.ArgumentParser(description='Beta VAE Training Configuration') + +parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs') +parser.add_argument('--batch_size', type=int, default=4096, help='Batch size') +parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') +parser.add_argument('--seed', type=int, default=42, help='Random seed') +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use') +parser.add_argument('--name', type=str, default='beta_vae', help='Model name') +parser.add_argument('--load', type=str, default=None, help='Path to load model') +parser.add_argument('--kld_weight', type=float, default=0.1, help='KLD weight') +parser.add_argument('--group_name', type=str, default='None', help='gropu name') +parser.add_argument('--dataset_url', type=str, default='None', help='dataset url') + +args = parser.parse_args() +config = vars(args) + + +# set seed +torch.manual_seed(config["seed"]) +if config["device"] == "cuda": + torch.cuda.manual_seed(config["seed"]) + +#initialize wandb +import wandb +wandb.init(project='hs_vae', entity='Hierarchical-Diffusion-Forcing', config=config) + +# Load the dataset +train_dataset = VOGMaze2dOfflineRLDataset(dataset_url=config['dataset_url'], split='training') +validation_dataset = VOGMaze2dOfflineRLDataset(dataset_url=config['dataset_url'], split='validation') + +# Create the dataloaders +train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) +validation_loader = DataLoader(validation_dataset, batch_size=config["batch_size"], shuffle=False) + +print("Dataset loaded") + +# Initialize the model +model = BetaVAE() +model.to(config["device"]) + +# Load the model if specified +if config["load"]: + model.load_state_dict(torch.load(config["load"])) + +# Initialize the optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) + +# Train +print("Training started") +for epoch in tqdm(range(config["num_epochs"])): + model.train() + for i, (obs, pos, _) in enumerate(train_loader): + obs = obs.to(config["device"]) + optimizer.zero_grad() + recons, inputs, mu, log_var = model(obs) + loss_dict = model.loss_function(recons, inputs, mu, log_var) + loss = loss_dict["loss"] + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(f"Epoch: {epoch}, Iter: {model.num_iter}", end=" " ) + for k, v in loss_dict.items(): + wandb.log({k: v}) + print(f"{k}: {v}", end=", ") + print(end="\r") + + # Validation + model.eval() + with torch.no_grad(): + for i, (obs, pos, _) in enumerate(validation_loader): + obs = obs.to(config["device"]) + recons, inputs, mu, log_var = model(obs) + loss_dict = model.loss_function(recons, inputs, mu, log_var) + if i == 0: + idx = torch.randint(0, obs.size(0), (32,)) + # log the base input and reconstruction + inputs = inputs[idx[:8]] + recons = recons[idx[:8]] + inputs = inputs * 71.0288272312382 + 141.785487953533 + recons = recons * 71.0288272312382 + 141.785487953533 + wandb.log({"input": wandb.Image(inputs)}) + wandb.log({"recons": wandb.Image(recons)}) + + + for k, v in loss_dict.items(): + wandb.log({k+'val': v}) + print("Epoch:", epoch,k+'val', ":", v, end=", ") + + # Save the model with time + if epoch % 10 == 0: + maze_type = 'large' if 'large' in config['dataset_url'] else 'giant' + torch.save(model.state_dict(), f"{maze_type}_loss{loss_dict['Reconstruction_Loss']}_{epoch}.pth") +torch.save(model.state_dict(), f"{maze_type}_loss{loss_dict['Reconstruction_Loss']}_last.pth") + + + diff --git a/ogbench/pretrain/train_vae.sh b/ogbench/pretrain/train_vae.sh new file mode 100644 index 00000000..1cfaabcf --- /dev/null +++ b/ogbench/pretrain/train_vae.sh @@ -0,0 +1,31 @@ +# SESSION_NAME=0 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=0 +# python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-large-navigate-v0.npz +# " C-m + +# SESSION_NAME=1 +# tmux new-session -d -s $SESSION_NAME +# tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +# tmux send-keys -t $SESSION_NAME " +# export CUDA_VISIBLE_DEVICES=1 +# python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-gaint-navigate-v0.npz +# " C-m +SESSION_NAME=4 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=4 +python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-large-navigate-v0.npz --load /home/hyeons/workspace/ogbench/ogbench/pretrain/large_100.pth +" C-m + +SESSION_NAME=6 +tmux new-session -d -s $SESSION_NAME +tmux send-keys -t $SESSION_NAME "conda activate og_game" C-m +tmux send-keys -t $SESSION_NAME " +export CUDA_VISIBLE_DEVICES=6 +python train_vae.py --dataset_url /home/hyeons/workspace/ogbench/data_gen_scripts/data/visual-pointmaze-giant-navigate-v0.npz --load /home/hyeons/workspace/ogbench/ogbench/pretrain/giant_100.pth +" C-m + diff --git a/ogbench/pretrain/visualize.ipynb b/ogbench/pretrain/visualize.ipynb new file mode 100644 index 00000000..516333b0 --- /dev/null +++ b/ogbench/pretrain/visualize.ipynb @@ -0,0 +1,619 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from models.bvae import BetaVAE\n", + "from models.mlp import MLP\n", + "from dataset.vog_maze import VOGMaze2dOfflineRLDataset\n", + "from dataset.vog_maze_emb import VOGEmbeddingDataset\n", + "import sys\n", + "sys.path.append('../data_gen_scripts')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load Model and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_461393/2780672498.py:11: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " model.load_state_dict(torch.load(path_to_model))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset loaded\n" + ] + } + ], + "source": [ + "# load model & dataset\\||\n", + "dataset_type = 'emb'\n", + "\n", + "# path_to_model = './weights/vae_weight.pth'\n", + "# path_to_dataset = './data/original_data'\n", + "path_to_model = './weights/e2s_weight.pth'\n", + "path_to_dataset = './data/embedded_data'\n", + "\n", + "if dataset_type == 'emb':\n", + " model = MLP(8, 2)\n", + " model.load_state_dict(torch.load(path_to_model))\n", + " data = VOGEmbeddingDataset(dataset_url=path_to_dataset, split='validation')\n", + " print('Dataset loaded')\n", + "else: \n", + " model = BetaVAE().to('cuda:0')\n", + " model.load_state_dict(torch.load(path_to_model))\n", + " model.eval()\n", + " print('Model loaded')\n", + "\n", + " # load dataset - this will take a while\n", + " data = VOGMaze2dOfflineRLDataset(dataset_url=path_to_dataset) \n", + " print('Dataset loaded')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Visualize Map" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8 8\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAX/klEQVR4nO3df2zUhf3H8dfB2UOxPQUptuGAhjbyo4DQMlbA+QPsckEi2cZ0QVZHtqxL+WVj4qp/SPaDY38smcTZrMx0EoIlywQxW8GSSXEx3Uq1kaHBMoi9CawB5a70jyO2n+8f33BZh5R+rn33w6c8H8kn2V3u/LxCmj73uWt7AcdxHAEAMMzGeD0AADA6ERgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGAiONIn7Ovr09mzZ5Wdna1AIDDSpwcADIHjOOru7lZ+fr7GjBn4GmXEA3P27FlFIpGRPi0AYBjF43FNmTJlwMeMeGCys7Ml/f+4nJyckT49AGAIksmkIpFI+nv5QEY8MFdfFsvJySEwAOBTg3mLgzf5AQAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwkVFgXnnlFRUUFGjcuHEqKSnRu+++O9y7AAA+5zowe/fu1ZYtW/TCCy/ogw8+0AMPPKBoNKrOzk6LfQAAnwo4juO4ecLixYu1cOFC1dbWpu+bNWuWVq9erVgsdsPnJ5NJhcNhJRIJPjIZAHzGzfdwV1cwV65cUVtbm8rLy/vdX15ervfee+8rn5NKpZRMJvsdAIDRz1VgLly4oN7eXk2ePLnf/ZMnT9b58+e/8jmxWEzhcDh9RCKRzNcCAHwjozf5A4FAv9uO41xz31U1NTVKJBLpIx6PZ3JKAIDPBN08+J577tHYsWOvuVrp6uq65qrmqlAopFAolPlCAIAvubqCycrKUklJiZqamvrd39TUpCVLlgzrMACAv7m6gpGk6upqrVu3TqWlpSorK1NdXZ06OztVWVlpsQ8A4FOuA/PEE0/o4sWL+tnPfqZz586puLhYf/nLXzRt2jSLfQAAn3L9ezBDxe/BAIB/mf0eDAAAg0VgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMuP7AsVtZUVGH1xMyU/i51wsy1tG42OsJt5Sixh1eT8iMn7/Gi7Z6PcEMVzAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATLgOzNGjR7Vq1Srl5+crEAho//79BrMAAH7nOjA9PT2aP3++Xn75ZYs9AIBRIuj2CdFoVNFo1GILAGAUcR0Yt1KplFKpVPp2Mpm0PiUA4CZg/iZ/LBZTOBxOH5FIxPqUAICbgHlgampqlEgk0kc8Hrc+JQDgJmD+ElkoFFIoFLI+DQDgJsPvwQAATLi+grl8+bJOnTqVvn3mzBm1t7drwoQJmjp16rCOAwD4l+vAHDt2TA8//HD6dnV1tSSpoqJCf/jDH4ZtGADA31wH5qGHHpLjOBZbAACjCO/BAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABOuPw/mVlaoz72ekCG/7vav6I4dXk/ISGGRX79W/Lp7dOMKBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJV4GJxWJatGiRsrOzlZubq9WrV+vkyZNW2wAAPuYqMM3NzaqqqlJLS4uampr05Zdfqry8XD09PVb7AAA+FXTz4IMHD/a7XV9fr9zcXLW1tekb3/jGsA4DAPibq8D8r0QiIUmaMGHCdR+TSqWUSqXSt5PJ5FBOCQDwiYzf5HccR9XV1Vq2bJmKi4uv+7hYLKZwOJw+IpFIpqcEAPhIxoHZsGGDPvzwQ73++usDPq6mpkaJRCJ9xOPxTE8JAPCRjF4i27hxow4cOKCjR49qypQpAz42FAopFAplNA4A4F+uAuM4jjZu3Kh9+/bpyJEjKigosNoFAPA5V4GpqqrSnj179Oabbyo7O1vnz5+XJIXDYd1+++0mAwEA/uTqPZja2lolEgk99NBDysvLSx979+612gcA8CnXL5EBADAY/C0yAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMBJwR/hSxZDKpcDisRCKhnJyckTz1LWtTdLfXEzJX9LnXCzJT6M/dOzZt9XoCbnJuvodzBQMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACZcBaa2tlbz5s1TTk6OcnJyVFZWpsbGRqttAAAfcxWYKVOmaPv27Tp27JiOHTumRx55RI8//rhOnDhhtQ8A4FNBNw9etWpVv9u//OUvVVtbq5aWFs2ZM2dYhwEA/M1VYP5bb2+v/vjHP6qnp0dlZWXXfVwqlVIqlUrfTiaTmZ4SAOAjrt/kP378uO68806FQiFVVlZq3759mj179nUfH4vFFA6H00ckEhnSYACAP7gOzH333af29na1tLToJz/5iSoqKvTRRx9d9/E1NTVKJBLpIx6PD2kwAMAfXL9ElpWVpcLCQklSaWmpWltb9dJLL+l3v/vdVz4+FAopFAoNbSUAwHeG/HswjuP0e48FAADJ5RXM888/r2g0qkgkou7ubjU0NOjIkSM6ePCg1T4AgE+5Csx//vMfrVu3TufOnVM4HNa8efN08OBBPfroo1b7AAA+5Sowr776qtUOAMAow98iAwCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADAhKsPHLvV7Sja7fWEjBQWfu71hIxt2rHJ6wkZ2eHX3Y1PeT0hM4UXvV6QsU1FjV5PMMMVDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmBhSYGKxmAKBgLZs2TJMcwAAo0XGgWltbVVdXZ3mzZs3nHsAAKNERoG5fPmy1q5dq507d+ruu+8e7k0AgFEgo8BUVVVp5cqVWrFixXDvAQCMEkG3T2hoaND777+v1tbWQT0+lUoplUqlbyeTSbenBAD4kKsrmHg8rs2bN2v37t0aN27coJ4Ti8UUDofTRyQSyWgoAMBfXAWmra1NXV1dKikpUTAYVDAYVHNzs3bs2KFgMKje3t5rnlNTU6NEIpE+4vH4sI0HANy8XL1Etnz5ch0/frzffT/4wQ80c+ZMPffccxo7duw1zwmFQgqFQkNbCQDwHVeByc7OVnFxcb/7xo8fr4kTJ15zPwDg1sZv8gMATLj+KbL/deTIkWGYAQAYbbiCAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADAxJA/cOxWUqTPvZ6QkWjjJq8nZKxxkz+3FxX69GslutvrCRlp7Fjs9QR8Ba5gAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJhwFZitW7cqEAj0O+69916rbQAAHwu6fcKcOXN0+PDh9O2xY8cO6yAAwOjgOjDBYJCrFgDADbl+D6ajo0P5+fkqKCjQk08+qdOnTw/4+FQqpWQy2e8AAIx+rgKzePFi7dq1S4cOHdLOnTt1/vx5LVmyRBcvXrzuc2KxmMLhcPqIRCJDHg0AuPm5Ckw0GtW3v/1tzZ07VytWrNCf//xnSdJrr7123efU1NQokUikj3g8PrTFAABfcP0ezH8bP3685s6dq46Ojus+JhQKKRQKDeU0AAAfGtLvwaRSKX388cfKy8sbrj0AgFHCVWCeffZZNTc368yZM/r73/+u73znO0omk6qoqLDaBwDwKVcvkf373//W9773PV24cEGTJk3S17/+dbW0tGjatGlW+wAAPuUqMA0NDVY7AACjDH+LDABggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJhw9Xkwt7pCfe71hIx0RDd5PSFjhUX+/Dcv2rTb6wkZ6Whc7PWEjBQW+vPrZLTjCgYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJggMAAAEwQGAGCCwAAATBAYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACdeB+eyzz/TUU09p4sSJuuOOO3T//ferra3NYhsAwMeCbh78xRdfaOnSpXr44YfV2Nio3Nxc/etf/9Jdd91lNA8A4FeuAvOrX/1KkUhE9fX16fumT58+3JsAAKOAq5fIDhw4oNLSUq1Zs0a5ublasGCBdu7cOeBzUqmUkslkvwMAMPq5Cszp06dVW1uroqIiHTp0SJWVldq0aZN27dp13efEYjGFw+H0EYlEhjwaAHDzcxWYvr4+LVy4UNu2bdOCBQv04x//WD/60Y9UW1t73efU1NQokUikj3g8PuTRAICbn6vA5OXlafbs2f3umzVrljo7O6/7nFAopJycnH4HAGD0cxWYpUuX6uTJk/3u++STTzRt2rRhHQUA8D9XgXnmmWfU0tKibdu26dSpU9qzZ4/q6upUVVVltQ8A4FOuArNo0SLt27dPr7/+uoqLi/Xzn/9cv/nNb7R27VqrfQAAn3L1ezCS9Nhjj+mxxx6z2AIAGEX4W2QAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAwQWAAACYIDADABIEBAJgIOI7jjOQJk8mkwuGwEomEcnJyRvLUAIAhcvM9nCsYAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABMEBgAgAkCAwAw4Sow06dPVyAQuOaoqqqy2gcA8Kmgmwe3traqt7c3ffuf//ynHn30Ua1Zs2bYhwEA/M1VYCZNmtTv9vbt2zVjxgw9+OCDwzoKAOB/rgLz365cuaLdu3erurpagUDguo9LpVJKpVLp28lkMtNTAgB8JOM3+ffv369Lly7p6aefHvBxsVhM4XA4fUQikUxPCQDwkYDjOE4mT/zmN7+prKwsvfXWWwM+7quuYCKRiBKJhHJycjI5NQDAI8lkUuFweFDfwzN6iezTTz/V4cOH9cYbb9zwsaFQSKFQKJPTAAB8LKOXyOrr65Wbm6uVK1cO9x4AwCjhOjB9fX2qr69XRUWFgsGMf0YAADDKuQ7M4cOH1dnZqfXr11vsAQCMEq4vQcrLy5XhzwUAAG4h/C0yAIAJAgMAMEFgAAAmCAwAwASBAQCYIDAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYGLEP5Ly6mfJJJPJkT41AGCIrn7vHszngo14YLq7uyVJkUhkpE8NABgm3d3dCofDAz4m4Izwx1P29fXp7Nmzys7OViAQGNb/djKZVCQSUTweV05OzrD+ty2xe2Sxe+T5dTu7r+U4jrq7u5Wfn68xYwZ+l2XEr2DGjBmjKVOmmJ4jJyfHV18MV7F7ZLF75Pl1O7v7u9GVy1W8yQ8AMEFgAAAmRlVgQqGQXnzxRYVCIa+nuMLukcXukefX7ewemhF/kx8AcGsYVVcwAICbB4EBAJggMAAAEwQGAGBi1ATmlVdeUUFBgcaNG6eSkhK9++67Xk+6oaNHj2rVqlXKz89XIBDQ/v37vZ40KLFYTIsWLVJ2drZyc3O1evVqnTx50utZN1RbW6t58+alf/msrKxMjY2NXs9yLRaLKRAIaMuWLV5PGdDWrVsVCAT6Hffee6/Xswbls88+01NPPaWJEyfqjjvu0P3336+2tjavZ93Q9OnTr/k3DwQCqqqq8mTPqAjM3r17tWXLFr3wwgv64IMP9MADDygajaqzs9PraQPq6enR/Pnz9fLLL3s9xZXm5mZVVVWppaVFTU1N+vLLL1VeXq6enh6vpw1oypQp2r59u44dO6Zjx47pkUce0eOPP64TJ054PW3QWltbVVdXp3nz5nk9ZVDmzJmjc+fOpY/jx497PemGvvjiCy1dulS33XabGhsb9dFHH+nXv/617rrrLq+n3VBra2u/f++mpiZJ0po1a7wZ5IwCX/va15zKysp+982cOdP56U9/6tEi9yQ5+/bt83pGRrq6uhxJTnNzs9dTXLv77rud3//+917PGJTu7m6nqKjIaWpqch588EFn8+bNXk8a0IsvvujMnz/f6xmuPffcc86yZcu8njEsNm/e7MyYMcPp6+vz5Py+v4K5cuWK2traVF5e3u/+8vJyvffeex6turUkEglJ0oQJEzxeMni9vb1qaGhQT0+PysrKvJ4zKFVVVVq5cqVWrFjh9ZRB6+joUH5+vgoKCvTkk0/q9OnTXk+6oQMHDqi0tFRr1qxRbm6uFixYoJ07d3o9y7UrV65o9+7dWr9+/bD/YeHB8n1gLly4oN7eXk2ePLnf/ZMnT9b58+c9WnXrcBxH1dXVWrZsmYqLi72ec0PHjx/XnXfeqVAopMrKSu3bt0+zZ8/2etYNNTQ06P3331csFvN6yqAtXrxYu3bt0qFDh7Rz506dP39eS5Ys0cWLF72eNqDTp0+rtrZWRUVFOnTokCorK7Vp0ybt2rXL62mu7N+/X5cuXdLTTz/t2YYR/2vKVv630I7jeFbtW8mGDRv04Ycf6m9/+5vXUwblvvvuU3t7uy5duqQ//elPqqioUHNz800dmXg8rs2bN+vtt9/WuHHjvJ4zaNFoNP2/586dq7KyMs2YMUOvvfaaqqurPVw2sL6+PpWWlmrbtm2SpAULFujEiROqra3V97//fY/XDd6rr76qaDSq/Px8zzb4/grmnnvu0dixY6+5Wunq6rrmqgbDa+PGjTpw4IDeeecd849gGC5ZWVkqLCxUaWmpYrGY5s+fr5deesnrWQNqa2tTV1eXSkpKFAwGFQwG1dzcrB07digYDKq3t9friYMyfvx4zZ07Vx0dHV5PGVBeXt41/4dj1qxZN/0PDf23Tz/9VIcPH9YPf/hDT3f4PjBZWVkqKSlJ/7TEVU1NTVqyZIlHq0Y3x3G0YcMGvfHGG/rrX/+qgoICrydlzHEcpVIpr2cMaPny5Tp+/Lja29vTR2lpqdauXav29naNHTvW64mDkkql9PHHHysvL8/rKQNaunTpNT92/8knn2jatGkeLXKvvr5eubm5Wrlypac7RsVLZNXV1Vq3bp1KS0tVVlamuro6dXZ2qrKy0utpA7p8+bJOnTqVvn3mzBm1t7drwoQJmjp1qofLBlZVVaU9e/bozTffVHZ2dvrqMRwO6/bbb/d43fU9//zzikajikQi6u7uVkNDg44cOaKDBw96PW1A2dnZ17y/NX78eE2cOPGmft/r2Wef1apVqzR16lR1dXXpF7/4hZLJpCoqKryeNqBnnnlGS5Ys0bZt2/Td735X//jHP1RXV6e6ujqvpw1KX1+f6uvrVVFRoWDQ42/xnvzsmoHf/va3zrRp05ysrCxn4cKFvviR2XfeeceRdM1RUVHh9bQBfdVmSU59fb3X0wa0fv369NfIpEmTnOXLlztvv/2217My4ocfU37iiSecvLw857bbbnPy8/Odb33rW86JEye8njUob731llNcXOyEQiFn5syZTl1dndeTBu3QoUOOJOfkyZNeT3H4c/0AABO+fw8GAHBzIjAAABMEBgBggsAAAEwQGACACQIDADBBYAAAJggMAMAEgQEAmCAwAAATBAYAYILAAABM/B/IT73Jw2IkmQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "maze_type = 'medium'\n", + "\n", + "def get_2d_colors(points, min_point, max_point):\n", + " \"\"\"Get colors corresponding to 2-D points.\"\"\"\n", + " points = np.array(points)\n", + " min_point = np.array(min_point)\n", + " max_point = np.array(max_point)\n", + "\n", + " colors = (points - min_point) / (max_point - min_point)\n", + " colors = np.hstack((colors, (2 - np.sum(colors, axis=1, keepdims=True)) / 2))\n", + " colors = np.clip(colors, 0, 1)\n", + " colors = np.c_[colors, np.full(len(colors), 0.8)]\n", + "\n", + " return colors\n", + "if maze_type == 'medium':\n", + " maze_map = [\n", + " [1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 0, 0, 1, 1, 0, 0, 1],\n", + " [1, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 1, 0, 0, 0, 1, 1, 1],\n", + " [1, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 0, 1, 0, 0, 1, 0, 1],\n", + " [1, 0, 0, 0, 1, 0, 0, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1],\n", + " ]\n", + "elif maze_type == 'large':\n", + " maze_map = [\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],\n", + " [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],\n", + " [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],\n", + " [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],\n", + " [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],\n", + " [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " ]\n", + "elif maze_type == 'giant':\n", + " maze_map = [\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],\n", + " [1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],\n", + " [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1],\n", + " [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1],\n", + " [1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1],\n", + " [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],\n", + " [1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1],\n", + " [1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1],\n", + " [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " ]\n", + "elif maze_type == 'teleport':\n", + " maze_map = [\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1],\n", + " [1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1],\n", + " [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1],\n", + " [1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],\n", + " [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],\n", + " [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],\n", + " [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " ]\n", + "\n", + "height, width = len(maze_map), len(maze_map[0])\n", + "print(height, width)\n", + "map = np.zeros((height, width, 3))\n", + "for i in range(height):\n", + " for j in range(width):\n", + " if maze_map[i][j] == 1:\n", + " map[i, j] = [1, 1, 1]\n", + " else:\n", + " map[i, j] = get_2d_colors([[i, j]], [0, 0], [height-1, width -1])[0, :3]\n", + "\n", + "import matplotlib.pyplot as plt\n", + "plt.imshow(map)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def modify_pos(pos):\n", + " pos_modif = pos / 4 + 1\n", + " return pos_modif\n", + "\n", + "def unnormalize_pos(pos):\n", + " # Normalizations\n", + " pos_mean = np.array([10.273524, 9.648321])\n", + " pos_std = np.array([5.627576, 4.897987])\n", + " pos_unnorm = pos * pos_std + pos_mean\n", + " return pos_unnorm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot e2s" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_461393/217026334.py:9: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n", + " pos_unnorm = pos * pos_std + pos_mean\n" + ] + } + ], + "source": [ + "import imageio\n", + "import cv2\n", + "frames = []\n", + "for i in range(0 ,5001, 50):\n", + " emb, pos, act = data.__getitem__(i)\n", + " pos_est = modify_pos(unnormalize_pos(model(emb).detach().numpy()))\n", + " pos = modify_pos(unnormalize_pos(pos))\n", + "\n", + " fig, ax = plt.subplots()\n", + " ax.imshow(map)\n", + " ax.scatter(pos_est[0], pos_est[1], c='red', s=20, label='Estimated', alpha=0.8)\n", + " ax.scatter(pos[0], pos[1], c='blue', s=20, label='True', alpha=0.8)\n", + " ax.invert_yaxis()\n", + " ax.legend()\n", + "\n", + " plt.savefig('temp_points.png')\n", + " plt.close(fig)\n", + "\n", + " frame = cv2.imread('temp_points.png')\n", + " frames.append(frame)\n", + "\n", + "imageio.mimsave('embed_2_pos.gif', frames, fps=1)\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Plot" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 101/101 [00:00<00:00, 1052.39it/s]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import cv2\n", + "import imageio\n", + "from tqdm import tqdm\n", + "\n", + "# Update font size\n", + "plt.rcParams.update({'font.size': 24})\n", + "\n", + "# Function to calculate L2 distance\n", + "def l2_distance(a, b):\n", + " return np.linalg.norm(a - b)\n", + "\n", + "# Create a list to store frames\n", + "frames = []\n", + "\n", + "# 예시) 전체 맵 이미지 불러오기\n", + "base_idx = 1001\n", + "base_obs, base_pos = data.__getitem__(base_idx)\n", + "base_obs = base_obs.unsqueeze(0) # (1, C, H, W)\n", + "base_emb = model.encode(base_obs.to('cuda:0'))[0].detach().cpu().numpy()\n", + "base_pos_x, base_pos_y = modify_pos(base_pos)\n", + "\n", + "max_latent_l2_dist = 0\n", + "max_pos_l2_dist = 0\n", + "\n", + "# find max l2 distance\n", + "for i in tqdm(range(base_idx, base_idx + 501, 5)):\n", + " obs, pos = data.__getitem__(i)\n", + " obs = obs.unsqueeze(0) # (1, C, H, W)\n", + " pos_l2_dist = l2_distance(pos, base_pos) \n", + " latent = model.encode(obs.to('cuda:0'))[0].detach().cpu().numpy() \n", + " latent_l2_dist = l2_distance(latent, base_emb) \n", + " max_latent_l2_dist = max(max_latent_l2_dist, latent_l2_dist)\n", + " max_pos_l2_dist = max(max_pos_l2_dist, pos_l2_dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 101/101 [00:21<00:00, 4.67it/s]\n" + ] + } + ], + "source": [ + "for i in tqdm(range(base_idx, base_idx + 501, 5)):\n", + " # 새 Figure & Subplots\n", + " fig, axs = plt.subplots(1, 4, figsize=(24, 9))\n", + "\n", + " # -------------------------------\n", + " # 1) 전체 맵 + 시작 위치 + 현재 위치\n", + " # -------------------------------\n", + " # (예시) full_map 을 첫 번째 subplot 에 표시\n", + " axs[0].imshow(map)\n", + " axs[0].set_title('Entire Map')\n", + "\n", + " # i번째 data\n", + " obs, pos = data.__getitem__(i)\n", + " \n", + " # 시작 위치, 현재 위치를 맵 상에 찍기\n", + " # ※ base_pos, pos 의 좌표계가 full_map 상의 (x, y) 픽셀 좌표와 같아야 제대로 표시됩니다.\n", + " pos_x, pos_y = modify_pos(pos)\n", + " axs[0].scatter(base_pos_x, base_pos_y, color='red', s=200, marker='o', label=\"Base Position\")\n", + " axs[0].scatter(pos_x, pos_y, color='blue', s=200, marker='o', label=\"Current Position\")\n", + " axs[0].invert_yaxis()\n", + " # -------------------------------\n", + " # 2) 관측 이미지(Observation) 표시\n", + " # -------------------------------\n", + " obs_unnormalized = (obs * 71.0288272312382 + 141.785487953533) / 255.0\n", + " obs_unnormalized = obs_unnormalized.permute(1, 2, 0).numpy()\n", + " axs[1].imshow(cv2.cvtColor(obs_unnormalized, cv2.COLOR_BGR2RGB))\n", + " axs[1].set_title(f'Observation {i - base_idx}')\n", + "\n", + " # -------------------------------\n", + " # 3) Reconstruction 이미지 표시\n", + " # -----------------------------\n", + " with torch.no_grad():\n", + " recon, _, mu, log_var = model(obs.unsqueeze(0).to('cuda:0'))\n", + " recon_normalized = (recon * 71.0288272312382 + 141.785487953533) / 255.0\n", + " recon_normalized = recon_normalized.squeeze(0).permute(1, 2, 0).cpu().numpy()\n", + " axs[2].imshow(cv2.cvtColor(recon_normalized, cv2.COLOR_BGR2RGB))\n", + " axs[2].set_title(f'Reconstruction {i - base_idx}')\n", + "\n", + " # -------------------------------\n", + " # 4) L2 Distance 막대 그래프\n", + " # -------------------------------\n", + " obs = obs.unsqueeze(0) # (1, C, H, W)\n", + " pos_l2_dist = l2_distance(pos, base_pos) / max_pos_l2_dist\n", + " latent = model.reparameterize(mu, log_var).detach().cpu().numpy() \n", + " latent_l2_dist = l2_distance(latent, base_emb) / max_latent_l2_dist \n", + " axs[3].bar(['Position', 'Latent'],\n", + " [pos_l2_dist, latent_l2_dist],\n", + " color=['blue', 'red'])\n", + " axs[3].set_ylim(0, 1)\n", + " axs[3].set_title('L2 Distances (Normalized)')\n", + "\n", + " plt.savefig('temp_plot.png')\n", + " plt.close(fig)\n", + "\n", + " frame = cv2.imread('temp_plot.png')\n", + " if 340 < (i - base_idx) < 365: \n", + " cv2.imwrite(f'frame_{i-base_idx}.png', cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n", + " frames.append(frame)\n", + "\n", + "imageio.mimsave(f'{path_to_model}.gif', frames, fps=5)\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(np.float32(12.780803), np.float32(20.188135))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_latent_l2_dist, max_pos_l2_dist" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Observations" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/101 [00:00 0 else -1],\n", + " # color=['green', 'yellow', 'red' if pos_diff * obs_diff > 0 else 'blue'])\n", + " # axs[3].set_ylim(-1, 1)\n", + " # axs[3].set_title('L2 Distances (Change)')\n", + " # prev_obs_l2_dist = obs_l2_dist\n", + " # prev_pos_l2_dist = pos_l2_dist\n", + "\n", + " plt.savefig('temp_plot.png')\n", + " plt.close(fig)\n", + "\n", + " frame = cv2.imread('temp_plot.png')\n", + " # if 340 < (i - base_idx) < 365: \n", + " # cv2.imwrite(f'frame_{i-base_idx}.png', cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n", + " frames.append(frame)\n", + "\n", + "imageio.mimsave(f'obs.gif', frames, fps=5)\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.42555568, 0.07089323)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_obs_diff, max_pos_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "215.95131" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_obs_l2_dist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "og_game", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ogbench/pretrain/visualize_id.ipynb b/ogbench/pretrain/visualize_id.ipynb new file mode 100644 index 00000000..f2f853db --- /dev/null +++ b/ogbench/pretrain/visualize_id.ipynb @@ -0,0 +1,109 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np \n", + "from models.bvae import BetaVAE\n", + "from models.mlp import MLP\n", + "from dataset.vog_maze import VOGMaze2dOfflineRLDataset\n", + "from dataset.vog_maze_emb import VOGEmbeddingDataset\n", + "import sys\n", + "sys.path.append('../data_gen_scripts')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ogbench\n", + "\n", + "env = ogbench.locomaze.maze.make_maze_env(\n", + " loco_env_type='point', maze_env_type='maze',maze_type='giant', ob_type='pixels', render_mode='rgb_array',width=64,height=64, camera_name='back'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "\n", + "# Reset the environment and get the initial observation\n", + "observation = env.reset()[0]\n", + "print(observation.shape)\n", + "# Convert the observation array to an image\n", + "image = Image.fromarray(observation)\n", + "# Save the image \n", + "\n", + "image.save('initial_observation2.png')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load model & dataset\\||\n", + "dataset_type = 'emb'\n", + "\n", + "path_to_model = './weights/vae_weight.pth'\n", + "path_to_dataset = './data/original_data'\n", + "\n", + "path_to_e2s_model = './weights/e2s_weight.pth'\n", + "path_to_dataset = './data/embedded_data'\n", + "\n", + "\n", + "e2s_model = MLP(8, 2)\n", + "model.load_state_dict(torch.load(path_to_model))\n", + "print('Dataset loaded')\n", + "\n", + "model = BetaVAE().to('cuda:0')\n", + "model.load_state_dict(torch.load(path_to_model))\n", + "model.eval()\n", + "print('Model loaded')\n", + "\n", + "# load dataset - this will take a while\n", + "data = VOGMaze2dOfflineRLDataset(dataset_url=path_to_dataset, split='validation') \n", + "print('Dataset loaded')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "og_game", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/temp.ipynb b/temp.ipynb new file mode 100644 index 00000000..2297f943 --- /dev/null +++ b/temp.ipynb @@ -0,0 +1,140 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# load sample.gif\n", + "\n", + "import sys\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import cv2\n", + "\n", + "def load_gif(filename):\n", + " cap = cv2.VideoCapture(filename)\n", + " if not cap.isOpened():\n", + " print(\"Error: could not open video.\")\n", + " sys.exit(1)\n", + "\n", + " frames = []\n", + " while True:\n", + " ret, frame = cap.read()\n", + " if not ret:\n", + " break\n", + " frames.append(frame)\n", + " cap.release()\n", + "\n", + " return frames\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "gif = load_gif(\"sample.gif\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 89, 102, 135],\n", + " [ 92, 107, 159],\n", + " [ 92, 107, 159],\n", + " ...,\n", + " [123, 161, 196],\n", + " [169, 183, 204],\n", + " [173, 197, 201]],\n", + "\n", + " [[ 70, 74, 86],\n", + " [ 71, 71, 93],\n", + " [ 70, 70, 99],\n", + " ...,\n", + " [ 92, 107, 159],\n", + " [158, 178, 203],\n", + " [173, 197, 201]],\n", + "\n", + " [[ 69, 69, 74],\n", + " [ 68, 68, 72],\n", + " [ 68, 68, 74],\n", + " ...,\n", + " [ 85, 94, 152],\n", + " [158, 178, 203],\n", + " [184, 204, 210]],\n", + "\n", + " ...,\n", + "\n", + " [[207, 209, 210],\n", + " [209, 210, 210],\n", + " [209, 210, 210],\n", + " ...,\n", + " [209, 210, 210],\n", + " [209, 210, 210],\n", + " [208, 210, 210]],\n", + "\n", + " [[205, 209, 210],\n", + " [208, 210, 210],\n", + " [208, 210, 210],\n", + " ...,\n", + " [209, 210, 210],\n", + " [209, 210, 210],\n", + " [206, 210, 210]],\n", + "\n", + " [[198, 206, 208],\n", + " [202, 209, 210],\n", + " [205, 210, 210],\n", + " ...,\n", + " [208, 210, 210],\n", + " [207, 210, 210],\n", + " [203, 210, 210]]], shape=(64, 64, 3), dtype=uint8)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gif[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "og_game", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}