Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The IDEA Authors. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| from typing import List | |
| import torch | |
| from detrex.modeling import SetCriterion | |
| from detrex.utils import get_world_size, is_dist_avail_and_initialized | |
| from .assigner import Stage1Assigner, Stage2Assigner | |
| class DETACriterion(SetCriterion): | |
| """This class computes the loss for Deformable-DETR | |
| and two-stage Deformable-DETR | |
| """ | |
| def __init__( | |
| self, | |
| num_classes, | |
| matcher, | |
| weight_dict, | |
| losses: List[str] = ["class", "boxes"], | |
| eos_coef: float = 0.1, | |
| loss_class_type: str = "focal_loss", | |
| alpha: float = 0.25, | |
| gamma: float = 2.0, | |
| num_queries: int = 300, | |
| assign_first_stage: bool = False, | |
| assign_second_stage: bool = False, | |
| ): | |
| super(DETACriterion, self).__init__( | |
| num_classes=num_classes, | |
| matcher=matcher, | |
| weight_dict=weight_dict, | |
| losses=losses, | |
| eos_coef=eos_coef, | |
| loss_class_type=loss_class_type, | |
| alpha=alpha, | |
| gamma=gamma, | |
| ) | |
| self.assign_first_stage = assign_first_stage | |
| self.assign_second_stage = assign_second_stage | |
| if self.assign_first_stage: | |
| self.stg1_assigner = Stage1Assigner() | |
| if self.assign_second_stage: | |
| self.stg2_assigner = Stage2Assigner(num_queries) | |
| def forward(self, outputs, targets): | |
| outputs_without_aux = { | |
| k: v for k, v in outputs.items() if k != "aux_outputs" and k != "enc_outputs" | |
| } | |
| # Retrieve the matching between the outputs of the last layer and the targets | |
| if self.assign_second_stage: | |
| indices = self.stg2_assigner(outputs_without_aux, targets) | |
| else: | |
| indices = self.matcher(outputs_without_aux, targets) | |
| # Compute the average number of target boxes accross all nodes, for normalization purposes | |
| num_boxes = sum(len(t["labels"]) for t in targets) | |
| num_boxes = torch.as_tensor( | |
| [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device | |
| ) | |
| if is_dist_avail_and_initialized(): | |
| torch.distributed.all_reduce(num_boxes) | |
| num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| kwargs = {} | |
| losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if "aux_outputs" in outputs: | |
| for i, aux_outputs in enumerate(outputs["aux_outputs"]): | |
| if not self.assign_second_stage: | |
| indices = self.matcher(aux_outputs, targets) | |
| for loss in self.losses: | |
| l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + f"_{i}": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| # Compute losses for two-stage deformable-detr | |
| if "enc_outputs" in outputs: | |
| enc_outputs = outputs["enc_outputs"] | |
| bin_targets = copy.deepcopy(targets) | |
| for bt in bin_targets: | |
| bt["labels"] = torch.zeros_like(bt["labels"]) | |
| if self.assign_first_stage: | |
| indices = self.stg1_assigner(enc_outputs, bin_targets) | |
| else: | |
| indices = self.matcher(enc_outputs, bin_targets) | |
| for loss in self.losses: | |
| l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + "_enc": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| return losses | |