mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
# ------------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import weight_reduce_loss
def sigmoid_focal_loss(
preds,
targets,
weight=None,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "mean",
avg_factor: int = None,
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
preds (torch.Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets (torch.Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float, optional): Weighting factor in range (0, 1) to balance
positive vs negative examples. Default: 0.25.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: 2.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
avg_factor (int): Average factor that is used to average
the loss. Default: None.
Returns:
torch.Tensor: The computed sigmoid focal loss with the reduction option applied.
"""
preds = preds.float()
targets = targets.float()
p = torch.sigmoid(preds)
ce_loss = F.binary_cross_entropy_with_logits(preds, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if weight is not None:
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def focal_loss_with_prob(
preds,
targets,
weight=None,
alpha=0.25,
gamma=2.0,
reduction="mean",
avg_factor=None,
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `sigmoid_focal_loss`, this function accepts probability
as input.
Args:
preds (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
targets (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes = preds.size(1)
targets = F.one_hot(targets, num_classes=num_classes + 1)
targets = targets[:, :num_classes]
targets = targets.type_as(preds)
p_t = preds * targets + (1 - preds) * (1 - targets)
ce_loss = F.binary_cross_entropy(preds, targets, reduction="none")
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if weight is not None:
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
class FocalLoss(nn.Module):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
alpha=0.25,
gamma=2.0,
reduction="mean",
loss_weight=1.0,
activated=False,
):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.loss_weight = loss_weight
self.activated = activated
def forward(
self,
preds,
targets,
weight=None,
avg_factor=None,
):
"""Forward function for FocalLoss
Args:
preds (torch.Tensor): The prediction probability with shape ``(N, C)``.
C is the number of classes.
targets (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
if self.activated:
loss_func = focal_loss_with_prob
else:
num_classes = preds.size(1)
targets = F.one_hot(targets, num_classes=num_classes + 1)
targets = targets[:, :num_classes]
loss_func = sigmoid_focal_loss
loss_class = self.loss_weight * loss_func(
preds,
targets,
weight,
alpha=self.alpha,
gamma=self.gamma,
reduction=self.reduction,
avg_factor=avg_factor,
)
return loss_class