Spaces:
Runtime error
Runtime error
File size: 6,733 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Copyright (c) OpenMMLab. All rights reserved.
# migrate from mmdetection with modifications
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmpretrain.registry import MODELS
from .utils import weight_reduce_loss
def seesaw_ce_loss(cls_score,
labels,
weight,
cum_samples,
num_classes,
p,
q,
eps,
reduction='mean',
avg_factor=None):
"""Calculate the Seesaw CrossEntropy loss.
Args:
cls_score (torch.Tensor): The prediction with shape (N, C),
C is the number of classes.
labels (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor): Sample-wise loss weight.
cum_samples (torch.Tensor): Cumulative samples for each category.
num_classes (int): The number of classes.
p (float): The ``p`` in the mitigation factor.
q (float): The ``q`` in the compenstation factor.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert cls_score.size(-1) == num_classes
assert len(cum_samples) == num_classes
onehot_labels = F.one_hot(labels, num_classes)
seesaw_weights = cls_score.new_ones(onehot_labels.size())
# mitigation factor
if p > 0:
sample_ratio_matrix = cum_samples[None, :].clamp(
min=1) / cum_samples[:, None].clamp(min=1)
index = (sample_ratio_matrix < 1.0).float()
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index
) # M_{ij}
mitigation_factor = sample_weights[labels.long(), :]
seesaw_weights = seesaw_weights * mitigation_factor
# compensation factor
if q > 0:
scores = F.softmax(cls_score.detach(), dim=1)
self_scores = scores[
torch.arange(0, len(scores)).to(scores.device).long(),
labels.long()]
score_matrix = scores / self_scores[:, None].clamp(min=eps)
index = (score_matrix > 1.0).float()
compensation_factor = score_matrix.pow(q) * index + (1 - index)
seesaw_weights = seesaw_weights * compensation_factor
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
@MODELS.register_module()
class SeesawLoss(nn.Module):
"""Implementation of seesaw loss.
Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
<https://arxiv.org/abs/2008.10032>`_
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid of softmax.
Only False is supported. Defaults to False.
p (float): The ``p`` in the mitigation factor.
Defaults to 0.8.
q (float): The ``q`` in the compenstation factor.
Defaults to 2.0.
num_classes (int): The number of classes.
Defaults to 1000 for the ImageNet dataset.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor, default to 1e-2.
reduction (str): The method that reduces the loss to a scalar.
Options are "none", "mean" and "sum". Defaults to "mean".
loss_weight (float): The weight of the loss. Defaults to 1.0
"""
def __init__(self,
use_sigmoid=False,
p=0.8,
q=2.0,
num_classes=1000,
eps=1e-2,
reduction='mean',
loss_weight=1.0):
super(SeesawLoss, self).__init__()
assert not use_sigmoid, '`use_sigmoid` is not supported'
self.use_sigmoid = False
self.p = p
self.q = q
self.num_classes = num_classes
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
self.cls_criterion = seesaw_ce_loss
# cumulative samples for each category
self.register_buffer('cum_samples',
torch.zeros(self.num_classes, dtype=torch.float))
def forward(self,
cls_score,
labels,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction with shape (N, C).
labels (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum'), \
f'The `reduction_override` should be one of (None, "none", ' \
f'"mean", "sum"), but get "{reduction_override}".'
assert cls_score.size(0) == labels.view(-1).size(0), \
f'Expected `labels` shape [{cls_score.size(0)}], ' \
f'but got {list(labels.size())}'
reduction = (
reduction_override if reduction_override else self.reduction)
assert cls_score.size(-1) == self.num_classes, \
f'The channel number of output ({cls_score.size(-1)}) does ' \
f'not match the `num_classes` of seesaw loss ({self.num_classes}).'
# accumulate the samples for each category
unique_labels = labels.unique()
for u_l in unique_labels:
inds_ = labels == u_l.item()
self.cum_samples[u_l] += inds_.sum()
if weight is not None:
weight = weight.float()
else:
weight = labels.new_ones(labels.size(), dtype=torch.float)
# calculate loss_cls_classes
loss_cls = self.loss_weight * self.cls_criterion(
cls_score, labels, weight, self.cum_samples, self.num_classes,
self.p, self.q, self.eps, reduction, avg_factor)
return loss_cls
|