diff --git a/scripts/eval/configs/habitat_r2r_pix.yaml b/scripts/eval/configs/habitat_r2r_pix.yaml index b0cadaea..454cdef9 100644 --- a/scripts/eval/configs/habitat_r2r_pix.yaml +++ b/scripts/eval/configs/habitat_r2r_pix.yaml @@ -79,5 +79,5 @@ habitat: dataset: type: R2RVLN-v1 split: val_seen - scenes_dir: data/scene_datasets/ + scenes_dir: data/scene_data/mp3d_ce data_path: data/datasets/vln/mp3d/r2r/v1/{split}/{split}.json.gz diff --git a/scripts/eval/configs/vln_r2r.yaml b/scripts/eval/configs/vln_r2r.yaml index 379f6d7a..ed8361c8 100644 --- a/scripts/eval/configs/vln_r2r.yaml +++ b/scripts/eval/configs/vln_r2r.yaml @@ -69,9 +69,9 @@ habitat: look_down: type: LookDownAction agent_index: 0 - + dataset: type: R2RVLN-v1 split: val_seen - scenes_dir: data/scene_data/ - data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz \ No newline at end of file + scenes_dir: data/scene_data/mp3d_ce + data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz diff --git a/scripts/train/train.py b/scripts/train/train.py index 060af53f..becab271 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,42 +1,34 @@ import os import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) +sys.path.append('./src/diffusion-policy') import logging +import sys +from datetime import datetime from pathlib import Path -import torch.distributed as dist import torch +import torch.distributed as dist import tyro from pydantic import BaseModel from transformers import TrainerCallback, TrainingArguments from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn -from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn -from internnav.model import ( - CMAModelConfig, - CMANet, - RDPModelConfig, - RDPNet, - Seq2SeqModelConfig, - Seq2SeqNet, - NavDPNet, - NavDPModelConfig, -) +from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn +from internnav.model import get_config, get_policy from internnav.model.utils.logger import MyLogger from internnav.model.utils.utils import load_dataset -from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer +from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer from scripts.train.configs import ( cma_exp_cfg, cma_plus_exp_cfg, + navdp_exp_cfg, rdp_exp_cfg, seq2seq_exp_cfg, seq2seq_plus_exp_cfg, - navdp_exp_cfg, ) -import sys -from datetime import datetime + class TrainCfg(BaseModel): """Training configuration class""" @@ -68,16 +60,16 @@ def on_save(self, args, state, control, **kwargs): def _make_dir(config): - config.tensorboard_dir = config.tensorboard_dir % config.name + config.tensorboard_dir = config.tensorboard_dir % config.name config.checkpoint_folder = config.checkpoint_folder % config.name config.log_dir = config.log_dir % config.name config.output_dir = config.output_dir % config.name if not os.path.exists(config.tensorboard_dir): - os.makedirs(config.tensorboard_dir,exist_ok=True) + os.makedirs(config.tensorboard_dir, exist_ok=True) if not os.path.exists(config.checkpoint_folder): - os.makedirs(config.checkpoint_folder,exist_ok=True) + os.makedirs(config.checkpoint_folder, exist_ok=True) if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir,exist_ok=True) + os.makedirs(config.log_dir, exist_ok=True) def main(config, model_class, model_config_class): @@ -101,28 +93,23 @@ def main(config, model_class, model_config_class): local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) rank = int(os.getenv('RANK', '0')) - + # Set CUDA device for each process device_id = local_rank torch.cuda.set_device(device_id) device = torch.device(f'cuda:{device_id}') print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}") - + # Initialize distributed training environment if world_size > 1: try: - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=world_size, - rank=rank - ) + dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) print("Distributed initialization SUCCESS") except Exception as e: print(f"Distributed initialization FAILED: {str(e)}") world_size = 1 - print("="*50) + print("=" * 50) print("After distributed init:") print(f"LOCAL_RANK: {local_rank}") print(f"WORLD_SIZE: {world_size}") @@ -150,26 +137,24 @@ def main(config, model_class, model_config_class): if buffer.device != device: print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}") buffer.data = buffer.data.to(device) - + # If distributed training, wrap the model with DDP if world_size > 1: model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=True + model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) # ------------ load logger ------------ train_logger_filename = os.path.join(config.log_dir, 'train.log') if dist.is_initialized() and dist.get_rank() == 0: train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename + name='train', + level=logging.INFO, + format_str='%(asctime)-15s %(message)s', + filename=train_logger_filename, ) else: # Other processes use console logging - train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s' - ) + train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s') transformers_logger = logging.getLogger("transformers") if transformers_logger.hasHandlers(): transformers_logger.handlers = [] @@ -177,19 +162,20 @@ def main(config, model_class, model_config_class): transformers_logger.addHandler(train_logger.handlers[0]) transformers_logger.setLevel(logging.INFO) - # ------------ load dataset ------------ if config.model_name == "navdp": - train_dataset_data = NavDP_Base_Datset(config.il.root_dir, - config.il.dataset_navdp, - config.il.memory_size, - config.il.predict_size, - config.il.batch_size, - config.il.image_size, - config.il.scene_scale, - preload = config.il.preload, - random_digit = config.il.random_digit, - prior_sample = config.il.prior_sample) + train_dataset_data = NavDP_Base_Datset( + config.il.root_dir, + config.il.dataset_navdp, + config.il.memory_size, + config.il.predict_size, + config.il.batch_size, + config.il.image_size, + config.il.scene_scale, + preload=config.il.preload, + random_digit=config.il.random_digit, + prior_sample=config.il.prior_sample, + ) else: if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir: dataset_root_dir = config.il.dataset_six_floor_root_dir @@ -223,7 +209,7 @@ def main(config, model_class, model_config_class): config, config.il.lerobot_features_dir, dataset_data=train_dataset_data, - batch_size=config.il.batch_size, + batch_size=config.il.batch_size, ) collate_fn = rdp_collate_fn(global_batch_size=global_batch_size) elif config.model_name == 'navdp': @@ -238,7 +224,7 @@ def main(config, model_class, model_config_class): remove_unused_columns=False, deepspeed='', gradient_checkpointing=False, - bf16=False,#fp16=False, + bf16=False, # fp16=False, tf32=False, per_device_train_batch_size=config.il.batch_size, gradient_accumulation_steps=1, @@ -249,7 +235,7 @@ def main(config, model_class, model_config_class): lr_scheduler_type='cosine', logging_steps=10.0, num_train_epochs=config.il.epochs, - save_strategy='epoch',# no + save_strategy='epoch', # no save_steps=config.il.save_interval_epochs, save_total_limit=8, report_to=config.il.report_to, @@ -260,7 +246,7 @@ def main(config, model_class, model_config_class): torch_compile_mode=None, dataloader_drop_last=True, disable_tqdm=True, - log_level="info" + log_level="info", ) # Create the trainer @@ -279,14 +265,15 @@ def main(config, model_class, model_config_class): handler.flush() except Exception as e: import traceback + print(f"Unhandled exception: {str(e)}") print("Stack trace:") traceback.print_exc() - + # If distributed environment, ensure all processes exit if dist.is_initialized(): dist.destroy_process_group() - + raise @@ -304,18 +291,20 @@ def main(config, model_class, model_config_class): # Select configuration based on model_name supported_cfg = { - 'seq2seq': [seq2seq_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig], - 'seq2seq_plus': [seq2seq_plus_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig], - 'cma': [cma_exp_cfg, CMANet, CMAModelConfig], - 'cma_plus': [cma_plus_exp_cfg, CMANet, CMAModelConfig], - 'rdp': [rdp_exp_cfg, RDPNet, RDPModelConfig], - 'navdp': [navdp_exp_cfg, NavDPNet, NavDPModelConfig], + 'seq2seq': [seq2seq_exp_cfg, "Seq2Seq_Policy"], + 'seq2seq_plus': [seq2seq_plus_exp_cfg, 'Seq2Seq_Policy'], + 'cma': [cma_exp_cfg, "CMA_Policy"], + 'cma_plus': [cma_plus_exp_cfg, "CMA_Policy"], + 'rdp': [rdp_exp_cfg, "RDP_Policy"], + 'navdp': [navdp_exp_cfg, "NavDP_Policy"], } if config.model_name not in supported_cfg: raise ValueError(f'Invalid model name: {config.model_name}. Supported models are: {list(supported_cfg.keys())}') - exp_cfg, model_class, model_config_class = supported_cfg[config.model_name] + exp_cfg, policy_name = supported_cfg[config.model_name] + model_class, model_config_class = get_policy(policy_name), get_config(policy_name) + exp_cfg.name = config.name exp_cfg.num_gpus = len(exp_cfg.torch_gpu_ids) exp_cfg.world_size = exp_cfg.num_gpus