diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 3e6ee7ca..4437e7c4 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -1,50 +1,52 @@ -import numpy as np -import os +# Override the built-in print function with a timestamp version +import builtins import json +import os +from datetime import datetime + import cv2 +import numpy as np import open3d as o3d -import io -import time import pandas as pd -from datetime import datetime import torch -from tqdm import tqdm -from torch.utils.data import Dataset -from scipy.spatial.transform import Rotation -from scipy.interpolate import CubicSpline -import torchvision.transforms as T from PIL import Image -from io import BytesIO -import pdb +from scipy.interpolate import CubicSpline +from torch.utils.data import Dataset +from tqdm import tqdm -# Override the built-in print function with a timestamp version -import builtins original_print = builtins.print + + def print(*args, **kwargs): try: rank = int(os.environ.get('RANK', 0)) if rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] original_print(f"[{timestamp}]", *args, **kwargs) - except: - pass + except Exception: # Catch any exception to prevent crashes + pass + + builtins.print = print + class NavDP_Base_Datset(Dataset): - def __init__(self, - root_dirs, - preload_path=False, - memory_size=8, - predict_size=24, - batch_size=64, - image_size=224, - scene_data_scale=1.0, - trajectory_data_scale=1.0, - debug=False, - preload=False, - random_digit=False, - prior_sample=False): - + def __init__( + self, + root_dirs, + preload_path=False, + memory_size=8, + predict_size=24, + batch_size=64, + image_size=224, + scene_data_scale=1.0, + trajectory_data_scale=1.0, + debug=False, + preload=False, + random_digit=False, + prior_sample=False, + ): + self.dataset_dirs = np.array([p for p in os.listdir(root_dirs)]) self.memory_size = memory_size self.image_size = image_size @@ -63,47 +65,55 @@ def __init__(self, self.batch_size = batch_size self.batch_time_sum = 0.0 self._last_time = None - - if preload == False: - for group_dir in self.dataset_dirs: # gibson_zed, 3dfront ... - all_scene_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs,group_dir))]) - select_scene_dirs = all_scene_dirs[np.arange(0,all_scene_dirs.shape[0],1/self.scene_scale_size).astype(np.int32)] + + if preload is False: + for group_dir in self.dataset_dirs: # gibson_zed, 3dfront ... + all_scene_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir))]) + select_scene_dirs = all_scene_dirs[ + np.arange(0, all_scene_dirs.shape[0], 1 / self.scene_scale_size).astype(np.int32) + ] for scene_dir in select_scene_dirs: - all_traj_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs,group_dir,scene_dir))]) - select_traj_dirs = all_traj_dirs[np.arange(0,all_traj_dirs.shape[0],1/self.trajectory_data_scale).astype(np.int32)] + all_traj_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir, scene_dir))]) + select_traj_dirs = all_traj_dirs[ + np.arange(0, all_traj_dirs.shape[0], 1 / self.trajectory_data_scale).astype(np.int32) + ] for traj_dir in tqdm(select_traj_dirs): - entire_task_dir = os.path.join(root_dirs,group_dir,scene_dir,traj_dir) - rgb_dir = os.path.join(entire_task_dir,"videos/chunk-000/observation.images.rgb/") - depth_dir = os.path.join(entire_task_dir,"videos/chunk-000/observation.images.depth/") - data_path = os.path.join(entire_task_dir,'data/chunk-000/episode_000000.parquet') # intrinsic, extrinsic, cam_traj, path - afford_path = os.path.join(entire_task_dir,'data/chunk-000/path.ply') + entire_task_dir = os.path.join(root_dirs, group_dir, scene_dir, traj_dir) + rgb_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.rgb/") + depth_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.depth/") + data_path = os.path.join( + entire_task_dir, 'data/chunk-000/episode_000000.parquet' + ) # intrinsic, extrinsic, cam_traj, path + afford_path = os.path.join(entire_task_dir, 'data/chunk-000/path.ply') rgbs_length = len([p for p in os.listdir(rgb_dir)]) depths_length = len([p for p in os.listdir(depth_dir)]) - + rgbs_path = [] depths_path = [] if depths_length != rgbs_length: continue for i in range(rgbs_length): - rgbs_path.append(os.path.join(rgb_dir,"%d.jpg"%i)) - depths_path.append(os.path.join(depth_dir,"%d.png"%i)) - if os.path.exists(data_path) == False: + rgbs_path.append(os.path.join(rgb_dir, "%d.jpg" % i)) + depths_path.append(os.path.join(depth_dir, "%d.png" % i)) + if os.path.exists(data_path) is False: continue self.trajectory_dirs.append(entire_task_dir) self.trajectory_data_dir.append(data_path) self.trajectory_rgb_path.append(rgbs_path) self.trajectory_depth_path.append(depths_path) self.trajectory_afford_path.append(afford_path) - - save_dict = {'trajectory_dirs':self.trajectory_dirs, - 'trajectory_data_dir':self.trajectory_data_dir, - 'trajectory_rgb_path':self.trajectory_rgb_path, - 'trajectory_depth_path':self.trajectory_depth_path, - 'trajectory_afford_path':self.trajectory_afford_path} - with open(preload_path,'w') as f: - json.dump(save_dict,f,indent=4) + + save_dict = { + 'trajectory_dirs': self.trajectory_dirs, + 'trajectory_data_dir': self.trajectory_data_dir, + 'trajectory_rgb_path': self.trajectory_rgb_path, + 'trajectory_depth_path': self.trajectory_depth_path, + 'trajectory_afford_path': self.trajectory_afford_path, + } + with open(preload_path, 'w') as f: + json.dump(save_dict, f, indent=4) else: - load_dict = json.load(open(preload_path,'r')) + load_dict = json.load(open(preload_path, 'r')) self.trajectory_dirs = load_dict['trajectory_dirs'] * 50 self.trajectory_data_dir = load_dict['trajectory_data_dir'] * 50 self.trajectory_rgb_path = load_dict['trajectory_rgb_path'] * 50 @@ -112,48 +122,52 @@ def __init__(self, def __len__(self): return len(self.trajectory_dirs) - - def load_image(self,image_url): + + def load_image(self, image_url): image = Image.open(image_url) - image = np.array(image,np.uint8) + image = np.array(image, np.uint8) return image - - def load_depth(self,depth_url): + + def load_depth(self, depth_url): depth = Image.open(depth_url) - depth = np.array(depth,np.uint16) + depth = np.array(depth, np.uint16) return depth - - def load_pointcloud(self,pcd_url): + + def load_pointcloud(self, pcd_url): pcd = o3d.io.read_point_cloud(pcd_url) return pcd - - def process_image(self,image_path): + + def process_image(self, image_path): image = self.load_image(image_path) - H,W,C = image.shape - prop = self.image_size/max(H,W) - image = cv2.resize(image,(-1,-1),fx=prop,fy=prop) - pad_width = max((self.image_size - image.shape[1])//2,0) - pad_height = max((self.image_size - image.shape[0])//2,0) - pad_image = np.pad(image,((pad_height,pad_height),(pad_width,pad_width),(0,0)),mode='constant',constant_values=0) - image = cv2.resize(pad_image,(self.image_size,self.image_size)) - image = np.array(image,np.float32)/255.0 + H, W, C = image.shape + prop = self.image_size / max(H, W) + image = cv2.resize(image, (-1, -1), fx=prop, fy=prop) + pad_width = max((self.image_size - image.shape[1]) // 2, 0) + pad_height = max((self.image_size - image.shape[0]) // 2, 0) + pad_image = np.pad( + image, ((pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant', constant_values=0 + ) + image = cv2.resize(pad_image, (self.image_size, self.image_size)) + image = np.array(image, np.float32) / 255.0 return image - - def process_depth(self,depth_path): - depth = (self.load_depth(depth_path)/10000.0) - H,W = depth.shape - prop = self.image_size/max(H,W) - depth = cv2.resize(depth,(-1,-1),fx=prop,fy=prop) - pad_width = max((self.image_size - depth.shape[1])//2,0) - pad_height = max((self.image_size - depth.shape[0])//2,0) - pad_depth = np.pad(depth,((pad_height,pad_height),(pad_width,pad_width)),mode='constant',constant_values=0) + + def process_depth(self, depth_path): + depth = self.load_depth(depth_path) / 10000.0 + H, W = depth.shape + prop = self.image_size / max(H, W) + depth = cv2.resize(depth, (-1, -1), fx=prop, fy=prop) + pad_width = max((self.image_size - depth.shape[1]) // 2, 0) + pad_height = max((self.image_size - depth.shape[0]) // 2, 0) + pad_depth = np.pad( + depth, ((pad_height, pad_height), (pad_width, pad_width)), mode='constant', constant_values=0 + ) pad_depth[pad_depth > 5.0] = 0 pad_depth[pad_depth < 0.1] = 0 - depth = cv2.resize(pad_depth,(self.image_size,self.image_size)) - depth = np.array(depth,np.float32) - return depth[:,:,np.newaxis] + depth = cv2.resize(pad_depth, (self.image_size, self.image_size)) + depth = np.array(depth, np.float32) + return depth[:, :, np.newaxis] - def process_data_parquet(self,index): + def process_data_parquet(self, index): if not os.path.isfile(self.trajectory_data_dir[index]): raise FileNotFoundError(self.trajectory_data_dir[index]) df = pd.read_parquet(self.trajectory_data_dir[index]) @@ -161,244 +175,354 @@ def process_data_parquet(self,index): camera_extrinsic = np.vstack(np.array(df['observation.camera_extrinsic'].tolist()[0])).reshape(4, 4) trajectory_length = len(df['action'].tolist()) camera_trajectory = np.array([np.stack(frame) for frame in df['action']], dtype=np.float64) - return camera_intrinsic,camera_extrinsic,camera_trajectory,trajectory_length + return camera_intrinsic, camera_extrinsic, camera_trajectory, trajectory_length - def process_path_points(self,index): + def process_path_points(self, index): trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) trajectory_color = np.array(trajectory_pcd.colors) - color_distance = np.abs(trajectory_color - np.array([0,0,0])).sum(axis=-1) # sometimes, the path are saved as black points - select_index = np.where(color_distance<0.05)[0] + color_distance = np.abs(trajectory_color - np.array([0, 0, 0])).sum( + axis=-1 + ) # sometimes, the path are saved as black points + select_index = np.where(color_distance < 0.05)[0] trajectory_path = o3d.geometry.PointCloud() trajectory_path.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) trajectory_path.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) - return np.array(trajectory_path.points),trajectory_path - - def process_obstacle_points(self,index,path_points): + return np.array(trajectory_path.points), trajectory_path + + def process_obstacle_points(self, index, path_points): trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) trajectory_color = np.array(trajectory_pcd.colors) trajectory_points = np.array(trajectory_pcd.points) - color_distance = np.abs(trajectory_color - np.array([0,0,0.5])).sum(axis=-1) # the obstacles are save in blue + color_distance = np.abs(trajectory_color - np.array([0, 0, 0.5])).sum(axis=-1) # the obstacles are save in blue path_lower_bound = path_points.min(axis=0) path_upper_bound = path_points.max(axis=0) - condition_x = (trajectory_points[:,0] >= path_lower_bound[0]-2.0) & (trajectory_points[:,0] <= path_upper_bound[0]+2.0) - condition_y = (trajectory_points[:,1] >= path_lower_bound[1]-2.0) & (trajectory_points[:,1] <= path_upper_bound[1]+2.0) - select_index = np.where((color_distance<0.05) & condition_x & condition_y)[0] + condition_x = (trajectory_points[:, 0] >= path_lower_bound[0] - 2.0) & ( + trajectory_points[:, 0] <= path_upper_bound[0] + 2.0 + ) + condition_y = (trajectory_points[:, 1] >= path_lower_bound[1] - 2.0) & ( + trajectory_points[:, 1] <= path_upper_bound[1] + 2.0 + ) + select_index = np.where((color_distance < 0.05) & condition_x & condition_y)[0] trajectory_obstacle = o3d.geometry.PointCloud() trajectory_obstacle.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) trajectory_obstacle.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) - return np.array(trajectory_obstacle.points),trajectory_obstacle - - def process_memory(self,rgb_paths,depth_paths,start_step,memory_digit=1): - memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step+1,memory_digit) + return np.array(trajectory_obstacle.points), trajectory_obstacle + + def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1): + memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step + 1, memory_digit) outrange_sum = (memory_index < 0).sum() memory_index = memory_index[outrange_sum:] - context_image = np.zeros((self.memory_size,self.image_size,self.image_size,3),np.float32) + context_image = np.zeros((self.memory_size, self.image_size, self.image_size, 3), np.float32) context_image[outrange_sum:] = np.array([self.process_image(rgb_paths[i]) for i in memory_index]) context_depth = self.process_depth(depth_paths[start_step]) - return context_image,context_depth,memory_index - - def process_pixel_goal(self, image_url, target_point, camera_intrinsic, camera_extrinsic): + return context_image, context_depth, memory_index + + def process_pixel_goal(self, image_url, target_point, camera_intrinsic, camera_extrinsic): image = Image.open(image_url) - image = np.array(image,np.uint8) + image = np.array(image, np.uint8) resize_image = self.process_image(image_url) - coordinate = np.array([-target_point[1],target_point[0],camera_extrinsic[2,3]*0.8]) - camera_coordinate = np.matmul(camera_extrinsic[0:3,0:3],coordinate[:,None]) - pixel_coord_x = camera_intrinsic[0,2] + (camera_coordinate[0] / camera_coordinate[2]) * camera_intrinsic[0,0] - pixel_coord_y = camera_intrinsic[1,2] + (-camera_coordinate[1] / camera_coordinate[2]) * camera_intrinsic[1,1] + coordinate = np.array([-target_point[1], target_point[0], camera_extrinsic[2, 3] * 0.8]) + camera_coordinate = np.matmul(camera_extrinsic[0:3, 0:3], coordinate[:, None]) + pixel_coord_x = camera_intrinsic[0, 2] + (camera_coordinate[0] / camera_coordinate[2]) * camera_intrinsic[0, 0] + pixel_coord_y = camera_intrinsic[1, 2] + (-camera_coordinate[1] / camera_coordinate[2]) * camera_intrinsic[1, 1] pixel_mask = np.zeros_like(image) visible_flag = False - - if pixel_coord_x > 0 and pixel_coord_x < image.shape[1] and pixel_coord_y > 0 and pixel_coord_y < image.shape[0]: - pixel_mask = cv2.rectangle(pixel_mask,(int(pixel_coord_x-np.random.randint(6,12)),int(pixel_coord_y-np.random.randint(6,12))),(int(pixel_coord_x+np.random.randint(6,12)),int(pixel_coord_y+np.random.randint(6,12))),(255,255,255),-1) - visible_flag = True - + + if ( + pixel_coord_x > 0 + and pixel_coord_x < image.shape[1] + and pixel_coord_y > 0 + and pixel_coord_y < image.shape[0] + ): + pixel_mask = cv2.rectangle( + pixel_mask, + (int(pixel_coord_x - np.random.randint(6, 12)), int(pixel_coord_y - np.random.randint(6, 12))), + (int(pixel_coord_x + np.random.randint(6, 12)), int(pixel_coord_y + np.random.randint(6, 12))), + (255, 255, 255), + -1, + ) + visible_flag = True + H, W, C = pixel_mask.shape - prop = self.image_size/max(H, W) + prop = self.image_size / max(H, W) pixel_mask = cv2.resize(pixel_mask, (-1, -1), fx=prop, fy=prop) - pad_width = max((self.image_size - pixel_mask.shape[1])//2, 0) - pad_height = max((self.image_size - pixel_mask.shape[0])//2, 0) - pad_mask = np.pad(pixel_mask, ((pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant', constant_values=0) + pad_width = max((self.image_size - pixel_mask.shape[1]) // 2, 0) + pad_height = max((self.image_size - pixel_mask.shape[0]) // 2, 0) + pad_mask = np.pad( + pixel_mask, ((pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant', constant_values=0 + ) mask = cv2.resize(pad_mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST) - mask = np.array(mask, np.float32)/255.0 - mask = mask.mean(axis=-1)[:,:,None] - return np.concatenate((resize_image,mask),axis=-1),visible_flag - - def relative_pose(self,R_base,T_base,R_world,T_world,base_extrinsic): - R_base = np.matmul(R_base,np.linalg.inv(base_extrinsic[0:3,0:3])) + mask = np.array(mask, np.float32) / 255.0 + mask = mask.mean(axis=-1)[:, :, None] + return np.concatenate((resize_image, mask), axis=-1), visible_flag + + def relative_pose(self, R_base, T_base, R_world, T_world, base_extrinsic): + R_base = np.matmul(R_base, np.linalg.inv(base_extrinsic[0:3, 0:3])) if len(T_world.shape) == 1: homo_RT = np.eye(4) - homo_RT[0:3,0:3] = R_base - homo_RT[0:3,3] = T_base - R_frame = np.dot(R_world,R_base.T) - T_frame = np.dot(np.linalg.inv(homo_RT),np.array([*T_world,1]).T)[0:3] - T_frame = np.array([T_frame[1],-T_frame[0],T_frame[2]]) #[:T[1],-T[0],T[2] - return R_frame,T_frame + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_frame = np.dot(R_world, R_base.T) + T_frame = np.dot(np.linalg.inv(homo_RT), np.array([*T_world, 1]).T)[0:3] + T_frame = np.array([T_frame[1], -T_frame[0], T_frame[2]]) # [:T[1],-T[0],T[2] + return R_frame, T_frame else: homo_RT = np.eye(4) - homo_RT[0:3,0:3] = R_base - homo_RT[0:3,3] = T_base - R_frame = np.dot(R_world,R_base.T) - T_frame = np.dot(np.linalg.inv(homo_RT),np.concatenate((T_world,np.ones((T_world.shape[0],1))),axis=-1).T).T[:,0:3] - T_frame = T_frame[:,[1,0,2]] - T_frame[:,1] = -T_frame[:,1] - return R_frame,T_frame - - def absolute_pose(self,R_base,T_base,R_frame,T_frame,base_extrinsic): - R_base = np.matmul(R_base,np.linalg.inv(base_extrinsic[0:3,0:3])) + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_frame = np.dot(R_world, R_base.T) + T_frame = np.dot( + np.linalg.inv(homo_RT), np.concatenate((T_world, np.ones((T_world.shape[0], 1))), axis=-1).T + ).T[:, 0:3] + T_frame = T_frame[:, [1, 0, 2]] + T_frame[:, 1] = -T_frame[:, 1] + return R_frame, T_frame + + def absolute_pose(self, R_base, T_base, R_frame, T_frame, base_extrinsic): + R_base = np.matmul(R_base, np.linalg.inv(base_extrinsic[0:3, 0:3])) if len(T_frame.shape) == 1: homo_RT = np.eye(4) - homo_RT[0:3,0:3] = R_base - homo_RT[0:3,3] = T_base - R_world = np.dot(R_frame,R_base) - T_world = np.dot(homo_RT,np.array([-T_frame[1],T_frame[0],T_frame[2],1]).T)[0:3] + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_world = np.dot(R_frame, R_base) + T_world = np.dot(homo_RT, np.array([-T_frame[1], T_frame[0], T_frame[2], 1]).T)[0:3] else: homo_RT = np.eye(4) - homo_RT[0:3,0:3] = R_base - homo_RT[0:3,3] = T_base - R_world = np.dot(R_frame,R_base) - T_world = np.dot(homo_RT,np.concatenate((np.stack((-T_frame[:,1],T_frame[:,0],T_frame[:,2]),axis=-1),np.ones((T_frame.shape[0],1))),axis=-1).T).T[:,0:3] - return R_world,T_world - - def xyz_to_xyt(self,xyz_actions,init_vector): + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_world = np.dot(R_frame, R_base) + T_world = np.dot( + homo_RT, + np.concatenate( + (np.stack((-T_frame[:, 1], T_frame[:, 0], T_frame[:, 2]), axis=-1), np.ones((T_frame.shape[0], 1))), + axis=-1, + ).T, + ).T[:, 0:3] + return R_world, T_world + + def xyz_to_xyt(self, xyz_actions, init_vector): xyt_actions = [] - for i in range(0,xyz_actions.shape[0]-1): - current_vector = xyz_actions[i+1] - xyz_actions[i] - dot_product = np.dot(init_vector[0:2],current_vector[0:2]) - cross_product = np.cross(init_vector[0:2],current_vector[0:2]) - theta = np.arctan2(cross_product,dot_product) - xyt_actions.append([xyz_actions[i][0],xyz_actions[i][1],theta]) - return np.array(xyt_actions) - - def process_actions(self,extrinsics,base_extrinsic,start_step,end_step,pred_digit=1): + for i in range(0, xyz_actions.shape[0] - 1): + current_vector = xyz_actions[i + 1] - xyz_actions[i] + dot_product = np.dot(init_vector[0:2], current_vector[0:2]) + cross_product = np.cross(init_vector[0:2], current_vector[0:2]) + theta = np.arctan2(cross_product, dot_product) + xyt_actions.append([xyz_actions[i][0], xyz_actions[i][1], theta]) + return np.array(xyt_actions) + + def process_actions(self, extrinsics, base_extrinsic, start_step, end_step, pred_digit=1): label_linear_pos = [] - for f_ext in extrinsics[start_step:end_step+1]: - R,T = self.relative_pose(extrinsics[start_step][0:3,0:3],extrinsics[start_step][0:3,3],f_ext[0:3,0:3],f_ext[0:3,3],base_extrinsic) + for f_ext in extrinsics[start_step : end_step + 1]: + R, T = self.relative_pose( + extrinsics[start_step][0:3, 0:3], + extrinsics[start_step][0:3, 3], + f_ext[0:3, 0:3], + f_ext[0:3, 3], + base_extrinsic, + ) label_linear_pos.append(T) label_actions = np.array(label_linear_pos) - - # this is usesd for action augmentations: + + # this is usesd for action augmentations: # (1) apply random rotation to the future steps # (2) interpolate between the rotated actions and origin actions - rotate_yaw_angle = np.random.uniform(-np.pi/3,np.pi/3) - rotate_matrix = np.array([[np.cos(rotate_yaw_angle),-np.sin(rotate_yaw_angle)],[np.sin(rotate_yaw_angle),np.cos(rotate_yaw_angle)]],np.float32) - rotate_local_actions = np.matmul(rotate_matrix,label_actions[:,0:2].T).T - rotate_local_actions = np.stack((rotate_local_actions[:,0],rotate_local_actions[:,1],np.zeros_like(rotate_local_actions[:,0])),axis=-1) + rotate_yaw_angle = np.random.uniform(-np.pi / 3, np.pi / 3) + rotate_matrix = np.array( + [ + [np.cos(rotate_yaw_angle), -np.sin(rotate_yaw_angle)], + [np.sin(rotate_yaw_angle), np.cos(rotate_yaw_angle)], + ], + np.float32, + ) + rotate_local_actions = np.matmul(rotate_matrix, label_actions[:, 0:2].T).T + rotate_local_actions = np.stack( + (rotate_local_actions[:, 0], rotate_local_actions[:, 1], np.zeros_like(rotate_local_actions[:, 0])), axis=-1 + ) rotate_world_points = [] for act in rotate_local_actions: - w_rot,w_act = self.absolute_pose(extrinsics[start_step,0:3,0:3],extrinsics[start_step,0:3,3],np.eye(3),act,base_extrinsic) + w_rot, w_act = self.absolute_pose( + extrinsics[start_step, 0:3, 0:3], extrinsics[start_step, 0:3, 3], np.eye(3), act, base_extrinsic + ) rotate_world_points.append(w_act) rotate_world_points = np.array(rotate_world_points) - origin_world_points = extrinsics[start_step:end_step+1,0:3,3] + origin_world_points = extrinsics[start_step : end_step + 1, 0:3, 3] mix_anchor_points = rotate_world_points - - t = np.linspace(0,1,mix_anchor_points.shape[0]) - cs_x = CubicSpline(t,mix_anchor_points[:,0]) - cs_y = CubicSpline(t,mix_anchor_points[:,1]) - cs_z = CubicSpline(t,mix_anchor_points[:,2]) + + t = np.linspace(0, 1, mix_anchor_points.shape[0]) + cs_x = CubicSpline(t, mix_anchor_points[:, 0]) + cs_y = CubicSpline(t, mix_anchor_points[:, 1]) + cs_z = CubicSpline(t, mix_anchor_points[:, 2]) interpolate_nums = origin_world_points.shape[0] - t_fine = np.linspace(0,1,int(interpolate_nums)) + t_fine = np.linspace(0, 1, int(interpolate_nums)) x_fine = cs_x(t_fine) y_fine = cs_y(t_fine) z_fine = cs_z(t_fine) - result_augment_points = np.stack((x_fine,y_fine,z_fine),axis=-1) + result_augment_points = np.stack((x_fine, y_fine, z_fine), axis=-1) local_label_points = [] local_augment_points = [] - for f_ext,g_ext in zip(origin_world_points,result_augment_points): - Rf,Tf = self.relative_pose(extrinsics[start_step][0:3,0:3],extrinsics[start_step][0:3,3],np.eye(3),f_ext,base_extrinsic) - Rg,Tg = self.relative_pose(extrinsics[start_step][0:3,0:3],extrinsics[start_step][0:3,3],np.eye(3),g_ext,base_extrinsic) + for f_ext, g_ext in zip(origin_world_points, result_augment_points): + Rf, Tf = self.relative_pose( + extrinsics[start_step][0:3, 0:3], extrinsics[start_step][0:3, 3], np.eye(3), f_ext, base_extrinsic + ) + Rg, Tg = self.relative_pose( + extrinsics[start_step][0:3, 0:3], extrinsics[start_step][0:3, 3], np.eye(3), g_ext, base_extrinsic + ) local_label_points.append(Tf) local_augment_points.append(Tg) local_label_points = np.array(local_label_points) local_augment_points = np.array(local_augment_points) - action_indexes = np.clip(np.arange(self.predict_size+1) * pred_digit,0,label_actions.shape[0]-2) - return local_label_points,local_augment_points,origin_world_points,result_augment_points,action_indexes - - def rank_steps(self,extrinsics,obstacle_points,pred_digit=4): + action_indexes = np.clip(np.arange(self.predict_size + 1) * pred_digit, 0, label_actions.shape[0] - 2) + return local_label_points, local_augment_points, origin_world_points, result_augment_points, action_indexes + + def rank_steps(self, extrinsics, obstacle_points, pred_digit=4): points_score = [] - trajectory = extrinsics[:,0:2,3] - bev_points = obstacle_points[:,0:2] - for i in range(0,trajectory.shape[0]-1): - future_actions = trajectory[i:min(i+self.predict_size * pred_digit, trajectory.shape[0]-1)] - future_bound = [np.min(future_actions[:,0]) - 1,np.min(future_actions[:,1]) - 1,np.max(future_actions[:,0]) + 1,np.max(future_actions[:,1]) + 1] - within_bound_points = (obstacle_points[:,0] > future_bound[0]) & (obstacle_points[:,1] > future_bound[1]) & (obstacle_points[:,0] < future_bound[2]) & (obstacle_points[:,1] < future_bound[3]) + trajectory = extrinsics[:, 0:2, 3] + # bev_points = obstacle_points[:, 0:2] + for i in range(0, trajectory.shape[0] - 1): + future_actions = trajectory[i : min(i + self.predict_size * pred_digit, trajectory.shape[0] - 1)] + future_bound = [ + np.min(future_actions[:, 0]) - 1, + np.min(future_actions[:, 1]) - 1, + np.max(future_actions[:, 0]) + 1, + np.max(future_actions[:, 1]) + 1, + ] + within_bound_points = ( + (obstacle_points[:, 0] > future_bound[0]) + & (obstacle_points[:, 1] > future_bound[1]) + & (obstacle_points[:, 0] < future_bound[2]) + & (obstacle_points[:, 1] < future_bound[3]) + ) points_score.append(np.sum(within_bound_points)) points_score = np.array(points_score) / (np.array(points_score).max() + 1e-8) probs = np.exp(points_score / 0.2) / np.sum(np.exp(points_score / 0.2)) - start_choice = np.random.choice(np.arange(probs.shape[0]),p=probs) - target_choice_candidates = np.arange(start_choice+1,trajectory.shape[0]) - target_choice_p = (target_choice_candidates-start_choice) / ((target_choice_candidates-start_choice).max() + 1e-8) - target_choice_p = np.exp(target_choice_p/0.2)/np.exp(target_choice_p/0.2).sum() - target_choice = np.random.choice(target_choice_candidates,p=target_choice_p) - return start_choice,target_choice - - def __getitem__(self,index): - import time, os + start_choice = np.random.choice(np.arange(probs.shape[0]), p=probs) + target_choice_candidates = np.arange(start_choice + 1, trajectory.shape[0]) + target_choice_p = (target_choice_candidates - start_choice) / ( + (target_choice_candidates - start_choice).max() + 1e-8 + ) + target_choice_p = np.exp(target_choice_p / 0.2) / np.exp(target_choice_p / 0.2).sum() + target_choice = np.random.choice(target_choice_candidates, p=target_choice_p) + return start_choice, target_choice + + def __getitem__(self, index): + import os + import time + if self._last_time is None: self._last_time = time.time() start_time = time.time() - - camera_intrinsic,trajectory_base_extrinsic,trajectory_extrinsics,trajectory_length = self.process_data_parquet(index) - trajectory_path_points,trajectory_path_pcd = self.process_path_points(index) - trajectory_obstacle_points,trajectory_obstacle_pcd = self.process_obstacle_points(index,trajectory_path_points) - + ( + camera_intrinsic, + trajectory_base_extrinsic, + trajectory_extrinsics, + trajectory_length, + ) = self.process_data_parquet(index) + + trajectory_path_points, trajectory_path_pcd = self.process_path_points(index) + trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points( + index, trajectory_path_points + ) + if self.prior_sample: - pixel_start_choice,target_choice = self.rank_steps() - memory_start_choice = np.random.randint(pixel_start_choice,target_choice) + pixel_start_choice, target_choice = self.rank_steps() + memory_start_choice = np.random.randint(pixel_start_choice, target_choice) else: - pixel_start_choice = np.random.randint(0,trajectory_length//2) - target_choice = np.random.randint(pixel_start_choice+1,trajectory_length-1) - memory_start_choice = np.random.randint(pixel_start_choice,target_choice) - - target_extrinsic = trajectory_extrinsics[target_choice] + pixel_start_choice = np.random.randint(0, trajectory_length // 2) + target_choice = np.random.randint(pixel_start_choice + 1, trajectory_length - 1) + memory_start_choice = np.random.randint(pixel_start_choice, target_choice) + + # target_extrinsic = trajectory_extrinsics[target_choice] if self.random_digit: - memory_digit = np.random.randint(2,8) + memory_digit = np.random.randint(2, 8) pred_digit = memory_digit else: memory_digit = 4 pred_digit = 4 - - memory_images,depth_image,memory_index = self.process_memory(self.trajectory_rgb_path[index],self.trajectory_depth_path[index],memory_start_choice,memory_digit=memory_digit) - target_local_points,augment_local_points,target_world_points,augment_world_points,action_indexes = self.process_actions(trajectory_extrinsics,trajectory_base_extrinsic,memory_start_choice,target_choice,pred_digit=pred_digit) + + memory_images, depth_image, memory_index = self.process_memory( + self.trajectory_rgb_path[index], + self.trajectory_depth_path[index], + memory_start_choice, + memory_digit=memory_digit, + ) + ( + target_local_points, + augment_local_points, + target_world_points, + augment_world_points, + action_indexes, + ) = self.process_actions( + trajectory_extrinsics, trajectory_base_extrinsic, memory_start_choice, target_choice, pred_digit=pred_digit + ) # convert the xyz points into xy-theta points init_vector = target_local_points[1] - target_local_points[0] - target_xyt_actions = self.xyz_to_xyt(target_local_points,init_vector) - augment_xyt_actions = self.xyz_to_xyt(augment_local_points,init_vector) + target_xyt_actions = self.xyz_to_xyt(target_local_points, init_vector) + augment_xyt_actions = self.xyz_to_xyt(augment_local_points, init_vector) # based on the prediction length to decide the final prediction trajectories pred_actions = target_xyt_actions[action_indexes] augment_actions = augment_xyt_actions[action_indexes] if trajectory_obstacle_points.shape[0] != 0: - pred_distance = np.abs(target_world_points[:,np.newaxis,0:2] - trajectory_obstacle_points[np.newaxis,:,0:2]).sum(axis=-1).min(axis=-1) - augment_distance = np.abs(augment_world_points[:,np.newaxis,0:2] - trajectory_obstacle_points[np.newaxis,:,0:2]).sum(axis=-1).min(axis=-1) - pred_critic = -5.0 * (pred_distance[action_indexes[:-1]] < 0.1).mean() + 0.5*(pred_distance[action_indexes][1:] - pred_distance[action_indexes][:-1]).sum() - augment_critic = -5.0 * (augment_distance[action_indexes[:-1]] < 0.1).mean() + 0.5*(augment_distance[action_indexes][1:] - augment_distance[action_indexes][:-1]).sum() + pred_distance = ( + np.abs(target_world_points[:, np.newaxis, 0:2] - trajectory_obstacle_points[np.newaxis, :, 0:2]) + .sum(axis=-1) + .min(axis=-1) + ) + augment_distance = ( + np.abs(augment_world_points[:, np.newaxis, 0:2] - trajectory_obstacle_points[np.newaxis, :, 0:2]) + .sum(axis=-1) + .min(axis=-1) + ) + pred_critic = ( + -5.0 * (pred_distance[action_indexes[:-1]] < 0.1).mean() + + 0.5 * (pred_distance[action_indexes][1:] - pred_distance[action_indexes][:-1]).sum() + ) + augment_critic = ( + -5.0 * (augment_distance[action_indexes[:-1]] < 0.1).mean() + + 0.5 * (augment_distance[action_indexes][1:] - augment_distance[action_indexes][:-1]).sum() + ) else: - pred_distance = np.ones(pred_actions.shape[0],dtype=np.float32) - augment_distance = np.ones(pred_actions.shape[0],dtype=np.float32) + pred_distance = np.ones(pred_actions.shape[0], dtype=np.float32) + augment_distance = np.ones(pred_actions.shape[0], dtype=np.float32) pred_critic = 2.0 augment_critic = 2.0 - + point_goal = target_xyt_actions[-1] - image_goal = np.concatenate((self.process_image(self.trajectory_rgb_path[index][target_choice]),self.process_image(self.trajectory_rgb_path[index][memory_start_choice])),axis=-1) + image_goal = np.concatenate( + ( + self.process_image(self.trajectory_rgb_path[index][target_choice]), + self.process_image(self.trajectory_rgb_path[index][memory_start_choice]), + ), + axis=-1, + ) # process pixel projection - pixel_target_local_points,_,_,_,_ = self.process_actions(trajectory_extrinsics,trajectory_base_extrinsic,pixel_start_choice,target_choice,pred_digit=pred_digit) + pixel_target_local_points, _, _, _, _ = self.process_actions( + trajectory_extrinsics, trajectory_base_extrinsic, pixel_start_choice, target_choice, pred_digit=pred_digit + ) pixel_init_vector = pixel_target_local_points[1] - pixel_target_local_points[0] - pixel_xyt_actions = self.xyz_to_xyt(pixel_target_local_points,pixel_init_vector) - pixel_goal,pixel_flag = self.process_pixel_goal(self.trajectory_rgb_path[index][pixel_start_choice],pixel_xyt_actions[-1],camera_intrinsic,trajectory_base_extrinsic) - pixel_goal = np.concatenate((pixel_goal,memory_images[-1]),axis=-1) - + pixel_xyt_actions = self.xyz_to_xyt(pixel_target_local_points, pixel_init_vector) + pixel_goal, pixel_flag = self.process_pixel_goal( + self.trajectory_rgb_path[index][pixel_start_choice], + pixel_xyt_actions[-1], + camera_intrinsic, + trajectory_base_extrinsic, + ) + pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) + pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 - + # Summarize avg time of batch end_time = time.time() self.item_cnt += 1 - self.batch_time_sum += (end_time - start_time) + self.batch_time_sum += end_time - start_time if self.item_cnt % self.batch_size == 0: avg_time = self.batch_time_sum / self.batch_size - print(f'__getitem__ pid={os.getpid()}, avg_time(last {self.batch_size})={avg_time:.2f}s, cnt={self.item_cnt}') + print( + f'__getitem__ pid={os.getpid()}, avg_time(last {self.batch_size})={avg_time:.2f}s, cnt={self.item_cnt}' + ) self.batch_time_sum = 0.0 point_goal = torch.tensor(point_goal, dtype=torch.float32) image_goal = torch.tensor(image_goal, dtype=torch.float32) @@ -409,12 +533,22 @@ def __getitem__(self,index): augment_actions = torch.tensor(augment_actions, dtype=torch.float32) pred_critic = torch.tensor(pred_critic, dtype=torch.float32) augment_critic = torch.tensor(augment_critic, dtype=torch.float32) - return point_goal,image_goal,pixel_goal,memory_images,depth_image,pred_actions,augment_actions,pred_critic,augment_critic,float(pixel_flag) - + return ( + point_goal, + image_goal, + pixel_goal, + memory_images, + depth_image, + pred_actions, + augment_actions, + pred_critic, + augment_critic, + float(pixel_flag), + ) def navdp_collate_fn(batch): - + collated = { "batch_pg": torch.stack([item[0] for item in batch]), "batch_ig": torch.stack([item[1] for item in batch]), @@ -429,27 +563,58 @@ def navdp_collate_fn(batch): return collated - if __name__ == "__main__": - # Debug - dataset = NavDP_Base_Datset("/path/to/nav_20w_lerobot/", - "/path/to/navdp_trainer/output_test/multiview_dataset_lerobot.json", - 8,24,224,trajectory_data_scale=1.0,scene_data_scale=1.0,preload=True) - for i in range(200): - point_goal,image_goal,pixel_goal,memory_images,depth_image,pred_actions,augment_actions,pred_critic,augment_critic,pixel_flag = dataset.__getitem__(i) - pixel_obs = pixel_goal[:,:,0:3] * 255 - pixel_obs[pixel_goal[:,:,3]==1] = np.array([0,0,255]) - - draw_current_image = image_goal[:,:,3:6].copy()*255 - draw_current_image = cv2.putText(draw_current_image,"Current-Image",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - draw_goal_image = image_goal[:,:,0:3].copy()*255 - draw_goal_image = cv2.putText(draw_goal_image,"Image-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - draw_pixel_image = pixel_obs.copy() - draw_pixel_image = cv2.putText(draw_pixel_image,"Pixel-Goal",(50,30),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - - goal_info_image = np.concatenate((draw_current_image,draw_goal_image,draw_pixel_image),axis=1) - goal_info_image = cv2.putText(goal_info_image,"PointGoal=[{:.3f}, {:.3f}, {:.3f}]".format(*point_goal),(190,210),cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255)) - cv2.imwrite("./output_test/goal_information.png",goal_info_image) - \ No newline at end of file + os.makedirs("./navdp_dataset_test/", exist_ok=True) + dataset = NavDP_Base_Datset( + "/shared/smartbot_new/liuyu/vln-n1-minival/", + "./navdp_dataset_test/dataset_lerobot.json", + 8, + 24, + 224, + trajectory_data_scale=0.1, + scene_data_scale=0.1, + preload=False, + ) + + for i in range(10): + ( + point_goal, + image_goal, + pixel_goal, + memory_images, + depth_image, + pred_actions, + augment_actions, + pred_critic, + augment_critic, + pixel_flag, + ) = dataset.__getitem__(i) + if pixel_flag == 1.0: + pixel_obs = pixel_goal.numpy()[:, :, 0:3] * 255 + pixel_obs[pixel_goal[:, :, 3] == 1] = np.array([0, 0, 255]) + + draw_current_image = cv2.cvtColor(image_goal[:, :, 3:6].numpy() * 255, cv2.COLOR_BGR2RGB) + draw_current_image = cv2.putText( + draw_current_image, "Current-Image", (50, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255) + ) + + draw_goal_image = cv2.cvtColor(image_goal[:, :, 0:3].numpy() * 255, cv2.COLOR_BGR2RGB) + draw_goal_image = cv2.putText( + draw_goal_image, "Image-Goal", (50, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255) + ) + + draw_pixel_image = cv2.cvtColor(pixel_obs.copy(), cv2.COLOR_BGR2RGB) + draw_pixel_image = cv2.putText( + draw_pixel_image, "Pixel-Goal", (50, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255) + ) + + goal_info_image = np.concatenate((draw_current_image, draw_goal_image, draw_pixel_image), axis=1) + goal_info_image = cv2.putText( + goal_info_image, + "PointGoal=[{:.3f}, {:.3f}, {:.3f}]".format(*point_goal), + (190, 210), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 0, 255), + ) + cv2.imwrite("./navdp_dataset_test/goal_information_%d.png" % i, goal_info_image) diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 2d290869..6a8da420 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -1,20 +1,13 @@ - -import copy import os -from typing import Dict, Optional, Tuple + import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import os -from scipy.signal import savgol_filter from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -from internnav.model.encoder.navdp_backbone import * from transformers import PretrainedConfig, PreTrainedModel from internnav.configs.model.base_encoders import ModelCfg from internnav.configs.trainer.exp import ExpCfg - +from internnav.model.encoder.navdp_backbone import * class NavDPModelConfig(PretrainedConfig): @@ -25,7 +18,6 @@ def __init__(self, **kwargs): # pass in navdp_exp_cfg self.model_cfg = kwargs.get('model_cfg', None) - @classmethod def from_dict(cls, config_dict): if 'model_cfg' in config_dict: @@ -38,7 +30,7 @@ class NavDPNet(PreTrainedModel): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - config = kwargs.pop('config', None)#navdp_exp_cfg_dict_NavDPModelConfig + config = kwargs.pop('config', None) # navdp_exp_cfg_dict_NavDPModelConfig if config is None: config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -86,93 +78,115 @@ def __init__(self, config: NavDPModelConfig): self.input_channels = self.config.model_cfg['il']['channels'] self.dropout = self.config.model_cfg['il']['dropout'] self.token_dim = self.config.model_cfg['il']['token_dim'] - self.scratch=self.config.model_cfg['il']['scratch'] - self.finetune=self.config.model_cfg['il']['finetune'] - self.rgbd_encoder = NavDP_RGBD_Backbone(self.image_size,self.token_dim,memory_size=self.memory_size,finetune=self.finetune,device=self._device) - self.point_encoder = nn.Linear(3,self.token_dim) - decoder_layer = nn.TransformerDecoderLayer(d_model = self.token_dim, - nhead = self.attention_heads, - dim_feedforward = 4 * self.token_dim, - dropout = self.dropout, - activation = 'gelu', - batch_first = True, - norm_first = True) - self.decoder = nn.TransformerDecoder(decoder_layer = decoder_layer, - num_layers = self.temporal_depth) - self.input_embed = nn.Linear(3,self.token_dim) - - self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 2) + self.scratch = self.config.model_cfg['il']['scratch'] + self.finetune = self.config.model_cfg['il']['finetune'] + self.rgbd_encoder = NavDP_RGBD_Backbone( + self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device + ) + self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.point_encoder = nn.Linear(3, self.token_dim) + + if not self.finetune: + for p in self.rgbd_encoder.parameters(): + p.requires_grad = False + self.rgbd_encoder.eval() + + decoder_layer = nn.TransformerDecoderLayer( + d_model=self.token_dim, + nhead=self.attention_heads, + dim_feedforward=4 * self.token_dim, + dropout=self.dropout, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=self.temporal_depth) + self.input_embed = nn.Linear(3, self.token_dim) + + self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 4) self.out_pos_embed = LearnablePositionalEncoding(self.token_dim, self.predict_size) self.drop = nn.Dropout(self.dropout) self.time_emb = SinusoidalPosEmb(self.token_dim) self.layernorm = nn.LayerNorm(self.token_dim) self.action_head = nn.Linear(self.token_dim, 3) self.critic_head = nn.Linear(self.token_dim, 1) - self.noise_scheduler = DDPMScheduler(num_train_timesteps=10, - beta_schedule='squaredcos_cap_v2', - clip_sample=True, - prediction_type='epsilon') + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=10, beta_schedule='squaredcos_cap_v2', clip_sample=True, prediction_type='epsilon' + ) self.tgt_mask = (torch.triu(torch.ones(self.predict_size, self.predict_size)) == 1).transpose(0, 1) - self.tgt_mask = self.tgt_mask.float().masked_fill(self.tgt_mask == 0, float('-inf')).masked_fill(self.tgt_mask == 1, float(0.0)) - self.cond_critic_mask = torch.zeros((self.predict_size,2 + self.memory_size * 16)) - self.cond_critic_mask[:,0:2] = float('-inf') + self.tgt_mask = ( + self.tgt_mask.float() + .masked_fill(self.tgt_mask == 0, float('-inf')) + .masked_fill(self.tgt_mask == 1, float(0.0)) + ) self.tgt_mask = self.tgt_mask.to(self._device) - + + self.cond_critic_mask = torch.zeros((self.predict_size, 4 + self.memory_size * 16)) + self.cond_critic_mask[:, 0:4] = float('-inf') + + self.pixel_aux_head = nn.Linear(self.token_dim, 3) + self.image_aux_head = nn.Linear(self.token_dim, 3) + def to(self, device, *args, **kwargs): # first call the to method of the parent class self = super().to(device, *args, **kwargs) - + # ensure the buffer is on the correct device self.cond_critic_mask = self.cond_critic_mask.to(device) - + # update device attribute self._device = device - - return self - - def sample_noise(self,action): - # device = next(self.parameters()).device - # if device is None: - # device = action.device - # action = action.to(self._device) + + return self + + def sample_noise(self, action): device = action.device noise = torch.randn(action.shape, device=device) - timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps,(action.shape[0],), device=device).long() + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (action.shape[0],), device=device + ).long() time_embeds = self.time_emb(timesteps).unsqueeze(1) noisy_action = self.noise_scheduler.add_noise(action, noise, timesteps) noisy_action_embed = self.input_embed(noisy_action) - return noise,time_embeds,noisy_action_embed + return noise, time_embeds, noisy_action_embed - def predict_noise(self,last_actions,timestep,goal_embed,rgbd_embed): + def predict_noise(self, last_actions, timestep, goal_embed, rgbd_embed): action_embeds = self.input_embed(last_actions) time_embeds = self.time_emb(timestep.to(self._device)).unsqueeze(1) - cond_embedding = torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1)) - cond_embedding = cond_embedding.repeat(action_embeds.shape[0],1,1) + cond_embedding = torch.cat( + [time_embeds, goal_embed, goal_embed, goal_embed, rgbd_embed], dim=1 + ) + self.cond_pos_embed(torch.cat([time_embeds, goal_embed, goal_embed, goal_embed, rgbd_embed], dim=1)) + cond_embedding = cond_embedding.repeat(action_embeds.shape[0], 1, 1) input_embedding = action_embeds + self.out_pos_embed(action_embeds) - output = self.decoder(tgt = input_embedding,memory = cond_embedding, tgt_mask = self.tgt_mask.to(self._device)) + output = self.decoder(tgt=input_embedding, memory=cond_embedding, tgt_mask=self.tgt_mask.to(self._device)) output = self.layernorm(output) output = self.action_head(output) return output - - def predict_critic(self,predict_trajectory,rgbd_embed): - repeat_rgbd_embed = rgbd_embed.repeat(predict_trajectory.shape[0],1,1) - nogoal_embed = torch.zeros_like(repeat_rgbd_embed[:,0:1]) + + def predict_critic(self, predict_trajectory, rgbd_embed): + repeat_rgbd_embed = rgbd_embed.repeat(predict_trajectory.shape[0], 1, 1) + nogoal_embed = torch.zeros_like(repeat_rgbd_embed[:, 0:1]) action_embeddings = self.input_embed(predict_trajectory) action_embeddings = action_embeddings + self.out_pos_embed(action_embeddings) - cond_embeddings = torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1)) - critic_output = self.decoder(tgt = action_embeddings, memory = cond_embeddings, memory_mask = self.cond_critic_mask) + cond_embeddings = torch.cat( + [nogoal_embed, nogoal_embed, nogoal_embed, nogoal_embed, repeat_rgbd_embed], dim=1 + ) + self.cond_pos_embed( + torch.cat([nogoal_embed, nogoal_embed, nogoal_embed, nogoal_embed, repeat_rgbd_embed], dim=1) + ) + critic_output = self.decoder(tgt=action_embeddings, memory=cond_embeddings, memory_mask=self.cond_critic_mask) critic_output = self.layernorm(critic_output) - critic_output = self.critic_head(critic_output.mean(dim=1))[:,0] + critic_output = self.critic_head(critic_output.mean(dim=1))[:, 0] return critic_output - - def forward(self,goal_point,goal_image,input_images,input_depths,output_actions,augment_actions): + + def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths, output_actions, augment_actions): # """get device safely""" # # get device safely # try: # # try to get device through model parameters # device = next(self.parameters()).device # except StopIteration: - # # model has no parameters, use the default device + # # model has no parameters, use the default device # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # move all inputs to model device # goal_point = goal_point.to(device) @@ -184,7 +198,7 @@ def forward(self,goal_point,goal_image,input_images,input_depths,output_actions, # device = self._device # print(f"self.parameters() is:{self.parameters()}") device = next(self.parameters()).device - + assert input_images.shape[1] == self.memory_size tensor_point_goal = torch.as_tensor(goal_point, dtype=torch.float32).to(device) tensor_label_actions = torch.as_tensor(output_actions, dtype=torch.float32).to(device) @@ -192,55 +206,82 @@ def forward(self,goal_point,goal_image,input_images,input_depths,output_actions, input_images = input_images.to(device) input_depths = input_depths.to(device) - ng_noise,ng_time_embed,ng_noisy_action_embed = self.sample_noise(tensor_label_actions) - pg_noise,pg_time_embed,pg_noisy_action_embed = self.sample_noise(tensor_label_actions) - # ig_noise,ig_time_embed,ig_noisy_action_embed = self.sample_noise(tensor_label_actions) + ng_noise, ng_time_embed, ng_noisy_action_embed = self.sample_noise(tensor_label_actions) + mg_noise, mg_time_embed, mg_noisy_action_embed = self.sample_noise(tensor_label_actions) - rgbd_embed = self.rgbd_encoder(input_images,input_depths) + rgbd_embed = self.rgbd_encoder(input_images, input_depths) pointgoal_embed = self.point_encoder(tensor_point_goal).unsqueeze(1) nogoal_embed = torch.zeros_like(pointgoal_embed) - # imagegoal_embed = torch.zeros_like(pointgoal_embed) + imagegoal_embed = self.image_encoder(goal_image).unsqueeze(1) + pixelgoal_embed = self.pixel_encoder(goal_pixel).unsqueeze(1) + + imagegoal_aux_pred = self.image_aux_head(imagegoal_embed[:, 0]) + pixelgoal_aux_pred = self.pixel_aux_head(pixelgoal_embed[:, 0]) label_embed = self.input_embed(tensor_label_actions).detach() augment_embed = self.input_embed(tensor_augment_actions).detach() - - cond_pos_embed = self.cond_pos_embed(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1)) - ng_cond_embeddings = self.drop(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1) + cond_pos_embed) - pg_cond_embeddings = self.drop(torch.cat([pg_time_embed,pointgoal_embed,rgbd_embed],dim=1) + cond_pos_embed) - # ig_cond_embeddings = self.drop(torch.cat([ig_time_embed,imagegoal_embed,rgbd_embed],dim=1) + cond_pos_embed) + + cond_pos_embed = self.cond_pos_embed( + torch.cat([ng_time_embed, nogoal_embed, imagegoal_embed, pixelgoal_embed, rgbd_embed], dim=1) + ) + ng_cond_embeddings = self.drop( + torch.cat([ng_time_embed, nogoal_embed, nogoal_embed, nogoal_embed, rgbd_embed], dim=1) + cond_pos_embed + ) + + cand_goal_embed = [pointgoal_embed, imagegoal_embed, pixelgoal_embed] + batch_size = pointgoal_embed.shape[0] + + # Generate deterministic selections for each sample in the batch using vectorized operations + batch_indices = torch.arange(batch_size, device=pointgoal_embed.device) + pattern_indices = batch_indices % 27 # 3^3 = 27 possible combinations + selections_0 = pattern_indices % 3 + selections_1 = (pattern_indices // 3) % 3 + selections_2 = (pattern_indices // 9) % 3 + goal_embeds = torch.stack(cand_goal_embed, dim=0) # [3, batch_size, 1, token_dim] + selected_goals_0 = goal_embeds[selections_0, torch.arange(batch_size), :, :] # [batch_size, 1, token_dim] + selected_goals_1 = goal_embeds[selections_1, torch.arange(batch_size), :, :] + selected_goals_2 = goal_embeds[selections_2, torch.arange(batch_size), :, :] + mg_cond_embed_tensor = torch.cat( + [mg_time_embed, selected_goals_0, selected_goals_1, selected_goals_2, rgbd_embed], dim=1 + ) + mg_cond_embeddings = self.drop(mg_cond_embed_tensor + cond_pos_embed) out_pos_embed = self.out_pos_embed(ng_noisy_action_embed) ng_action_embeddings = self.drop(ng_noisy_action_embed + out_pos_embed) - pg_action_embeddings = self.drop(pg_noisy_action_embed + out_pos_embed) - # ig_action_embeddings = self.drop(ig_noisy_action_embed + out_pos_embed) + mg_action_embeddings = self.drop(mg_noisy_action_embed + out_pos_embed) label_action_embeddings = self.drop(label_embed + out_pos_embed) augment_action_embeddings = self.drop(augment_embed + out_pos_embed) - # ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) - ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask) + ng_output = self.decoder(tgt=ng_action_embeddings, memory=ng_cond_embeddings, tgt_mask=self.tgt_mask) ng_output = self.layernorm(ng_output) noise_pred_ng = self.action_head(ng_output) - pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) - # pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask) - pg_output = self.layernorm(pg_output) - noise_pred_pg = self.action_head(pg_output) - - # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device)) - # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask) - # ig_output = self.layernorm(ig_output) - # noise_pred_ig = self.action_head(ig_output) + mg_output = self.decoder( + tgt=mg_action_embeddings, memory=mg_cond_embeddings, tgt_mask=self.tgt_mask.to(ng_action_embeddings.device) + ) + mg_output = self.layernorm(mg_output) + noise_pred_mg = self.action_head(mg_output) - cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device)) - # cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask) + cr_label_output = self.decoder( + tgt=label_action_embeddings, memory=ng_cond_embeddings, memory_mask=self.cond_critic_mask.to(self._device) + ) cr_label_output = self.layernorm(cr_label_output) - cr_label_pred = self.critic_head(cr_label_output.mean(dim=1))[:,0] + cr_label_pred = self.critic_head(cr_label_output.mean(dim=1))[:, 0] - cr_augment_output = self.decoder(tgt = augment_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device)) + cr_augment_output = self.decoder( + tgt=augment_action_embeddings, memory=ng_cond_embeddings, memory_mask=self.cond_critic_mask.to(self._device) + ) cr_augment_output = self.layernorm(cr_augment_output) - cr_augment_pred = self.critic_head(cr_augment_output.mean(dim=1))[:,0] - return noise_pred_ng,noise_pred_pg,cr_label_pred,cr_augment_pred,[ng_noise,pg_noise] - + cr_augment_pred = self.critic_head(cr_augment_output.mean(dim=1))[:, 0] + return ( + noise_pred_ng, + noise_pred_mg, + cr_label_pred, + cr_augment_pred, + [ng_noise, mg_noise], + [imagegoal_aux_pred, pixelgoal_aux_pred], + ) + def _get_device(self): """Safe get device information""" # try to get device through model parameters @@ -249,14 +290,14 @@ def _get_device(self): return param.device except StopIteration: pass - + # try to get device through buffer try: for buffer in self.buffers(): return buffer.device except StopIteration: pass - + # try to get device through submodule for module in self.children(): try: @@ -264,45 +305,47 @@ def _get_device(self): return param.device except StopIteration: continue - + # finally revert to default device return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def predict_pointgoal_batch_action_vel(self,goal_point,input_images,input_depths,sample_num=32): + + def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_depths, sample_num=32): with torch.no_grad(): - tensor_point_goal = torch.as_tensor(goal_point,dtype=torch.float32,device=self._device) - rgbd_embed = self.rgbd_encoder(input_images,input_depths) + tensor_point_goal = torch.as_tensor(goal_point, dtype=torch.float32, device=self._device) + rgbd_embed = self.rgbd_encoder(input_images, input_depths) pointgoal_embed = self.point_encoder(tensor_point_goal).unsqueeze(1) - noisy_action = torch.randn((sample_num * pointgoal_embed.shape[0], self.predict_size, 3), device=self._device) + noisy_action = torch.randn( + (sample_num * pointgoal_embed.shape[0], self.predict_size, 3), device=self._device + ) naction = noisy_action self.noise_scheduler.set_timesteps(self.noise_scheduler.config.num_train_timesteps) for k in self.noise_scheduler.timesteps[:]: - noise_pred = self.predict_noise(naction,k.to(self._device).unsqueeze(0),pointgoal_embed,rgbd_embed) - naction = self.noise_scheduler.step(model_output=noise_pred,timestep=k,sample=naction).prev_sample + noise_pred = self.predict_noise(naction, k.to(self._device).unsqueeze(0), pointgoal_embed, rgbd_embed) + naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample - critic_values = self.predict_critic(naction,rgbd_embed) + critic_values = self.predict_critic(naction, rgbd_embed) all_trajectory = torch.cumsum(naction / 4.0, dim=1) negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] - return negative_trajectory,positive_trajectory - - def predict_nogoal_batch_action_vel(self,input_images,input_depths,sample_num=32): + return negative_trajectory, positive_trajectory + + def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num=32): with torch.no_grad(): - rgbd_embed = self.rgbd_encoder(input_images,input_depths) - nogoal_embed = torch.zeros_like(rgbd_embed[:,0:1]) + rgbd_embed = self.rgbd_encoder(input_images, input_depths) + nogoal_embed = torch.zeros_like(rgbd_embed[:, 0:1]) noisy_action = torch.randn((sample_num * nogoal_embed.shape[0], self.predict_size, 3), device=self._device) naction = noisy_action self.noise_scheduler.set_timesteps(self.noise_scheduler.config.num_train_timesteps) for k in self.noise_scheduler.timesteps[:]: - noise_pred = self.predict_noise(naction,k.unsqueeze(0),nogoal_embed,rgbd_embed) - naction = self.noise_scheduler.step(model_output=noise_pred,timestep=k,sample=naction).prev_sample + noise_pred = self.predict_noise(naction, k.unsqueeze(0), nogoal_embed, rgbd_embed) + naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample - critic_values = self.predict_critic(naction,rgbd_embed) + critic_values = self.predict_critic(naction, rgbd_embed) all_trajectory = torch.cumsum(naction / 4.0, dim=1) negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] - return negative_trajectory,positive_trajectory + return negative_trajectory, positive_trajectory diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index a1606b51..cd2e8794 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -1,6 +1,8 @@ +import math + import torch import torch.nn as nn -import math + from internnav.model.encoder.depth_anything.depth_anything_v2.dpt import DepthAnythingV2 @@ -8,6 +10,7 @@ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim + def forward(self, x): device = x.device half_dim = self.dim // 2 @@ -17,8 +20,10 @@ def forward(self, x): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb + class PositionalEncoding(nn.Module): """Positional encoding module""" + def __init__(self, embed_dim, max_len=1000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, embed_dim) @@ -29,10 +34,12 @@ def __init__(self, embed_dim, max_len=1000): self.register_buffer('pe', pe) def forward(self, x): - return self.pe[:x.size(1)] - + return self.pe[: x.size(1)] + + class LearnablePositionalEncoding(nn.Module): """Learnable positional encoding using nn.Embedding""" + def __init__(self, embed_dim, max_len=5000): super(LearnablePositionalEncoding, self).__init__() self.embed_dim = embed_dim @@ -61,8 +68,8 @@ def __init__(self, embed_dim, num_heads, target_length): self.target_embedding = nn.Embedding(target_length, embed_dim) # Positional encoding - self.positional_encoding = PositionalEncoding(embed_dim) - + self.positional_encoding = PositionalEncoding(embed_dim) + self.token_positional_encoding = LearnablePositionalEncoding(embed_dim) self.query_positional_encoding = LearnablePositionalEncoding(embed_dim) @@ -75,37 +82,35 @@ def forward(self, x, padding_mask=None): padding_mask: (bs, N) - Padding mask for input sequence (True for padding positions) """ bs, token_len, _ = x.shape - + # Add positional encoding to input token_pe = self.token_positional_encoding(x) x = x + token_pe - + query = self.target_embedding.weight.unsqueeze(0).expand(bs, -1, -1) # Get target sequence from embedding query_pe = self.query_positional_encoding(query) - + query = query + query_pe # Cross Attention: target is Query, x is Key and Value - out, _ = self.cross_attention( - query=query, - key=x, - value=x, - key_padding_mask=padding_mask - ) + out, _ = self.cross_attention(query=query, key=x, value=x, key_padding_mask=padding_mask) return out + class DAT_RGBD_Patch_Backbone(nn.Module): - def __init__(self, - image_size=224, - embed_size=512, - finetune=True, - memory_size=8, - checkpoint="checkpoints/depth_anything_v2_vits.pth", - input_dtype="bf16", - version=0.0, - device = 'cuda:0'): + def __init__( + self, + image_size=224, + embed_size=512, + finetune=True, + memory_size=8, + checkpoint="checkpoints/depth_anything_v2_vits.pth", + input_dtype="bf16", + version=0.0, + device='cuda:0', + ): super().__init__() self.finetune = finetune self.memory_size = memory_size @@ -118,7 +123,7 @@ def __init__(self, self.rgb_model = DepthAnythingV2(**model_configs['vits']) self.rgb_model.load_state_dict(torch.load(checkpoint), strict=False) self.rgb_model = self.rgb_model.pretrained - + self.preprocess_mean = torch.tensor([0.485, 0.456, 0.406], dtype=self.input_dtype) self.preprocess_std = torch.tensor([0.229, 0.224, 0.225], dtype=self.input_dtype) @@ -133,12 +138,12 @@ def __init__(self, self.former_query = nn.Embedding(self.memory_size * 16, 384) nn.init.constant_(self.former_query.weight, val=0) - + if self.version > 0.0: self.former_pe = nn.Embedding((self.memory_size * 2) * 256, 384) else: self.former_pe = nn.Embedding((self.memory_size + 1) * 256, 384) - + nn.init.constant_(self.former_pe.weight, val=0) self.former_net = nn.TransformerDecoder(nn.TransformerDecoderLayer(384, 8, batch_first=True), 2) self.project_layer = nn.Linear(384, embed_size) @@ -147,13 +152,17 @@ def forward(self, images, depths): if len(images.shape) == 4: tensor_images = images.to(dtype=self.input_dtype).permute(0, 3, 1, 2) tensor_images = tensor_images.reshape(-1, 3, self.image_size, self.image_size) - tensor_norm_images = (tensor_images - self.preprocess_mean.reshape(1, 3, 1, 1).to(images.device)) / self.preprocess_std.to(images.device).reshape(1, 3, 1, 1) + tensor_norm_images = ( + tensor_images - self.preprocess_mean.reshape(1, 3, 1, 1).to(images.device) + ) / self.preprocess_std.to(images.device).reshape(1, 3, 1, 1) image_token = self.rgb_model.get_intermediate_layers(tensor_norm_images)[0] elif len(images.shape) == 5: B, T, H, W, C = images.shape tensor_images = images.to(dtype=self.input_dtype).permute(0, 1, 4, 2, 3) tensor_images = tensor_images.reshape(-1, 3, self.image_size, self.image_size) - tensor_norm_images = (tensor_images - self.preprocess_mean.to(images.device).reshape(1, 3, 1, 1)) / self.preprocess_std.to(images.device).reshape(1, 3, 1, 1) + tensor_norm_images = ( + tensor_images - self.preprocess_mean.to(images.device).reshape(1, 3, 1, 1) + ) / self.preprocess_std.to(images.device).reshape(1, 3, 1, 1) image_token = self.rgb_model.get_intermediate_layers(tensor_norm_images)[0].reshape(B, T * 256, -1) if not self.finetune: @@ -170,30 +179,39 @@ def forward(self, images, depths): tensor_depths = tensor_depths.reshape(-1, 1, self.image_size, self.image_size) tensor_depths = torch.cat([tensor_depths, tensor_depths, tensor_depths], dim=1) depth_token = self.depth_model.get_intermediate_layers(tensor_depths)[0].reshape(B, T * 256, -1) - + if self.version > 0.0: - former_pe_indice = torch.arange((self.memory_size * 2) * 256, device=images.device).expand(image_token.shape[0], (self.memory_size * 2) * 256) + former_pe_indice = torch.arange((self.memory_size * 2) * 256, device=images.device).expand( + image_token.shape[0], (self.memory_size * 2) * 256 + ) else: - former_pe_indice = torch.arange((self.memory_size + 1) * 256, device=images.device).expand(image_token.shape[0], (self.memory_size + 1) * 256) + former_pe_indice = torch.arange((self.memory_size + 1) * 256, device=images.device).expand( + image_token.shape[0], (self.memory_size + 1) * 256 + ) former_pe = self.former_pe(former_pe_indice) former_token = torch.cat((image_token, depth_token), dim=1) + former_pe - former_query_indice = torch.arange(self.memory_size * 16, device=images.device).expand(image_token.shape[0], self.memory_size * 16) + former_query_indice = torch.arange(self.memory_size * 16, device=images.device).expand( + image_token.shape[0], self.memory_size * 16 + ) former_query = self.former_query(former_query_indice) memory_token = self.former_net(former_query, former_token) memory_token = self.project_layer(memory_token) return memory_token + class NavDP_RGBD_Backbone(nn.Module): - def __init__(self, - image_size=224, - embed_size=512, - finetune=True, - memory_size=8, - checkpoint="checkpoints/depth_anything_v2_vits.pth", - device='cuda:0'): + def __init__( + self, + image_size=224, + embed_size=512, + finetune=True, + memory_size=8, + checkpoint="checkpoints/depth_anything_v2_vits.pth", + device='cuda:0', + ): super().__init__() # ensure the device is valid if device is None: @@ -212,8 +230,8 @@ def __init__(self, # TODO: Hack for navdp training using transformers 4.51.0 when loading the checkpoint self.rgb_model.load_state_dict(torch.load(checkpoint), strict=False) self.rgb_model = self.rgb_model.pretrained.float() - self.preprocess_mean = torch.tensor([0.485,0.456,0.406],dtype=torch.float32) - self.preprocess_std = torch.tensor([0.229,0.224,0.225],dtype=torch.float32) + self.preprocess_mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) + self.preprocess_std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) if finetune: self.rgb_model.train() else: @@ -221,45 +239,52 @@ def __init__(self, self.depth_model = DepthAnythingV2(**model_configs['vits']) self.depth_model = self.depth_model.pretrained.float() self.depth_model.train() - self.former_query = LearnablePositionalEncoding(384,self.memory_size*16) - self.former_pe = LearnablePositionalEncoding(384,(self.memory_size+1)*256) - self.former_net = nn.TransformerDecoder(nn.TransformerDecoderLayer(384,8,batch_first=True),2) - self.project_layer = nn.Linear(384,embed_size) + self.former_query = LearnablePositionalEncoding(384, self.memory_size * 16) + self.former_pe = LearnablePositionalEncoding(384, (self.memory_size + 1) * 256) + self.former_net = nn.TransformerDecoder(nn.TransformerDecoderLayer(384, 8, batch_first=True), 2) + self.project_layer = nn.Linear(384, embed_size) self.to(device) - - def forward(self,images,depths): + + def forward(self, images, depths): device = self._get_device() images = images.to(device) depths = depths.to(device) if len(images.shape) == 4: - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,3,1,2) - tensor_images = tensor_images.reshape(-1,3,self.image_size,self.image_size) - tensor_norm_images = (tensor_images - self.preprocess_mean.reshape(1,3,1,1).to(device))/self.preprocess_std.reshape(1,3,1,1).to(device) + tensor_images = torch.as_tensor(images, dtype=torch.float32, device=device).permute(0, 3, 1, 2) + tensor_images = tensor_images.reshape(-1, 3, self.image_size, self.image_size) + tensor_norm_images = ( + tensor_images - self.preprocess_mean.reshape(1, 3, 1, 1).to(device) + ) / self.preprocess_std.reshape(1, 3, 1, 1).to(device) image_token = self.rgb_model.get_intermediate_layers(tensor_norm_images)[0] elif len(images.shape) == 5: - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,1,4,2,3) - B,T,C,H,W = tensor_images.shape - tensor_images = tensor_images.reshape(-1,3,self.image_size,self.image_size) - tensor_norm_images = (tensor_images - self.preprocess_mean.reshape(1,3,1,1).to(device))/self.preprocess_std.reshape(1,3,1,1).to(device) - image_token = self.rgb_model.get_intermediate_layers(tensor_norm_images)[0].reshape(B,T*256,-1) + tensor_images = torch.as_tensor(images, dtype=torch.float32, device=device).permute(0, 1, 4, 2, 3) + B, T, C, H, W = tensor_images.shape + tensor_images = tensor_images.reshape(-1, 3, self.image_size, self.image_size) + tensor_norm_images = ( + tensor_images - self.preprocess_mean.reshape(1, 3, 1, 1).to(device) + ) / self.preprocess_std.reshape(1, 3, 1, 1).to(device) + image_token = self.rgb_model.get_intermediate_layers(tensor_norm_images)[0].reshape(B, T * 256, -1) if not self.finetune: image_token = image_token.detach() if len(depths.shape) == 4: - tensor_depths = torch.as_tensor(depths,dtype=torch.float32,device=device).permute(0,3,1,2) - tensor_depths = tensor_depths.reshape(-1,1,self.image_size,self.image_size) - tensor_depths = torch.concat([tensor_depths,tensor_depths,tensor_depths],dim=1) + tensor_depths = torch.as_tensor(depths, dtype=torch.float32, device=device).permute(0, 3, 1, 2) + tensor_depths = tensor_depths.reshape(-1, 1, self.image_size, self.image_size) + tensor_depths = torch.concat([tensor_depths, tensor_depths, tensor_depths], dim=1) depth_token = self.depth_model.get_intermediate_layers(tensor_depths)[0] elif len(depths.shape) == 5: - tensor_depths = torch.as_tensor(depths,dtype=torch.float32,device=device).permute(0,1,4,2,3) - B,T,C,H,W = tensor_depths.shape - tensor_depths = tensor_depths.reshape(-1,1,self.image_size,self.image_size) - tensor_depths = torch.concat([tensor_depths,tensor_depths,tensor_depths],dim=1) - depth_token = self.depth_model.get_intermediate_layers(tensor_depths)[0].reshape(B,T*256,-1) - former_token = torch.concat((image_token,depth_token),dim=1) + self.former_pe(torch.concat((image_token,depth_token),dim=1)) - former_query = self.former_query(torch.zeros((image_token.shape[0], self.memory_size * 16, 384),device=device)) - memory_token = self.former_net(former_query,former_token) + tensor_depths = torch.as_tensor(depths, dtype=torch.float32, device=device).permute(0, 1, 4, 2, 3) + B, T, C, H, W = tensor_depths.shape + tensor_depths = tensor_depths.reshape(-1, 1, self.image_size, self.image_size) + tensor_depths = torch.concat([tensor_depths, tensor_depths, tensor_depths], dim=1) + depth_token = self.depth_model.get_intermediate_layers(tensor_depths)[0].reshape(B, T * 256, -1) + former_token = torch.concat((image_token, depth_token), dim=1) + self.former_pe( + torch.concat((image_token, depth_token), dim=1) + ) + former_query = self.former_query(torch.zeros((image_token.shape[0], self.memory_size * 16, 384), device=device)) + memory_token = self.former_net(former_query, former_token) memory_token = self.project_layer(memory_token) return memory_token + def _get_device(self): """get device safely""" # try to get device through model parameters @@ -268,14 +293,14 @@ def _get_device(self): return param.device except StopIteration: pass - + # try to get device through buffer try: for buffer in self.buffers(): return buffer.device except StopIteration: pass - + # try to get device through submodule for module in self.children(): try: @@ -283,60 +308,132 @@ def _get_device(self): return param.device except StopIteration: continue - + # finally revert to default device return torch.device("cuda" if torch.cuda.is_available() else "cpu") + class NavDP_ImageGoal_Backbone(nn.Module): - def __init__(self, - image_size=224, - embed_size=512, - device='cuda:0'): + def __init__(self, image_size=224, embed_size=512, device='cuda:0'): super().__init__() + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) self.device = device self.image_size = image_size self.embed_size = embed_size model_configs = {'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}} self.imagegoal_encoder = DepthAnythingV2(**model_configs['vits']) self.imagegoal_encoder = self.imagegoal_encoder.pretrained.float() - self.imagegoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=6, - out_channels = self.imagegoal_encoder.patch_embed.proj.out_channels, - kernel_size = self.imagegoal_encoder.patch_embed.proj.kernel_size, - stride = self.imagegoal_encoder.patch_embed.proj.stride, - padding = self.imagegoal_encoder.patch_embed.proj.padding) + self.imagegoal_encoder.patch_embed.proj = nn.Conv2d( + in_channels=6, + out_channels=self.imagegoal_encoder.patch_embed.proj.out_channels, + kernel_size=self.imagegoal_encoder.patch_embed.proj.kernel_size, + stride=self.imagegoal_encoder.patch_embed.proj.stride, + padding=self.imagegoal_encoder.patch_embed.proj.padding, + ) self.imagegoal_encoder.train() - self.project_layer = nn.Linear(384,embed_size) - - def forward(self,images): - assert len(images.shape) == 4 # B,C,H,W - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2) + self.project_layer = nn.Linear(384, embed_size) + self.to(device) + + def forward(self, images): + assert len(images.shape) == 4 # B,C,H,W + device = self._get_device() + images = images.to(device) + tensor_images = torch.as_tensor(images, dtype=torch.float32, device=device).permute(0, 3, 1, 2) image_token = self.imagegoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1) image_token = self.project_layer(image_token) return image_token + def _get_device(self): + """get device safely""" + # try to get device through model parameters + try: + for param in self.parameters(): + return param.device + except StopIteration: + pass + + # try to get device through buffer + try: + for buffer in self.buffers(): + return buffer.device + except StopIteration: + pass + + # try to get device through submodule + for module in self.children(): + try: + for param in module.parameters(): + return param.device + except StopIteration: + continue + + # finally revert to default device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + class NavDP_PixelGoal_Backbone(nn.Module): - def __init__(self, - image_size=224, - embed_size=512, - device='cuda:0'): + def __init__(self, image_size=224, embed_size=512, device='cuda:0'): super().__init__() + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) self.device = device self.image_size = image_size self.embed_size = embed_size model_configs = {'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}} self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() - self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=4, - out_channels = self.pixelgoal_encoder.patch_embed.proj.out_channels, - kernel_size = self.pixelgoal_encoder.patch_embed.proj.kernel_size, - stride = self.pixelgoal_encoder.patch_embed.proj.stride, - padding = self.pixelgoal_encoder.patch_embed.proj.padding) + self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d( + in_channels=7, + out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels, + kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size, + stride=self.pixelgoal_encoder.patch_embed.proj.stride, + padding=self.pixelgoal_encoder.patch_embed.proj.padding, + ) self.pixelgoal_encoder.train() - self.project_layer = nn.Linear(384,embed_size) - - def forward(self,images): - assert len(images.shape) == 4 # B,C,H,W - tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2) + self.project_layer = nn.Linear(384, embed_size) + self.to(device) + + def forward(self, images): + assert len(images.shape) == 4 # B,C,H,W + device = self._get_device() + images = images.to(device) + tensor_images = torch.as_tensor(images, dtype=torch.float32, device=device).permute(0, 3, 1, 2) image_token = self.pixelgoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1) image_token = self.project_layer(image_token) - return image_token \ No newline at end of file + return image_token + + def _get_device(self): + """get device safely""" + # try to get device through model parameters + try: + for param in self.parameters(): + return param.device + except StopIteration: + pass + + # try to get device through buffer + try: + for buffer in self.buffers(): + return buffer.device + except StopIteration: + pass + + # try to get device through submodule + for module in self.children(): + try: + for param in module.parameters(): + return param.device + except StopIteration: + continue + + # finally revert to default device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/internnav/trainer/navdp_trainer.py b/internnav/trainer/navdp_trainer.py index 73849f1c..5542562e 100644 --- a/internnav/trainer/navdp_trainer.py +++ b/internnav/trainer/navdp_trainer.py @@ -1,13 +1,12 @@ +import os +import time + import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + from internnav.trainer.base import BaseTrainer -import os -import time -from datetime import datetime -import multiprocessing + class NavDPTrainer(BaseTrainer): def __init__(self, config, **kwargs): @@ -21,13 +20,13 @@ def __init__(self, config, **kwargs): self.model_device = self.model.module.device else: self.model_device = self.model.device - + print(f"[Rank {dist.get_rank() if dist.is_initialized() else 0}] Model device: {self.model_device}") - + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # get model device model_device = next(model.parameters()).device - + # ensure all inputs are on the model device inputs_on_device = {} for key, value in inputs.items(): @@ -36,8 +35,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N inputs_on_device[key] = value.to(model_device, non_blocking=True) else: inputs_on_device[key] = value - + import os + import psutil current_pid = os.getpid() @@ -52,20 +52,21 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N print(f"There are {len(children)} training processes running") else: print("Cannot determine parent process") - + # Ensure all inputs are on the model device inputs_on_device = { "batch_pg": inputs["batch_pg"].to(model_device), "batch_ig": inputs["batch_ig"].to(model_device), + "batch_tg": inputs["batch_tg"].to(model_device), "batch_rgb": inputs["batch_rgb"].to(model_device), "batch_depth": inputs["batch_depth"].to(model_device), "batch_labels": inputs["batch_labels"].to(model_device), "batch_augments": inputs["batch_augments"].to(model_device), "batch_label_critic": inputs["batch_label_critic"].to(model_device), - "batch_augment_critic": inputs["batch_augment_critic"].to(model_device) + "batch_augment_critic": inputs["batch_augment_critic"].to(model_device), } torch.cuda.synchronize(model_device) - + # unpack input data and move to device # batch_pg = inputs["batch_pg"] # batch_ig = inputs["batch_ig"] @@ -75,36 +76,40 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # batch_augments = inputs["batch_augments"] batch_label_critic = inputs["batch_label_critic"] batch_augment_critic = inputs["batch_augment_critic"] - - pred_ng, pred_pg, critic_pred, augment_pred, noise = model( - inputs_on_device["batch_pg"], - inputs_on_device["batch_ig"], - inputs_on_device["batch_rgb"], - inputs_on_device["batch_depth"], - inputs_on_device["batch_labels"], - inputs_on_device["batch_augments"] - ) - + + pred_ng, pred_mg, critic_pred, augment_pred, noise, aux_pred = model( + inputs_on_device["batch_pg"], + inputs_on_device["batch_ig"], + inputs_on_device["batch_tg"], + inputs_on_device["batch_rgb"], + inputs_on_device["batch_depth"], + inputs_on_device["batch_labels"], + inputs_on_device["batch_augments"], + ) + ng_action_loss = (pred_ng - noise[0]).square().mean() - pg_action_loss = (pred_pg - noise[1]).square().mean() - # ig_action_loss = (pred_ig - noise[2]).square().mean() - action_loss = 0.5 * pg_action_loss + 0.5 * ng_action_loss - critic_loss = (critic_pred - batch_label_critic).square().mean() + \ - (augment_pred - batch_augment_critic).square().mean() - loss = 0.8 * action_loss + 0.2 * critic_loss - + mg_action_loss = (pred_mg - noise[1]).square().mean() + aux_loss = ( + 0.5 * (inputs_on_device["batch_pg"] - aux_pred[0]).square().mean() + + 0.5 * (inputs_on_device["batch_pg"] - aux_pred[1]).square().mean() + ) + action_loss = 0.5 * mg_action_loss + 0.5 * ng_action_loss + critic_loss = (critic_pred - batch_label_critic).square().mean() + ( + augment_pred - batch_augment_critic + ).square().mean() + loss = 0.8 * action_loss + 0.2 * critic_loss + 0.5 * aux_loss + outputs = { 'pred_ng': pred_ng, - 'pred_pg': pred_pg, - # 'pred_ig': pred_ig, + 'pred_mg': pred_mg, 'critic_pred': critic_pred, 'augment_pred': augment_pred, 'noise': noise, 'loss': loss, 'ng_action_loss': ng_action_loss, - 'pg_action_loss': pg_action_loss, - # 'ig_action_loss': ig_action_loss, - 'critic_loss': critic_loss + 'mg_action_loss': mg_action_loss, + 'aux_loss': aux_loss, + 'critic_loss': critic_loss, } # if self.logger: # self.logger.info( @@ -114,13 +119,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # f"Critic Loss: {critic_loss.item():.4f}" # ) - return (loss, outputs) if return_outputs else loss def create_optimizer(self): """create and return optimizer""" rank = dist.get_rank() if dist.is_initialized() else 0 - + # get learning rate try: lr = self.config.il.lr @@ -130,59 +134,45 @@ def create_optimizer(self): lr = 1e-4 if rank == 0: print(f"[Rank 0] Warning: Using default learning rate: {lr}") - + # Ensure the model is on the correct device if hasattr(self.model, 'module'): model_for_optim = self.model.module else: model_for_optim = self.model - + # Create optimizer - optimizer = torch.optim.Adam( - model_for_optim.parameters(), - lr=lr - ) - + optimizer = torch.optim.Adam(model_for_optim.parameters(), lr=lr) + if rank == 0: print(f"[Rank 0] Optimizer created with {len(optimizer.param_groups)} param groups") total_params = sum(p.numel() for p in model_for_optim.parameters() if p.requires_grad) print(f"[Rank 0] Total trainable parameters: {total_params:,}") - - return optimizer - + def create_scheduler(self, optimizer, num_training_steps: int): """Create learning rate scheduler""" - scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1.0, - end_factor=0.5, - total_iters=10000 - ) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=10000) return scheduler - + def create_optimizer_and_scheduler(self, num_training_steps: int): """override parent class method, completely control the creation process""" print("\n=== create optimizer and scheduler ===") - + # create optimizer self.optimizer = self.create_optimizer() - + # create scheduler (note the parameter order) self.lr_scheduler = self.create_scheduler(self.optimizer, num_training_steps) - + return self.optimizer, self.lr_scheduler - + def get_train_dataloader(self): world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 - sampler = DistributedSampler(self.train_dataset, - num_replicas=world_size, - rank=rank, - shuffle=True, - seed=1234) - + sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=rank, shuffle=True, seed=1234) + loader = DataLoader( self.train_dataset, batch_size=self.config.il.batch_size, @@ -190,7 +180,26 @@ def get_train_dataloader(self): num_workers=self.config.il.num_workers, pin_memory=True, drop_last=True, - collate_fn=self.data_collator + collate_fn=self.data_collator, ) # print(loader) - return loader \ No newline at end of file + return loader + + def save_model(self, output_dir, state_dict=None, **kwargs): + """ + save model to specified directory + + handle DDP wrapped model + """ + # check if it is a DDP wrapped model + if hasattr(self.model, 'module'): + # get original model + model_to_save = self.model.module + else: + model_to_save = self.model + + # ensure the output directory exists + os.makedirs(output_dir, exist_ok=True) + torch.save(model_to_save.state_dict(), output_dir + "navdp.ckpt") + + print(f"Saving model to {output_dir} (is DDP: {hasattr(self.model, 'module')})")