Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from detrex.utils import get_world_size, is_dist_avail_and_initialized | |
from .two_stage_criterion import TwoStageCriterion | |
class DINOCriterion(TwoStageCriterion): | |
"""This class computes the loss for DETR. | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def forward(self, outputs, targets, dn_metas=None): | |
"""This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
""" | |
losses = super(DINOCriterion, self).forward(outputs, targets) | |
# import pdb;pdb.set_trace() | |
num_boxes = sum(len(t["labels"]) for t in targets) | |
num_boxes = torch.as_tensor( | |
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device | |
) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
# Compute all the requested losses | |
aux_num = 0 | |
if "aux_outputs" in outputs: | |
aux_num = len(outputs["aux_outputs"]) | |
dn_losses = self.compute_dn_loss(dn_metas, targets, aux_num, num_boxes) | |
losses.update(dn_losses) | |
return losses | |
def compute_dn_loss(self, dn_metas, targets, aux_num, num_boxes): | |
""" | |
compute dn loss in criterion | |
Args: | |
dn_metas: a dict for dn information | |
training: training or inference flag | |
aux_num: aux loss number | |
focal_alpha: for focal loss | |
""" | |
losses = {} | |
if dn_metas and "output_known_lbs_bboxes" in dn_metas: | |
output_known_lbs_bboxes, dn_num, single_padding = ( | |
dn_metas["output_known_lbs_bboxes"], | |
dn_metas["dn_num"], | |
dn_metas["single_padding"], | |
) | |
dn_idx = [] | |
for i in range(len(targets)): | |
if len(targets[i]["labels"]) > 0: | |
t = torch.arange(0, len(targets[i]["labels"])).long().cuda() | |
t = t.unsqueeze(0).repeat(dn_num, 1) | |
tgt_idx = t.flatten() | |
output_idx = ( | |
torch.tensor(range(dn_num)) * single_padding | |
).long().cuda().unsqueeze(1) + t | |
output_idx = output_idx.flatten() | |
else: | |
output_idx = tgt_idx = torch.tensor([]).long().cuda() | |
dn_idx.append((output_idx, tgt_idx)) | |
l_dict = {} | |
for loss in self.losses: | |
kwargs = {} | |
if "labels" in loss: | |
kwargs = {"log": False} | |
l_dict.update( | |
self.get_loss( | |
loss, output_known_lbs_bboxes, targets, dn_idx, num_boxes * dn_num, **kwargs | |
) | |
) | |
l_dict = {k + "_dn": v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
else: | |
losses["loss_bbox_dn"] = torch.as_tensor(0.0).to("cuda") | |
losses["loss_giou_dn"] = torch.as_tensor(0.0).to("cuda") | |
losses["loss_class_dn"] = torch.as_tensor(0.0).to("cuda") | |
for i in range(aux_num): | |
# dn aux loss | |
l_dict = {} | |
if dn_metas and "output_known_lbs_bboxes" in dn_metas: | |
output_known_lbs_bboxes_aux = output_known_lbs_bboxes["aux_outputs"][i] | |
for loss in self.losses: | |
kwargs = {} | |
if "labels" in loss: | |
kwargs = {"log": False} | |
l_dict.update( | |
self.get_loss( | |
loss, | |
output_known_lbs_bboxes_aux, | |
targets, | |
dn_idx, | |
num_boxes * dn_num, | |
**kwargs, | |
) | |
) | |
l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()} | |
else: | |
l_dict["loss_bbox_dn"] = torch.as_tensor(0.0).to("cuda") | |
l_dict["loss_giou_dn"] = torch.as_tensor(0.0).to("cuda") | |
l_dict["loss_class_dn"] = torch.as_tensor(0.0).to("cuda") | |
l_dict = {k + f"_{i}": v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
return losses | |