Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmpretrain.registry import MODELS | |
from .utils import weight_reduce_loss | |
def cross_entropy(pred, | |
label, | |
weight=None, | |
reduction='mean', | |
avg_factor=None, | |
class_weight=None): | |
"""Calculate the CrossEntropy loss. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
of classes. | |
label (torch.Tensor): The gt label of the prediction. | |
weight (torch.Tensor, optional): Sample-wise loss weight. | |
reduction (str): The method used to reduce the loss. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (torch.Tensor, optional): The weight for each class with | |
shape (C), C is the number of classes. Default None. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
# element-wise losses | |
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') | |
# apply weights and do the reduction | |
if weight is not None: | |
weight = weight.float() | |
loss = weight_reduce_loss( | |
loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
def soft_cross_entropy(pred, | |
label, | |
weight=None, | |
reduction='mean', | |
class_weight=None, | |
avg_factor=None): | |
"""Calculate the Soft CrossEntropy loss. The label can be float. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
of classes. | |
label (torch.Tensor): The gt label of the prediction with shape (N, C). | |
When using "mixup", the label can be float. | |
weight (torch.Tensor, optional): Sample-wise loss weight. | |
reduction (str): The method used to reduce the loss. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (torch.Tensor, optional): The weight for each class with | |
shape (C), C is the number of classes. Default None. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
# element-wise losses | |
loss = -label * F.log_softmax(pred, dim=-1) | |
if class_weight is not None: | |
loss *= class_weight | |
loss = loss.sum(dim=-1) | |
# apply weights and do the reduction | |
if weight is not None: | |
weight = weight.float() | |
loss = weight_reduce_loss( | |
loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
def binary_cross_entropy(pred, | |
label, | |
weight=None, | |
reduction='mean', | |
avg_factor=None, | |
class_weight=None, | |
pos_weight=None): | |
r"""Calculate the binary CrossEntropy loss with logits. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, \*). | |
label (torch.Tensor): The gt label with shape (N, \*). | |
weight (torch.Tensor, optional): Element-wise weight of loss with shape | |
(N, ). Defaults to None. | |
reduction (str): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". If reduction is 'none' , loss | |
is same shape as pred and label. Defaults to 'mean'. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (torch.Tensor, optional): The weight for each class with | |
shape (C), C is the number of classes. Default None. | |
pos_weight (torch.Tensor, optional): The positive weight for each | |
class with shape (C), C is the number of classes. Default None. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
# Ensure that the size of class_weight is consistent with pred and label to | |
# avoid automatic boracast, | |
assert pred.dim() == label.dim() | |
if class_weight is not None: | |
N = pred.size()[0] | |
class_weight = class_weight.repeat(N, 1) | |
loss = F.binary_cross_entropy_with_logits( | |
pred, | |
label.float(), # only accepts float type tensor | |
weight=class_weight, | |
pos_weight=pos_weight, | |
reduction='none') | |
# apply weights and do the reduction | |
if weight is not None: | |
assert weight.dim() == 1 | |
weight = weight.float() | |
if pred.dim() > 1: | |
weight = weight.reshape(-1, 1) | |
loss = weight_reduce_loss( | |
loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
class CrossEntropyLoss(nn.Module): | |
"""Cross entropy loss. | |
Args: | |
use_sigmoid (bool): Whether the prediction uses sigmoid | |
of softmax. Defaults to False. | |
use_soft (bool): Whether to use the soft version of CrossEntropyLoss. | |
Defaults to False. | |
reduction (str): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". Defaults to 'mean'. | |
loss_weight (float): Weight of the loss. Defaults to 1.0. | |
class_weight (List[float], optional): The weight for each class with | |
shape (C), C is the number of classes. Default None. | |
pos_weight (List[float], optional): The positive weight for each | |
class with shape (C), C is the number of classes. Only enabled in | |
BCE loss when ``use_sigmoid`` is True. Default None. | |
""" | |
def __init__(self, | |
use_sigmoid=False, | |
use_soft=False, | |
reduction='mean', | |
loss_weight=1.0, | |
class_weight=None, | |
pos_weight=None): | |
super(CrossEntropyLoss, self).__init__() | |
self.use_sigmoid = use_sigmoid | |
self.use_soft = use_soft | |
assert not ( | |
self.use_soft and self.use_sigmoid | |
), 'use_sigmoid and use_soft could not be set simultaneously' | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.class_weight = class_weight | |
self.pos_weight = pos_weight | |
if self.use_sigmoid: | |
self.cls_criterion = binary_cross_entropy | |
elif self.use_soft: | |
self.cls_criterion = soft_cross_entropy | |
else: | |
self.cls_criterion = cross_entropy | |
def forward(self, | |
cls_score, | |
label, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None, | |
**kwargs): | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
if self.class_weight is not None: | |
class_weight = cls_score.new_tensor(self.class_weight) | |
else: | |
class_weight = None | |
# only BCE loss has pos_weight | |
if self.pos_weight is not None and self.use_sigmoid: | |
pos_weight = cls_score.new_tensor(self.pos_weight) | |
kwargs.update({'pos_weight': pos_weight}) | |
else: | |
pos_weight = None | |
loss_cls = self.loss_weight * self.cls_criterion( | |
cls_score, | |
label, | |
weight, | |
class_weight=class_weight, | |
reduction=reduction, | |
avg_factor=avg_factor, | |
**kwargs) | |
return loss_cls | |