Spaces:
Runtime error
Runtime error
File size: 7,689 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch.nn as nn
from torch import Tensor
from mmdet.registry import MODELS
from .utils import weight_reduce_loss, weighted_loss
@weighted_loss
def gaussian_focal_loss(pred: Tensor,
gaussian_target: Tensor,
alpha: float = 2.0,
gamma: float = 4.0,
pos_weight: float = 1.0,
neg_weight: float = 1.0) -> Tensor:
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
distribution.
Args:
pred (torch.Tensor): The prediction.
gaussian_target (torch.Tensor): The learning target of the prediction
in gaussian distribution.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 2.0.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 4.0.
pos_weight(float): Positive sample loss weight. Defaults to 1.0.
neg_weight(float): Negative sample loss weight. Defaults to 1.0.
"""
eps = 1e-12
pos_weights = gaussian_target.eq(1)
neg_weights = (1 - gaussian_target).pow(gamma)
pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
return pos_weight * pos_loss + neg_weight * neg_loss
def gaussian_focal_loss_with_pos_inds(
pred: Tensor,
gaussian_target: Tensor,
pos_inds: Tensor,
pos_labels: Tensor,
alpha: float = 2.0,
gamma: float = 4.0,
pos_weight: float = 1.0,
neg_weight: float = 1.0,
reduction: str = 'mean',
avg_factor: Optional[Union[int, float]] = None) -> Tensor:
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
distribution.
Note: The index with a value of 1 in ``gaussian_target`` in the
``gaussian_focal_loss`` function is a positive sample, but in
``gaussian_focal_loss_with_pos_inds`` the positive sample is passed
in through the ``pos_inds`` parameter.
Args:
pred (torch.Tensor): The prediction. The shape is (N, num_classes).
gaussian_target (torch.Tensor): The learning target of the prediction
in gaussian distribution. The shape is (N, num_classes).
pos_inds (torch.Tensor): The positive sample index.
The shape is (M, ).
pos_labels (torch.Tensor): The label corresponding to the positive
sample index. The shape is (M, ).
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 2.0.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 4.0.
pos_weight(float): Positive sample loss weight. Defaults to 1.0.
neg_weight(float): Negative sample loss weight. Defaults to 1.0.
reduction (str): Options are "none", "mean" and "sum".
Defaults to 'mean`.
avg_factor (int, float, optional): Average factor that is used to
average the loss. Defaults to None.
"""
eps = 1e-12
neg_weights = (1 - gaussian_target).pow(gamma)
pos_pred_pix = pred[pos_inds]
pos_pred = pos_pred_pix.gather(1, pos_labels.unsqueeze(1))
pos_loss = -(pos_pred + eps).log() * (1 - pos_pred).pow(alpha)
pos_loss = weight_reduce_loss(pos_loss, None, reduction, avg_factor)
neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
neg_loss = weight_reduce_loss(neg_loss, None, reduction, avg_factor)
return pos_weight * pos_loss + neg_weight * neg_loss
@MODELS.register_module()
class GaussianFocalLoss(nn.Module):
"""GaussianFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Code is modified from `kp_utils.py
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
not 0/1 binary target.
Args:
alpha (float): Power of prediction.
gamma (float): Power of target for negative samples.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
pos_weight(float): Positive sample loss weight. Defaults to 1.0.
neg_weight(float): Negative sample loss weight. Defaults to 1.0.
"""
def __init__(self,
alpha: float = 2.0,
gamma: float = 4.0,
reduction: str = 'mean',
loss_weight: float = 1.0,
pos_weight: float = 1.0,
neg_weight: float = 1.0) -> None:
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.loss_weight = loss_weight
self.pos_weight = pos_weight
self.neg_weight = neg_weight
def forward(self,
pred: Tensor,
target: Tensor,
pos_inds: Optional[Tensor] = None,
pos_labels: Optional[Tensor] = None,
weight: Optional[Tensor] = None,
avg_factor: Optional[Union[int, float]] = None,
reduction_override: Optional[str] = None) -> Tensor:
"""Forward function.
If you want to manually determine which positions are
positive samples, you can set the pos_index and pos_label
parameter. Currently, only the CenterNet update version uses
the parameter.
Args:
pred (torch.Tensor): The prediction. The shape is (N, num_classes).
target (torch.Tensor): The learning target of the prediction
in gaussian distribution. The shape is (N, num_classes).
pos_inds (torch.Tensor): The positive sample index.
Defaults to None.
pos_labels (torch.Tensor): The label corresponding to the positive
sample index. Defaults to None.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, float, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if pos_inds is not None:
assert pos_labels is not None
# Only used by centernet update version
loss_reg = self.loss_weight * gaussian_focal_loss_with_pos_inds(
pred,
target,
pos_inds,
pos_labels,
alpha=self.alpha,
gamma=self.gamma,
pos_weight=self.pos_weight,
neg_weight=self.neg_weight,
reduction=reduction,
avg_factor=avg_factor)
else:
loss_reg = self.loss_weight * gaussian_focal_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
pos_weight=self.pos_weight,
neg_weight=self.neg_weight,
reduction=reduction,
avg_factor=avg_factor)
return loss_reg
|