Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/eval/configs/habitat_r2r_pix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ habitat:
dataset:
type: R2RVLN-v1
split: val_seen
scenes_dir: data/scene_datasets/
scenes_dir: data/scene_data/mp3d_ce
data_path: data/datasets/vln/mp3d/r2r/v1/{split}/{split}.json.gz
6 changes: 3 additions & 3 deletions scripts/eval/configs/vln_r2r.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ habitat:
look_down:
type: LookDownAction
agent_index: 0

dataset:
type: R2RVLN-v1
split: val_seen
scenes_dir: data/scene_data/
data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz
scenes_dir: data/scene_data/mp3d_ce
data_path: data/vln_ce/raw_data/r2r/{split}/{split}.json.gz
115 changes: 52 additions & 63 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,34 @@
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
sys.path.append('./src/diffusion-policy')
import logging
import sys
from datetime import datetime
from pathlib import Path
import torch.distributed as dist

import torch
import torch.distributed as dist
import tyro
from pydantic import BaseModel
from transformers import TrainerCallback, TrainingArguments

from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn
from internnav.model import (
CMAModelConfig,
CMANet,
RDPModelConfig,
RDPNet,
Seq2SeqModelConfig,
Seq2SeqNet,
NavDPNet,
NavDPModelConfig,
)
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
from internnav.model import get_config, get_policy
from internnav.model.utils.logger import MyLogger
from internnav.model.utils.utils import load_dataset
from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer
from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer
from scripts.train.configs import (
cma_exp_cfg,
cma_plus_exp_cfg,
navdp_exp_cfg,
rdp_exp_cfg,
seq2seq_exp_cfg,
seq2seq_plus_exp_cfg,
navdp_exp_cfg,
)
import sys
from datetime import datetime


class TrainCfg(BaseModel):
"""Training configuration class"""
Expand Down Expand Up @@ -68,16 +60,16 @@ def on_save(self, args, state, control, **kwargs):


def _make_dir(config):
config.tensorboard_dir = config.tensorboard_dir % config.name
config.tensorboard_dir = config.tensorboard_dir % config.name
config.checkpoint_folder = config.checkpoint_folder % config.name
config.log_dir = config.log_dir % config.name
config.output_dir = config.output_dir % config.name
if not os.path.exists(config.tensorboard_dir):
os.makedirs(config.tensorboard_dir,exist_ok=True)
os.makedirs(config.tensorboard_dir, exist_ok=True)
if not os.path.exists(config.checkpoint_folder):
os.makedirs(config.checkpoint_folder,exist_ok=True)
os.makedirs(config.checkpoint_folder, exist_ok=True)
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir,exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)


def main(config, model_class, model_config_class):
Expand All @@ -101,28 +93,23 @@ def main(config, model_class, model_config_class):
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))

# Set CUDA device for each process
device_id = local_rank
torch.cuda.set_device(device_id)
device = torch.device(f'cuda:{device_id}')
print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}")

# Initialize distributed training environment
if world_size > 1:
try:
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank
)
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
print("Distributed initialization SUCCESS")
except Exception as e:
print(f"Distributed initialization FAILED: {str(e)}")
world_size = 1

print("="*50)
print("=" * 50)
print("After distributed init:")
print(f"LOCAL_RANK: {local_rank}")
print(f"WORLD_SIZE: {world_size}")
Expand Down Expand Up @@ -150,46 +137,45 @@ def main(config, model_class, model_config_class):
if buffer.device != device:
print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}")
buffer.data = buffer.data.to(device)

# If distributed training, wrap the model with DDP
if world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True
)
# ------------ load logger ------------
train_logger_filename = os.path.join(config.log_dir, 'train.log')
if dist.is_initialized() and dist.get_rank() == 0:
train_logger = MyLogger(
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename
name='train',
level=logging.INFO,
format_str='%(asctime)-15s %(message)s',
filename=train_logger_filename,
)
else:
# Other processes use console logging
train_logger = MyLogger(
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s'
)
train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s')
transformers_logger = logging.getLogger("transformers")
if transformers_logger.hasHandlers():
transformers_logger.handlers = []
if config.model_name == "navdp" and local_rank in [0, -1]: # Only main process or non-distributed
transformers_logger.addHandler(train_logger.handlers[0])
transformers_logger.setLevel(logging.INFO)


# ------------ load dataset ------------
if config.model_name == "navdp":
train_dataset_data = NavDP_Base_Datset(config.il.root_dir,
config.il.dataset_navdp,
config.il.memory_size,
config.il.predict_size,
config.il.batch_size,
config.il.image_size,
config.il.scene_scale,
preload = config.il.preload,
random_digit = config.il.random_digit,
prior_sample = config.il.prior_sample)
train_dataset_data = NavDP_Base_Datset(
config.il.root_dir,
config.il.dataset_navdp,
config.il.memory_size,
config.il.predict_size,
config.il.batch_size,
config.il.image_size,
config.il.scene_scale,
preload=config.il.preload,
random_digit=config.il.random_digit,
prior_sample=config.il.prior_sample,
)
else:
if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir:
dataset_root_dir = config.il.dataset_six_floor_root_dir
Expand Down Expand Up @@ -223,7 +209,7 @@ def main(config, model_class, model_config_class):
config,
config.il.lerobot_features_dir,
dataset_data=train_dataset_data,
batch_size=config.il.batch_size,
batch_size=config.il.batch_size,
)
collate_fn = rdp_collate_fn(global_batch_size=global_batch_size)
elif config.model_name == 'navdp':
Expand All @@ -238,7 +224,7 @@ def main(config, model_class, model_config_class):
remove_unused_columns=False,
deepspeed='',
gradient_checkpointing=False,
bf16=False,#fp16=False,
bf16=False, # fp16=False,
tf32=False,
per_device_train_batch_size=config.il.batch_size,
gradient_accumulation_steps=1,
Expand All @@ -249,7 +235,7 @@ def main(config, model_class, model_config_class):
lr_scheduler_type='cosine',
logging_steps=10.0,
num_train_epochs=config.il.epochs,
save_strategy='epoch',# no
save_strategy='epoch', # no
save_steps=config.il.save_interval_epochs,
save_total_limit=8,
report_to=config.il.report_to,
Expand All @@ -260,7 +246,7 @@ def main(config, model_class, model_config_class):
torch_compile_mode=None,
dataloader_drop_last=True,
disable_tqdm=True,
log_level="info"
log_level="info",
)

# Create the trainer
Expand All @@ -279,14 +265,15 @@ def main(config, model_class, model_config_class):
handler.flush()
except Exception as e:
import traceback

print(f"Unhandled exception: {str(e)}")
print("Stack trace:")
traceback.print_exc()

# If distributed environment, ensure all processes exit
if dist.is_initialized():
dist.destroy_process_group()

raise


Expand All @@ -304,18 +291,20 @@ def main(config, model_class, model_config_class):

# Select configuration based on model_name
supported_cfg = {
'seq2seq': [seq2seq_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig],
'seq2seq_plus': [seq2seq_plus_exp_cfg, Seq2SeqNet, Seq2SeqModelConfig],
'cma': [cma_exp_cfg, CMANet, CMAModelConfig],
'cma_plus': [cma_plus_exp_cfg, CMANet, CMAModelConfig],
'rdp': [rdp_exp_cfg, RDPNet, RDPModelConfig],
'navdp': [navdp_exp_cfg, NavDPNet, NavDPModelConfig],
'seq2seq': [seq2seq_exp_cfg, "Seq2Seq_Policy"],
'seq2seq_plus': [seq2seq_plus_exp_cfg, 'Seq2Seq_Policy'],
'cma': [cma_exp_cfg, "CMA_Policy"],
'cma_plus': [cma_plus_exp_cfg, "CMA_Policy"],
'rdp': [rdp_exp_cfg, "RDP_Policy"],
'navdp': [navdp_exp_cfg, "NavDP_Policy"],
}

if config.model_name not in supported_cfg:
raise ValueError(f'Invalid model name: {config.model_name}. Supported models are: {list(supported_cfg.keys())}')

exp_cfg, model_class, model_config_class = supported_cfg[config.model_name]
exp_cfg, policy_name = supported_cfg[config.model_name]
model_class, model_config_class = get_policy(policy_name), get_config(policy_name)

exp_cfg.name = config.name
exp_cfg.num_gpus = len(exp_cfg.torch_gpu_ids)
exp_cfg.world_size = exp_cfg.num_gpus
Expand Down