import torch import torch.nn as nn import torch.nn.functional as F ################################################################### # ########################## iou loss ############################# ################################################################### class IOU(torch.nn.Module): def __init__(self): super(IOU, self).__init__() def _iou(self, pred, target): pred = torch.sigmoid(pred) inter = (pred * target).sum(dim=(2, 3)) union = (pred + target).sum(dim=(2, 3)) - inter iou = 1 - (inter / union) return iou.mean() def forward(self, pred, target): return self._iou(pred, target)