Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmpretrain.registry import MODELS | |
from .cross_entropy_loss import CrossEntropyLoss | |
from .utils import convert_to_one_hot | |
class LabelSmoothLoss(nn.Module): | |
r"""Initializer for the label smoothed cross entropy loss. | |
Refers to `Rethinking the Inception Architecture for Computer Vision | |
<https://arxiv.org/abs/1512.00567>`_ | |
This decreases gap between output scores and encourages generalization. | |
Labels provided to forward can be one-hot like vectors (NxC) or class | |
indices (Nx1). | |
And this accepts linear combination of one-hot like labels from mixup or | |
cutmix except multi-label task. | |
Args: | |
label_smooth_val (float): The degree of label smoothing. | |
num_classes (int, optional): Number of classes. Defaults to None. | |
mode (str): Refers to notes, Options are 'original', 'classy_vision', | |
'multi_label'. Defaults to 'original'. | |
use_sigmoid (bool, optional): Whether the prediction uses sigmoid of | |
softmax. Defaults to None, which means to use sigmoid in | |
"multi_label" mode and not use in other modes. | |
reduction (str): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". Defaults to 'mean'. | |
loss_weight (float): Weight of the loss. Defaults to 1.0. | |
Notes: | |
- if the mode is **"original"**, this will use the same label smooth | |
method as the original paper as: | |
.. math:: | |
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} | |
where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the | |
``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which | |
equals 1 for :math:`k=y` and 0 otherwise. | |
- if the mode is **"classy_vision"**, this will use the same label | |
smooth method as the facebookresearch/ClassyVision repo as: | |
.. math:: | |
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} | |
- if the mode is **"multi_label"**, this will accept labels from | |
multi-label task and smoothing them as: | |
.. math:: | |
(1-2\epsilon)\delta_{k, y} + \epsilon | |
""" | |
def __init__(self, | |
label_smooth_val, | |
num_classes=None, | |
use_sigmoid=None, | |
mode='original', | |
reduction='mean', | |
loss_weight=1.0, | |
class_weight=None, | |
pos_weight=None): | |
super().__init__() | |
self.num_classes = num_classes | |
self.loss_weight = loss_weight | |
assert (isinstance(label_smooth_val, float) | |
and 0 <= label_smooth_val < 1), \ | |
f'LabelSmoothLoss accepts a float label_smooth_val ' \ | |
f'over [0, 1), but gets {label_smooth_val}' | |
self.label_smooth_val = label_smooth_val | |
accept_reduction = {'none', 'mean', 'sum'} | |
assert reduction in accept_reduction, \ | |
f'LabelSmoothLoss supports reduction {accept_reduction}, ' \ | |
f'but gets {mode}.' | |
self.reduction = reduction | |
accept_mode = {'original', 'classy_vision', 'multi_label'} | |
assert mode in accept_mode, \ | |
f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.' | |
self.mode = mode | |
self._eps = label_smooth_val | |
if mode == 'classy_vision': | |
self._eps = label_smooth_val / (1 + label_smooth_val) | |
if mode == 'multi_label': | |
if not use_sigmoid: | |
from mmengine.logging import MMLogger | |
MMLogger.get_current_instance().warning( | |
'For multi-label tasks, please set `use_sigmoid=True` ' | |
'to use binary cross entropy.') | |
self.smooth_label = self.multilabel_smooth_label | |
use_sigmoid = True if use_sigmoid is None else use_sigmoid | |
else: | |
self.smooth_label = self.original_smooth_label | |
use_sigmoid = False if use_sigmoid is None else use_sigmoid | |
self.ce = CrossEntropyLoss( | |
use_sigmoid=use_sigmoid, | |
use_soft=not use_sigmoid, | |
reduction=reduction, | |
class_weight=class_weight, | |
pos_weight=pos_weight) | |
def generate_one_hot_like_label(self, label): | |
"""This function takes one-hot or index label vectors and computes one- | |
hot like label vectors (float)""" | |
# check if targets are inputted as class integers | |
if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1): | |
label = convert_to_one_hot(label.view(-1, 1), self.num_classes) | |
return label.float() | |
def original_smooth_label(self, one_hot_like_label): | |
assert self.num_classes > 0 | |
smooth_label = one_hot_like_label * (1 - self._eps) | |
smooth_label += self._eps / self.num_classes | |
return smooth_label | |
def multilabel_smooth_label(self, one_hot_like_label): | |
assert self.num_classes > 0 | |
smooth_label = torch.full_like(one_hot_like_label, self._eps) | |
smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps) | |
return smooth_label | |
def forward(self, | |
cls_score, | |
label, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None, | |
**kwargs): | |
r"""Label smooth loss. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, \*). | |
label (torch.Tensor): The ground truth label of the prediction | |
with shape (N, \*). | |
weight (torch.Tensor, optional): Sample-wise loss weight with shape | |
(N, \*). Defaults to None. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction_override (str, optional): The method used to reduce the | |
loss into a scalar. Options are "none", "mean" and "sum". | |
Defaults to None. | |
Returns: | |
torch.Tensor: Loss. | |
""" | |
if self.num_classes is not None: | |
assert self.num_classes == cls_score.shape[1], \ | |
f'num_classes should equal to cls_score.shape[1], ' \ | |
f'but got num_classes: {self.num_classes} and ' \ | |
f'cls_score.shape[1]: {cls_score.shape[1]}' | |
else: | |
self.num_classes = cls_score.shape[1] | |
one_hot_like_label = self.generate_one_hot_like_label(label=label) | |
assert one_hot_like_label.shape == cls_score.shape, \ | |
f'LabelSmoothLoss requires output and target ' \ | |
f'to be same shape, but got output.shape: {cls_score.shape} ' \ | |
f'and target.shape: {one_hot_like_label.shape}' | |
smoothed_label = self.smooth_label(one_hot_like_label) | |
return self.loss_weight * self.ce.forward( | |
cls_score, | |
smoothed_label, | |
weight=weight, | |
avg_factor=avg_factor, | |
reduction_override=reduction_override, | |
**kwargs) | |