ZeroShape / utils /loss.py
zxhuang1698's picture
initial commit
414b431
import torch
import torch.nn as nn
import torch.nn.functional as torch_F
from copy import deepcopy
from model.depth.midas_loss import MidasLoss
class Loss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = deepcopy(opt)
self.occ_loss = nn.BCEWithLogitsLoss(reduction='none')
self.midas_loss = MidasLoss(alpha=opt.training.depth_loss.grad_reg,
inverse_depth=opt.training.depth_loss.depth_inv,
shrink_mask=opt.training.depth_loss.mask_shrink)
def shape_loss(self, pred_occ_raw, gt_sdf):
assert len(pred_occ_raw.shape) == 2
assert len(gt_sdf.shape) == 2
# [B, N]
gt_occ = (gt_sdf < 0).float()
loss = self.occ_loss(pred_occ_raw, gt_occ)
weight_mask = torch.ones_like(loss)
thres = self.opt.training.shape_loss.impt_thres
weight_mask[torch.abs(gt_sdf) < thres] = weight_mask[torch.abs(gt_sdf) < thres] * self.opt.training.shape_loss.impt_weight
loss = loss * weight_mask
return loss.mean()
def depth_loss(self, pred_depth, gt_depth, mask):
assert len(pred_depth.shape) == len(gt_depth.shape) == len(mask.shape) == 4
assert pred_depth.shape[1] == gt_depth.shape[1] == mask.shape[1] == 1
loss = self.midas_loss(pred_depth, gt_depth, mask)
return loss
def intr_loss(self, seen_pred, seen_gt, mask):
assert len(seen_pred.shape) == len(seen_gt.shape) == 3
assert len(mask.shape) == 2
# [B, HW]
distance = torch.sum((seen_pred - seen_gt)**2, dim=-1)
loss = (distance * mask).sum() / (mask.sum() + 1.e-8)
return loss