Spaces:
Build error
Build error
import torch | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch import nn | |
from scipy.optimize import linear_sum_assignment | |
from torch.cuda.amp import custom_fwd, custom_bwd | |
def box_area(boxes): | |
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
# modified from torchvision to also return the union | |
def box_iou(boxes1, boxes2): | |
area1 = box_area(boxes1) | |
area2 = box_area(boxes2) | |
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] | |
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] | |
wh = (rb - lt).clamp(min=0) # [N,M,2] | |
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] | |
union = area1[:, None] + area2 - inter | |
iou = inter / union | |
return iou, union | |
def generalized_box_iou(boxes1, boxes2): | |
""" | |
Generalized IoU from https://giou.stanford.edu/ | |
The boxes should be in [x0, y0, x1, y1] format | |
Returns a [N, M] pairwise matrix, where N = len(boxes1) | |
and M = len(boxes2) | |
""" | |
# degenerate boxes gives inf / nan results | |
# so do an early check | |
#assert (boxes1[:, 2:] >= boxes1[:, :2]).all() | |
#assert (boxes2[:, 2:] >= boxes2[:, :2]).all() | |
iou, union = box_iou(boxes1, boxes2) | |
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) | |
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) | |
wh = (rb - lt).clamp(min=0) # [N,M,2] | |
area = wh[:, :, 0] * wh[:, :, 1] | |
return iou - (area - union) / area | |
def dice_loss(inputs, targets, num_boxes): | |
""" | |
Compute the DICE loss, similar to generalized IOU for masks | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
""" | |
inputs = inputs.sigmoid() | |
inputs = inputs.flatten(1) | |
numerator = 2 * (inputs * targets).sum(1) | |
denominator = inputs.sum(-1) + targets.sum(-1) | |
loss = 1 - (numerator + 1) / (denominator + 1) | |
return loss.sum() / num_boxes | |
def sigmoid_focal_loss(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none"): | |
""" | |
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
alpha: (optional) Weighting factor in range (0,1) to balance | |
positive vs negative examples. Default = -1 (no weighting). | |
gamma: Exponent of the modulating factor (1 - p_t) to | |
balance easy vs hard examples. | |
reduction: 'none' | 'mean' | 'sum' | |
'none': No reduction will be applied to the output. | |
'mean': The output will be averaged. | |
'sum': The output will be summed. | |
Returns: | |
Loss tensor with the reduction option applied. | |
""" | |
p = torch.sigmoid(inputs) | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
p_t = p * targets + (1 - p) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss | |
if reduction == "mean": | |
loss = loss.mean() | |
elif reduction == "sum": | |
loss = loss.sum() | |
return loss | |
sigmoid_focal_loss_jit = torch.jit.script( | |
sigmoid_focal_loss | |
) # type: torch.jit.ScriptModule | |
class HungarianMatcher(nn.Module): | |
"""This class computes an assignment between the targets and the predictions of the network | |
For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |
while the others are un-matched (and thus treated as non-objects). | |
""" | |
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, | |
use_focal: bool = False, focal_loss_alpha: float = 0.25, focal_loss_gamma: float = 2.0, | |
**kwargs): | |
"""Creates the matcher | |
Params: | |
cost_class: This is the relative weight of the classification error in the matching cost | |
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost | |
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost | |
""" | |
super().__init__() | |
self.cost_class = cost_class | |
self.cost_bbox = cost_bbox | |
self.cost_giou = cost_giou | |
self.use_focal = use_focal | |
if self.use_focal: | |
self.focal_loss_alpha = focal_loss_alpha | |
self.focal_loss_gamma = focal_loss_gamma | |
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" | |
def forward(self, outputs, targets): | |
""" Performs the matching | |
Params: | |
outputs: This is a dict that contains at least these entries: | |
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits | |
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates | |
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | |
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | |
objects in the target) containing the class labels | |
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | |
Returns: | |
A list of size batch_size, containing tuples of (index_i, index_j) where: | |
- index_i is the indices of the selected predictions (in order) | |
- index_j is the indices of the corresponding selected targets (in order) | |
For each batch element, it holds: | |
len(index_i) = len(index_j) = min(num_queries, num_target_boxes) | |
""" | |
bs, num_queries = outputs["pred_logits"].shape[:2] | |
# We flatten to compute the cost matrices in a batch | |
if self.use_focal: | |
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] | |
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
else: | |
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] | |
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
# Also concat the target labels and boxes | |
tgt_ids = torch.cat([v["labels"] for v in targets]) | |
tgt_bbox = torch.cat([v["boxes_xyxy"] for v in targets]) | |
# Compute the classification cost. Contrary to the loss, we don't use the NLL, | |
# but approximate it in 1 - proba[target class]. | |
# The 1 is a constant that doesn't change the matching, it can be ommitted. | |
if self.use_focal: | |
# Compute the classification cost. | |
alpha = self.focal_loss_alpha | |
gamma = self.focal_loss_gamma | |
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) | |
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) | |
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] | |
else: | |
cost_class = -out_prob[:, tgt_ids] | |
# Compute the L1 cost between boxes | |
image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets]) | |
image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1) | |
image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) | |
out_bbox_ = out_bbox / image_size_out | |
tgt_bbox_ = tgt_bbox / image_size_tgt | |
cost_bbox = torch.cdist(out_bbox_, tgt_bbox_, p=1) | |
# Compute the giou cost betwen boxes | |
# cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) | |
cost_giou = -generalized_box_iou(out_bbox, tgt_bbox) | |
# Final cost matrix | |
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou | |
C = C.view(bs, num_queries, -1).cpu() | |
C[torch.isnan(C)] = 0.0 | |
C[torch.isinf(C)] = 0.0 | |
sizes = [len(v["boxes"]) for v in targets] | |
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
class SetCriterion(nn.Module): | |
""" | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, | |
use_focal, focal_loss_alpha=0.25, focal_loss_gamma=2.0): | |
""" Create the criterion. | |
Parameters: | |
num_classes: number of object categories, omitting the special no-object category | |
matcher: module able to compute a matching between targets and proposals | |
weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
eos_coef: relative classification weight applied to the no-object category | |
losses: list of all the losses to be applied. See get_loss for list of available losses. | |
""" | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.eos_coef = eos_coef | |
self.losses = losses | |
self.use_focal = use_focal | |
if self.use_focal: | |
self.focal_loss_alpha = focal_loss_alpha | |
self.focal_loss_gamma = focal_loss_gamma | |
else: | |
empty_weight = torch.ones(self.num_classes + 1) | |
empty_weight[-1] = self.eos_coef | |
self.register_buffer('empty_weight', empty_weight) | |
def loss_labels(self, outputs, targets, indices, num_boxes, log=False): | |
"""Classification loss (NLL) | |
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
""" | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] | |
idx = self._get_src_permutation_idx(indices) | |
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full(src_logits.shape[:2], self.num_classes, | |
dtype=torch.int64, device=src_logits.device) | |
target_classes[idx] = target_classes_o | |
if self.use_focal: | |
src_logits = src_logits.flatten(0, 1) | |
# prepare one_hot target. | |
target_classes = target_classes.flatten(0, 1) | |
pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] | |
labels = torch.zeros_like(src_logits) | |
labels[pos_inds, target_classes[pos_inds]] = 1 | |
# comp focal loss. | |
class_loss = sigmoid_focal_loss_jit( | |
src_logits, | |
labels, | |
alpha=self.focal_loss_alpha, | |
gamma=self.focal_loss_gamma, | |
reduction="sum", | |
) / num_boxes | |
losses = {'loss_ce': class_loss} | |
else: | |
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) | |
losses = {'loss_ce': loss_ce} | |
return losses | |
def loss_boxes(self, outputs, targets, indices, num_boxes): | |
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. | |
""" | |
assert 'pred_boxes' in outputs | |
idx = self._get_src_permutation_idx(indices) | |
src_boxes = outputs['pred_boxes'][idx] | |
target_boxes = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
losses = {} | |
loss_giou = 1 - torch.diag(generalized_box_iou(src_boxes, target_boxes)) | |
losses['loss_giou'] = loss_giou.sum() / num_boxes | |
image_size = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) | |
src_boxes_ = src_boxes / image_size | |
target_boxes_ = target_boxes / image_size | |
loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='none') | |
losses['loss_bbox'] = loss_bbox.sum() / num_boxes | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): | |
loss_map = { | |
'labels': self.loss_labels, | |
'boxes': self.loss_boxes, | |
} | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) | |
def forward(self, outputs, targets, *argrs, **kwargs): | |
""" This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
""" | |
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} | |
# Retrieve the matching between the outputs of the last layer and the targets | |
indices = self.matcher(outputs_without_aux, targets) | |
# Compute the average number of target boxes accross all nodes, for normalization purposes | |
num_boxes = sum(len(t["labels"]) for t in targets) | |
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
if dist.is_available() and dist.is_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
word_size = dist.get_world_size() | |
else: | |
word_size = 1 | |
num_boxes = torch.clamp(num_boxes / word_size, min=1).item() | |
# Compute all the requested losses | |
losses = {} | |
for loss in self.losses: | |
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if 'aux_outputs' in outputs: | |
for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
indices = self.matcher(aux_outputs, targets) | |
for loss in self.losses: | |
if loss == 'masks': | |
# Intermediate masks losses are too costly to compute, we ignore them. | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
# Logging is enabled only for the last layer | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
return losses | |