From a4b8157d556a6e8222a5bb441a41df6754602ca0 Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Thu, 16 Oct 2025 08:33:52 +0000 Subject: [PATCH] [FEAT] Add support for the navdp fine-tuning --- internnav/dataset/navdp_dataset_lerobot.py | 5 ++++- internnav/model/basemodel/navdp/navdp_policy.py | 7 +++---- internnav/model/encoder/navdp_backbone.py | 4 ++-- scripts/train/configs/navdp.py | 1 + scripts/train/train.py | 1 + 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 4437e7c4..f5b3a2ff 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -41,6 +41,7 @@ def __init__( image_size=224, scene_data_scale=1.0, trajectory_data_scale=1.0, + pixel_channel=7, debug=False, preload=False, random_digit=False, @@ -61,6 +62,7 @@ def __init__( self.trajectory_afford_path = [] self.random_digit = random_digit self.prior_sample = prior_sample + self.pixel_channel = pixel_channel self.item_cnt = 0 self.batch_size = batch_size self.batch_time_sum = 0.0 @@ -509,7 +511,8 @@ def __getitem__(self, index): camera_intrinsic, trajectory_base_extrinsic, ) - pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) + if self.pixel_channel == 7: + pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 6a8da420..0eb2cfd0 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -52,7 +52,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): pass else: incompatible_keys, _ = model.load_state_dict( - torch.load(pretrained_model_name_or_path)['state_dict'], strict=False + torch.load(pretrained_model_name_or_path), strict=False ) if len(incompatible_keys) > 0: print(f'Incompatible keys: {incompatible_keys}') @@ -66,13 +66,12 @@ def __init__(self, config: NavDPModelConfig): self.model_config = ModelCfg(**config.model_cfg['model']) else: self.model_config = config - self.config.model_cfg['il'] - self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}") self.image_size = self.config.model_cfg['il']['image_size'] self.memory_size = self.config.model_cfg['il']['memory_size'] self.predict_size = self.config.model_cfg['il']['predict_size'] + self.pixel_channel = self.config.model_cfg['il']['pixel_channel'] self.temporal_depth = self.config.model_cfg['il']['temporal_depth'] self.attention_heads = self.config.model_cfg['il']['heads'] self.input_channels = self.config.model_cfg['il']['channels'] @@ -83,7 +82,7 @@ def __init__(self, config: NavDPModelConfig): self.rgbd_encoder = NavDP_RGBD_Backbone( self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device ) - self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device) self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device) self.point_encoder = nn.Linear(3, self.token_dim) diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index cd2e8794..680c67b9 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -377,7 +377,7 @@ def _get_device(self): class NavDP_PixelGoal_Backbone(nn.Module): - def __init__(self, image_size=224, embed_size=512, device='cuda:0'): + def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'): super().__init__() if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'): self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d( - in_channels=7, + in_channels=pixel_channel, out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels, kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size, stride=self.pixelgoal_encoder.patch_embed.proj.stride, diff --git a/scripts/train/configs/navdp.py b/scripts/train/configs/navdp.py index 4085f8a8..7d7924fb 100644 --- a/scripts/train/configs/navdp.py +++ b/scripts/train/configs/navdp.py @@ -51,6 +51,7 @@ dataset_navdp='data/datasets/navdp_dataset_lerobot.json', root_dir='data/datasets/InternData-N1/vln_n1/traj_data', image_size=224, + pixel_channel=7, scene_scale=1.0, preload=False, random_digit=False, diff --git a/scripts/train/train.py b/scripts/train/train.py index 060af53f..6182ca42 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -187,6 +187,7 @@ def main(config, model_class, model_config_class): config.il.batch_size, config.il.image_size, config.il.scene_scale, + pixel_channel=config.il.pixel_channel, preload = config.il.preload, random_digit = config.il.random_digit, prior_sample = config.il.prior_sample)