Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from .accuracy import accuracy | |
from .cross_entropy_loss import cross_entropy | |
from .utils import weight_reduce_loss | |
def seesaw_ce_loss(cls_score: Tensor, | |
labels: Tensor, | |
label_weights: Tensor, | |
cum_samples: Tensor, | |
num_classes: int, | |
p: float, | |
q: float, | |
eps: float, | |
reduction: str = 'mean', | |
avg_factor: Optional[int] = None) -> Tensor: | |
"""Calculate the Seesaw CrossEntropy loss. | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C), | |
C is the number of classes. | |
labels (Tensor): The learning label of the prediction. | |
label_weights (Tensor): Sample-wise loss weight. | |
cum_samples (Tensor): Cumulative samples for each category. | |
num_classes (int): The number of classes. | |
p (float): The ``p`` in the mitigation factor. | |
q (float): The ``q`` in the compenstation factor. | |
eps (float): The minimal value of divisor to smooth | |
the computation of compensation factor | |
reduction (str, optional): The method used to reduce the loss. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
Returns: | |
Tensor: The calculated loss | |
""" | |
assert cls_score.size(-1) == num_classes | |
assert len(cum_samples) == num_classes | |
onehot_labels = F.one_hot(labels, num_classes) | |
seesaw_weights = cls_score.new_ones(onehot_labels.size()) | |
# mitigation factor | |
if p > 0: | |
sample_ratio_matrix = cum_samples[None, :].clamp( | |
min=1) / cum_samples[:, None].clamp(min=1) | |
index = (sample_ratio_matrix < 1.0).float() | |
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index) | |
mitigation_factor = sample_weights[labels.long(), :] | |
seesaw_weights = seesaw_weights * mitigation_factor | |
# compensation factor | |
if q > 0: | |
scores = F.softmax(cls_score.detach(), dim=1) | |
self_scores = scores[ | |
torch.arange(0, len(scores)).to(scores.device).long(), | |
labels.long()] | |
score_matrix = scores / self_scores[:, None].clamp(min=eps) | |
index = (score_matrix > 1.0).float() | |
compensation_factor = score_matrix.pow(q) * index + (1 - index) | |
seesaw_weights = seesaw_weights * compensation_factor | |
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) | |
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') | |
if label_weights is not None: | |
label_weights = label_weights.float() | |
loss = weight_reduce_loss( | |
loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
class SeesawLoss(nn.Module): | |
""" | |
Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) | |
arXiv: https://arxiv.org/abs/2008.10032 | |
Args: | |
use_sigmoid (bool, optional): Whether the prediction uses sigmoid | |
of softmax. Only False is supported. | |
p (float, optional): The ``p`` in the mitigation factor. | |
Defaults to 0.8. | |
q (float, optional): The ``q`` in the compenstation factor. | |
Defaults to 2.0. | |
num_classes (int, optional): The number of classes. | |
Default to 1203 for LVIS v1 dataset. | |
eps (float, optional): The minimal value of divisor to smooth | |
the computation of compensation factor | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
return_dict (bool, optional): Whether return the losses as a dict. | |
Default to True. | |
""" | |
def __init__(self, | |
use_sigmoid: bool = False, | |
p: float = 0.8, | |
q: float = 2.0, | |
num_classes: int = 1203, | |
eps: float = 1e-2, | |
reduction: str = 'mean', | |
loss_weight: float = 1.0, | |
return_dict: bool = True) -> None: | |
super().__init__() | |
assert not use_sigmoid | |
self.use_sigmoid = False | |
self.p = p | |
self.q = q | |
self.num_classes = num_classes | |
self.eps = eps | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.return_dict = return_dict | |
# 0 for pos, 1 for neg | |
self.cls_criterion = seesaw_ce_loss | |
# cumulative samples for each category | |
self.register_buffer( | |
'cum_samples', | |
torch.zeros(self.num_classes + 1, dtype=torch.float)) | |
# custom output channels of the classifier | |
self.custom_cls_channels = True | |
# custom activation of cls_score | |
self.custom_activation = True | |
# custom accuracy of the classsifier | |
self.custom_accuracy = True | |
def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]: | |
"""split cls_score. | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C + 2). | |
Returns: | |
Tuple[Tensor, Tensor]: The score for classes and objectness, | |
respectively | |
""" | |
# split cls_score to cls_score_classes and cls_score_objectness | |
assert cls_score.size(-1) == self.num_classes + 2 | |
cls_score_classes = cls_score[..., :-2] | |
cls_score_objectness = cls_score[..., -2:] | |
return cls_score_classes, cls_score_objectness | |
def get_cls_channels(self, num_classes: int) -> int: | |
"""Get custom classification channels. | |
Args: | |
num_classes (int): The number of classes. | |
Returns: | |
int: The custom classification channels. | |
""" | |
assert num_classes == self.num_classes | |
return num_classes + 2 | |
def get_activation(self, cls_score: Tensor) -> Tensor: | |
"""Get custom activation of cls_score. | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C + 2). | |
Returns: | |
Tensor: The custom activation of cls_score with shape | |
(N, C + 1). | |
""" | |
cls_score_classes, cls_score_objectness = self._split_cls_score( | |
cls_score) | |
score_classes = F.softmax(cls_score_classes, dim=-1) | |
score_objectness = F.softmax(cls_score_objectness, dim=-1) | |
score_pos = score_objectness[..., [0]] | |
score_neg = score_objectness[..., [1]] | |
score_classes = score_classes * score_pos | |
scores = torch.cat([score_classes, score_neg], dim=-1) | |
return scores | |
def get_accuracy(self, cls_score: Tensor, | |
labels: Tensor) -> Dict[str, Tensor]: | |
"""Get custom accuracy w.r.t. cls_score and labels. | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C + 2). | |
labels (Tensor): The learning label of the prediction. | |
Returns: | |
Dict [str, Tensor]: The accuracy for objectness and classes, | |
respectively. | |
""" | |
pos_inds = labels < self.num_classes | |
obj_labels = (labels == self.num_classes).long() | |
cls_score_classes, cls_score_objectness = self._split_cls_score( | |
cls_score) | |
acc_objectness = accuracy(cls_score_objectness, obj_labels) | |
acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds]) | |
acc = dict() | |
acc['acc_objectness'] = acc_objectness | |
acc['acc_classes'] = acc_classes | |
return acc | |
def forward( | |
self, | |
cls_score: Tensor, | |
labels: Tensor, | |
label_weights: Optional[Tensor] = None, | |
avg_factor: Optional[int] = None, | |
reduction_override: Optional[str] = None | |
) -> Union[Tensor, Dict[str, Tensor]]: | |
"""Forward function. | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C + 2). | |
labels (Tensor): The learning label of the prediction. | |
label_weights (Tensor, optional): Sample-wise loss weight. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction (str, optional): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". | |
Returns: | |
Tensor | Dict [str, Tensor]: | |
if return_dict == False: The calculated loss | | |
if return_dict == True: The dict of calculated losses | |
for objectness and classes, respectively. | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
assert cls_score.size(-1) == self.num_classes + 2 | |
pos_inds = labels < self.num_classes | |
# 0 for pos, 1 for neg | |
obj_labels = (labels == self.num_classes).long() | |
# accumulate the samples for each category | |
unique_labels = labels.unique() | |
for u_l in unique_labels: | |
inds_ = labels == u_l.item() | |
self.cum_samples[u_l] += inds_.sum() | |
if label_weights is not None: | |
label_weights = label_weights.float() | |
else: | |
label_weights = labels.new_ones(labels.size(), dtype=torch.float) | |
cls_score_classes, cls_score_objectness = self._split_cls_score( | |
cls_score) | |
# calculate loss_cls_classes (only need pos samples) | |
if pos_inds.sum() > 0: | |
loss_cls_classes = self.loss_weight * self.cls_criterion( | |
cls_score_classes[pos_inds], labels[pos_inds], | |
label_weights[pos_inds], self.cum_samples[:self.num_classes], | |
self.num_classes, self.p, self.q, self.eps, reduction, | |
avg_factor) | |
else: | |
loss_cls_classes = cls_score_classes[pos_inds].sum() | |
# calculate loss_cls_objectness | |
loss_cls_objectness = self.loss_weight * cross_entropy( | |
cls_score_objectness, obj_labels, label_weights, reduction, | |
avg_factor) | |
if self.return_dict: | |
loss_cls = dict() | |
loss_cls['loss_cls_objectness'] = loss_cls_objectness | |
loss_cls['loss_cls_classes'] = loss_cls_classes | |
else: | |
loss_cls = loss_cls_classes + loss_cls_objectness | |
return loss_cls | |