Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as torch_F | |
from copy import deepcopy | |
from model.depth.midas_loss import MidasLoss | |
class Loss(nn.Module): | |
def __init__(self, opt): | |
super().__init__() | |
self.opt = deepcopy(opt) | |
self.occ_loss = nn.BCEWithLogitsLoss(reduction='none') | |
self.midas_loss = MidasLoss(alpha=opt.training.depth_loss.grad_reg, | |
inverse_depth=opt.training.depth_loss.depth_inv, | |
shrink_mask=opt.training.depth_loss.mask_shrink) | |
def shape_loss(self, pred_occ_raw, gt_sdf): | |
assert len(pred_occ_raw.shape) == 2 | |
assert len(gt_sdf.shape) == 2 | |
# [B, N] | |
gt_occ = (gt_sdf < 0).float() | |
loss = self.occ_loss(pred_occ_raw, gt_occ) | |
weight_mask = torch.ones_like(loss) | |
thres = self.opt.training.shape_loss.impt_thres | |
weight_mask[torch.abs(gt_sdf) < thres] = weight_mask[torch.abs(gt_sdf) < thres] * self.opt.training.shape_loss.impt_weight | |
loss = loss * weight_mask | |
return loss.mean() | |
def depth_loss(self, pred_depth, gt_depth, mask): | |
assert len(pred_depth.shape) == len(gt_depth.shape) == len(mask.shape) == 4 | |
assert pred_depth.shape[1] == gt_depth.shape[1] == mask.shape[1] == 1 | |
loss = self.midas_loss(pred_depth, gt_depth, mask) | |
return loss | |
def intr_loss(self, seen_pred, seen_gt, mask): | |
assert len(seen_pred.shape) == len(seen_gt.shape) == 3 | |
assert len(mask.shape) == 2 | |
# [B, HW] | |
distance = torch.sum((seen_pred - seen_gt)**2, dim=-1) | |
loss = (distance * mask).sum() / (mask.sum() + 1.e-8) | |
return loss |