File size: 4,092 Bytes
864ec44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Author: Bingxin Ke
# Last modified: 2024-02-22
import torch
def get_loss(loss_name, **kwargs):
if "silog_mse" == loss_name:
criterion = SILogMSELoss(**kwargs)
elif "silog_rmse" == loss_name:
criterion = SILogRMSELoss(**kwargs)
elif "mse_loss" == loss_name:
criterion = torch.nn.MSELoss(**kwargs)
elif "l1_loss" == loss_name:
criterion = torch.nn.L1Loss(**kwargs)
elif "l1_loss_with_mask" == loss_name:
criterion = L1LossWithMask(**kwargs)
elif "mean_abs_rel" == loss_name:
criterion = MeanAbsRelLoss()
else:
raise NotImplementedError
return criterion
class L1LossWithMask:
def __init__(self, batch_reduction=False):
self.batch_reduction = batch_reduction
def __call__(self, depth_pred, depth_gt, valid_mask=None):
diff = depth_pred - depth_gt
if valid_mask is not None:
diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = depth_gt.shape[-2] * depth_gt.shape[-1]
loss = torch.sum(torch.abs(diff)) / n
if self.batch_reduction:
loss = loss.mean()
return loss
class MeanAbsRelLoss:
def __init__(self) -> None:
# super().__init__()
pass
def __call__(self, pred, gt):
diff = pred - gt
rel_abs = torch.abs(diff / gt)
loss = torch.mean(rel_abs, dim=0)
return loss
class SILogMSELoss:
def __init__(self, lamb, log_pred=True, batch_reduction=True):
"""Scale Invariant Log MSE Loss
Args:
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred
"""
super(SILogMSELoss, self).__init__()
self.lamb = lamb
self.pred_in_log = log_pred
self.batch_reduction = batch_reduction
def __call__(self, depth_pred, depth_gt, valid_mask=None):
log_depth_pred = (
depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8))
)
log_depth_gt = torch.log(depth_gt)
diff = log_depth_pred - log_depth_gt
if valid_mask is not None:
diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = depth_gt.shape[-2] * depth_gt.shape[-1]
diff2 = torch.pow(diff, 2)
first_term = torch.sum(diff2, (-1, -2)) / n
second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
loss = first_term - second_term
if self.batch_reduction:
loss = loss.mean()
return loss
class SILogRMSELoss:
def __init__(self, lamb, alpha, log_pred=True):
"""Scale Invariant Log RMSE Loss
Args:
lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss
alpha:
log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred
"""
super(SILogRMSELoss, self).__init__()
self.lamb = lamb
self.alpha = alpha
self.pred_in_log = log_pred
def __call__(self, depth_pred, depth_gt, valid_mask):
log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred)
log_depth_gt = torch.log(depth_gt)
# borrowed from https://github.com/aliyun/NeWCRFs
# diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask]
# return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha
diff = log_depth_pred - log_depth_gt
if valid_mask is not None:
diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = depth_gt.shape[-2] * depth_gt.shape[-1]
diff2 = torch.pow(diff, 2)
first_term = torch.sum(diff2, (-1, -2)) / n
second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
loss = torch.sqrt(first_term - second_term).mean() * self.alpha
return loss
|