Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internnav/dataset/navdp_dataset_lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
image_size=224,
scene_data_scale=1.0,
trajectory_data_scale=1.0,
pixel_channel=7,
debug=False,
preload=False,
random_digit=False,
Expand All @@ -61,6 +62,7 @@ def __init__(
self.trajectory_afford_path = []
self.random_digit = random_digit
self.prior_sample = prior_sample
self.pixel_channel = pixel_channel
self.item_cnt = 0
self.batch_size = batch_size
self.batch_time_sum = 0.0
Expand Down Expand Up @@ -509,7 +511,8 @@ def __getitem__(self, index):
camera_intrinsic,
trajectory_base_extrinsic,
)
pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1)
if self.pixel_channel == 7:
pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1)

pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0
augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0
Expand Down
11 changes: 5 additions & 6 deletions internnav/model/basemodel/navdp/navdp_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif pretrained_model_name_or_path is None or len(pretrained_model_name_or_path) == 0:
pass
else:
incompatible_keys, _ = model.load_state_dict(
torch.load(pretrained_model_name_or_path)['state_dict'], strict=False
)
incompatible_keys, _ = model.load_state_dict(torch.load(pretrained_model_name_or_path), strict=False)
if len(incompatible_keys) > 0:
print(f'Incompatible keys: {incompatible_keys}')

Expand All @@ -66,13 +64,12 @@ def __init__(self, config: NavDPModelConfig):
self.model_config = ModelCfg(**config.model_cfg['model'])
else:
self.model_config = config

self.config.model_cfg['il']

self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}")
self.image_size = self.config.model_cfg['il']['image_size']
self.memory_size = self.config.model_cfg['il']['memory_size']
self.predict_size = self.config.model_cfg['il']['predict_size']
self.pixel_channel = self.config.model_cfg['il']['pixel_channel']
self.temporal_depth = self.config.model_cfg['il']['temporal_depth']
self.attention_heads = self.config.model_cfg['il']['heads']
self.input_channels = self.config.model_cfg['il']['channels']
Expand All @@ -83,7 +80,9 @@ def __init__(self, config: NavDPModelConfig):
self.rgbd_encoder = NavDP_RGBD_Backbone(
self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device
)
self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device)
self.pixel_encoder = NavDP_PixelGoal_Backbone(
self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device
)
self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device)
self.point_encoder = nn.Linear(3, self.token_dim)

Expand Down
4 changes: 2 additions & 2 deletions internnav/model/encoder/navdp_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _get_device(self):


class NavDP_PixelGoal_Backbone(nn.Module):
def __init__(self, image_size=224, embed_size=512, device='cuda:0'):
def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'):
super().__init__()
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'):
self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits'])
self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float()
self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(
in_channels=7,
in_channels=pixel_channel,
out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels,
kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size,
stride=self.pixelgoal_encoder.patch_embed.proj.stride,
Expand Down
9 changes: 5 additions & 4 deletions scripts/train/configs/navdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
inflection_weight_coef=3.2,
save_interval_epochs=5,
save_filter_frozen_weights=False,
load_from_ckpt=False,
ckpt_to_load='',
load_from_ckpt=True,
ckpt_to_load='/shared/smartbot_new/caiwenzhe/InternNav/checkpoints/cross-waic-final4-125.ckpt',
lmdb_map_size=1e12,
dataset_r2r_root_dir='data/vln_pe/raw_data/r2r',
dataset_3dgs_root_dir='',
Expand All @@ -48,9 +48,10 @@
lerobot_features_dir='data/vln_pe/traj_data/r2r',
camera_name='pano_camera_0',
report_to='tensorboard', # wandb, tensorboard, none
dataset_navdp='data/datasets/navdp_dataset_lerobot.json',
root_dir='data/datasets/InternData-N1/vln_n1/traj_data',
dataset_navdp='./navdp_dataset_lerobot.json',
root_dir='/shared/smartbot_new/liuyu/vln-n1-minival/',
image_size=224,
pixel_channel=4,
scene_scale=1.0,
preload=False,
random_digit=False,
Expand Down
95 changes: 46 additions & 49 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,41 @@

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
import logging
import sys
from datetime import datetime
from pathlib import Path
import torch.distributed as dist

import torch
import torch.distributed as dist
import tyro
from pydantic import BaseModel
from transformers import TrainerCallback, TrainingArguments

from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn
from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn
from internnav.model import (
CMAModelConfig,
CMANet,
NavDPModelConfig,
NavDPNet,
RDPModelConfig,
RDPNet,
Seq2SeqModelConfig,
Seq2SeqNet,
NavDPNet,
NavDPModelConfig,
)
from internnav.model.utils.logger import MyLogger
from internnav.model.utils.utils import load_dataset
from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer
from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer
from scripts.train.configs import (
cma_exp_cfg,
cma_plus_exp_cfg,
navdp_exp_cfg,
rdp_exp_cfg,
seq2seq_exp_cfg,
seq2seq_plus_exp_cfg,
navdp_exp_cfg,
)
import sys
from datetime import datetime


class TrainCfg(BaseModel):
"""Training configuration class"""
Expand Down Expand Up @@ -68,29 +69,29 @@ def on_save(self, args, state, control, **kwargs):


def _make_dir(config):
config.tensorboard_dir = config.tensorboard_dir % config.name
config.tensorboard_dir = config.tensorboard_dir % config.name
config.checkpoint_folder = config.checkpoint_folder % config.name
config.log_dir = config.log_dir % config.name
config.output_dir = config.output_dir % config.name
if not os.path.exists(config.tensorboard_dir):
os.makedirs(config.tensorboard_dir,exist_ok=True)
os.makedirs(config.tensorboard_dir, exist_ok=True)
if not os.path.exists(config.checkpoint_folder):
os.makedirs(config.checkpoint_folder,exist_ok=True)
os.makedirs(config.checkpoint_folder, exist_ok=True)
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir,exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)


def main(config, model_class, model_config_class):
try:
"""Main training function."""
_make_dir(config)

print(f"=== Start training ===")
print("=== Start training ===")
print(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Environment variables:")
print("Environment variables:")
print(f" RANK: {os.getenv('RANK', 'Not set')}")
print(f" LOCAL_RANK: {os.getenv('LOCAL_RANK', 'Not set')}")
print(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'Not set')}")
Expand All @@ -101,28 +102,23 @@ def main(config, model_class, model_config_class):
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))

# Set CUDA device for each process
device_id = local_rank
torch.cuda.set_device(device_id)
device = torch.device(f'cuda:{device_id}')
print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}")

# Initialize distributed training environment
if world_size > 1:
try:
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank
)
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
print("Distributed initialization SUCCESS")
except Exception as e:
print(f"Distributed initialization FAILED: {str(e)}")
world_size = 1

print("="*50)
print("=" * 50)
print("After distributed init:")
print(f"LOCAL_RANK: {local_rank}")
print(f"WORLD_SIZE: {world_size}")
Expand Down Expand Up @@ -150,46 +146,46 @@ def main(config, model_class, model_config_class):
if buffer.device != device:
print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}")
buffer.data = buffer.data.to(device)

# If distributed training, wrap the model with DDP
if world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True
)
# ------------ load logger ------------
train_logger_filename = os.path.join(config.log_dir, 'train.log')
if dist.is_initialized() and dist.get_rank() == 0:
train_logger = MyLogger(
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename
name='train',
level=logging.INFO,
format_str='%(asctime)-15s %(message)s',
filename=train_logger_filename,
)
else:
# Other processes use console logging
train_logger = MyLogger(
name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s'
)
train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s')
transformers_logger = logging.getLogger("transformers")
if transformers_logger.hasHandlers():
transformers_logger.handlers = []
if config.model_name == "navdp" and local_rank in [0, -1]: # Only main process or non-distributed
transformers_logger.addHandler(train_logger.handlers[0])
transformers_logger.setLevel(logging.INFO)


# ------------ load dataset ------------
if config.model_name == "navdp":
train_dataset_data = NavDP_Base_Datset(config.il.root_dir,
config.il.dataset_navdp,
config.il.memory_size,
config.il.predict_size,
config.il.batch_size,
config.il.image_size,
config.il.scene_scale,
preload = config.il.preload,
random_digit = config.il.random_digit,
prior_sample = config.il.prior_sample)
train_dataset_data = NavDP_Base_Datset(
config.il.root_dir,
config.il.dataset_navdp,
config.il.memory_size,
config.il.predict_size,
config.il.batch_size,
config.il.image_size,
config.il.scene_scale,
pixel_channel=config.il.pixel_channel,
preload=config.il.preload,
random_digit=config.il.random_digit,
prior_sample=config.il.prior_sample,
)
else:
if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir:
dataset_root_dir = config.il.dataset_six_floor_root_dir
Expand Down Expand Up @@ -223,7 +219,7 @@ def main(config, model_class, model_config_class):
config,
config.il.lerobot_features_dir,
dataset_data=train_dataset_data,
batch_size=config.il.batch_size,
batch_size=config.il.batch_size,
)
collate_fn = rdp_collate_fn(global_batch_size=global_batch_size)
elif config.model_name == 'navdp':
Expand All @@ -238,7 +234,7 @@ def main(config, model_class, model_config_class):
remove_unused_columns=False,
deepspeed='',
gradient_checkpointing=False,
bf16=False,#fp16=False,
bf16=False, # fp16=False,
tf32=False,
per_device_train_batch_size=config.il.batch_size,
gradient_accumulation_steps=1,
Expand All @@ -249,7 +245,7 @@ def main(config, model_class, model_config_class):
lr_scheduler_type='cosine',
logging_steps=10.0,
num_train_epochs=config.il.epochs,
save_strategy='epoch',# no
save_strategy='epoch', # no
save_steps=config.il.save_interval_epochs,
save_total_limit=8,
report_to=config.il.report_to,
Expand All @@ -260,7 +256,7 @@ def main(config, model_class, model_config_class):
torch_compile_mode=None,
dataloader_drop_last=True,
disable_tqdm=True,
log_level="info"
log_level="info",
)

# Create the trainer
Expand All @@ -279,14 +275,15 @@ def main(config, model_class, model_config_class):
handler.flush()
except Exception as e:
import traceback

print(f"Unhandled exception: {str(e)}")
print("Stack trace:")
traceback.print_exc()

# If distributed environment, ensure all processes exit
if dist.is_initialized():
dist.destroy_process_group()

raise


Expand Down