From 4753396738417f82aad54aa46c2104b803daefbe Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Mon, 13 Oct 2025 09:27:19 +0000 Subject: [PATCH 1/2] [Fix] update the minimal test for navdp dataset --- internnav/dataset/navdp_dataset_lerobot.py | 45 +++++++++++----------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 3e6ee7ca..1b0dbbfd 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -431,25 +431,26 @@ def navdp_collate_fn(batch): if __name__ == "__main__": - # Debug - dataset = NavDP_Base_Datset("/path/to/nav_20w_lerobot/", - "/path/to/navdp_trainer/output_test/multiview_dataset_lerobot.json", - 8,24,224,trajectory_data_scale=1.0,scene_data_scale=1.0,preload=True) - for i in range(200): - point_goal,image_goal,pixel_goal,memory_images,depth_image,pred_actions,augment_actions,pred_critic,augment_critic,pixel_flag = dataset.__getitem__(i) - pixel_obs = pixel_goal[:,:,0:3] * 255 - pixel_obs[pixel_goal[:,:,3]==1] = np.array([0,0,255]) - - draw_current_image = image_goal[:,:,3:6].copy()*255 - draw_current_image = cv2.putText(draw_current_image,"Current-Image",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - draw_goal_image = image_goal[:,:,0:3].copy()*255 - draw_goal_image = cv2.putText(draw_goal_image,"Image-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - draw_pixel_image = pixel_obs.copy() - draw_pixel_image = cv2.putText(draw_pixel_image,"Pixel-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - goal_info_image = np.concatenate((draw_current_image,draw_goal_image,draw_pixel_image),axis=1) - goal_info_image = cv2.putText(goal_info_image,"PointGoal=[{:.3f}, {:.3f}, {:.3f}]".format(*point_goal),(190,210),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - cv2.imwrite("./output_test/goal_information.png",goal_info_image) - \ No newline at end of file + os.makedirs("./navdp_dataset_test/", exist_ok=True) + dataset = NavDP_Base_Datset("/shared/smartbot_new/liuyu/vln-n1-minival/", + "./navdp_dataset_test/dataset_lerobot.json", + 8,24,224,trajectory_data_scale=0.1,scene_data_scale=0.1,preload=False) + + for i in range(10): + point_goal,image_goal,pixel_goal,memory_images,depth_image,pred_actions,augment_actions,pred_critic,augment_critic,pixel_flag = dataset.__getitem__(i) + if pixel_flag == 1.0: + pixel_obs = pixel_goal.numpy()[:,:,0:3] * 255 + pixel_obs[pixel_goal[:,:,3]==1] = np.array([0,0,255]) + + draw_current_image = cv2.cvtColor(image_goal[:,:,3:6].numpy()*255,cv2.COLOR_BGR2RGB) + draw_current_image = cv2.putText(draw_current_image,"Current-Image",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) + + draw_goal_image = cv2.cvtColor(image_goal[:,:,0:3].numpy()*255,cv2.COLOR_BGR2RGB) + draw_goal_image = cv2.putText(draw_goal_image,"Image-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) + + draw_pixel_image = cv2.cvtColor(pixel_obs.copy(),cv2.COLOR_BGR2RGB) + draw_pixel_image = cv2.putText(draw_pixel_image,"Pixel-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) + + goal_info_image = np.concatenate((draw_current_image,draw_goal_image,draw_pixel_image),axis=1) + goal_info_image = cv2.putText(goal_info_image,"PointGoal=[{:.3f}, {:.3f}, {:.3f}]".format(*point_goal),(190,210),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) + cv2.imwrite("./navdp_dataset_test/goal_information_%d.png"%i,goal_info_image) \ No newline at end of file From 031e3d8058fd566c1b5ed3bd07e927a0bf54f899 Mon Sep 17 00:00:00 2001 From: wzcai99 Date: Mon, 13 Oct 2025 09:40:41 +0000 Subject: [PATCH 2/2] [FIX] update the code for multi-gpu training and checkpoint save --- .../model/basemodel/navdp/navdp_policy.py | 84 +++++++++++-------- internnav/model/encoder/navdp_backbone.py | 81 +++++++++++++++++- internnav/trainer/navdp_trainer.py | 43 +++++++--- 3 files changed, 157 insertions(+), 51 deletions(-) diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 2d290869..699376d9 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import numpy as np import os +import random from scipy.signal import savgol_filter from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from internnav.model.encoder.navdp_backbone import * @@ -15,17 +16,13 @@ from internnav.configs.model.base_encoders import ModelCfg from internnav.configs.trainer.exp import ExpCfg - - class NavDPModelConfig(PretrainedConfig): model_type = 'navdp' - def __init__(self, **kwargs): super().__init__(**kwargs) # pass in navdp_exp_cfg self.model_cfg = kwargs.get('model_cfg', None) - @classmethod def from_dict(cls, config_dict): if 'model_cfg' in config_dict: @@ -89,7 +86,15 @@ def __init__(self, config: NavDPModelConfig): self.scratch=self.config.model_cfg['il']['scratch'] self.finetune=self.config.model_cfg['il']['finetune'] self.rgbd_encoder = NavDP_RGBD_Backbone(self.image_size,self.token_dim,memory_size=self.memory_size,finetune=self.finetune,device=self._device) + self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size,self.token_dim,device=self._device) + self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size,self.token_dim,device=self._device) self.point_encoder = nn.Linear(3,self.token_dim) + + if not self.finetune: + for p in self.rgbd_encoder.parameters(): + p.requires_grad = False + self.rgbd_encoder.eval() + decoder_layer = nn.TransformerDecoderLayer(d_model = self.token_dim, nhead = self.attention_heads, dim_feedforward = 4 * self.token_dim, @@ -101,7 +106,8 @@ def __init__(self, config: NavDPModelConfig): num_layers = self.temporal_depth) self.input_embed = nn.Linear(3,self.token_dim) - self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 2) + + self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 4) self.out_pos_embed = LearnablePositionalEncoding(self.token_dim, self.predict_size) self.drop = nn.Dropout(self.dropout) self.time_emb = SinusoidalPosEmb(self.token_dim) @@ -114,9 +120,13 @@ def __init__(self, config: NavDPModelConfig): prediction_type='epsilon') self.tgt_mask = (torch.triu(torch.ones(self.predict_size, self.predict_size)) == 1).transpose(0, 1) self.tgt_mask = self.tgt_mask.float().masked_fill(self.tgt_mask == 0, float('-inf')).masked_fill(self.tgt_mask == 1, float(0.0)) - self.cond_critic_mask = torch.zeros((self.predict_size,2 + self.memory_size * 16)) - self.cond_critic_mask[:,0:2] = float('-inf') self.tgt_mask = self.tgt_mask.to(self._device) + + self.cond_critic_mask = torch.zeros((self.predict_size,4 + self.memory_size * 16)) + self.cond_critic_mask[:,0:4] = float('-inf') + + self.pixel_aux_head = nn.Linear(self.token_dim,3) + self.image_aux_head = nn.Linear(self.token_dim,3) def to(self, device, *args, **kwargs): # first call the to method of the parent class @@ -131,10 +141,6 @@ def to(self, device, *args, **kwargs): return self def sample_noise(self,action): - # device = next(self.parameters()).device - # if device is None: - # device = action.device - # action = action.to(self._device) device = action.device noise = torch.randn(action.shape, device=device) timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps,(action.shape[0],), device=device).long() @@ -146,7 +152,7 @@ def sample_noise(self,action): def predict_noise(self,last_actions,timestep,goal_embed,rgbd_embed): action_embeds = self.input_embed(last_actions) time_embeds = self.time_emb(timestep.to(self._device)).unsqueeze(1) - cond_embedding = torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1)) + cond_embedding = torch.cat([time_embeds,goal_embed,goal_embed,goal_embed,rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([time_embeds,goal_embed,goal_embed,goal_embed,rgbd_embed],dim=1)) cond_embedding = cond_embedding.repeat(action_embeds.shape[0],1,1) input_embedding = action_embeds + self.out_pos_embed(action_embeds) output = self.decoder(tgt = input_embedding,memory = cond_embedding, tgt_mask = self.tgt_mask.to(self._device)) @@ -159,13 +165,13 @@ def predict_critic(self,predict_trajectory,rgbd_embed): nogoal_embed = torch.zeros_like(repeat_rgbd_embed[:,0:1]) action_embeddings = self.input_embed(predict_trajectory) action_embeddings = action_embeddings + self.out_pos_embed(action_embeddings) - cond_embeddings = torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1)) + cond_embeddings = torch.cat([nogoal_embed,nogoal_embed,nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([nogoal_embed,nogoal_embed,nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1)) critic_output = self.decoder(tgt = action_embeddings, memory = cond_embeddings, memory_mask = self.cond_critic_mask) critic_output = self.layernorm(critic_output) critic_output = self.critic_head(critic_output.mean(dim=1))[:,0] return critic_output - def forward(self,goal_point,goal_image,input_images,input_depths,output_actions,augment_actions): + def forward(self,goal_point,goal_image,goal_pixel,input_images,input_depths,output_actions,augment_actions): # """get device safely""" # # get device safely # try: @@ -193,53 +199,61 @@ def forward(self,goal_point,goal_image,input_images,input_depths,output_actions, input_depths = input_depths.to(device) ng_noise,ng_time_embed,ng_noisy_action_embed = self.sample_noise(tensor_label_actions) - pg_noise,pg_time_embed,pg_noisy_action_embed = self.sample_noise(tensor_label_actions) - # ig_noise,ig_time_embed,ig_noisy_action_embed = self.sample_noise(tensor_label_actions) + mg_noise,mg_time_embed,mg_noisy_action_embed = self.sample_noise(tensor_label_actions) rgbd_embed = self.rgbd_encoder(input_images,input_depths) pointgoal_embed = self.point_encoder(tensor_point_goal).unsqueeze(1) nogoal_embed = torch.zeros_like(pointgoal_embed) - # imagegoal_embed = torch.zeros_like(pointgoal_embed) + imagegoal_embed = self.image_encoder(goal_image).unsqueeze(1) + pixelgoal_embed = self.pixel_encoder(goal_pixel).unsqueeze(1) + + imagegoal_aux_pred = self.image_aux_head(imagegoal_embed[:,0]) + pixelgoal_aux_pred = self.pixel_aux_head(pixelgoal_embed[:,0]) label_embed = self.input_embed(tensor_label_actions).detach() augment_embed = self.input_embed(tensor_augment_actions).detach() - cond_pos_embed = self.cond_pos_embed(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1)) - ng_cond_embeddings = self.drop(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1) + cond_pos_embed) - pg_cond_embeddings = self.drop(torch.cat([pg_time_embed,pointgoal_embed,rgbd_embed],dim=1) + cond_pos_embed) - # ig_cond_embeddings = self.drop(torch.cat([ig_time_embed,imagegoal_embed,rgbd_embed],dim=1) + cond_pos_embed) + cond_pos_embed = self.cond_pos_embed(torch.cat([ng_time_embed,nogoal_embed,imagegoal_embed,pixelgoal_embed,rgbd_embed],dim=1)) + ng_cond_embeddings = self.drop(torch.cat([ng_time_embed,nogoal_embed,nogoal_embed,nogoal_embed,rgbd_embed],dim=1) + cond_pos_embed) + + cand_goal_embed = [pointgoal_embed,imagegoal_embed,pixelgoal_embed] + batch_size = pointgoal_embed.shape[0] + + # Generate deterministic selections for each sample in the batch using vectorized operations + batch_indices = torch.arange(batch_size, device=pointgoal_embed.device) + pattern_indices = batch_indices % 27 # 3^3 = 27 possible combinations + selections_0 = pattern_indices % 3 + selections_1 = (pattern_indices // 3) % 3 + selections_2 = (pattern_indices // 9) % 3 + goal_embeds = torch.stack(cand_goal_embed, dim=0) # [3, batch_size, 1, token_dim] + selected_goals_0 = goal_embeds[selections_0, torch.arange(batch_size), :, :] # [batch_size, 1, token_dim] + selected_goals_1 = goal_embeds[selections_1, torch.arange(batch_size), :, :] + selected_goals_2 = goal_embeds[selections_2, torch.arange(batch_size), :, :] + mg_cond_embed_tensor = torch.cat([mg_time_embed, selected_goals_0, selected_goals_1, selected_goals_2, rgbd_embed], dim=1) + mg_cond_embeddings = self.drop(mg_cond_embed_tensor + cond_pos_embed) out_pos_embed = self.out_pos_embed(ng_noisy_action_embed) ng_action_embeddings = self.drop(ng_noisy_action_embed + out_pos_embed) - pg_action_embeddings = self.drop(pg_noisy_action_embed + out_pos_embed) - # ig_action_embeddings = self.drop(ig_noisy_action_embed + out_pos_embed) + mg_action_embeddings = self.drop(mg_noisy_action_embed + out_pos_embed) label_action_embeddings = self.drop(label_embed + out_pos_embed) augment_action_embeddings = self.drop(augment_embed + out_pos_embed) - # ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask) ng_output = self.layernorm(ng_output) noise_pred_ng = self.action_head(ng_output) - pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) - # pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask) - pg_output = self.layernorm(pg_output) - noise_pred_pg = self.action_head(pg_output) - - # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) - # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask) - # ig_output = self.layernorm(ig_output) - # noise_pred_ig = self.action_head(ig_output) + mg_output = self.decoder(tgt = mg_action_embeddings,memory = mg_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) + mg_output = self.layernorm(mg_output) + noise_pred_mg = self.action_head(mg_output) cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device)) - # cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask) cr_label_output = self.layernorm(cr_label_output) cr_label_pred = self.critic_head(cr_label_output.mean(dim=1))[:,0] cr_augment_output = self.decoder(tgt = augment_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device)) cr_augment_output = self.layernorm(cr_augment_output) cr_augment_pred = self.critic_head(cr_augment_output.mean(dim=1))[:,0] - return noise_pred_ng,noise_pred_pg,cr_label_pred,cr_augment_pred,[ng_noise,pg_noise] + return noise_pred_ng,noise_pred_mg,cr_label_pred,cr_augment_pred,[ng_noise,mg_noise],[imagegoal_aux_pred,pixelgoal_aux_pred] def _get_device(self): """Safe get device information""" diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index a1606b51..34f785c1 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -260,6 +260,7 @@ def forward(self,images,depths): memory_token = self.former_net(former_query,former_token) memory_token = self.project_layer(memory_token) return memory_token + def _get_device(self): """get device safely""" # try to get device through model parameters @@ -293,6 +294,12 @@ def __init__(self, embed_size=512, device='cuda:0'): super().__init__() + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) self.device = device self.image_size = image_size self.embed_size = embed_size @@ -306,13 +313,43 @@ def __init__(self, padding = self.imagegoal_encoder.patch_embed.proj.padding) self.imagegoal_encoder.train() self.project_layer = nn.Linear(384,embed_size) + self.to(device) def forward(self,images): assert len(images.shape) == 4 # B,C,H,W - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2) + device = self._get_device() + images = images.to(device) + tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,3,1,2) image_token = self.imagegoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1) image_token = self.project_layer(image_token) return image_token + + def _get_device(self): + """get device safely""" + # try to get device through model parameters + try: + for param in self.parameters(): + return param.device + except StopIteration: + pass + + # try to get device through buffer + try: + for buffer in self.buffers(): + return buffer.device + except StopIteration: + pass + + # try to get device through submodule + for module in self.children(): + try: + for param in module.parameters(): + return param.device + except StopIteration: + continue + + # finally revert to default device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") class NavDP_PixelGoal_Backbone(nn.Module): def __init__(self, @@ -320,23 +357,59 @@ def __init__(self, embed_size=512, device='cuda:0'): super().__init__() + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) self.device = device self.image_size = image_size self.embed_size = embed_size model_configs = {'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}} self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() - self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=4, + self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=7, out_channels = self.pixelgoal_encoder.patch_embed.proj.out_channels, kernel_size = self.pixelgoal_encoder.patch_embed.proj.kernel_size, stride = self.pixelgoal_encoder.patch_embed.proj.stride, padding = self.pixelgoal_encoder.patch_embed.proj.padding) self.pixelgoal_encoder.train() self.project_layer = nn.Linear(384,embed_size) + self.to(device) def forward(self,images): assert len(images.shape) == 4 # B,C,H,W - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2) + device = self._get_device() + images = images.to(device) + tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,3,1,2) image_token = self.pixelgoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1) image_token = self.project_layer(image_token) - return image_token \ No newline at end of file + return image_token + + def _get_device(self): + """get device safely""" + # try to get device through model parameters + try: + for param in self.parameters(): + return param.device + except StopIteration: + pass + + # try to get device through buffer + try: + for buffer in self.buffers(): + return buffer.device + except StopIteration: + pass + + # try to get device through submodule + for module in self.children(): + try: + for param in module.parameters(): + return param.device + except StopIteration: + continue + + # finally revert to default device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") \ No newline at end of file diff --git a/internnav/trainer/navdp_trainer.py b/internnav/trainer/navdp_trainer.py index 73849f1c..2e00e3dc 100644 --- a/internnav/trainer/navdp_trainer.py +++ b/internnav/trainer/navdp_trainer.py @@ -57,6 +57,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N inputs_on_device = { "batch_pg": inputs["batch_pg"].to(model_device), "batch_ig": inputs["batch_ig"].to(model_device), + "batch_tg": inputs["batch_tg"].to(model_device), "batch_rgb": inputs["batch_rgb"].to(model_device), "batch_depth": inputs["batch_depth"].to(model_device), "batch_labels": inputs["batch_labels"].to(model_device), @@ -76,9 +77,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N batch_label_critic = inputs["batch_label_critic"] batch_augment_critic = inputs["batch_augment_critic"] - pred_ng, pred_pg, critic_pred, augment_pred, noise = model( + pred_ng, pred_mg, critic_pred, augment_pred, noise, aux_pred = model( inputs_on_device["batch_pg"], inputs_on_device["batch_ig"], + inputs_on_device["batch_tg"], inputs_on_device["batch_rgb"], inputs_on_device["batch_depth"], inputs_on_device["batch_labels"], @@ -86,24 +88,22 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) ng_action_loss = (pred_ng - noise[0]).square().mean() - pg_action_loss = (pred_pg - noise[1]).square().mean() - # ig_action_loss = (pred_ig - noise[2]).square().mean() - action_loss = 0.5 * pg_action_loss + 0.5 * ng_action_loss - critic_loss = (critic_pred - batch_label_critic).square().mean() + \ - (augment_pred - batch_augment_critic).square().mean() - loss = 0.8 * action_loss + 0.2 * critic_loss + mg_action_loss = (pred_mg - noise[1]).square().mean() + aux_loss = 0.5*(inputs_on_device["batch_pg"] - aux_pred[0]).square().mean() + 0.5*(inputs_on_device["batch_pg"] - aux_pred[1]).square().mean() + action_loss = 0.5 * mg_action_loss + 0.5 * ng_action_loss + critic_loss = (critic_pred - batch_label_critic).square().mean() + (augment_pred - batch_augment_critic).square().mean() + loss = 0.8 * action_loss + 0.2 * critic_loss + 0.5 * aux_loss outputs = { 'pred_ng': pred_ng, - 'pred_pg': pred_pg, - # 'pred_ig': pred_ig, + 'pred_mg': pred_mg, 'critic_pred': critic_pred, 'augment_pred': augment_pred, 'noise': noise, 'loss': loss, 'ng_action_loss': ng_action_loss, - 'pg_action_loss': pg_action_loss, - # 'ig_action_loss': ig_action_loss, + 'mg_action_loss': mg_action_loss, + 'aux_loss': aux_loss, 'critic_loss': critic_loss } # if self.logger: @@ -193,4 +193,23 @@ def get_train_dataloader(self): collate_fn=self.data_collator ) # print(loader) - return loader \ No newline at end of file + return loader + + def save_model(self, output_dir, state_dict=None, **kwargs): + """ + save model to specified directory + + handle DDP wrapped model + """ + # check if it is a DDP wrapped model + if hasattr(self.model, 'module'): + # get original model + model_to_save = self.model.module + else: + model_to_save = self.model + + # ensure the output directory exists + os.makedirs(output_dir, exist_ok=True) + torch.save(model_to_save.state_dict(), output_dir+"navdp.ckpt") + + print(f"Saving model to {output_dir} (is DDP: {hasattr(self.model, 'module')})") \ No newline at end of file