Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monodepth2/datasets/mono_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __getitem__(self, index):
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)

if do_color_aug:
color_aug = transforms.ColorJitter.get_params(
color_aug = transforms.ColorJitter(
self.brightness, self.contrast, self.saturation, self.hue)
else:
color_aug = (lambda x: x)
Expand Down
2 changes: 1 addition & 1 deletion monodepth2/evaluate_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def evaluate(opt):
quit()

gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1')["data"]
gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"]

print("-> Evaluating")

Expand Down
7 changes: 4 additions & 3 deletions monodepth2/networks/depth_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=2, use_skips

self.decoder = nn.ModuleList(list(self.convs.values()))
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()

def forward(self, input_features):
self.outputs = {}
Expand All @@ -60,8 +61,8 @@ def forward(self, input_features):
x = torch.cat(x, 1)
x = self.convs[("upconv", i, 1)](x)
if i in self.scales:
outs = self.sigmoid(self.convs[("dispconv", i)](x))
self.outputs[("disp", i)] = outs[:, 0, :, :]
self.outputs[("disp-sigma", i)] = outs[:, 1, :, :]
outs = self.convs[("dispconv", i)](x)
self.outputs[("disp", i)] = self.sigmoid(torch.unsqueeze(outs[:, 0, :, :], axis=1))
self.outputs[("disp-sigma", i)] = self.relu(torch.unsqueeze(outs[:, 1, :, :], axis=1))

return self.outputs
10 changes: 5 additions & 5 deletions monodepth2/networks/pose_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def __init__(self, num_input_frames):

self.num_convs = len(self.convs)

self.relu = nn.ReLU(True)
self.relu = nn.ReLU()

self.tanh = nn.Tanh(True)
self.tanh = nn.Tanh()

self.softplus = nn.Softplus(True)
self.softplus = nn.Softplus()

self.net = nn.ModuleList(list(self.convs.values()))

Expand All @@ -55,8 +55,8 @@ def forward(self, out):
out_b = out_b.mean(3).mean(2)

out_pose = 0.01 * out_pose.view(-1, self.num_input_frames - 1, 1, 6)
out_a = 0.01 * out_a.view(-1, self.num_input_frames - 1, 1, 1)
out_b = 0.01 * out_b.view(-1, self.num_input_frames - 1, 1, 1)
out_a = out_a.view(-1, self.num_input_frames - 1, 1, 1)
out_b = out_b.view(-1, self.num_input_frames - 1, 1, 1)

axisangle = out_pose[..., :3]
translation = out_pose[..., 3:]
Expand Down
29 changes: 25 additions & 4 deletions monodepth2/networks/pose_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=Non
self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)

self.a_conv = nn.Conv2d(256, num_frames_to_predict_for, 1)
self.b_conv = nn.Conv2d(256, num_frames_to_predict_for, 1)

self.tanh = nn.Tanh()
self.softplus = nn.Softplus()
self.relu = nn.ReLU()

self.net = nn.ModuleList(list(self.convs.values()))
Expand All @@ -39,16 +44,32 @@ def forward(self, input_features):
cat_features = torch.cat(cat_features, 1)

out = cat_features
out_ab = None
for i in range(3):
out = self.convs[("pose", i)](out)
if i != 2:
out = self.relu(out)

out = out.mean(3).mean(2)
if i==1:
out_ab = out

out_pose = out.mean(3).mean(2)


out_a = self.a_conv(out_ab)
out_a = self.softplus(out_a)
out_a = out_a.mean(3).mean(2)
out_b = self.b_conv(out_ab)
out_b = self.tanh(out_b)
out_b = out_b.mean(3).mean(2)

out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
out_a = 0.01 * out_a.view(-1, self.num_frames_to_predict_for, 1, 1)
out_b = 0.01 * out_b.view(-1, self.num_frames_to_predict_for, 1, 1)

axisangle = out[..., :3]
translation = out[..., 3:]
axisangle = out_pose[..., :3]
translation = out_pose[..., 3:]
a = out_a
b = out_b

return axisangle, translation
return axisangle, translation, a, b
12 changes: 11 additions & 1 deletion monodepth2/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,23 @@ def test_simple(args):
input_image = input_image.to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
output_name = os.path.splitext(os.path.basename(image_path))[0]

disp = outputs[("disp", 0)]

dis_sigma = outputs[("disp-sigma", 0)]
disp_sigma_resized = torch.nn.functional.interpolate(
disp, (original_height, original_width), mode="bilinear", align_corners=False)
disp_sigma_im = disp_sigma_resized.detach().cpu().numpy()
disp_sigma_im = pil.fromarray((disp_sigma_im[0][0] * 255.0).astype(np.uint8))
name_dest_im_sigma = os.path.join(output_directory, "{}_disp_sigma.jpeg".format(output_name))
disp_sigma_im.save(name_dest_im_sigma)


disp_resized = torch.nn.functional.interpolate(
disp, (original_height, original_width), mode="bilinear", align_corners=False)

# Saving numpy file
output_name = os.path.splitext(os.path.basename(image_path))[0]
scaled_disp, depth = disp_to_depth(disp, 0.1, 100)
if args.pred_metric_depth:
name_dest_npy = os.path.join(output_directory, "{}_depth.npy".format(output_name))
Expand Down
52 changes: 38 additions & 14 deletions monodepth2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torchvision.utils import save_image


import json

Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(self, options):
self.train_loader = DataLoader(
train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True
)

val_dataset = self.dataset(
self.opt.data_path, val_filenames, self.opt.height, self.opt.width, self.opt.frame_ids, 4, is_train=False, img_ext=img_ext
)
Expand Down Expand Up @@ -210,11 +213,11 @@ def run_epoch(self):
self.compute_depth_losses(inputs, outputs, losses)

self.log("train", inputs, outputs, losses)
self.val()
self.val(self.epoch * 10 + batch_idx)

self.step += 1

def process_batch(self, inputs):
def process_batch(self, inputs, batch_idx = -1):
"""Pass a minibatch through the network and generate images and losses
"""
for key, ipt in inputs.items():
Expand Down Expand Up @@ -244,7 +247,7 @@ def process_batch(self, inputs):
outputs.update(self.predict_poses(inputs, features))

self.generate_images_pred(inputs, outputs)
losses = self.compute_losses(inputs, outputs)
losses = self.compute_losses(inputs, outputs, batch_idx)

return outputs, losses

Expand Down Expand Up @@ -318,18 +321,18 @@ def predict_poses(self, inputs, features):

return outputs

def val(self):
def val(self, batch_idx):
"""Validate the model on a single minibatch
"""
self.set_eval()
try:
inputs = self.val_iter.next()
inputs = next(self.val_iter)
except StopIteration:
self.val_iter = iter(self.val_loader)
inputs = self.val_iter.next()
inputs = next(self.val_iter)

with torch.no_grad():
outputs, losses = self.process_batch(inputs)
outputs, losses = self.process_batch(inputs, batch_idx)

if "depth_gt" in inputs:
self.compute_depth_losses(inputs, outputs, losses)
Expand All @@ -345,10 +348,13 @@ def generate_images_pred(self, inputs, outputs):
"""
for scale in self.opt.scales:
disp = outputs[("disp", scale)]
sigma = outputs[("disp-sigma",scale)]
if self.opt.v1_multiscale:
source_scale = scale
else:
disp = F.interpolate(disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
disp_sigma = F.interpolate(sigma, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
outputs[("disp-sigma",scale)] = disp_sigma
source_scale = 0

_, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)
Expand Down Expand Up @@ -389,21 +395,26 @@ def compute_reprojection_loss(self, pred, target, sigma):
"""Computes reprojection loss between a batch of predicted and target images
"""
abs_diff = torch.abs(target - pred)
l1_loss = abs_diff
l1_loss = abs_diff.mean(1, True)

if self.opt.no_ssim:
reprojection_loss = l1_loss
else:
ssim_loss = (self.ssim(pred, target)) / sigma + torch.log(sigma)
ssim_loss = (self.ssim(pred, target)).mean(1, True)
reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

reprojection_loss = reprojection_loss / sigma + torch.log(sigma)
# Reference: https://github.com/no-Seaweed/Learning-Deep-Learning-1/blob/master/paper_notes/sfm_learner.md
# transformed_sigma = (10 * sigma + 0.1)

# Exp 1
transformed_sigma = sigma + 1
reprojection_loss = (reprojection_loss / transformed_sigma) + torch.log(transformed_sigma)

reprojection_loss = reprojection_loss.mean(1, True)
# reprojection_loss = (reprojection_loss * sigma)

return reprojection_loss

def compute_losses(self, inputs, outputs):
def compute_losses(self, inputs, outputs, batch_idx=-1):
"""Compute the reprojection and smoothness losses for a minibatch
"""
losses = {}
Expand All @@ -429,6 +440,11 @@ def compute_losses(self, inputs, outputs):
a = outputs[("a", 0, frame_id)].unsqueeze(1)
b = outputs[("b", 0, frame_id)].unsqueeze(1)
target_frame = target * a + b
if batch_idx != -1:
save_image(pred[-1], f'val_images/pred_{self.step}.jpeg')
save_image(target_frame[-1], f'val_images/target_frame_{self.step}.jpeg')
save_image(target[-1], f'val_images/target_{self.step}.jpeg')

reprojection_losses.append(self.compute_reprojection_loss(pred, target_frame, sigma))
ab_losses.append((a - 1) ** 2 + b ** 2)

Expand All @@ -440,7 +456,10 @@ def compute_losses(self, inputs, outputs):
identity_reprojection_losses = []
for frame_id in self.opt.frame_ids[1:]:
pred = inputs[("color", frame_id, source_scale)]
identity_reprojection_losses.append(self.compute_reprojection_loss(pred, target))
a = outputs[("a", 0, frame_id)].unsqueeze(1)
b = outputs[("b", 0, frame_id)].unsqueeze(1)
target_frame = target * a + b
identity_reprojection_losses.append(self.compute_reprojection_loss(pred, target_frame, sigma))

identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)

Expand Down Expand Up @@ -488,8 +507,13 @@ def compute_losses(self, inputs, outputs):
mean_disp = disp.mean(2, True).mean(3, True)
norm_disp = disp / (mean_disp + 1e-7)
smooth_loss = get_smooth_loss(norm_disp, color)
reg_loss = smooth_loss + self.opt.ab_weight * ab_loss

reg_loss = smooth_loss + self.opt.ab_weight * torch.mean(ab_loss)

# loss += torch.mean((sigma - 1) ** 2)
# categorical_loss = nn.CrossEntropyLoss()
# loss += categorical_loss(sigma, torch.ones_like(sigma))
# print(sigma.min(), sigma.max(), sigma.mean())
loss += self.opt.disparity_smoothness * reg_loss / (2 ** scale)

total_loss += loss
Expand Down