From cc5bb8130ae1b6a93ef0cf7aa7bd0a489f70039d Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Thu, 16 Oct 2025 11:32:33 +0000 Subject: [PATCH 1/4] [FEAT] Add support for navdp finetuning --- internnav/dataset/navdp_dataset_lerobot.py | 5 +- .../model/basemodel/navdp/navdp_policy.py | 11 +-- internnav/model/encoder/navdp_backbone.py | 4 +- scripts/train/configs/navdp.py | 9 +- scripts/train/train.py | 95 +++++++++---------- 5 files changed, 62 insertions(+), 62 deletions(-) diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 4437e7c4..f5b3a2ff 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -41,6 +41,7 @@ def __init__( image_size=224, scene_data_scale=1.0, trajectory_data_scale=1.0, + pixel_channel=7, debug=False, preload=False, random_digit=False, @@ -61,6 +62,7 @@ def __init__( self.trajectory_afford_path = [] self.random_digit = random_digit self.prior_sample = prior_sample + self.pixel_channel = pixel_channel self.item_cnt = 0 self.batch_size = batch_size self.batch_time_sum = 0.0 @@ -509,7 +511,8 @@ def __getitem__(self, index): camera_intrinsic, trajectory_base_extrinsic, ) - pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) + if self.pixel_channel == 7: + pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 6a8da420..784044d7 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -51,9 +51,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): elif pretrained_model_name_or_path is None or len(pretrained_model_name_or_path) == 0: pass else: - incompatible_keys, _ = model.load_state_dict( - torch.load(pretrained_model_name_or_path)['state_dict'], strict=False - ) + incompatible_keys, _ = model.load_state_dict(torch.load(pretrained_model_name_or_path), strict=False) if len(incompatible_keys) > 0: print(f'Incompatible keys: {incompatible_keys}') @@ -66,13 +64,12 @@ def __init__(self, config: NavDPModelConfig): self.model_config = ModelCfg(**config.model_cfg['model']) else: self.model_config = config - self.config.model_cfg['il'] - self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}") self.image_size = self.config.model_cfg['il']['image_size'] self.memory_size = self.config.model_cfg['il']['memory_size'] self.predict_size = self.config.model_cfg['il']['predict_size'] + self.pixel_channel = self.config.model_cfg['il']['pixel_channel'] self.temporal_depth = self.config.model_cfg['il']['temporal_depth'] self.attention_heads = self.config.model_cfg['il']['heads'] self.input_channels = self.config.model_cfg['il']['channels'] @@ -83,7 +80,9 @@ def __init__(self, config: NavDPModelConfig): self.rgbd_encoder = NavDP_RGBD_Backbone( self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device ) - self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.pixel_encoder = NavDP_PixelGoal_Backbone( + self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device + ) self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device) self.point_encoder = nn.Linear(3, self.token_dim) diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index cd2e8794..680c67b9 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -377,7 +377,7 @@ def _get_device(self): class NavDP_PixelGoal_Backbone(nn.Module): - def __init__(self, image_size=224, embed_size=512, device='cuda:0'): + def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'): super().__init__() if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'): self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d( - in_channels=7, + in_channels=pixel_channel, out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels, kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size, stride=self.pixelgoal_encoder.patch_embed.proj.stride, diff --git a/scripts/train/configs/navdp.py b/scripts/train/configs/navdp.py index 4085f8a8..7858c175 100644 --- a/scripts/train/configs/navdp.py +++ b/scripts/train/configs/navdp.py @@ -38,8 +38,8 @@ inflection_weight_coef=3.2, save_interval_epochs=5, save_filter_frozen_weights=False, - load_from_ckpt=False, - ckpt_to_load='', + load_from_ckpt=True, + ckpt_to_load='/shared/smartbot_new/caiwenzhe/InternNav/checkpoints/cross-waic-final4-125.ckpt', lmdb_map_size=1e12, dataset_r2r_root_dir='data/vln_pe/raw_data/r2r', dataset_3dgs_root_dir='', @@ -48,9 +48,10 @@ lerobot_features_dir='data/vln_pe/traj_data/r2r', camera_name='pano_camera_0', report_to='tensorboard', # wandb, tensorboard, none - dataset_navdp='data/datasets/navdp_dataset_lerobot.json', - root_dir='data/datasets/InternData-N1/vln_n1/traj_data', + dataset_navdp='./navdp_dataset_lerobot.json', + root_dir='/shared/smartbot_new/liuyu/vln-n1-minival/', image_size=224, + pixel_channel=4, scene_scale=1.0, preload=False, random_digit=False, diff --git a/scripts/train/train.py b/scripts/train/train.py index 060af53f..538a9c2d 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -3,40 +3,41 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) import logging +import sys +from datetime import datetime from pathlib import Path -import torch.distributed as dist import torch +import torch.distributed as dist import tyro from pydantic import BaseModel from transformers import TrainerCallback, TrainingArguments from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn -from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn +from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn from internnav.model import ( CMAModelConfig, CMANet, + NavDPModelConfig, + NavDPNet, RDPModelConfig, RDPNet, Seq2SeqModelConfig, Seq2SeqNet, - NavDPNet, - NavDPModelConfig, ) from internnav.model.utils.logger import MyLogger from internnav.model.utils.utils import load_dataset -from internnav.trainer import CMATrainer, RDPTrainer, NavDPTrainer +from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer from scripts.train.configs import ( cma_exp_cfg, cma_plus_exp_cfg, + navdp_exp_cfg, rdp_exp_cfg, seq2seq_exp_cfg, seq2seq_plus_exp_cfg, - navdp_exp_cfg, ) -import sys -from datetime import datetime + class TrainCfg(BaseModel): """Training configuration class""" @@ -68,16 +69,16 @@ def on_save(self, args, state, control, **kwargs): def _make_dir(config): - config.tensorboard_dir = config.tensorboard_dir % config.name + config.tensorboard_dir = config.tensorboard_dir % config.name config.checkpoint_folder = config.checkpoint_folder % config.name config.log_dir = config.log_dir % config.name config.output_dir = config.output_dir % config.name if not os.path.exists(config.tensorboard_dir): - os.makedirs(config.tensorboard_dir,exist_ok=True) + os.makedirs(config.tensorboard_dir, exist_ok=True) if not os.path.exists(config.checkpoint_folder): - os.makedirs(config.checkpoint_folder,exist_ok=True) + os.makedirs(config.checkpoint_folder, exist_ok=True) if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir,exist_ok=True) + os.makedirs(config.log_dir, exist_ok=True) def main(config, model_class, model_config_class): @@ -85,12 +86,12 @@ def main(config, model_class, model_config_class): """Main training function.""" _make_dir(config) - print(f"=== Start training ===") + print("=== Start training ===") print(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA device count: {torch.cuda.device_count()}") - print(f"Environment variables:") + print("Environment variables:") print(f" RANK: {os.getenv('RANK', 'Not set')}") print(f" LOCAL_RANK: {os.getenv('LOCAL_RANK', 'Not set')}") print(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'Not set')}") @@ -101,28 +102,23 @@ def main(config, model_class, model_config_class): local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) rank = int(os.getenv('RANK', '0')) - + # Set CUDA device for each process device_id = local_rank torch.cuda.set_device(device_id) device = torch.device(f'cuda:{device_id}') print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}") - + # Initialize distributed training environment if world_size > 1: try: - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=world_size, - rank=rank - ) + dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) print("Distributed initialization SUCCESS") except Exception as e: print(f"Distributed initialization FAILED: {str(e)}") world_size = 1 - print("="*50) + print("=" * 50) print("After distributed init:") print(f"LOCAL_RANK: {local_rank}") print(f"WORLD_SIZE: {world_size}") @@ -150,26 +146,24 @@ def main(config, model_class, model_config_class): if buffer.device != device: print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}") buffer.data = buffer.data.to(device) - + # If distributed training, wrap the model with DDP if world_size > 1: model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=True + model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) # ------------ load logger ------------ train_logger_filename = os.path.join(config.log_dir, 'train.log') if dist.is_initialized() and dist.get_rank() == 0: train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s', filename=train_logger_filename + name='train', + level=logging.INFO, + format_str='%(asctime)-15s %(message)s', + filename=train_logger_filename, ) else: # Other processes use console logging - train_logger = MyLogger( - name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s' - ) + train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s') transformers_logger = logging.getLogger("transformers") if transformers_logger.hasHandlers(): transformers_logger.handlers = [] @@ -177,19 +171,21 @@ def main(config, model_class, model_config_class): transformers_logger.addHandler(train_logger.handlers[0]) transformers_logger.setLevel(logging.INFO) - # ------------ load dataset ------------ if config.model_name == "navdp": - train_dataset_data = NavDP_Base_Datset(config.il.root_dir, - config.il.dataset_navdp, - config.il.memory_size, - config.il.predict_size, - config.il.batch_size, - config.il.image_size, - config.il.scene_scale, - preload = config.il.preload, - random_digit = config.il.random_digit, - prior_sample = config.il.prior_sample) + train_dataset_data = NavDP_Base_Datset( + config.il.root_dir, + config.il.dataset_navdp, + config.il.memory_size, + config.il.predict_size, + config.il.batch_size, + config.il.image_size, + config.il.scene_scale, + pixel_channel=config.il.pixel_channel, + preload=config.il.preload, + random_digit=config.il.random_digit, + prior_sample=config.il.prior_sample, + ) else: if '3dgs' in config.il.lmdb_features_dir or '3dgs' in config.il.lmdb_features_dir: dataset_root_dir = config.il.dataset_six_floor_root_dir @@ -223,7 +219,7 @@ def main(config, model_class, model_config_class): config, config.il.lerobot_features_dir, dataset_data=train_dataset_data, - batch_size=config.il.batch_size, + batch_size=config.il.batch_size, ) collate_fn = rdp_collate_fn(global_batch_size=global_batch_size) elif config.model_name == 'navdp': @@ -238,7 +234,7 @@ def main(config, model_class, model_config_class): remove_unused_columns=False, deepspeed='', gradient_checkpointing=False, - bf16=False,#fp16=False, + bf16=False, # fp16=False, tf32=False, per_device_train_batch_size=config.il.batch_size, gradient_accumulation_steps=1, @@ -249,7 +245,7 @@ def main(config, model_class, model_config_class): lr_scheduler_type='cosine', logging_steps=10.0, num_train_epochs=config.il.epochs, - save_strategy='epoch',# no + save_strategy='epoch', # no save_steps=config.il.save_interval_epochs, save_total_limit=8, report_to=config.il.report_to, @@ -260,7 +256,7 @@ def main(config, model_class, model_config_class): torch_compile_mode=None, dataloader_drop_last=True, disable_tqdm=True, - log_level="info" + log_level="info", ) # Create the trainer @@ -279,14 +275,15 @@ def main(config, model_class, model_config_class): handler.flush() except Exception as e: import traceback + print(f"Unhandled exception: {str(e)}") print("Stack trace:") traceback.print_exc() - + # If distributed environment, ensure all processes exit if dist.is_initialized(): dist.destroy_process_group() - + raise From ce6d62d23409b3d63e3ea46411be54a1ee9f9b2b Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Wed, 29 Oct 2025 10:47:42 +0000 Subject: [PATCH 2/4] [FIX] NavDP Training Gradient --- internnav/model/basemodel/navdp/navdp_policy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 784044d7..1d16017e 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -87,9 +87,9 @@ def __init__(self, config: NavDPModelConfig): self.point_encoder = nn.Linear(3, self.token_dim) if not self.finetune: - for p in self.rgbd_encoder.parameters(): + for p in self.rgbd_encoder.rgb_model.parameters(): p.requires_grad = False - self.rgbd_encoder.eval() + self.rgbd_encoder.rgb_model.eval() decoder_layer = nn.TransformerDecoderLayer( d_model=self.token_dim, @@ -348,3 +348,7 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] return negative_trajectory, positive_trajectory + + +# if __name__ == "__main__": +# policy = NavDPNet(config=) \ No newline at end of file From 8e66e4f5274ab90c6a9b22cf14a2b2a703f3b96c Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Thu, 30 Oct 2025 05:51:41 +0000 Subject: [PATCH 3/4] [FIX] Support NavDP Finetune --- internnav/dataset/navdp_dataset_lerobot.py | 69 ++++++++++++---------- scripts/train/train.py | 3 +- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index f5b3a2ff..c721d49d 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -70,6 +70,8 @@ def __init__( if preload is False: for group_dir in self.dataset_dirs: # gibson_zed, 3dfront ... + if 'fixed' in group_dir: + continue all_scene_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir))]) select_scene_dirs = all_scene_dirs[ np.arange(0, all_scene_dirs.shape[0], 1 / self.scene_scale_size).astype(np.int32) @@ -81,6 +83,9 @@ def __init__( ] for traj_dir in tqdm(select_traj_dirs): entire_task_dir = os.path.join(root_dirs, group_dir, scene_dir, traj_dir) + video_dir = os.path.join(entire_task_dir, "videos/") + if not os.path.exists(video_dir): + continue rgb_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.rgb/") depth_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.depth/") data_path = os.path.join( @@ -116,11 +121,11 @@ def __init__( json.dump(save_dict, f, indent=4) else: load_dict = json.load(open(preload_path, 'r')) - self.trajectory_dirs = load_dict['trajectory_dirs'] * 50 - self.trajectory_data_dir = load_dict['trajectory_data_dir'] * 50 - self.trajectory_rgb_path = load_dict['trajectory_rgb_path'] * 50 - self.trajectory_depth_path = load_dict['trajectory_depth_path'] * 50 - self.trajectory_afford_path = load_dict['trajectory_afford_path'] * 50 + self.trajectory_dirs = load_dict['trajectory_dirs'] #* 50 + self.trajectory_data_dir = load_dict['trajectory_data_dir'] #* 50 + self.trajectory_rgb_path = load_dict['trajectory_rgb_path'] #* 50 + self.trajectory_depth_path = load_dict['trajectory_depth_path']#* 50 + self.trajectory_afford_path = load_dict['trajectory_afford_path'] #* 50 def __len__(self): return len(self.trajectory_dirs) @@ -192,23 +197,26 @@ def process_path_points(self, index): return np.array(trajectory_path.points), trajectory_path def process_obstacle_points(self, index, path_points): - trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) - trajectory_color = np.array(trajectory_pcd.colors) - trajectory_points = np.array(trajectory_pcd.points) - color_distance = np.abs(trajectory_color - np.array([0, 0, 0.5])).sum(axis=-1) # the obstacles are save in blue - path_lower_bound = path_points.min(axis=0) - path_upper_bound = path_points.max(axis=0) - condition_x = (trajectory_points[:, 0] >= path_lower_bound[0] - 2.0) & ( - trajectory_points[:, 0] <= path_upper_bound[0] + 2.0 - ) - condition_y = (trajectory_points[:, 1] >= path_lower_bound[1] - 2.0) & ( - trajectory_points[:, 1] <= path_upper_bound[1] + 2.0 - ) - select_index = np.where((color_distance < 0.05) & condition_x & condition_y)[0] - trajectory_obstacle = o3d.geometry.PointCloud() - trajectory_obstacle.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) - trajectory_obstacle.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) - return np.array(trajectory_obstacle.points), trajectory_obstacle + try: + trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) + trajectory_color = np.array(trajectory_pcd.colors) + trajectory_points = np.array(trajectory_pcd.points) + color_distance = np.abs(trajectory_color - np.array([0, 0, 0.5])).sum(axis=-1) # the obstacles are save in blue + path_lower_bound = path_points.min(axis=0) + path_upper_bound = path_points.max(axis=0) + condition_x = (trajectory_points[:, 0] >= path_lower_bound[0] - 2.0) & ( + trajectory_points[:, 0] <= path_upper_bound[0] + 2.0 + ) + condition_y = (trajectory_points[:, 1] >= path_lower_bound[1] - 2.0) & ( + trajectory_points[:, 1] <= path_upper_bound[1] + 2.0 + ) + select_index = np.where((color_distance < 0.05) & condition_x & condition_y)[0] + trajectory_obstacle = o3d.geometry.PointCloud() + trajectory_obstacle.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) + trajectory_obstacle.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) + return np.array(trajectory_obstacle.points), trajectory_obstacle + except: + return np.zeros((0,3)), None def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1): memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step + 1, memory_digit) @@ -548,7 +556,8 @@ def __getitem__(self, index): augment_critic, float(pixel_flag), ) - + + def navdp_collate_fn(batch): @@ -569,17 +578,17 @@ def navdp_collate_fn(batch): if __name__ == "__main__": os.makedirs("./navdp_dataset_test/", exist_ok=True) dataset = NavDP_Base_Datset( - "/shared/smartbot_new/liuyu/vln-n1-minival/", - "./navdp_dataset_test/dataset_lerobot.json", + "/shared/smartbot_new/liuyu/vln-n1", + "/shared/smartbot_new/caiwenzhe/InternNav/internnav/dataset/navdp_dataset_test/dataset_lerobot_balanced.json", 8, 24, 224, - trajectory_data_scale=0.1, - scene_data_scale=0.1, - preload=False, + trajectory_data_scale=1.0, + scene_data_scale=1.0, + preload=True, ) - - for i in range(10): + from tqdm import tqdm + for i in tqdm(range(dataset.__len__())): ( point_goal, image_goal, diff --git a/scripts/train/train.py b/scripts/train/train.py index 538a9c2d..321164df 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -15,7 +15,7 @@ from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn -from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn +# from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn from internnav.model import ( CMAModelConfig, CMANet, @@ -258,7 +258,6 @@ def main(config, model_class, model_config_class): disable_tqdm=True, log_level="info", ) - # Create the trainer trainer = policy_trainer( config=config, model=model, args=training_args, train_dataset=train_dataset, data_collator=collate_fn From e2ed6617f580cd165c964b6342a117792c66fd7a Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Thu, 30 Oct 2025 07:36:26 +0000 Subject: [PATCH 4/4] [FIX] support NavDP finetune --- scripts/train/train.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 321164df..411d8406 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -15,17 +15,12 @@ from internnav.dataset.cma_lerobot_dataset import CMALerobotDataset, cma_collate_fn from internnav.dataset.navdp_dataset_lerobot import NavDP_Base_Datset, navdp_collate_fn -# from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn -from internnav.model import ( - CMAModelConfig, - CMANet, - NavDPModelConfig, - NavDPNet, - RDPModelConfig, - RDPNet, - Seq2SeqModelConfig, - Seq2SeqNet, -) +from internnav.dataset.rdp_lerobot_dataset import RDP_LerobotDataset, rdp_collate_fn + +from internnav.model.basemodel.seq2seq.seq2seq_policy import Seq2SeqModelConfig, Seq2SeqNet +from internnav.model.basemodel.cma.cma_policy import CMAModelConfig, CMANet +from internnav.model.basemodel.rdp.rdp_policy import RDPModelConfig, RDPNet +from internnav.model.basemodel.navdp.navdp_policy import NavDPModelConfig, NavDPNet from internnav.model.utils.logger import MyLogger from internnav.model.utils.utils import load_dataset from internnav.trainer import CMATrainer, NavDPTrainer, RDPTrainer