Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from detrex.layers import MLP, box_cxcywh_to_xyxy, box_xyxy_to_cxcywh | |
from detrex.utils import inverse_sigmoid | |
from detectron2.modeling import detector_postprocess | |
from detectron2.structures import Boxes, ImageList, Instances | |
from detectron2.layers.nms import batched_nms | |
class DeformableDETR(nn.Module): | |
"""Implements the Deformable DETR model. | |
Code is modified from the `official github repo | |
<https://github.com/fundamentalvision/Deformable-DETR>`_. | |
More details can be found in the `paper | |
<https://arxiv.org/abs/2010.04159>`_ . | |
Args: | |
backbone (nn.Module): the backbone module. | |
position_embedding (nn.Module): the position embedding module. | |
neck (nn.Module): the neck module. | |
transformer (nn.Module): the transformer module. | |
embed_dim (int): the dimension of the embedding. | |
num_classes (int): Number of total categories. | |
num_queries (int): Number of proposal dynamic anchor boxes in Transformer | |
criterion (nn.Module): Criterion for calculating the total losses. | |
pixel_mean (List[float]): Pixel mean value for image normalization. | |
Default: [123.675, 116.280, 103.530]. | |
pixel_std (List[float]): Pixel std value for image normalization. | |
Default: [58.395, 57.120, 57.375]. | |
aux_loss (bool): whether to use auxiliary loss. Default: True. | |
with_box_refine (bool): whether to use box refinement. Default: False. | |
as_two_stage (bool): whether to use two-stage. Default: False. | |
select_box_nums_for_evaluation (int): the number of topk candidates | |
slected at postprocess for evaluation. Default: 100. | |
""" | |
def __init__( | |
self, | |
backbone, | |
position_embedding, | |
neck, | |
transformer, | |
embed_dim, | |
num_classes, | |
num_queries, | |
criterion, | |
pixel_mean, | |
pixel_std, | |
aux_loss=True, | |
with_box_refine=False, | |
as_two_stage=False, | |
select_box_nums_for_evaluation=100, | |
device="cuda", | |
): | |
super().__init__() | |
# define backbone and position embedding module | |
self.backbone = backbone | |
self.position_embedding = position_embedding | |
# define neck module | |
self.neck = neck | |
# define learnable query embedding | |
self.num_queries = num_queries | |
if not as_two_stage: | |
self.query_embedding = nn.Embedding(num_queries, embed_dim * 2) | |
# define transformer module | |
self.transformer = transformer | |
# define classification head and box head | |
self.num_classes = num_classes | |
self.class_embed = nn.Linear(embed_dim, num_classes) | |
self.bbox_embed = MLP(embed_dim, embed_dim, 4, 3) | |
# where to calculate auxiliary loss in criterion | |
self.aux_loss = aux_loss | |
self.criterion = criterion | |
# define contoller for box refinement and two-stage variants | |
self.with_box_refine = with_box_refine | |
self.as_two_stage = as_two_stage | |
# init parameters for heads | |
prior_prob = 0.01 | |
bias_value = -math.log((1 - prior_prob) / prior_prob) | |
self.class_embed.bias.data = torch.ones(num_classes) * bias_value | |
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) | |
for _, neck_layer in self.neck.named_modules(): | |
if isinstance(neck_layer, nn.Conv2d): | |
nn.init.xavier_uniform_(neck_layer.weight, gain=1) | |
nn.init.constant_(neck_layer.bias, 0) | |
# If two-stage, the last class_embed and bbox_embed is for region proposal generation | |
# Decoder layers share the same heads without box refinement, while use the different | |
# heads when box refinement is used. | |
num_pred = ( | |
(transformer.decoder.num_layers + 1) if as_two_stage else transformer.decoder.num_layers | |
) | |
if with_box_refine: | |
self.class_embed = nn.ModuleList( | |
[copy.deepcopy(self.class_embed) for i in range(num_pred)] | |
) | |
self.bbox_embed = nn.ModuleList( | |
[copy.deepcopy(self.bbox_embed) for i in range(num_pred)] | |
) | |
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) | |
self.transformer.decoder.bbox_embed = self.bbox_embed | |
else: | |
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) | |
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) | |
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) | |
self.transformer.decoder.bbox_embed = None | |
# hack implementation for two-stage. The last class_embed and bbox_embed is for region proposal generation | |
if as_two_stage: | |
self.transformer.decoder.class_embed = self.class_embed | |
for box_embed in self.bbox_embed: | |
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) | |
# set topk boxes selected for inference | |
self.select_box_nums_for_evaluation = select_box_nums_for_evaluation | |
# normalizer for input raw images | |
self.device = device | |
pixel_mean = torch.Tensor(pixel_mean).to(self.device).view(3, 1, 1) | |
pixel_std = torch.Tensor(pixel_std).to(self.device).view(3, 1, 1) | |
self.normalizer = lambda x: (x - pixel_mean) / pixel_std | |
def forward(self, batched_inputs): | |
images = self.preprocess_image(batched_inputs) | |
if self.training: | |
batch_size, _, H, W = images.tensor.shape | |
img_masks = images.tensor.new_ones(batch_size, H, W) | |
for img_id in range(batch_size): | |
# mask padding regions in batched images | |
img_h, img_w = batched_inputs[img_id]["instances"].image_size | |
img_masks[img_id, :img_h, :img_w] = 0 | |
else: | |
batch_size, _, H, W = images.tensor.shape | |
img_masks = images.tensor.new_zeros(batch_size, H, W) | |
# original features | |
features = self.backbone(images.tensor) # output feature dict | |
# project backbone features to the reuired dimension of transformer | |
# we use multi-scale features in deformable DETR | |
multi_level_feats = self.neck(features) | |
multi_level_masks = [] | |
multi_level_position_embeddings = [] | |
for feat in multi_level_feats: | |
multi_level_masks.append( | |
F.interpolate(img_masks[None], size=feat.shape[-2:]).to(torch.bool).squeeze(0) | |
) | |
multi_level_position_embeddings.append(self.position_embedding(multi_level_masks[-1])) | |
# initialize object query embeddings | |
query_embeds = None | |
if not self.as_two_stage: | |
query_embeds = self.query_embedding.weight | |
( | |
inter_states, | |
init_reference, | |
inter_references, | |
enc_outputs_class, | |
enc_outputs_coord_unact, | |
anchors, | |
) = self.transformer( | |
multi_level_feats, multi_level_masks, multi_level_position_embeddings, query_embeds | |
) | |
# Calculate output coordinates and classes. | |
outputs_classes = [] | |
outputs_coords = [] | |
for lvl in range(inter_states.shape[0]): | |
if lvl == 0: | |
reference = init_reference | |
else: | |
reference = inter_references[lvl - 1] | |
reference = inverse_sigmoid(reference) | |
outputs_class = self.class_embed[lvl](inter_states[lvl]) | |
tmp = self.bbox_embed[lvl](inter_states[lvl]) | |
if reference.shape[-1] == 4: | |
tmp += reference | |
else: | |
assert reference.shape[-1] == 2 | |
tmp[..., :2] += reference | |
outputs_coord = tmp.sigmoid() | |
outputs_classes.append(outputs_class) | |
outputs_coords.append(outputs_coord) | |
outputs_class = torch.stack(outputs_classes) | |
# tensor shape: [num_decoder_layers, bs, num_query, num_classes] | |
outputs_coord = torch.stack(outputs_coords) | |
# tensor shape: [num_decoder_layers, bs, num_query, 4] | |
# prepare for loss computation | |
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | |
'init_reference': init_reference} | |
if self.aux_loss: | |
output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) | |
if self.as_two_stage: | |
enc_outputs_coord = enc_outputs_coord_unact.sigmoid() | |
output["enc_outputs"] = { | |
"pred_logits": enc_outputs_class, | |
"pred_boxes": enc_outputs_coord, | |
"anchors": anchors, | |
} | |
if self.training: | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
targets = self.prepare_targets(gt_instances) | |
loss_dict = self.criterion(output, targets) | |
weight_dict = self.criterion.weight_dict | |
for k in loss_dict.keys(): | |
if k in weight_dict: | |
loss_dict[k] *= weight_dict[k] | |
return loss_dict | |
else: | |
box_cls = output["pred_logits"] | |
box_pred = output["pred_boxes"] | |
if self.criterion.assign_second_stage: | |
results = self.nms_inference(box_cls, box_pred, images.image_sizes) | |
else: | |
results = self.inference(box_cls, box_pred, images.image_sizes) | |
processed_results = [] | |
for results_per_image, input_per_image, image_size in zip( | |
results, batched_inputs, images.image_sizes | |
): | |
height = input_per_image.get("height", image_size[0]) | |
width = input_per_image.get("width", image_size[1]) | |
r = detector_postprocess(results_per_image, height, width) | |
processed_results.append({"instances": r}) | |
return processed_results | |
def _set_aux_loss(self, outputs_class, outputs_coord): | |
# this is a workaround to make torchscript happy, as torchscript | |
# doesn't support dictionary with non-homogeneous values, such | |
# as a dict having both a Tensor and a list. | |
return [ | |
{"pred_logits": a, "pred_boxes": b} | |
for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) | |
] | |
def inference(self, box_cls, box_pred, image_sizes): | |
""" | |
Arguments: | |
box_cls (Tensor): tensor of shape (batch_size, num_queries, K). | |
The tensor predicts the classification probability for each query. | |
box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). | |
The tensor predicts 4-vector (x,y,w,h) box | |
regression values for every queryx | |
image_sizes (List[torch.Size]): the input image sizes | |
Returns: | |
results (List[Instances]): a list of #images elements. | |
""" | |
assert len(box_cls) == len(image_sizes) | |
results = [] | |
# Select top-k confidence boxes for inference | |
prob = box_cls.sigmoid() | |
topk_values, topk_indexes = torch.topk( | |
prob.view(box_cls.shape[0], -1), self.select_box_nums_for_evaluation, dim=1 | |
) | |
scores = topk_values | |
topk_boxes = torch.div(topk_indexes, box_cls.shape[2], rounding_mode="floor") | |
labels = topk_indexes % box_cls.shape[2] | |
boxes = torch.gather(box_pred, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate( | |
zip(scores, labels, boxes, image_sizes) | |
): | |
result = Instances(image_size) | |
result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image)) | |
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) | |
result.scores = scores_per_image | |
result.pred_classes = labels_per_image | |
results.append(result) | |
return results | |
# DETA using nms for post-process | |
def nms_inference(self, box_cls, box_pred, image_sizes): | |
""" | |
Arguments: | |
box_cls (Tensor): tensor of shape (batch_size, num_queries, K). | |
The tensor predicts the classification probability for each query. | |
box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). | |
The tensor predicts 4-vector (x,y,w,h) box | |
regression values for every queryx | |
image_sizes (List[torch.Size]): the input image sizes | |
Returns: | |
results (List[Instances]): a list of #images elements. | |
""" | |
assert len(box_cls) == len(image_sizes) | |
results = [] | |
bs, n_queries, n_cls = box_cls.shape | |
# Select top-k confidence boxes for inference | |
prob = box_cls.sigmoid() | |
all_scores = prob.view(bs, n_queries * n_cls).to(box_cls.device) | |
all_indexes = torch.arange(n_queries * n_cls)[None].repeat(bs, 1).to(box_cls.device) | |
all_boxes = torch.div(all_indexes, box_cls.shape[2], rounding_mode="floor") | |
all_labels = all_indexes % box_cls.shape[2] | |
# convert to xyxy for nms post-process | |
boxes = box_cxcywh_to_xyxy(box_pred) | |
boxes = torch.gather(boxes, 1, all_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate( | |
zip(all_scores, all_labels, boxes, image_sizes) | |
): | |
pre_topk = scores_per_image.topk(10000).indices | |
box = box_pred_per_image[pre_topk] | |
score = scores_per_image[pre_topk] | |
label = labels_per_image[pre_topk] | |
# nms post-process | |
keep_index = batched_nms(box, score, label, 0.7)[:100] | |
result = Instances(image_size) | |
result.pred_boxes = Boxes(box[keep_index]) | |
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) | |
result.scores = score[keep_index] | |
result.pred_classes = label[keep_index] | |
results.append(result) | |
return results | |
def prepare_targets(self, targets): | |
new_targets = [] | |
for targets_per_image in targets: | |
h, w = targets_per_image.image_size | |
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) | |
gt_classes = targets_per_image.gt_classes | |
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy | |
gt_boxes = box_xyxy_to_cxcywh(gt_boxes) | |
new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) | |
return new_targets | |
def preprocess_image(self, batched_inputs): | |
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] | |
images = ImageList.from_tensors(images) | |
return images | |