Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
raw
history blame
No virus
2.15 kB
import torch
import torch.nn as nn
from omegaconf import OmegaConf
def weight_loss(log_assignment, weights, gamma=0.0):
b, m, n = log_assignment.shape
m -= 1
n -= 1
loss_sc = log_assignment * weights
num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0)
num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0)
num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0)
nll_pos = -loss_sc[:, :m, :n].sum((-1, -2))
nll_pos /= num_pos.clamp(min=1.0)
nll_neg0 = -loss_sc[:, :m, -1].sum(-1)
nll_neg1 = -loss_sc[:, -1, :n].sum(-1)
nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1)
return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0
class NLLLoss(nn.Module):
default_conf = {
"nll_balancing": 0.5,
"gamma_f": 0.0, # focal loss
}
def __init__(self, conf):
super().__init__()
self.conf = OmegaConf.merge(self.default_conf, conf)
self.loss_fn = self.nll_loss
def forward(self, pred, data, weights=None):
log_assignment = pred["log_assignment"]
if weights is None:
weights = self.loss_fn(log_assignment, data)
nll_pos, nll_neg, num_pos, num_neg = weight_loss(
log_assignment, weights, gamma=self.conf.gamma_f
)
nll = (
self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg
)
return (
nll,
weights,
{
"assignment_nll": nll,
"nll_pos": nll_pos,
"nll_neg": nll_neg,
"num_matchable": num_pos,
"num_unmatchable": num_neg,
},
)
def nll_loss(self, log_assignment, data):
m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1)
positive = data["gt_assignment"].float()
neg0 = (data["gt_matches0"] == -1).float()
neg1 = (data["gt_matches1"] == -1).float()
weights = torch.zeros_like(log_assignment)
weights[:, :m, :n] = positive
weights[:, :m, -1] = neg0
weights[:, -1, :n] = neg1
return weights