Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------------------------------ | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Copyright (c) OpenMMLab. All rights reserved. | |
# ------------------------------------------------------------------------------------------------ | |
# Modified from: | |
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/giou_loss.py | |
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/iou_loss.py | |
# ------------------------------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
from .utils import weight_reduce_loss | |
def giou_loss( | |
preds: torch.Tensor, | |
targets: torch.Tensor, | |
weight=None, | |
eps: float = 1e-6, | |
reduction: str = "mean", | |
avg_factor: int = None, | |
): | |
r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding | |
Box Regression <https://arxiv.org/abs/1902.09630>`_. | |
Args: | |
preds (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2), | |
shape (n, 4). | |
targets (torch.Tensor): Corresponding gt bboxes, shape (n, 4). | |
eps (float): Eps to avoid log(0). | |
Return: | |
Tensor: Loss tensor. | |
""" | |
if targets.numel() == 0: | |
return preds.sum() * 0 | |
x1, y1, x2, y2 = preds.unbind(dim=-1) | |
x1g, y1g, x2g, y2g = targets.unbind(dim=-1) | |
assert (x2 >= x1).all(), "bad box: x1 larger than x2" | |
assert (y2 >= y1).all(), "bad box: y1 larger than y2" | |
# Intersection keypoints | |
xkis1 = torch.max(x1, x1g) | |
ykis1 = torch.max(y1, y1g) | |
xkis2 = torch.min(x2, x2g) | |
ykis2 = torch.min(y2, y2g) | |
intsctk = torch.zeros_like(x1) | |
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | |
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | |
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk | |
iouk = intsctk / (unionk + eps) | |
# smallest enclosing box | |
xc1 = torch.min(x1, x1g) | |
yc1 = torch.min(y1, y1g) | |
xc2 = torch.max(x2, x2g) | |
yc2 = torch.max(y2, y2g) | |
area_c = (xc2 - xc1) * (yc2 - yc1) | |
miouk = iouk - ((area_c - unionk) / (area_c + eps)) | |
loss = 1 - miouk | |
if weight is not None: | |
assert weight.ndim == loss.ndim | |
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
return loss | |
class GIoULoss(nn.Module): | |
def __init__( | |
self, | |
eps: float = 1e-6, | |
reduction: str = "mean", | |
loss_weight: float = 1.0, | |
): | |
super(GIoULoss, self).__init__() | |
self.eps = eps | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward( | |
self, | |
preds, | |
targets, | |
weight=None, | |
avg_factor=None, | |
): | |
loss_giou = self.loss_weight * giou_loss( | |
preds, targets, weight, eps=self.eps, reduction=self.reduction, avg_factor=avg_factor | |
) | |
return loss_giou | |