desco / maskrcnn_benchmark /layers /sigmoid_focal_loss.py
zdou0830's picture
desco
749745d
raw
history blame
No virus
7.88 kB
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from maskrcnn_benchmark import _C
# TODO: Use JIT to replace CUDA implementation in the future.
class _SigmoidFocalLoss(Function):
@staticmethod
def forward(ctx, logits, targets, gamma, alpha):
ctx.save_for_backward(logits, targets)
num_classes = logits.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
losses = _C.sigmoid_focalloss_forward(logits, targets, num_classes, gamma, alpha)
return losses
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
logits, targets = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_logits = _C.sigmoid_focalloss_backward(logits, targets, d_loss, num_classes, gamma, alpha)
return d_logits, None, None, None, None
sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply
def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
num_classes = logits.shape[1]
dtype = targets.dtype
device = targets.device
class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0)
t = targets.unsqueeze(1)
p = torch.sigmoid(logits)
term1 = (1 - p) ** gamma * torch.log(p)
term2 = p**gamma * torch.log(1 - p)
return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
# Protect for the case where there are no boxes!
if targets.nelement() == 0:
return torch.as_tensor(0, device=logits.device)
if logits.is_cuda:
loss_func = sigmoid_focal_loss_cuda
else:
loss_func = sigmoid_focal_loss_cpu
loss = loss_func(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
# Another modification is that because we use the cross entropy version, there is no frequent or not frequent class.
# So we temporarily retired the design of alpha.
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
# reprocess target to become probability map ready for softmax
targets = targets.float()
target_num = targets.sum(-1) + 1e-8 # numerical stability
targets = targets / target_num.unsqueeze(-1) # T(x)
if text_mask is not None:
# reserve the last token for non object
assert text_mask.dim() == 2
text_mask[:, -1] = 1
text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1) # copy along the image channel
pred_logits = pred_logits.masked_fill(~text_mask, -1000000) # softmax
out_prob = pred_logits.softmax(-1)
filled_targets = targets.clone()
filled_targets[filled_targets == 0] = 1.0
weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets
weight = torch.pow(weight, gamma) # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma)
loss_ce = (
-targets * weight * pred_logits.log_softmax(-1)
) # only those positives with positive target_sim will have losses.
return loss_ce
def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None):
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
if text_mask is not None:
assert text_mask.dim() == 2
# We convert everything into binary
out_prob = pred_logits.sigmoid()
out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8 # batch x boxes x 256 x 2
weight = torch.pow(-out_prob_neg_pos + 1.0, gamma)
focal_zero = -weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * (1 - alpha) # negative class
focal_one = -weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha # positive class
focal = torch.stack([focal_zero, focal_one], dim=-1)
loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1)
return loss_ce
def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
# binary version of focal loss
# copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor with the reduction option applied.
"""
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
bs, n, _ = pred_logits.shape
if text_mask is not None:
assert text_mask.dim() == 2
text_mask = (text_mask > 0).unsqueeze(1)
text_mask = text_mask.repeat(1, pred_logits.size(1), 1) # copy along the image channel dimension
pred_logits = torch.masked_select(pred_logits, text_mask)
if targets.size(-1) > text_mask.size(-1):
targets = targets[:, :, : text_mask.size(-1)]
targets = torch.masked_select(targets, text_mask)
# print(pred_logits.shape)
# print(targets.shape)
p = torch.sigmoid(pred_logits)
ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss
class TokenSigmoidFocalLoss(nn.Module):
def __init__(self, alpha, gamma):
super(TokenSigmoidFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets, text_masks=None, version="binary", **kwargs):
# Protect for the case where there are no boxes!
if targets.nelement() == 0:
return torch.as_tensor(0, device=logits.device)
if version == "binary":
loss_func = token_sigmoid_binary_focal_loss
elif version == "softmax":
loss_func = token_sigmoid_softmax_focal_loss
elif version == "binaryv2":
loss_func = token_sigmoid_binary_focal_loss_v2
else:
raise NotImplementedError
loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr