# Copyright (c) OpenMMLab. All rights reserved. # migrate from mmdetection with modifications import torch import torch.nn as nn import torch.nn.functional as F from mmpretrain.registry import MODELS from .utils import weight_reduce_loss def seesaw_ce_loss(cls_score, labels, weight, cum_samples, num_classes, p, q, eps, reduction='mean', avg_factor=None): """Calculate the Seesaw CrossEntropy loss. Args: cls_score (torch.Tensor): The prediction with shape (N, C), C is the number of classes. labels (torch.Tensor): The learning label of the prediction. weight (torch.Tensor): Sample-wise loss weight. cum_samples (torch.Tensor): Cumulative samples for each category. num_classes (int): The number of classes. p (float): The ``p`` in the mitigation factor. q (float): The ``q`` in the compenstation factor. eps (float): The minimal value of divisor to smooth the computation of compensation factor reduction (str, optional): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. Returns: torch.Tensor: The calculated loss """ assert cls_score.size(-1) == num_classes assert len(cum_samples) == num_classes onehot_labels = F.one_hot(labels, num_classes) seesaw_weights = cls_score.new_ones(onehot_labels.size()) # mitigation factor if p > 0: sample_ratio_matrix = cum_samples[None, :].clamp( min=1) / cum_samples[:, None].clamp(min=1) index = (sample_ratio_matrix < 1.0).float() sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index ) # M_{ij} mitigation_factor = sample_weights[labels.long(), :] seesaw_weights = seesaw_weights * mitigation_factor # compensation factor if q > 0: scores = F.softmax(cls_score.detach(), dim=1) self_scores = scores[ torch.arange(0, len(scores)).to(scores.device).long(), labels.long()] score_matrix = scores / self_scores[:, None].clamp(min=eps) index = (score_matrix > 1.0).float() compensation_factor = score_matrix.pow(q) * index + (1 - index) seesaw_weights = seesaw_weights * compensation_factor cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') if weight is not None: weight = weight.float() loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss @MODELS.register_module() class SeesawLoss(nn.Module): """Implementation of seesaw loss. Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) `_ Args: use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. Only False is supported. Defaults to False. p (float): The ``p`` in the mitigation factor. Defaults to 0.8. q (float): The ``q`` in the compenstation factor. Defaults to 2.0. num_classes (int): The number of classes. Defaults to 1000 for the ImageNet dataset. eps (float): The minimal value of divisor to smooth the computation of compensation factor, default to 1e-2. reduction (str): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". Defaults to "mean". loss_weight (float): The weight of the loss. Defaults to 1.0 """ def __init__(self, use_sigmoid=False, p=0.8, q=2.0, num_classes=1000, eps=1e-2, reduction='mean', loss_weight=1.0): super(SeesawLoss, self).__init__() assert not use_sigmoid, '`use_sigmoid` is not supported' self.use_sigmoid = False self.p = p self.q = q self.num_classes = num_classes self.eps = eps self.reduction = reduction self.loss_weight = loss_weight self.cls_criterion = seesaw_ce_loss # cumulative samples for each category self.register_buffer('cum_samples', torch.zeros(self.num_classes, dtype=torch.float)) def forward(self, cls_score, labels, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: cls_score (torch.Tensor): The prediction with shape (N, C). labels (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum'), \ f'The `reduction_override` should be one of (None, "none", ' \ f'"mean", "sum"), but get "{reduction_override}".' assert cls_score.size(0) == labels.view(-1).size(0), \ f'Expected `labels` shape [{cls_score.size(0)}], ' \ f'but got {list(labels.size())}' reduction = ( reduction_override if reduction_override else self.reduction) assert cls_score.size(-1) == self.num_classes, \ f'The channel number of output ({cls_score.size(-1)}) does ' \ f'not match the `num_classes` of seesaw loss ({self.num_classes}).' # accumulate the samples for each category unique_labels = labels.unique() for u_l in unique_labels: inds_ = labels == u_l.item() self.cum_samples[u_l] += inds_.sum() if weight is not None: weight = weight.float() else: weight = labels.new_ones(labels.size(), dtype=torch.float) # calculate loss_cls_classes loss_cls = self.loss_weight * self.cls_criterion( cls_score, labels, weight, self.cum_samples, self.num_classes, self.p, self.q, self.eps, reduction, avg_factor) return loss_cls