diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 4437e7c4..f5b3a2ff 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -41,6 +41,7 @@ def __init__( image_size=224, scene_data_scale=1.0, trajectory_data_scale=1.0, + pixel_channel=7, debug=False, preload=False, random_digit=False, @@ -61,6 +62,7 @@ def __init__( self.trajectory_afford_path = [] self.random_digit = random_digit self.prior_sample = prior_sample + self.pixel_channel = pixel_channel self.item_cnt = 0 self.batch_size = batch_size self.batch_time_sum = 0.0 @@ -509,7 +511,8 @@ def __getitem__(self, index): camera_intrinsic, trajectory_base_extrinsic, ) - pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) + if self.pixel_channel == 7: + pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 6a8da420..1d16017e 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -51,9 +51,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): elif pretrained_model_name_or_path is None or len(pretrained_model_name_or_path) == 0: pass else: - incompatible_keys, _ = model.load_state_dict( - torch.load(pretrained_model_name_or_path)['state_dict'], strict=False - ) + incompatible_keys, _ = model.load_state_dict(torch.load(pretrained_model_name_or_path), strict=False) if len(incompatible_keys) > 0: print(f'Incompatible keys: {incompatible_keys}') @@ -66,13 +64,12 @@ def __init__(self, config: NavDPModelConfig): self.model_config = ModelCfg(**config.model_cfg['model']) else: self.model_config = config - self.config.model_cfg['il'] - self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}") self.image_size = self.config.model_cfg['il']['image_size'] self.memory_size = self.config.model_cfg['il']['memory_size'] self.predict_size = self.config.model_cfg['il']['predict_size'] + self.pixel_channel = self.config.model_cfg['il']['pixel_channel'] self.temporal_depth = self.config.model_cfg['il']['temporal_depth'] self.attention_heads = self.config.model_cfg['il']['heads'] self.input_channels = self.config.model_cfg['il']['channels'] @@ -83,14 +80,16 @@ def __init__(self, config: NavDPModelConfig): self.rgbd_encoder = NavDP_RGBD_Backbone( self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device ) - self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.pixel_encoder = NavDP_PixelGoal_Backbone( + self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device + ) self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device) self.point_encoder = nn.Linear(3, self.token_dim) if not self.finetune: - for p in self.rgbd_encoder.parameters(): + for p in self.rgbd_encoder.rgb_model.parameters(): p.requires_grad = False - self.rgbd_encoder.eval() + self.rgbd_encoder.rgb_model.eval() decoder_layer = nn.TransformerDecoderLayer( d_model=self.token_dim, @@ -349,3 +348,7 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] return negative_trajectory, positive_trajectory + + +# if __name__ == "__main__": +# policy = NavDPNet(config=) \ No newline at end of file diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index cd2e8794..680c67b9 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -377,7 +377,7 @@ def _get_device(self): class NavDP_PixelGoal_Backbone(nn.Module): - def __init__(self, image_size=224, embed_size=512, device='cuda:0'): + def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'): super().__init__() if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'): self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d( - in_channels=7, + in_channels=pixel_channel, out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels, kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size, stride=self.pixelgoal_encoder.patch_embed.proj.stride, diff --git a/scripts/train/train.py b/scripts/train/train.py index 060af53f..d6745a9b 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -3,40 +3,38 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) import logging +import sys +from datetime import datetime from pathlib import Path -import torch.distributed as dist import torch +import torch.distributed as dist import tyro from pydantic import BaseModel from transformers import TrainerCallback, TrainingArguments from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn -from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn -from internnav.model import ( - CMAModelConfig, - CMANet, - RDPModelConfig, - RDPNet, +from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn +from internnav.model.basemodel.cma.cma_policy import CMAModelConfig, CMANet +from internnav.model.basemodel.navdp.navdp_policy import NavDPModelConfig, NavDPNet +from internnav.model.basemodel.rdp.rdp_policy import RDPModelConfig, RDPNet +from internnav.model.basemodel.seq2seq.seq2seq_policy import ( Seq2SeqModelConfig, Seq2SeqNet, - NavDPNet, - NavDPModelConfig, ) from internnav.model.utils.logger import MyLogger from internnav.model.utils.utils import load_dataset -from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer +from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer from scripts.train.configs import ( cma_exp_cfg, cma_plus_exp_cfg, + navdp_exp_cfg, rdp_exp_cfg, seq2seq_exp_cfg, seq2seq_plus_exp_cfg, - navdp_exp_cfg, ) -import sys -from datetime import datetime + class TrainCfg(BaseModel): """Training configuration class""" @@ -68,16 +66,16 @@ def on_save(self, args, state, control, **kwargs): def _make_dir(config): - config.tensorboard_dir = config.tensorboard_dir % config.name + config.tensorboard_dir = config.tensorboard_dir % config.name config.checkpoint_folder = config.checkpoint_folder % config.name config.log_dir = config.log_dir % config.name config.output_dir = config.output_dir % config.name if not os.path.exists(config.tensorboard_dir): - os.makedirs(config.tensorboard_dir,exist_ok=True) + os.makedirs(config.tensorboard_dir, exist_ok=True) if not os.path.exists(config.checkpoint_folder): - os.makedirs(config.checkpoint_folder,exist_ok=True) + os.makedirs(config.checkpoint_folder, exist_ok=True) if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir,exist_ok=True) + os.makedirs(config.log_dir, exist_ok=True) def main(config, model_class, model_config_class): @@ -85,12 +83,12 @@ def main(config, model_class, model_config_class): """Main training function.""" _make_dir(config) - print(f"=== Start training ===") + print("=== Start training ===") print(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA device count: {torch.cuda.device_count()}") - print(f"Environment variables:") + print("Environment variables:") print(f" RANK: {os.getenv('RANK', 'Not set')}") print(f" LOCAL_RANK: {os.getenv('LOCAL_RANK', 'Not set')}") print(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'Not set')}") @@ -101,28 +99,23 @@ def main(config, model_class, model_config_class): local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) rank = int(os.getenv('RANK', '0')) - + # Set CUDA device for each process device_id = local_rank torch.cuda.set_device(device_id) device = torch.device(f'cuda:{device_id}') print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}") - + # Initialize distributed training environment if world_size > 1: try: - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=world_size, - rank=rank - ) + dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) print("Distributed initialization SUCCESS") except Exception as e: print(f"Distributed initialization FAILED: {str(e)}") world_size = 1 - print("="*50) + print("=" * 50) print("After distributed init:") print(f"LOCAL_RANK: {local_rank}") print(f"WORLD_SIZE: {world_size}") @@ -150,26 +143,24 @@ def main(config, model_class, model_config_class): if buffer.device != device: print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}") buffer.data = buffer.data.to(device) - + # If distributed training, wrap the model with DDP if world_size > 1: model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=True + model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) # ------------ load logger ------------ train_logger_filename = os.path.join(config.log_dir, 'train.log') if dist.is_initialized() and dist.get_rank() == 0: train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename + name='train', + level=logging.INFO, + format_str='%(asctime)-15s %(message)s', + filename=train_logger_filename, ) else: # Other processes use console logging - train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s' - ) + train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s') transformers_logger = logging.getLogger("transformers") if transformers_logger.hasHandlers(): transformers_logger.handlers = [] @@ -177,19 +168,21 @@ def main(config, model_class, model_config_class): transformers_logger.addHandler(train_logger.handlers[0]) transformers_logger.setLevel(logging.INFO) - # ------------ load dataset ------------ if config.model_name == "navdp": - train_dataset_data = NavDP_Base_Datset(config.il.root_dir, - config.il.dataset_navdp, - config.il.memory_size, - config.il.predict_size, - config.il.batch_size, - config.il.image_size, - config.il.scene_scale, - preload = config.il.preload, - random_digit = config.il.random_digit, - prior_sample = config.il.prior_sample) + train_dataset_data = NavDP_Base_Datset( + config.il.root_dir, + config.il.dataset_navdp, + config.il.memory_size, + config.il.predict_size, + config.il.batch_size, + config.il.image_size, + config.il.scene_scale, + pixel_channel=config.il.pixel_channel, + preload=config.il.preload, + random_digit=config.il.random_digit, + prior_sample=config.il.prior_sample, + ) else: if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir: dataset_root_dir = config.il.dataset_six_floor_root_dir @@ -223,7 +216,7 @@ def main(config, model_class, model_config_class): config, config.il.lerobot_features_dir, dataset_data=train_dataset_data, - batch_size=config.il.batch_size, + batch_size=config.il.batch_size, ) collate_fn = rdp_collate_fn(global_batch_size=global_batch_size) elif config.model_name == 'navdp': @@ -238,7 +231,7 @@ def main(config, model_class, model_config_class): remove_unused_columns=False, deepspeed='', gradient_checkpointing=False, - bf16=False,#fp16=False, + bf16=False, # fp16=False, tf32=False, per_device_train_batch_size=config.il.batch_size, gradient_accumulation_steps=1, @@ -249,7 +242,7 @@ def main(config, model_class, model_config_class): lr_scheduler_type='cosine', logging_steps=10.0, num_train_epochs=config.il.epochs, - save_strategy='epoch',# no + save_strategy='epoch', # no save_steps=config.il.save_interval_epochs, save_total_limit=8, report_to=config.il.report_to, @@ -260,7 +253,7 @@ def main(config, model_class, model_config_class): torch_compile_mode=None, dataloader_drop_last=True, disable_tqdm=True, - log_level="info" + log_level="info", ) # Create the trainer @@ -279,14 +272,15 @@ def main(config, model_class, model_config_class): handler.flush() except Exception as e: import traceback + print(f"Unhandled exception: {str(e)}") print("Stack trace:") traceback.print_exc() - + # If distributed environment, ensure all processes exit if dist.is_initialized(): dist.destroy_process_group() - + raise