code / SparseNeuS_demo_v1 /loss /color_loss.py
Chao Xu
Update README and elev est setup
a86e755
import torch
import torch.nn as nn
from loss.ncc import NCC
class Normalize(nn.Module):
def __init__(self):
super(Normalize, self).__init__()
def forward(self, bottom):
qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12
top = bottom.div(qn)
return top
class OcclusionColorLoss(nn.Module):
def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]):
super(OcclusionColorLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gama = gama
self.occlusion_aware = occlusion_aware
self.eps = 1e-4
self.weight_thred = weight_thred
self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True):
"""
:param pred: [N_pts, 3]
:param gt: [N_pts, 3]
:param weight: [N_pts]
:param mask: [N_pts]
:return:
"""
if detach:
weight = weight.detach()
error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) # [N_pts]
error = error[mask]
if not (self.occlusion_aware and occlusion_aware):
return torch.mean(error), torch.mean(error)
beta = self.adjuster(weight.mean())
# weight = weight[mask]
weight = weight.clamp(0.0, 1.0)
term1 = self.alpha * torch.mean(weight[mask] * error)
term2 = beta * torch.log(1 - weight + self.eps).mean()
term3 = self.gama * torch.log(weight + self.eps).mean()
return term1 + term2 + term3, term1
class OcclusionColorPatchLoss(nn.Module):
def __init__(self, alpha=1, beta=0.025, gama=0.015,
occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]):
super(OcclusionColorPatchLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gama = gama
self.occlusion_aware = occlusion_aware
self.type = type # 'l1' or 'ncc' loss
self.ncc = NCC(h_patch_size=h_patch_size)
self.eps = 1e-4
self.weight_thred = weight_thred
self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama,
weight_thred))
def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True):
"""
:param pred: [N_pts, Npx, 3]
:param gt: [N_pts, Npx, 3]
:param weight: [N_pts]
:param mask: [N_pts]
:return:
"""
if detach:
weight = weight.detach()
if self.type == 'l1':
error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) # [N_pts]
elif self.type == 'ncc':
error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] # ncc 1 positive, -1 negative
error, indices = torch.sort(error)
mask = torch.index_select(mask, 0, index=indices)
mask[int(penalize_ratio * mask.shape[0]):] = False # can help boundaries
elif self.type == 'ssd':
error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False)
error = error[mask]
if not (self.occlusion_aware and occlusion_aware):
return torch.mean(error), torch.mean(error), 0.
# * weight adjuster
beta = self.adjuster(weight.mean())
# weight = weight[mask]
weight = weight.clamp(0.0, 1.0)
term1 = self.alpha * torch.mean(weight[mask] * error)
term2 = beta * torch.log(1 - weight + self.eps).mean()
term3 = self.gama * torch.log(weight + self.eps).mean()
return term1 + term2 + term3, term1, beta
class ParamAdjuster(nn.Module):
def __init__(self, weight_thred, param):
super(ParamAdjuster, self).__init__()
self.weight_thred = weight_thred
self.thred_num = len(weight_thred)
self.param = param
self.global_step = 0
self.statis_window = 100
self.counter = 0
self.adjusted = False
self.adjusted_step = 0
self.thred_idx = 0
def reset(self):
self.counter = 0
self.adjusted = False
def adjust(self):
if (self.counter / self.statis_window) > 0.3:
self.param = self.param + 0.005
self.adjusted = True
self.adjusted_step = self.global_step
self.thred_idx += 1
print("adjusted param, now {}".format(self.param))
def forward(self, weight_mean):
self.global_step += 1
if (self.global_step % self.statis_window == 0) and self.adjusted is False:
self.adjust()
self.reset()
if self.thred_idx < self.thred_num:
if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted):
self.counter += 1
return self.param