Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, Tuple | |
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmcv.ops import MultiScaleDeformableAttention, batched_nms | |
from torch import Tensor, nn | |
from torch.nn.init import normal_ | |
from mmdet.registry import MODELS | |
from mmdet.structures import OptSampleList | |
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy | |
from mmdet.utils import OptConfigType | |
from ..layers import DDQTransformerDecoder | |
from ..utils import align_tensor | |
from .deformable_detr import DeformableDETR | |
from .dino import DINO | |
class DDQDETR(DINO): | |
r"""Implementation of `Dense Distinct Query for | |
End-to-End Object Detection <https://arxiv.org/abs/2303.12776>`_ | |
Code is modified from the `official github repo | |
<https://github.com/jshilong/DDQ>`_. | |
Args: | |
dense_topk_ratio (float): Ratio of num_dense queries to num_queries. | |
Defaults to 1.5. | |
dqs_cfg (:obj:`ConfigDict` or dict, optional): Config of | |
Distinct Queries Selection. Defaults to nms with | |
`iou_threshold` = 0.8. | |
""" | |
def __init__(self, | |
*args, | |
dense_topk_ratio: float = 1.5, | |
dqs_cfg: OptConfigType = dict(type='nms', iou_threshold=0.8), | |
**kwargs): | |
self.dense_topk_ratio = dense_topk_ratio | |
self.decoder_cfg = kwargs['decoder'] | |
self.dqs_cfg = dqs_cfg | |
super().__init__(*args, **kwargs) | |
# a share dict in all moduls | |
# pass some intermediate results and config parameters | |
cache_dict = dict() | |
for m in self.modules(): | |
m.cache_dict = cache_dict | |
# first element is the start index of matching queries | |
# second element is the number of matching queries | |
self.cache_dict['dis_query_info'] = [0, 0] | |
# mask for distinct queries in each decoder layer | |
self.cache_dict['distinct_query_mask'] = [] | |
# pass to decoder do the dqs | |
self.cache_dict['cls_branches'] = self.bbox_head.cls_branches | |
# Used to construct the attention mask after dqs | |
self.cache_dict['num_heads'] = self.encoder.layers[ | |
0].self_attn.num_heads | |
# pass to decoder to do the dqs | |
self.cache_dict['dqs_cfg'] = self.dqs_cfg | |
def _init_layers(self) -> None: | |
"""Initialize layers except for backbone, neck and bbox_head.""" | |
super(DDQDETR, self)._init_layers() | |
self.decoder = DDQTransformerDecoder(**self.decoder_cfg) | |
self.query_embedding = None | |
self.query_map = nn.Linear(self.embed_dims, self.embed_dims) | |
def init_weights(self) -> None: | |
"""Initialize weights for Transformer and other components.""" | |
super(DeformableDETR, self).init_weights() | |
for coder in self.encoder, self.decoder: | |
for p in coder.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
for m in self.modules(): | |
if isinstance(m, MultiScaleDeformableAttention): | |
m.init_weights() | |
nn.init.xavier_uniform_(self.memory_trans_fc.weight) | |
normal_(self.level_embed) | |
def pre_decoder( | |
self, | |
memory: Tensor, | |
memory_mask: Tensor, | |
spatial_shapes: Tensor, | |
batch_data_samples: OptSampleList = None, | |
) -> Tuple[Dict]: | |
"""Prepare intermediate variables before entering Transformer decoder, | |
such as `query`, `memory`, and `reference_points`. | |
Args: | |
memory (Tensor): The output embeddings of the Transformer encoder, | |
has shape (bs, num_feat_points, dim). | |
memory_mask (Tensor): ByteTensor, the padding mask of the memory, | |
has shape (bs, num_feat_points). Will only be used when | |
`as_two_stage` is `True`. | |
spatial_shapes (Tensor): Spatial shapes of features in all levels. | |
With shape (num_levels, 2), last dimension represents (h, w). | |
Will only be used when `as_two_stage` is `True`. | |
batch_data_samples (list[:obj:`DetDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
Defaults to None. | |
Returns: | |
tuple[dict]: The decoder_inputs_dict and head_inputs_dict. | |
- decoder_inputs_dict (dict): The keyword dictionary args of | |
`self.forward_decoder()`, which includes 'query', 'memory', | |
`reference_points`, and `dn_mask`. The reference points of | |
decoder input here are 4D boxes, although it has `points` | |
in its name. | |
- head_inputs_dict (dict): The keyword dictionary args of the | |
bbox_head functions, which includes `topk_score`, `topk_coords`, | |
`dense_topk_score`, `dense_topk_coords`, | |
and `dn_meta`, when `self.training` is `True`, else is empty. | |
""" | |
bs, _, c = memory.shape | |
output_memory, output_proposals = self.gen_encoder_output_proposals( | |
memory, memory_mask, spatial_shapes) | |
enc_outputs_class = self.bbox_head.cls_branches[ | |
self.decoder.num_layers]( | |
output_memory) | |
enc_outputs_coord_unact = self.bbox_head.reg_branches[ | |
self.decoder.num_layers](output_memory) + output_proposals | |
if self.training: | |
# aux dense branch particularly in DDQ DETR, which doesn't exist | |
# in DINO. | |
# -1 is the aux head for the encoder | |
dense_enc_outputs_class = self.bbox_head.cls_branches[-1]( | |
output_memory) | |
dense_enc_outputs_coord_unact = self.bbox_head.reg_branches[-1]( | |
output_memory) + output_proposals | |
topk = self.num_queries | |
dense_topk = int(topk * self.dense_topk_ratio) | |
proposals = enc_outputs_coord_unact.sigmoid() | |
proposals = bbox_cxcywh_to_xyxy(proposals) | |
scores = enc_outputs_class.max(-1)[0].sigmoid() | |
if self.training: | |
# aux dense branch particularly in DDQ DETR, which doesn't exist | |
# in DINO. | |
dense_proposals = dense_enc_outputs_coord_unact.sigmoid() | |
dense_proposals = bbox_cxcywh_to_xyxy(dense_proposals) | |
dense_scores = dense_enc_outputs_class.max(-1)[0].sigmoid() | |
num_imgs = len(scores) | |
topk_score = [] | |
topk_coords_unact = [] | |
# Distinct query. | |
query = [] | |
dense_topk_score = [] | |
dense_topk_coords_unact = [] | |
dense_query = [] | |
for img_id in range(num_imgs): | |
single_proposals = proposals[img_id] | |
single_scores = scores[img_id] | |
# `batched_nms` of class scores and bbox coordinations is used | |
# particularly by DDQ DETR for region proposal generation, | |
# instead of `topk` of class scores by DINO. | |
_, keep_idxs = batched_nms( | |
single_proposals, single_scores, | |
torch.ones(len(single_scores), device=single_scores.device), | |
self.cache_dict['dqs_cfg']) | |
if self.training: | |
# aux dense branch particularly in DDQ DETR, which doesn't | |
# exist in DINO. | |
dense_single_proposals = dense_proposals[img_id] | |
dense_single_scores = dense_scores[img_id] | |
# sort according the score | |
# Only sort by classification score, neither nms nor topk is | |
# required. So input parameter `nms_cfg` = None. | |
_, dense_keep_idxs = batched_nms( | |
dense_single_proposals, dense_single_scores, | |
torch.ones( | |
len(dense_single_scores), | |
device=dense_single_scores.device), None) | |
dense_topk_score.append(dense_enc_outputs_class[img_id] | |
[dense_keep_idxs][:dense_topk]) | |
dense_topk_coords_unact.append( | |
dense_enc_outputs_coord_unact[img_id][dense_keep_idxs] | |
[:dense_topk]) | |
topk_score.append(enc_outputs_class[img_id][keep_idxs][:topk]) | |
# Instead of initializing the content part with transformed | |
# coordinates in Deformable DETR, we fuse the feature map | |
# embedding of distinct positions as the content part, which | |
# makes the initial queries more distinct. | |
topk_coords_unact.append( | |
enc_outputs_coord_unact[img_id][keep_idxs][:topk]) | |
map_memory = self.query_map(memory[img_id].detach()) | |
query.append(map_memory[keep_idxs][:topk]) | |
if self.training: | |
# aux dense branch particularly in DDQ DETR, which doesn't | |
# exist in DINO. | |
dense_query.append(map_memory[dense_keep_idxs][:dense_topk]) | |
topk_score = align_tensor(topk_score, topk) | |
topk_coords_unact = align_tensor(topk_coords_unact, topk) | |
query = align_tensor(query, topk) | |
if self.training: | |
dense_topk_score = align_tensor(dense_topk_score) | |
dense_topk_coords_unact = align_tensor(dense_topk_coords_unact) | |
dense_query = align_tensor(dense_query) | |
num_dense_queries = dense_query.size(1) | |
if self.training: | |
query = torch.cat([query, dense_query], dim=1) | |
topk_coords_unact = torch.cat( | |
[topk_coords_unact, dense_topk_coords_unact], dim=1) | |
topk_coords = topk_coords_unact.sigmoid() | |
if self.training: | |
dense_topk_coords = topk_coords[:, -num_dense_queries:] | |
topk_coords = topk_coords[:, :-num_dense_queries] | |
topk_coords_unact = topk_coords_unact.detach() | |
if self.training: | |
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ | |
self.dn_query_generator(batch_data_samples) | |
query = torch.cat([dn_label_query, query], dim=1) | |
reference_points = torch.cat([dn_bbox_query, topk_coords_unact], | |
dim=1) | |
# Update `dn_mask` to add mask for dense queries. | |
ori_size = dn_mask.size(-1) | |
new_size = dn_mask.size(-1) + num_dense_queries | |
new_dn_mask = dn_mask.new_ones((new_size, new_size)).bool() | |
dense_mask = torch.zeros(num_dense_queries, | |
num_dense_queries).bool() | |
self.cache_dict['dis_query_info'] = [dn_label_query.size(1), topk] | |
new_dn_mask[ori_size:, ori_size:] = dense_mask | |
new_dn_mask[:ori_size, :ori_size] = dn_mask | |
dn_meta['num_dense_queries'] = num_dense_queries | |
dn_mask = new_dn_mask | |
self.cache_dict['num_dense_queries'] = num_dense_queries | |
self.decoder.aux_reg_branches = self.bbox_head.aux_reg_branches | |
else: | |
self.cache_dict['dis_query_info'] = [0, topk] | |
reference_points = topk_coords_unact | |
dn_mask, dn_meta = None, None | |
reference_points = reference_points.sigmoid() | |
decoder_inputs_dict = dict( | |
query=query, | |
memory=memory, | |
reference_points=reference_points, | |
dn_mask=dn_mask) | |
head_inputs_dict = dict( | |
enc_outputs_class=topk_score, | |
enc_outputs_coord=topk_coords, | |
aux_enc_outputs_class=dense_topk_score, | |
aux_enc_outputs_coord=dense_topk_coords, | |
dn_meta=dn_meta) if self.training else dict() | |
return decoder_inputs_dict, head_inputs_dict | |