RSPrompter / mmpretrain /models /losses /label_smooth_loss.py
KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
7.17 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from .cross_entropy_loss import CrossEntropyLoss
from .utils import convert_to_one_hot
@MODELS.register_module()
class LabelSmoothLoss(nn.Module):
r"""Initializer for the label smoothed cross entropy loss.
Refers to `Rethinking the Inception Architecture for Computer Vision
<https://arxiv.org/abs/1512.00567>`_
This decreases gap between output scores and encourages generalization.
Labels provided to forward can be one-hot like vectors (NxC) or class
indices (Nx1).
And this accepts linear combination of one-hot like labels from mixup or
cutmix except multi-label task.
Args:
label_smooth_val (float): The degree of label smoothing.
num_classes (int, optional): Number of classes. Defaults to None.
mode (str): Refers to notes, Options are 'original', 'classy_vision',
'multi_label'. Defaults to 'original'.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid of
softmax. Defaults to None, which means to use sigmoid in
"multi_label" mode and not use in other modes.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
Notes:
- if the mode is **"original"**, this will use the same label smooth
method as the original paper as:
.. math::
(1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K}
where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the
``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which
equals 1 for :math:`k=y` and 0 otherwise.
- if the mode is **"classy_vision"**, this will use the same label
smooth method as the facebookresearch/ClassyVision repo as:
.. math::
\frac{\delta_{k, y} + \epsilon/K}{1+\epsilon}
- if the mode is **"multi_label"**, this will accept labels from
multi-label task and smoothing them as:
.. math::
(1-2\epsilon)\delta_{k, y} + \epsilon
"""
def __init__(self,
label_smooth_val,
num_classes=None,
use_sigmoid=None,
mode='original',
reduction='mean',
loss_weight=1.0,
class_weight=None,
pos_weight=None):
super().__init__()
self.num_classes = num_classes
self.loss_weight = loss_weight
assert (isinstance(label_smooth_val, float)
and 0 <= label_smooth_val < 1), \
f'LabelSmoothLoss accepts a float label_smooth_val ' \
f'over [0, 1), but gets {label_smooth_val}'
self.label_smooth_val = label_smooth_val
accept_reduction = {'none', 'mean', 'sum'}
assert reduction in accept_reduction, \
f'LabelSmoothLoss supports reduction {accept_reduction}, ' \
f'but gets {mode}.'
self.reduction = reduction
accept_mode = {'original', 'classy_vision', 'multi_label'}
assert mode in accept_mode, \
f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.'
self.mode = mode
self._eps = label_smooth_val
if mode == 'classy_vision':
self._eps = label_smooth_val / (1 + label_smooth_val)
if mode == 'multi_label':
if not use_sigmoid:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().warning(
'For multi-label tasks, please set `use_sigmoid=True` '
'to use binary cross entropy.')
self.smooth_label = self.multilabel_smooth_label
use_sigmoid = True if use_sigmoid is None else use_sigmoid
else:
self.smooth_label = self.original_smooth_label
use_sigmoid = False if use_sigmoid is None else use_sigmoid
self.ce = CrossEntropyLoss(
use_sigmoid=use_sigmoid,
use_soft=not use_sigmoid,
reduction=reduction,
class_weight=class_weight,
pos_weight=pos_weight)
def generate_one_hot_like_label(self, label):
"""This function takes one-hot or index label vectors and computes one-
hot like label vectors (float)"""
# check if targets are inputted as class integers
if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1):
label = convert_to_one_hot(label.view(-1, 1), self.num_classes)
return label.float()
def original_smooth_label(self, one_hot_like_label):
assert self.num_classes > 0
smooth_label = one_hot_like_label * (1 - self._eps)
smooth_label += self._eps / self.num_classes
return smooth_label
def multilabel_smooth_label(self, one_hot_like_label):
assert self.num_classes > 0
smooth_label = torch.full_like(one_hot_like_label, self._eps)
smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps)
return smooth_label
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
r"""Label smooth loss.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
label (torch.Tensor): The ground truth label of the prediction
with shape (N, \*).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, \*). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss into a scalar. Options are "none", "mean" and "sum".
Defaults to None.
Returns:
torch.Tensor: Loss.
"""
if self.num_classes is not None:
assert self.num_classes == cls_score.shape[1], \
f'num_classes should equal to cls_score.shape[1], ' \
f'but got num_classes: {self.num_classes} and ' \
f'cls_score.shape[1]: {cls_score.shape[1]}'
else:
self.num_classes = cls_score.shape[1]
one_hot_like_label = self.generate_one_hot_like_label(label=label)
assert one_hot_like_label.shape == cls_score.shape, \
f'LabelSmoothLoss requires output and target ' \
f'to be same shape, but got output.shape: {cls_score.shape} ' \
f'and target.shape: {one_hot_like_label.shape}'
smoothed_label = self.smooth_label(one_hot_like_label)
return self.loss_weight * self.ce.forward(
cls_score,
smoothed_label,
weight=weight,
avg_factor=avg_factor,
reduction_override=reduction_override,
**kwargs)