|
import torch |
|
import torch.nn as nn |
|
from loss.ncc import NCC |
|
|
|
|
|
class Normalize(nn.Module): |
|
def __init__(self): |
|
super(Normalize, self).__init__() |
|
|
|
def forward(self, bottom): |
|
qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12 |
|
top = bottom.div(qn) |
|
|
|
return top |
|
|
|
|
|
class OcclusionColorLoss(nn.Module): |
|
def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]): |
|
super(OcclusionColorLoss, self).__init__() |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.gama = gama |
|
self.occlusion_aware = occlusion_aware |
|
self.eps = 1e-4 |
|
|
|
self.weight_thred = weight_thred |
|
self.adjuster = ParamAdjuster(self.weight_thred, self.beta) |
|
|
|
def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True): |
|
""" |
|
|
|
:param pred: [N_pts, 3] |
|
:param gt: [N_pts, 3] |
|
:param weight: [N_pts] |
|
:param mask: [N_pts] |
|
:return: |
|
""" |
|
if detach: |
|
weight = weight.detach() |
|
|
|
error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) |
|
error = error[mask] |
|
|
|
if not (self.occlusion_aware and occlusion_aware): |
|
return torch.mean(error), torch.mean(error) |
|
|
|
beta = self.adjuster(weight.mean()) |
|
|
|
|
|
weight = weight.clamp(0.0, 1.0) |
|
term1 = self.alpha * torch.mean(weight[mask] * error) |
|
term2 = beta * torch.log(1 - weight + self.eps).mean() |
|
term3 = self.gama * torch.log(weight + self.eps).mean() |
|
|
|
return term1 + term2 + term3, term1 |
|
|
|
|
|
class OcclusionColorPatchLoss(nn.Module): |
|
def __init__(self, alpha=1, beta=0.025, gama=0.015, |
|
occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]): |
|
super(OcclusionColorPatchLoss, self).__init__() |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.gama = gama |
|
self.occlusion_aware = occlusion_aware |
|
self.type = type |
|
self.ncc = NCC(h_patch_size=h_patch_size) |
|
self.eps = 1e-4 |
|
self.weight_thred = weight_thred |
|
|
|
self.adjuster = ParamAdjuster(self.weight_thred, self.beta) |
|
|
|
print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama, |
|
weight_thred)) |
|
|
|
def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True): |
|
""" |
|
|
|
:param pred: [N_pts, Npx, 3] |
|
:param gt: [N_pts, Npx, 3] |
|
:param weight: [N_pts] |
|
:param mask: [N_pts] |
|
:return: |
|
""" |
|
|
|
if detach: |
|
weight = weight.detach() |
|
|
|
if self.type == 'l1': |
|
error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) |
|
elif self.type == 'ncc': |
|
error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] |
|
error, indices = torch.sort(error) |
|
mask = torch.index_select(mask, 0, index=indices) |
|
mask[int(penalize_ratio * mask.shape[0]):] = False |
|
elif self.type == 'ssd': |
|
error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False) |
|
|
|
error = error[mask] |
|
if not (self.occlusion_aware and occlusion_aware): |
|
return torch.mean(error), torch.mean(error), 0. |
|
|
|
|
|
beta = self.adjuster(weight.mean()) |
|
|
|
|
|
weight = weight.clamp(0.0, 1.0) |
|
|
|
term1 = self.alpha * torch.mean(weight[mask] * error) |
|
term2 = beta * torch.log(1 - weight + self.eps).mean() |
|
term3 = self.gama * torch.log(weight + self.eps).mean() |
|
|
|
return term1 + term2 + term3, term1, beta |
|
|
|
|
|
class ParamAdjuster(nn.Module): |
|
def __init__(self, weight_thred, param): |
|
super(ParamAdjuster, self).__init__() |
|
self.weight_thred = weight_thred |
|
self.thred_num = len(weight_thred) |
|
self.param = param |
|
self.global_step = 0 |
|
self.statis_window = 100 |
|
self.counter = 0 |
|
self.adjusted = False |
|
self.adjusted_step = 0 |
|
self.thred_idx = 0 |
|
|
|
def reset(self): |
|
self.counter = 0 |
|
self.adjusted = False |
|
|
|
def adjust(self): |
|
if (self.counter / self.statis_window) > 0.3: |
|
self.param = self.param + 0.005 |
|
self.adjusted = True |
|
self.adjusted_step = self.global_step |
|
self.thred_idx += 1 |
|
print("adjusted param, now {}".format(self.param)) |
|
|
|
def forward(self, weight_mean): |
|
self.global_step += 1 |
|
|
|
if (self.global_step % self.statis_window == 0) and self.adjusted is False: |
|
self.adjust() |
|
self.reset() |
|
|
|
if self.thred_idx < self.thred_num: |
|
if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted): |
|
self.counter += 1 |
|
|
|
return self.param |
|
|