Spaces:
Runtime error
Runtime error
File size: 7,173 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from .cross_entropy_loss import CrossEntropyLoss
from .utils import convert_to_one_hot
@MODELS.register_module()
class LabelSmoothLoss(nn.Module):
r"""Initializer for the label smoothed cross entropy loss.
Refers to `Rethinking the Inception Architecture for Computer Vision
<https://arxiv.org/abs/1512.00567>`_
This decreases gap between output scores and encourages generalization.
Labels provided to forward can be one-hot like vectors (NxC) or class
indices (Nx1).
And this accepts linear combination of one-hot like labels from mixup or
cutmix except multi-label task.
Args:
label_smooth_val (float): The degree of label smoothing.
num_classes (int, optional): Number of classes. Defaults to None.
mode (str): Refers to notes, Options are 'original', 'classy_vision',
'multi_label'. Defaults to 'original'.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid of
softmax. Defaults to None, which means to use sigmoid in
"multi_label" mode and not use in other modes.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
Notes:
- if the mode is **"original"**, this will use the same label smooth
method as the original paper as:
.. math::
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the
``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which
equals 1 for :math:`k=y` and 0 otherwise.
- if the mode is **"classy_vision"**, this will use the same label
smooth method as the facebookresearch/ClassyVision repo as:
.. math::
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
- if the mode is **"multi_label"**, this will accept labels from
multi-label task and smoothing them as:
.. math::
(1-2\epsilon)\delta_{k, y} + \epsilon
"""
def __init__(self,
label_smooth_val,
num_classes=None,
use_sigmoid=None,
mode='original',
reduction='mean',
loss_weight=1.0,
class_weight=None,
pos_weight=None):
super().__init__()
self.num_classes = num_classes
self.loss_weight = loss_weight
assert (isinstance(label_smooth_val, float)
and 0 <= label_smooth_val < 1), \
f'LabelSmoothLoss accepts a float label_smooth_val ' \
f'over [0, 1), but gets {label_smooth_val}'
self.label_smooth_val = label_smooth_val
accept_reduction = {'none', 'mean', 'sum'}
assert reduction in accept_reduction, \
f'LabelSmoothLoss supports reduction {accept_reduction}, ' \
f'but gets {mode}.'
self.reduction = reduction
accept_mode = {'original', 'classy_vision', 'multi_label'}
assert mode in accept_mode, \
f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.'
self.mode = mode
self._eps = label_smooth_val
if mode == 'classy_vision':
self._eps = label_smooth_val / (1 + label_smooth_val)
if mode == 'multi_label':
if not use_sigmoid:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().warning(
'For multi-label tasks, please set `use_sigmoid=True` '
'to use binary cross entropy.')
self.smooth_label = self.multilabel_smooth_label
use_sigmoid = True if use_sigmoid is None else use_sigmoid
else:
self.smooth_label = self.original_smooth_label
use_sigmoid = False if use_sigmoid is None else use_sigmoid
self.ce = CrossEntropyLoss(
use_sigmoid=use_sigmoid,
use_soft=not use_sigmoid,
reduction=reduction,
class_weight=class_weight,
pos_weight=pos_weight)
def generate_one_hot_like_label(self, label):
"""This function takes one-hot or index label vectors and computes one-
hot like label vectors (float)"""
# check if targets are inputted as class integers
if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1):
label = convert_to_one_hot(label.view(-1, 1), self.num_classes)
return label.float()
def original_smooth_label(self, one_hot_like_label):
assert self.num_classes > 0
smooth_label = one_hot_like_label * (1 - self._eps)
smooth_label += self._eps / self.num_classes
return smooth_label
def multilabel_smooth_label(self, one_hot_like_label):
assert self.num_classes > 0
smooth_label = torch.full_like(one_hot_like_label, self._eps)
smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps)
return smooth_label
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
r"""Label smooth loss.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
label (torch.Tensor): The ground truth label of the prediction
with shape (N, \*).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, \*). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss into a scalar. Options are "none", "mean" and "sum".
Defaults to None.
Returns:
torch.Tensor: Loss.
"""
if self.num_classes is not None:
assert self.num_classes == cls_score.shape[1], \
f'num_classes should equal to cls_score.shape[1], ' \
f'but got num_classes: {self.num_classes} and ' \
f'cls_score.shape[1]: {cls_score.shape[1]}'
else:
self.num_classes = cls_score.shape[1]
one_hot_like_label = self.generate_one_hot_like_label(label=label)
assert one_hot_like_label.shape == cls_score.shape, \
f'LabelSmoothLoss requires output and target ' \
f'to be same shape, but got output.shape: {cls_score.shape} ' \
f'and target.shape: {one_hot_like_label.shape}'
smoothed_label = self.smooth_label(one_hot_like_label)
return self.loss_weight * self.ce.forward(
cls_score,
smoothed_label,
weight=weight,
avg_factor=avg_factor,
reduction_override=reduction_override,
**kwargs)
|