mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List
import torch
from detrex.modeling import SetCriterion
from detrex.utils import get_world_size, is_dist_avail_and_initialized
from .assigner import Stage1Assigner, Stage2Assigner
class DETACriterion(SetCriterion):
"""This class computes the loss for Deformable-DETR
and two-stage Deformable-DETR
"""
def __init__(
self,
num_classes,
matcher,
weight_dict,
losses: List[str] = ["class", "boxes"],
eos_coef: float = 0.1,
loss_class_type: str = "focal_loss",
alpha: float = 0.25,
gamma: float = 2.0,
num_queries: int = 300,
assign_first_stage: bool = False,
assign_second_stage: bool = False,
):
super(DETACriterion, self).__init__(
num_classes=num_classes,
matcher=matcher,
weight_dict=weight_dict,
losses=losses,
eos_coef=eos_coef,
loss_class_type=loss_class_type,
alpha=alpha,
gamma=gamma,
)
self.assign_first_stage = assign_first_stage
self.assign_second_stage = assign_second_stage
if self.assign_first_stage:
self.stg1_assigner = Stage1Assigner()
if self.assign_second_stage:
self.stg2_assigner = Stage2Assigner(num_queries)
def forward(self, outputs, targets):
outputs_without_aux = {
k: v for k, v in outputs.items() if k != "aux_outputs" and k != "enc_outputs"
}
# Retrieve the matching between the outputs of the last layer and the targets
if self.assign_second_stage:
indices = self.stg2_assigner(outputs_without_aux, targets)
else:
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor(
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
kwargs = {}
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
if not self.assign_second_stage:
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
# Compute losses for two-stage deformable-detr
if "enc_outputs" in outputs:
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
if self.assign_first_stage:
indices = self.stg1_assigner(enc_outputs, bin_targets)
else:
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses