File size: 1,713 Bytes
414b431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as torch_F

from copy import deepcopy
from model.depth.midas_loss import MidasLoss

class Loss(nn.Module):

    def __init__(self, opt):
        super().__init__()
        self.opt = deepcopy(opt)
        self.occ_loss = nn.BCEWithLogitsLoss(reduction='none')
        self.midas_loss = MidasLoss(alpha=opt.training.depth_loss.grad_reg, 
                                    inverse_depth=opt.training.depth_loss.depth_inv, 
                                    shrink_mask=opt.training.depth_loss.mask_shrink)

    def shape_loss(self, pred_occ_raw, gt_sdf):
        assert len(pred_occ_raw.shape) == 2
        assert len(gt_sdf.shape) == 2
        # [B, N]
        gt_occ = (gt_sdf < 0).float()
        loss = self.occ_loss(pred_occ_raw, gt_occ)
        weight_mask = torch.ones_like(loss)
        thres = self.opt.training.shape_loss.impt_thres
        weight_mask[torch.abs(gt_sdf) < thres] = weight_mask[torch.abs(gt_sdf) < thres] * self.opt.training.shape_loss.impt_weight 
        loss = loss * weight_mask
        return loss.mean()

    def depth_loss(self, pred_depth, gt_depth, mask):
        assert len(pred_depth.shape) == len(gt_depth.shape) == len(mask.shape) == 4
        assert pred_depth.shape[1] == gt_depth.shape[1] == mask.shape[1] == 1
        loss = self.midas_loss(pred_depth, gt_depth, mask)
        return loss
    
    def intr_loss(self, seen_pred, seen_gt, mask):
        assert len(seen_pred.shape) == len(seen_gt.shape) == 3
        assert len(mask.shape) == 2
        # [B, HW]
        distance = torch.sum((seen_pred - seen_gt)**2, dim=-1)
        loss = (distance * mask).sum() / (mask.sum() + 1.e-8)
        return loss