|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DepthLoss(nn.Module): |
|
def __init__(self, type='l1'): |
|
super(DepthLoss, self).__init__() |
|
self.type = type |
|
|
|
|
|
def forward(self, depth_pred, depth_gt, mask=None): |
|
if (depth_gt < 0).sum() > 0: |
|
|
|
return torch.tensor(0.0).to(depth_pred.device) |
|
if mask is not None: |
|
mask_d = (depth_gt > 0).float() |
|
|
|
mask = mask * mask_d |
|
|
|
mask_sum = mask.sum() + 1e-5 |
|
depth_error = (depth_pred - depth_gt) * mask |
|
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
|
reduction='sum') / mask_sum |
|
else: |
|
depth_error = depth_pred - depth_gt |
|
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
|
reduction='mean') |
|
return depth_loss |
|
|
|
def forward(self, depth_pred, depth_gt, mask=None): |
|
if mask is not None: |
|
mask_d = (depth_gt > 0).float() |
|
|
|
mask = mask * mask_d |
|
|
|
mask_sum = mask.sum() + 1e-5 |
|
depth_error = (depth_pred - depth_gt) * mask |
|
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
|
reduction='sum') / mask_sum |
|
else: |
|
depth_error = depth_pred - depth_gt |
|
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), |
|
reduction='mean') |
|
return depth_loss |
|
|
|
class DepthSmoothLoss(nn.Module): |
|
def __init__(self): |
|
super(DepthSmoothLoss, self).__init__() |
|
|
|
def forward(self, disp, img, mask): |
|
""" |
|
Computes the smoothness loss for a disparity image |
|
The color image is used for edge-aware smoothness |
|
:param disp: [B, 1, H, W] |
|
:param img: [B, 1, H, W] |
|
:param mask: [B, 1, H, W] |
|
:return: |
|
""" |
|
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) |
|
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) |
|
|
|
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) |
|
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) |
|
|
|
grad_disp_x *= torch.exp(-grad_img_x) |
|
grad_disp_y *= torch.exp(-grad_img_y) |
|
|
|
grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean() |
|
|
|
return grad_disp |
|
|