import torch from torch import nn class IOULoss(nn.Module): def __init__(self, loss_type="iou"): super(IOULoss, self).__init__() self.loss_type = loss_type def forward(self, pred, target, weight=None): pred_left = pred[:, 0] pred_top = pred[:, 1] pred_right = pred[:, 2] pred_bottom = pred[:, 3] target_left = target[:, 0] target_top = target[:, 1] target_right = target[:, 2] target_bottom = target[:, 3] target_area = (target_left + target_right) * (target_top + target_bottom) pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) g_w_intersect = torch.max(pred_left, target_left) + torch.max(pred_right, target_right) h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) ac_uion = g_w_intersect * g_h_intersect + 1e-7 area_intersect = w_intersect * h_intersect area_union = target_area + pred_area - area_intersect ious = (area_intersect + 1.0) / (area_union + 1.0) gious = ious - (ac_uion - area_union) / ac_uion if self.loss_type == "iou": losses = -torch.log(ious) elif self.loss_type == "linear_iou": losses = 1 - ious elif self.loss_type == "giou": losses = 1 - gious else: raise NotImplementedError if weight is not None and weight.sum() > 0: return (losses * weight).sum() else: assert losses.numel() != 0 return losses.sum() class IOUWHLoss(nn.Module): # used for anchor guiding def __init__(self, reduction="none"): super(IOUWHLoss, self).__init__() self.reduction = reduction def forward(self, pred, target): orig_shape = pred.shape pred = pred.view(-1, 4) target = target.view(-1, 4) target[:, :2] = 0 tl = torch.max((target[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)) br = torch.min((target[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)) area_p = torch.prod(pred[:, 2:], 1) area_g = torch.prod(target[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=1) area_i = torch.prod(br - tl, 1) * en U = area_p + area_g - area_i + 1e-16 iou = area_i / U loss = 1 - iou**2 if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss