import torch import torch.nn.functional as F import torch.distributed as dist from torch import nn from scipy.optimize import linear_sum_assignment from torch.cuda.amp import custom_fwd, custom_bwd def box_area(boxes): return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check # assert (boxes1[:, 2:] >= boxes1[:, :2]).all() # assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area def dice_loss(inputs, targets, num_boxes): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_boxes def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none" ): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. 'sum': The output will be summed. Returns: Loss tensor with the reduction option applied. """ p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss sigmoid_focal_loss_jit = torch.jit.script(sigmoid_focal_loss) # type: torch.jit.ScriptModule class HungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, use_focal: bool = False, focal_loss_alpha: float = 0.25, focal_loss_gamma: float = 2.0, **kwargs, ): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou self.use_focal = use_focal if self.use_focal: self.focal_loss_alpha = focal_loss_alpha self.focal_loss_gamma = focal_loss_gamma assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" @torch.no_grad() @custom_fwd(cast_inputs=torch.float32) def forward(self, outputs, targets): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ bs, num_queries = outputs["pred_logits"].shape[:2] # We flatten to compute the cost matrices in a batch if self.use_focal: out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] else: out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes tgt_ids = torch.cat([v["labels"] for v in targets]) tgt_bbox = torch.cat([v["boxes_xyxy"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. if self.use_focal: # Compute the classification cost. alpha = self.focal_loss_alpha gamma = self.focal_loss_gamma neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] else: cost_class = -out_prob[:, tgt_ids] # Compute the L1 cost between boxes image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets]) image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1) image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) out_bbox_ = out_bbox / image_size_out tgt_bbox_ = tgt_bbox / image_size_tgt cost_bbox = torch.cdist(out_bbox_, tgt_bbox_, p=1) # Compute the giou cost betwen boxes # cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) cost_giou = -generalized_box_iou(out_bbox, tgt_bbox) # Final cost matrix C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou C = C.view(bs, num_queries, -1).cpu() C[torch.isnan(C)] = 0.0 C[torch.isinf(C)] = 0.0 sizes = [len(v["boxes"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] class SetCriterion(nn.Module): """ The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__( self, num_classes, matcher, weight_dict, eos_coef, losses, use_focal, focal_loss_alpha=0.25, focal_loss_gamma=2.0, ): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses self.use_focal = use_focal if self.use_focal: self.focal_loss_alpha = focal_loss_alpha self.focal_loss_gamma = focal_loss_gamma else: empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) def loss_labels(self, outputs, targets, indices, num_boxes, log=False): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o if self.use_focal: src_logits = src_logits.flatten(0, 1) # prepare one_hot target. target_classes = target_classes.flatten(0, 1) pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] labels = torch.zeros_like(src_logits) labels[pos_inds, target_classes[pos_inds]] = 1 # comp focal loss. class_loss = ( sigmoid_focal_loss_jit( src_logits, labels, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_boxes ) losses = {"loss_ce": class_loss} else: loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"loss_ce": loss_ce} return losses def loss_boxes(self, outputs, targets, indices, num_boxes): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. """ assert "pred_boxes" in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] target_boxes = torch.cat([t["boxes_xyxy"][i] for t, (_, i) in zip(targets, indices)], dim=0) losses = {} loss_giou = 1 - torch.diag(generalized_box_iou(src_boxes, target_boxes)) losses["loss_giou"] = loss_giou.sum() / num_boxes image_size = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) src_boxes_ = src_boxes / image_size target_boxes_ = target_boxes / image_size loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction="none") losses["loss_bbox"] = loss_bbox.sum() / num_boxes return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): loss_map = { "labels": self.loss_labels, "boxes": self.loss_boxes, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) @custom_fwd(cast_inputs=torch.float32) def forward(self, outputs, targets, *argrs, **kwargs): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if dist.is_available() and dist.is_initialized(): torch.distributed.all_reduce(num_boxes) word_size = dist.get_world_size() else: word_size = 1 num_boxes = torch.clamp(num_boxes / word_size, min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == "masks": # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == "labels": # Logging is enabled only for the last layer kwargs = {"log": False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) return losses