Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
from detrex.layers.box_ops import box_cxcywh_to_xyxy | |
from detrex.modeling.losses import FocalLoss, GIoULoss, L1Loss | |
from detrex.modeling.matcher import FocalLossCost, GIoUCost, L1Cost, ModifedMatcher | |
from detrex.utils import get_world_size, is_dist_avail_and_initialized | |
class BaseCriterion(nn.Module): | |
"""Base criterion for calculating losses for DETR-like models. | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
matcher=ModifedMatcher( | |
cost_class=FocalLossCost( | |
alpha=0.25, | |
gamma=2.0, | |
weight=2.0, | |
), | |
cost_bbox=L1Cost(weight=5.0), | |
cost_giou=GIoUCost(weight=2.0), | |
), | |
loss_class: nn.Module = FocalLoss( | |
alpha=0.25, | |
gamma=2.0, | |
loss_weight=1.0, | |
), | |
loss_bbox: nn.Module = L1Loss(loss_weight=5.0), | |
loss_giou: nn.Module = GIoULoss(eps=1e-6, loss_weight=2.0), | |
): | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.loss_class = loss_class | |
self.loss_bbox = loss_bbox | |
self.loss_giou = loss_giou | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
def calculate_class_loss(self, pred_logits, targets, indices, num_boxes): | |
""" | |
Args: | |
preds (torch.Tensor): The predicted logits with shape ``(bs, num_queries, num_classes)``. | |
targets (dict): | |
indices (list): | |
num_boxes (int): | |
""" | |
idx = self._get_src_permutation_idx(indices) | |
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full( | |
pred_logits.shape[:2], | |
self.num_classes, | |
dtype=torch.int64, | |
device=pred_logits.device, | |
) | |
target_classes[idx] = target_classes_o | |
# Compute classification loss | |
pred_logits = pred_logits.view(-1, self.num_classes) | |
target_classes = target_classes.flatten() | |
losses = self.loss_class(pred_logits, target_classes, avg_factor=num_boxes) | |
return losses | |
def calculate_bbox_loss(self, pred_boxes, targets, indices, num_boxes): | |
idx = self._get_src_permutation_idx(indices) | |
pred_boxes = pred_boxes[idx] | |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
# Compute regression loss | |
losses = self.loss_bbox(pred_boxes, target_boxes, avg_factor=num_boxes) | |
return losses | |
def calculate_giou_loss(self, pred_boxes, targets, indices, num_boxes): | |
idx = self._get_src_permutation_idx(indices) | |
pred_boxes = pred_boxes[idx] | |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
# Convert box format to (x1, y1, x2, y2) | |
pred_boxes = box_cxcywh_to_xyxy(pred_boxes) | |
target_boxes = box_cxcywh_to_xyxy(target_boxes) | |
# Compute iou loss | |
losses = self.loss_giou(pred_boxes, target_boxes, avg_factor=num_boxes) | |
return losses | |
def forward(self, outputs, targets): | |
output_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} | |
# Collect preds and targets excluding aux_outputs for matcher | |
pred_logits = output_without_aux["pred_logits"] | |
pred_boxes = output_without_aux["pred_boxes"] | |
target_labels_list = [v["labels"] for v in targets] | |
target_boxes_list = [v["boxes"] for v in targets] | |
# Retrieve the matching between the outputs of the last layer and the targets | |
indices = self.matcher( | |
pred_logits, | |
pred_boxes, | |
target_labels_list, | |
target_boxes_list, | |
) | |
# Compute the average number of target boxes accross all nodes, for normalization purposes | |
num_boxes = sum(len(t["labels"]) for t in targets) | |
num_boxes = torch.as_tensor( | |
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device | |
) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
# Compute all losses for DETR-like models | |
losses = {} | |
losses["loss_class"] = self.calculate_class_loss(pred_logits, targets, indices, num_boxes) | |
losses["loss_bbox"] = self.calculate_bbox_loss(pred_boxes, targets, indices, num_boxes) | |
losses["loss_giou"] = self.calculate_giou_loss(pred_boxes, targets, indices, num_boxes) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if "aux_outputs" in outputs: | |
for i, aux_output in enumerate(outputs["aux_outputs"]): | |
aux_pred_logits = aux_output["pred_logits"] | |
aux_pred_boxes = aux_output["pred_boxes"] | |
indices = self.matcher( | |
aux_pred_logits, aux_pred_boxes, target_labels_list, target_boxes_list | |
) | |
losses["loss_class" + f"_{i}"] = self.calculate_class_loss( | |
aux_pred_logits, targets, indices, num_boxes | |
) | |
losses["loss_bbox" + f"_{i}"] = self.calculate_bbox_loss( | |
aux_pred_boxes, targets, indices, num_boxes | |
) | |
losses["loss_giou" + f"_{i}"] = self.calculate_giou_loss( | |
aux_pred_boxes, targets, indices, num_boxes | |
) | |
return losses | |