RSPrompter / mmdet /models /losses /varifocal_loss.py
KyanChen's picture
Upload 787 files
3e06e1c
raw
history blame
5.75 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmdet.registry import MODELS
from .utils import weight_reduce_loss
def varifocal_loss(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
alpha: float = 0.75,
gamma: float = 2.0,
iou_weighted: bool = True,
reduction: str = 'mean',
avg_factor: Optional[int] = None) -> Tensor:
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
pred (Tensor): The prediction with shape (N, C), C is the
number of classes.
target (Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal Loss.
Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive example with the iou target. Defaults to True.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
Tensor: Loss tensor.
"""
# pred and target should be of the same size
assert pred.size() == target.size()
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
if iou_weighted:
focal_weight = target * (target > 0.0).float() + \
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
(target <= 0.0).float()
else:
focal_weight = (target > 0.0).float() + \
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
(target <= 0.0).float()
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class VarifocalLoss(nn.Module):
def __init__(self,
use_sigmoid: bool = True,
alpha: float = 0.75,
gamma: float = 2.0,
iou_weighted: bool = True,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
use_sigmoid (bool, optional): Whether the prediction is
used for sigmoid or softmax. Defaults to True.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal
Loss. Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive examples with the iou target. Defaults to True.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
super().__init__()
assert use_sigmoid is True, \
'Only sigmoid varifocal loss supported now.'
assert alpha >= 0.0
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.iou_weighted = iou_weighted
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = None) -> Tensor:
"""Forward function.
Args:
pred (Tensor): The prediction with shape (N, C), C is the
number of classes.
target (Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is
the number of classes.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
loss_cls = self.loss_weight * varifocal_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
iou_weighted=self.iou_weighted,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls