File size: 2,699 Bytes
854f0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F


class DepthLoss(nn.Module):
    def __init__(self, type='l1'):
        super(DepthLoss, self).__init__()
        self.type = type


    def forward(self, depth_pred, depth_gt, mask=None):
            if (depth_gt < 0).sum() > 0:
                # print("no depth loss")
                return torch.tensor(0.0).to(depth_pred.device)
            if mask is not None:
                mask_d = (depth_gt > 0).float()

                mask = mask * mask_d

                mask_sum = mask.sum() + 1e-5
                depth_error = (depth_pred - depth_gt) * mask
                depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
                                    reduction='sum') / mask_sum
            else:
                depth_error = depth_pred - depth_gt
                depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
                                    reduction='mean')
            return depth_loss

def forward(self, depth_pred, depth_gt, mask=None):
        if mask is not None:
            mask_d = (depth_gt > 0).float()

            mask = mask * mask_d

            mask_sum = mask.sum() + 1e-5
            depth_error = (depth_pred - depth_gt) * mask
            depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
                                   reduction='sum') / mask_sum
        else:
            depth_error = depth_pred - depth_gt
            depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
                                   reduction='mean')
        return depth_loss

class DepthSmoothLoss(nn.Module):
    def __init__(self):
        super(DepthSmoothLoss, self).__init__()

    def forward(self, disp, img, mask):
        """
        Computes the smoothness loss for a disparity image
        The color image is used for edge-aware smoothness
        :param disp: [B, 1, H, W]
        :param img: [B, 1, H, W]
        :param mask: [B, 1, H, W]
        :return:
        """
        grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
        grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

        grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
        grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

        grad_disp_x *= torch.exp(-grad_img_x)
        grad_disp_y *= torch.exp(-grad_img_y)

        grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean()

        return grad_disp