diff --git a/monodepth2/datasets/mono_dataset.py b/monodepth2/datasets/mono_dataset.py index a381934..fc7eaba 100644 --- a/monodepth2/datasets/mono_dataset.py +++ b/monodepth2/datasets/mono_dataset.py @@ -173,7 +173,7 @@ def __getitem__(self, index): inputs[("inv_K", scale)] = torch.from_numpy(inv_K) if do_color_aug: - color_aug = transforms.ColorJitter.get_params( + color_aug = transforms.ColorJitter( self.brightness, self.contrast, self.saturation, self.hue) else: color_aug = (lambda x: x) diff --git a/monodepth2/evaluate_depth.py b/monodepth2/evaluate_depth.py index 7746ef9..5b11b15 100644 --- a/monodepth2/evaluate_depth.py +++ b/monodepth2/evaluate_depth.py @@ -163,7 +163,7 @@ def evaluate(opt): quit() gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") - gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1')["data"] + gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] print("-> Evaluating") diff --git a/monodepth2/networks/depth_decoder.py b/monodepth2/networks/depth_decoder.py index 786899a..5c1082c 100644 --- a/monodepth2/networks/depth_decoder.py +++ b/monodepth2/networks/depth_decoder.py @@ -46,6 +46,7 @@ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=2, use_skips self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU() def forward(self, input_features): self.outputs = {} @@ -60,8 +61,8 @@ def forward(self, input_features): x = torch.cat(x, 1) x = self.convs[("upconv", i, 1)](x) if i in self.scales: - outs = self.sigmoid(self.convs[("dispconv", i)](x)) - self.outputs[("disp", i)] = outs[:, 0, :, :] - self.outputs[("disp-sigma", i)] = outs[:, 1, :, :] + outs = self.convs[("dispconv", i)](x) + self.outputs[("disp", i)] = self.sigmoid(torch.unsqueeze(outs[:, 0, :, :], axis=1)) + self.outputs[("disp-sigma", i)] = self.relu(torch.unsqueeze(outs[:, 1, :, :], axis=1)) return self.outputs diff --git a/monodepth2/networks/pose_cnn.py b/monodepth2/networks/pose_cnn.py index a8e4dda..17b89be 100644 --- a/monodepth2/networks/pose_cnn.py +++ b/monodepth2/networks/pose_cnn.py @@ -31,11 +31,11 @@ def __init__(self, num_input_frames): self.num_convs = len(self.convs) - self.relu = nn.ReLU(True) + self.relu = nn.ReLU() - self.tanh = nn.Tanh(True) + self.tanh = nn.Tanh() - self.softplus = nn.Softplus(True) + self.softplus = nn.Softplus() self.net = nn.ModuleList(list(self.convs.values())) @@ -55,8 +55,8 @@ def forward(self, out): out_b = out_b.mean(3).mean(2) out_pose = 0.01 * out_pose.view(-1, self.num_input_frames - 1, 1, 6) - out_a = 0.01 * out_a.view(-1, self.num_input_frames - 1, 1, 1) - out_b = 0.01 * out_b.view(-1, self.num_input_frames - 1, 1, 1) + out_a = out_a.view(-1, self.num_input_frames - 1, 1, 1) + out_b = out_b.view(-1, self.num_input_frames - 1, 1, 1) axisangle = out_pose[..., :3] translation = out_pose[..., 3:] diff --git a/monodepth2/networks/pose_decoder.py b/monodepth2/networks/pose_decoder.py index 4b03b60..99f23c8 100644 --- a/monodepth2/networks/pose_decoder.py +++ b/monodepth2/networks/pose_decoder.py @@ -28,6 +28,11 @@ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=Non self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) + self.a_conv = nn.Conv2d(256, num_frames_to_predict_for, 1) + self.b_conv = nn.Conv2d(256, num_frames_to_predict_for, 1) + + self.tanh = nn.Tanh() + self.softplus = nn.Softplus() self.relu = nn.ReLU() self.net = nn.ModuleList(list(self.convs.values())) @@ -39,16 +44,32 @@ def forward(self, input_features): cat_features = torch.cat(cat_features, 1) out = cat_features + out_ab = None for i in range(3): out = self.convs[("pose", i)](out) if i != 2: out = self.relu(out) - out = out.mean(3).mean(2) + if i==1: + out_ab = out + + out_pose = out.mean(3).mean(2) + + + out_a = self.a_conv(out_ab) + out_a = self.softplus(out_a) + out_a = out_a.mean(3).mean(2) + out_b = self.b_conv(out_ab) + out_b = self.tanh(out_b) + out_b = out_b.mean(3).mean(2) out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) + out_a = 0.01 * out_a.view(-1, self.num_frames_to_predict_for, 1, 1) + out_b = 0.01 * out_b.view(-1, self.num_frames_to_predict_for, 1, 1) - axisangle = out[..., :3] - translation = out[..., 3:] + axisangle = out_pose[..., :3] + translation = out_pose[..., 3:] + a = out_a + b = out_b - return axisangle, translation + return axisangle, translation, a, b diff --git a/monodepth2/test_simple.py b/monodepth2/test_simple.py index d74d63b..063d8e0 100644 --- a/monodepth2/test_simple.py +++ b/monodepth2/test_simple.py @@ -131,13 +131,23 @@ def test_simple(args): input_image = input_image.to(device) features = encoder(input_image) outputs = depth_decoder(features) + output_name = os.path.splitext(os.path.basename(image_path))[0] disp = outputs[("disp", 0)] + + dis_sigma = outputs[("disp-sigma", 0)] + disp_sigma_resized = torch.nn.functional.interpolate( + disp, (original_height, original_width), mode="bilinear", align_corners=False) + disp_sigma_im = disp_sigma_resized.detach().cpu().numpy() + disp_sigma_im = pil.fromarray((disp_sigma_im[0][0] * 255.0).astype(np.uint8)) + name_dest_im_sigma = os.path.join(output_directory, "{}_disp_sigma.jpeg".format(output_name)) + disp_sigma_im.save(name_dest_im_sigma) + + disp_resized = torch.nn.functional.interpolate( disp, (original_height, original_width), mode="bilinear", align_corners=False) # Saving numpy file - output_name = os.path.splitext(os.path.basename(image_path))[0] scaled_disp, depth = disp_to_depth(disp, 0.1, 100) if args.pred_metric_depth: name_dest_npy = os.path.join(output_directory, "{}_depth.npy".format(output_name)) diff --git a/monodepth2/trainer.py b/monodepth2/trainer.py index 90450ac..77a345e 100644 --- a/monodepth2/trainer.py +++ b/monodepth2/trainer.py @@ -14,6 +14,8 @@ import torch.optim as optim from torch.utils.data import DataLoader from tensorboardX import SummaryWriter +from torchvision.utils import save_image + import json @@ -121,6 +123,7 @@ def __init__(self, options): self.train_loader = DataLoader( train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True ) + val_dataset = self.dataset( self.opt.data_path, val_filenames, self.opt.height, self.opt.width, self.opt.frame_ids, 4, is_train=False, img_ext=img_ext ) @@ -210,11 +213,11 @@ def run_epoch(self): self.compute_depth_losses(inputs, outputs, losses) self.log("train", inputs, outputs, losses) - self.val() + self.val(self.epoch * 10 + batch_idx) self.step += 1 - def process_batch(self, inputs): + def process_batch(self, inputs, batch_idx = -1): """Pass a minibatch through the network and generate images and losses """ for key, ipt in inputs.items(): @@ -244,7 +247,7 @@ def process_batch(self, inputs): outputs.update(self.predict_poses(inputs, features)) self.generate_images_pred(inputs, outputs) - losses = self.compute_losses(inputs, outputs) + losses = self.compute_losses(inputs, outputs, batch_idx) return outputs, losses @@ -318,18 +321,18 @@ def predict_poses(self, inputs, features): return outputs - def val(self): + def val(self, batch_idx): """Validate the model on a single minibatch """ self.set_eval() try: - inputs = self.val_iter.next() + inputs = next(self.val_iter) except StopIteration: self.val_iter = iter(self.val_loader) - inputs = self.val_iter.next() + inputs = next(self.val_iter) with torch.no_grad(): - outputs, losses = self.process_batch(inputs) + outputs, losses = self.process_batch(inputs, batch_idx) if "depth_gt" in inputs: self.compute_depth_losses(inputs, outputs, losses) @@ -345,10 +348,13 @@ def generate_images_pred(self, inputs, outputs): """ for scale in self.opt.scales: disp = outputs[("disp", scale)] + sigma = outputs[("disp-sigma",scale)] if self.opt.v1_multiscale: source_scale = scale else: disp = F.interpolate(disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False) + disp_sigma = F.interpolate(sigma, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False) + outputs[("disp-sigma",scale)] = disp_sigma source_scale = 0 _, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth) @@ -389,21 +395,26 @@ def compute_reprojection_loss(self, pred, target, sigma): """Computes reprojection loss between a batch of predicted and target images """ abs_diff = torch.abs(target - pred) - l1_loss = abs_diff + l1_loss = abs_diff.mean(1, True) if self.opt.no_ssim: reprojection_loss = l1_loss else: - ssim_loss = (self.ssim(pred, target)) / sigma + torch.log(sigma) + ssim_loss = (self.ssim(pred, target)).mean(1, True) reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss - reprojection_loss = reprojection_loss / sigma + torch.log(sigma) + # Reference: https://github.com/no-Seaweed/Learning-Deep-Learning-1/blob/master/paper_notes/sfm_learner.md + # transformed_sigma = (10 * sigma + 0.1) + + # Exp 1 + transformed_sigma = sigma + 1 + reprojection_loss = (reprojection_loss / transformed_sigma) + torch.log(transformed_sigma) - reprojection_loss = reprojection_loss.mean(1, True) + # reprojection_loss = (reprojection_loss * sigma) return reprojection_loss - def compute_losses(self, inputs, outputs): + def compute_losses(self, inputs, outputs, batch_idx=-1): """Compute the reprojection and smoothness losses for a minibatch """ losses = {} @@ -429,6 +440,11 @@ def compute_losses(self, inputs, outputs): a = outputs[("a", 0, frame_id)].unsqueeze(1) b = outputs[("b", 0, frame_id)].unsqueeze(1) target_frame = target * a + b + if batch_idx != -1: + save_image(pred[-1], f'val_images/pred_{self.step}.jpeg') + save_image(target_frame[-1], f'val_images/target_frame_{self.step}.jpeg') + save_image(target[-1], f'val_images/target_{self.step}.jpeg') + reprojection_losses.append(self.compute_reprojection_loss(pred, target_frame, sigma)) ab_losses.append((a - 1) ** 2 + b ** 2) @@ -440,7 +456,10 @@ def compute_losses(self, inputs, outputs): identity_reprojection_losses = [] for frame_id in self.opt.frame_ids[1:]: pred = inputs[("color", frame_id, source_scale)] - identity_reprojection_losses.append(self.compute_reprojection_loss(pred, target)) + a = outputs[("a", 0, frame_id)].unsqueeze(1) + b = outputs[("b", 0, frame_id)].unsqueeze(1) + target_frame = target * a + b + identity_reprojection_losses.append(self.compute_reprojection_loss(pred, target_frame, sigma)) identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1) @@ -488,8 +507,13 @@ def compute_losses(self, inputs, outputs): mean_disp = disp.mean(2, True).mean(3, True) norm_disp = disp / (mean_disp + 1e-7) smooth_loss = get_smooth_loss(norm_disp, color) - reg_loss = smooth_loss + self.opt.ab_weight * ab_loss + + reg_loss = smooth_loss + self.opt.ab_weight * torch.mean(ab_loss) + # loss += torch.mean((sigma - 1) ** 2) + # categorical_loss = nn.CrossEntropyLoss() + # loss += categorical_loss(sigma, torch.ones_like(sigma)) + # print(sigma.min(), sigma.max(), sigma.mean()) loss += self.opt.disparity_smoothness * reg_loss / (2 ** scale) total_loss += loss