Spaces:
Runtime error
Runtime error
File size: 7,602 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmpretrain.registry import MODELS
from .utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The gt label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def soft_cross_entropy(pred,
label,
weight=None,
reduction='mean',
class_weight=None,
avg_factor=None):
"""Calculate the Soft CrossEntropy loss. The label can be float.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The gt label of the prediction with shape (N, C).
When using "mixup", the label can be float.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = -label * F.log_softmax(pred, dim=-1)
if class_weight is not None:
loss *= class_weight
loss = loss.sum(dim=-1)
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
pos_weight=None):
r"""Calculate the binary CrossEntropy loss with logits.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
label (torch.Tensor): The gt label with shape (N, \*).
weight (torch.Tensor, optional): Element-wise weight of loss with shape
(N, ). Defaults to None.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
pos_weight (torch.Tensor, optional): The positive weight for each
class with shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# Ensure that the size of class_weight is consistent with pred and label to
# avoid automatic boracast,
assert pred.dim() == label.dim()
if class_weight is not None:
N = pred.size()[0]
class_weight = class_weight.repeat(N, 1)
loss = F.binary_cross_entropy_with_logits(
pred,
label.float(), # only accepts float type tensor
weight=class_weight,
pos_weight=pos_weight,
reduction='none')
# apply weights and do the reduction
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_soft (bool): Whether to use the soft version of CrossEntropyLoss.
Defaults to False.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
class_weight (List[float], optional): The weight for each class with
shape (C), C is the number of classes. Default None.
pos_weight (List[float], optional): The positive weight for each
class with shape (C), C is the number of classes. Only enabled in
BCE loss when ``use_sigmoid`` is True. Default None.
"""
def __init__(self,
use_sigmoid=False,
use_soft=False,
reduction='mean',
loss_weight=1.0,
class_weight=None,
pos_weight=None):
super(CrossEntropyLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.use_soft = use_soft
assert not (
self.use_soft and self.use_sigmoid
), 'use_sigmoid and use_soft could not be set simultaneously'
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.pos_weight = pos_weight
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_soft:
self.cls_criterion = soft_cross_entropy
else:
self.cls_criterion = cross_entropy
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# only BCE loss has pos_weight
if self.pos_weight is not None and self.use_sigmoid:
pos_weight = cls_score.new_tensor(self.pos_weight)
kwargs.update({'pos_weight': pos_weight})
else:
pos_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
|