zdou0830's picture
desco
749745d
raw
history blame
No virus
2.82 kB
import torch
from torch import nn
class IOULoss(nn.Module):
def __init__(self, loss_type="iou"):
super(IOULoss, self).__init__()
self.loss_type = loss_type
def forward(self, pred, target, weight=None):
pred_left = pred[:, 0]
pred_top = pred[:, 1]
pred_right = pred[:, 2]
pred_bottom = pred[:, 3]
target_left = target[:, 0]
target_top = target[:, 1]
target_right = target[:, 2]
target_bottom = target[:, 3]
target_area = (target_left + target_right) * (target_top + target_bottom)
pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)
w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right)
g_w_intersect = torch.max(pred_left, target_left) + torch.max(pred_right, target_right)
h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top)
g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top)
ac_uion = g_w_intersect * g_h_intersect + 1e-7
area_intersect = w_intersect * h_intersect
area_union = target_area + pred_area - area_intersect
ious = (area_intersect + 1.0) / (area_union + 1.0)
gious = ious - (ac_uion - area_union) / ac_uion
if self.loss_type == "iou":
losses = -torch.log(ious)
elif self.loss_type == "linear_iou":
losses = 1 - ious
elif self.loss_type == "giou":
losses = 1 - gious
else:
raise NotImplementedError
if weight is not None and weight.sum() > 0:
return (losses * weight).sum()
else:
assert losses.numel() != 0
return losses.sum()
class IOUWHLoss(nn.Module): # used for anchor guiding
def __init__(self, reduction="none"):
super(IOUWHLoss, self).__init__()
self.reduction = reduction
def forward(self, pred, target):
orig_shape = pred.shape
pred = pred.view(-1, 4)
target = target.view(-1, 4)
target[:, :2] = 0
tl = torch.max((target[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
br = torch.min((target[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
U = area_p + area_g - area_i + 1e-16
iou = area_i / U
loss = 1 - iou**2
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss