rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
11.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import MultiScaleDeformableAttention, batched_nms
from torch import Tensor, nn
from torch.nn.init import normal_
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmdet.utils import OptConfigType
from ..layers import DDQTransformerDecoder
from ..utils import align_tensor
from .deformable_detr import DeformableDETR
from .dino import DINO
@MODELS.register_module()
class DDQDETR(DINO):
r"""Implementation of `Dense Distinct Query for
End-to-End Object Detection <https://arxiv.org/abs/2303.12776>`_
Code is modified from the `official github repo
<https://github.com/jshilong/DDQ>`_.
Args:
dense_topk_ratio (float): Ratio of num_dense queries to num_queries.
Defaults to 1.5.
dqs_cfg (:obj:`ConfigDict` or dict, optional): Config of
Distinct Queries Selection. Defaults to nms with
`iou_threshold` = 0.8.
"""
def __init__(self,
*args,
dense_topk_ratio: float = 1.5,
dqs_cfg: OptConfigType = dict(type='nms', iou_threshold=0.8),
**kwargs):
self.dense_topk_ratio = dense_topk_ratio
self.decoder_cfg = kwargs['decoder']
self.dqs_cfg = dqs_cfg
super().__init__(*args, **kwargs)
# a share dict in all moduls
# pass some intermediate results and config parameters
cache_dict = dict()
for m in self.modules():
m.cache_dict = cache_dict
# first element is the start index of matching queries
# second element is the number of matching queries
self.cache_dict['dis_query_info'] = [0, 0]
# mask for distinct queries in each decoder layer
self.cache_dict['distinct_query_mask'] = []
# pass to decoder do the dqs
self.cache_dict['cls_branches'] = self.bbox_head.cls_branches
# Used to construct the attention mask after dqs
self.cache_dict['num_heads'] = self.encoder.layers[
0].self_attn.num_heads
# pass to decoder to do the dqs
self.cache_dict['dqs_cfg'] = self.dqs_cfg
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
super(DDQDETR, self)._init_layers()
self.decoder = DDQTransformerDecoder(**self.decoder_cfg)
self.query_embedding = None
self.query_map = nn.Linear(self.embed_dims, self.embed_dims)
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DeformableDETR, self).init_weights()
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
normal_(self.level_embed)
def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `memory`, and `reference_points`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points). Will only be used when
`as_two_stage` is `True`.
spatial_shapes (Tensor): Spatial shapes of features in all levels.
With shape (num_levels, 2), last dimension represents (h, w).
Will only be used when `as_two_stage` is `True`.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- decoder_inputs_dict (dict): The keyword dictionary args of
`self.forward_decoder()`, which includes 'query', 'memory',
`reference_points`, and `dn_mask`. The reference points of
decoder input here are 4D boxes, although it has `points`
in its name.
- head_inputs_dict (dict): The keyword dictionary args of the
bbox_head functions, which includes `topk_score`, `topk_coords`,
`dense_topk_score`, `dense_topk_coords`,
and `dn_meta`, when `self.training` is `True`, else is empty.
"""
bs, _, c = memory.shape
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't exist
# in DINO.
# -1 is the aux head for the encoder
dense_enc_outputs_class = self.bbox_head.cls_branches[-1](
output_memory)
dense_enc_outputs_coord_unact = self.bbox_head.reg_branches[-1](
output_memory) + output_proposals
topk = self.num_queries
dense_topk = int(topk * self.dense_topk_ratio)
proposals = enc_outputs_coord_unact.sigmoid()
proposals = bbox_cxcywh_to_xyxy(proposals)
scores = enc_outputs_class.max(-1)[0].sigmoid()
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't exist
# in DINO.
dense_proposals = dense_enc_outputs_coord_unact.sigmoid()
dense_proposals = bbox_cxcywh_to_xyxy(dense_proposals)
dense_scores = dense_enc_outputs_class.max(-1)[0].sigmoid()
num_imgs = len(scores)
topk_score = []
topk_coords_unact = []
# Distinct query.
query = []
dense_topk_score = []
dense_topk_coords_unact = []
dense_query = []
for img_id in range(num_imgs):
single_proposals = proposals[img_id]
single_scores = scores[img_id]
# `batched_nms` of class scores and bbox coordinations is used
# particularly by DDQ DETR for region proposal generation,
# instead of `topk` of class scores by DINO.
_, keep_idxs = batched_nms(
single_proposals, single_scores,
torch.ones(len(single_scores), device=single_scores.device),
self.cache_dict['dqs_cfg'])
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't
# exist in DINO.
dense_single_proposals = dense_proposals[img_id]
dense_single_scores = dense_scores[img_id]
# sort according the score
# Only sort by classification score, neither nms nor topk is
# required. So input parameter `nms_cfg` = None.
_, dense_keep_idxs = batched_nms(
dense_single_proposals, dense_single_scores,
torch.ones(
len(dense_single_scores),
device=dense_single_scores.device), None)
dense_topk_score.append(dense_enc_outputs_class[img_id]
[dense_keep_idxs][:dense_topk])
dense_topk_coords_unact.append(
dense_enc_outputs_coord_unact[img_id][dense_keep_idxs]
[:dense_topk])
topk_score.append(enc_outputs_class[img_id][keep_idxs][:topk])
# Instead of initializing the content part with transformed
# coordinates in Deformable DETR, we fuse the feature map
# embedding of distinct positions as the content part, which
# makes the initial queries more distinct.
topk_coords_unact.append(
enc_outputs_coord_unact[img_id][keep_idxs][:topk])
map_memory = self.query_map(memory[img_id].detach())
query.append(map_memory[keep_idxs][:topk])
if self.training:
# aux dense branch particularly in DDQ DETR, which doesn't
# exist in DINO.
dense_query.append(map_memory[dense_keep_idxs][:dense_topk])
topk_score = align_tensor(topk_score, topk)
topk_coords_unact = align_tensor(topk_coords_unact, topk)
query = align_tensor(query, topk)
if self.training:
dense_topk_score = align_tensor(dense_topk_score)
dense_topk_coords_unact = align_tensor(dense_topk_coords_unact)
dense_query = align_tensor(dense_query)
num_dense_queries = dense_query.size(1)
if self.training:
query = torch.cat([query, dense_query], dim=1)
topk_coords_unact = torch.cat(
[topk_coords_unact, dense_topk_coords_unact], dim=1)
topk_coords = topk_coords_unact.sigmoid()
if self.training:
dense_topk_coords = topk_coords[:, -num_dense_queries:]
topk_coords = topk_coords[:, :-num_dense_queries]
topk_coords_unact = topk_coords_unact.detach()
if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
# Update `dn_mask` to add mask for dense queries.
ori_size = dn_mask.size(-1)
new_size = dn_mask.size(-1) + num_dense_queries
new_dn_mask = dn_mask.new_ones((new_size, new_size)).bool()
dense_mask = torch.zeros(num_dense_queries,
num_dense_queries).bool()
self.cache_dict['dis_query_info'] = [dn_label_query.size(1), topk]
new_dn_mask[ori_size:, ori_size:] = dense_mask
new_dn_mask[:ori_size, :ori_size] = dn_mask
dn_meta['num_dense_queries'] = num_dense_queries
dn_mask = new_dn_mask
self.cache_dict['num_dense_queries'] = num_dense_queries
self.decoder.aux_reg_branches = self.bbox_head.aux_reg_branches
else:
self.cache_dict['dis_query_info'] = [0, topk]
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()
decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
aux_enc_outputs_class=dense_topk_score,
aux_enc_outputs_coord=dense_topk_coords,
dn_meta=dn_meta) if self.training else dict()
return decoder_inputs_dict, head_inputs_dict