Spaces:
Sleeping
Sleeping
File size: 7,883 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from maskrcnn_benchmark import _C
# TODO: Use JIT to replace CUDA implementation in the future.
class _SigmoidFocalLoss(Function):
@staticmethod
def forward(ctx, logits, targets, gamma, alpha):
ctx.save_for_backward(logits, targets)
num_classes = logits.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
losses = _C.sigmoid_focalloss_forward(logits, targets, num_classes, gamma, alpha)
return losses
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
logits, targets = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_logits = _C.sigmoid_focalloss_backward(logits, targets, d_loss, num_classes, gamma, alpha)
return d_logits, None, None, None, None
sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply
def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
num_classes = logits.shape[1]
dtype = targets.dtype
device = targets.device
class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0)
t = targets.unsqueeze(1)
p = torch.sigmoid(logits)
term1 = (1 - p) ** gamma * torch.log(p)
term2 = p**gamma * torch.log(1 - p)
return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
# Protect for the case where there are no boxes!
if targets.nelement() == 0:
return torch.as_tensor(0, device=logits.device)
if logits.is_cuda:
loss_func = sigmoid_focal_loss_cuda
else:
loss_func = sigmoid_focal_loss_cpu
loss = loss_func(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
# Another modification is that because we use the cross entropy version, there is no frequent or not frequent class.
# So we temporarily retired the design of alpha.
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
# reprocess target to become probability map ready for softmax
targets = targets.float()
target_num = targets.sum(-1) + 1e-8 # numerical stability
targets = targets / target_num.unsqueeze(-1) # T(x)
if text_mask is not None:
# reserve the last token for non object
assert text_mask.dim() == 2
text_mask[:, -1] = 1
text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1) # copy along the image channel
pred_logits = pred_logits.masked_fill(~text_mask, -1000000) # softmax
out_prob = pred_logits.softmax(-1)
filled_targets = targets.clone()
filled_targets[filled_targets == 0] = 1.0
weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets
weight = torch.pow(weight, gamma) # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma)
loss_ce = (
-targets * weight * pred_logits.log_softmax(-1)
) # only those positives with positive target_sim will have losses.
return loss_ce
def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None):
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
if text_mask is not None:
assert text_mask.dim() == 2
# We convert everything into binary
out_prob = pred_logits.sigmoid()
out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8 # batch x boxes x 256 x 2
weight = torch.pow(-out_prob_neg_pos + 1.0, gamma)
focal_zero = -weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * (1 - alpha) # negative class
focal_one = -weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha # positive class
focal = torch.stack([focal_zero, focal_one], dim=-1)
loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1)
return loss_ce
def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
# binary version of focal loss
# copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor with the reduction option applied.
"""
assert targets.dim() == 3
assert pred_logits.dim() == 3 # batch x from x to
bs, n, _ = pred_logits.shape
if text_mask is not None:
assert text_mask.dim() == 2
text_mask = (text_mask > 0).unsqueeze(1)
text_mask = text_mask.repeat(1, pred_logits.size(1), 1) # copy along the image channel dimension
pred_logits = torch.masked_select(pred_logits, text_mask)
if targets.size(-1) > text_mask.size(-1):
targets = targets[:, :, : text_mask.size(-1)]
targets = torch.masked_select(targets, text_mask)
# print(pred_logits.shape)
# print(targets.shape)
p = torch.sigmoid(pred_logits)
ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss
class TokenSigmoidFocalLoss(nn.Module):
def __init__(self, alpha, gamma):
super(TokenSigmoidFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets, text_masks=None, version="binary", **kwargs):
# Protect for the case where there are no boxes!
if targets.nelement() == 0:
return torch.as_tensor(0, device=logits.device)
if version == "binary":
loss_func = token_sigmoid_binary_focal_loss
elif version == "softmax":
loss_func = token_sigmoid_softmax_focal_loss
elif version == "binaryv2":
loss_func = token_sigmoid_binary_focal_loss_v2
else:
raise NotImplementedError
loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
|