File size: 2,699 Bytes
854f0d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class DepthLoss(nn.Module):
def __init__(self, type='l1'):
super(DepthLoss, self).__init__()
self.type = type
def forward(self, depth_pred, depth_gt, mask=None):
if (depth_gt < 0).sum() > 0:
# print("no depth loss")
return torch.tensor(0.0).to(depth_pred.device)
if mask is not None:
mask_d = (depth_gt > 0).float()
mask = mask * mask_d
mask_sum = mask.sum() + 1e-5
depth_error = (depth_pred - depth_gt) * mask
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
reduction='sum') / mask_sum
else:
depth_error = depth_pred - depth_gt
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
reduction='mean')
return depth_loss
def forward(self, depth_pred, depth_gt, mask=None):
if mask is not None:
mask_d = (depth_gt > 0).float()
mask = mask * mask_d
mask_sum = mask.sum() + 1e-5
depth_error = (depth_pred - depth_gt) * mask
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
reduction='sum') / mask_sum
else:
depth_error = depth_pred - depth_gt
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
reduction='mean')
return depth_loss
class DepthSmoothLoss(nn.Module):
def __init__(self):
super(DepthSmoothLoss, self).__init__()
def forward(self, disp, img, mask):
"""
Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
:param disp: [B, 1, H, W]
:param img: [B, 1, H, W]
:param mask: [B, 1, H, W]
:return:
"""
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
grad_disp_x *= torch.exp(-grad_img_x)
grad_disp_y *= torch.exp(-grad_img_y)
grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean()
return grad_disp
|