# coding=utf-8 # Copyright 2022 The IDEA Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn from detectron2.layers.nms import batched_nms from detrex.layers import ( FFN, BaseTransformerLayer, MultiheadAttention, MultiScaleDeformableAttention, TransformerLayerSequence, box_cxcywh_to_xyxy ) from detrex.utils import inverse_sigmoid class DeformableDetrTransformerEncoder(TransformerLayerSequence): def __init__( self, embed_dim: int = 256, num_heads: int = 8, feedforward_dim: int = 1024, attn_dropout: float = 0.1, ffn_dropout: float = 0.1, num_layers: int = 6, post_norm: bool = False, num_feature_levels: int = 4, ): super(DeformableDetrTransformerEncoder, self).__init__( transformer_layers=BaseTransformerLayer( attn=MultiScaleDeformableAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, ), ffn=FFN( embed_dim=embed_dim, feedforward_dim=feedforward_dim, output_dim=embed_dim, num_fcs=2, ffn_drop=ffn_dropout, ), norm=nn.LayerNorm(embed_dim), operation_order=("self_attn", "norm", "ffn", "norm"), ), num_layers=num_layers, ) self.embed_dim = self.layers[0].embed_dim self.pre_norm = self.layers[0].pre_norm if post_norm: self.post_norm_layer = nn.LayerNorm(self.embed_dim) else: self.post_norm_layer = None def forward( self, query, key, value, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs, ): for layer in self.layers: query = layer( query, key, value, query_pos=query_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, **kwargs, ) if self.post_norm_layer is not None: query = self.post_norm_layer(query) return query class DeformableDetrTransformerDecoder(TransformerLayerSequence): def __init__( self, embed_dim: int = 256, num_heads: int = 8, feedforward_dim: int = 1024, attn_dropout: float = 0.1, ffn_dropout: float = 0.1, num_layers: int = 6, return_intermediate: bool = True, num_feature_levels: int = 4, ): super(DeformableDetrTransformerDecoder, self).__init__( transformer_layers=BaseTransformerLayer( attn=[ MultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_dropout, batch_first=True, ), MultiScaleDeformableAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True, num_levels=num_feature_levels, ), ], ffn=FFN( embed_dim=embed_dim, feedforward_dim=feedforward_dim, output_dim=embed_dim, ffn_drop=ffn_dropout, ), norm=nn.LayerNorm(embed_dim), operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), ), num_layers=num_layers, ) self.return_intermediate = return_intermediate self.bbox_embed = None self.class_embed = None def forward( self, query, key, value, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, reference_points=None, valid_ratios=None, **kwargs, ): output = query intermediate = [] intermediate_reference_points = [] for layer_idx, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = ( reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] ) else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] output = layer( output, key, value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, reference_points=reference_points_input, **kwargs, ) if self.bbox_embed is not None: tmp = self.bbox_embed[layer_idx](output) if reference_points.shape[-1] == 4: new_reference_points = tmp + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack(intermediate_reference_points) return output, reference_points class DeformableDetrTransformer(nn.Module): """Transformer module for Deformable DETR Args: encoder (nn.Module): encoder module. decoder (nn.Module): decoder module. as_two_stage (bool): whether to use two-stage transformer. Default False. num_feature_levels (int): number of feature levels. Default 4. two_stage_num_proposals (int): number of proposals in two-stage transformer. Default 300. Only used when as_two_stage is True. """ def __init__( self, encoder=None, decoder=None, num_feature_levels=4, as_two_stage=False, two_stage_num_proposals=300, assign_first_stage=True, ): super(DeformableDetrTransformer, self).__init__() self.encoder = encoder self.decoder = decoder self.num_feature_levels = num_feature_levels self.as_two_stage = as_two_stage self.two_stage_num_proposals = two_stage_num_proposals # DETA implementation self.assign_first_stage = assign_first_stage self.embed_dim = self.encoder.embed_dim self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dim)) if self.as_two_stage: self.enc_output = nn.Linear(self.embed_dim, self.embed_dim) self.enc_output_norm = nn.LayerNorm(self.embed_dim) self.pos_trans = nn.Linear(self.embed_dim * 2, self.embed_dim * 2) self.pos_trans_norm = nn.LayerNorm(self.embed_dim * 2) # DETA implementation self.pix_trans = nn.Linear(self.embed_dim, self.embed_dim) self.pix_trans_norm = nn.LayerNorm(self.embed_dim) else: self.reference_points = nn.Linear(self.embed_dim, 2) self.init_weights() def init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() if not self.as_two_stage: nn.init.xavier_normal_(self.reference_points.weight.data, gain=1.0) nn.init.constant_(self.reference_points.bias.data, 0.0) nn.init.normal_(self.level_embeds) def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): N, S, C = memory.shape proposals = [] _cur = 0 level_ids = [] for lvl, (H, W) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid( torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), ) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) proposal = torch.cat((grid, wh), -1).view(N, -1, 4) proposals.append(proposal) _cur += H * W level_ids.append(grid.new_ones(H * W, dtype=torch.long) * lvl) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( -1, keepdim=True ) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill( memory_padding_mask.unsqueeze(-1), float("inf") ) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.enc_output_norm(self.enc_output(output_memory)) level_ids = torch.cat(level_ids) return output_memory, output_proposals, level_ids @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): """Get the reference points used in decoder. Args: spatial_shapes (Tensor): The shape of all feature maps, has shape (num_level, 2). valid_ratios (Tensor): The ratios of valid points on the feature map, has shape (bs, num_levels, 2) device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): # TODO check this 0.5 ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def get_valid_ratio(self, mask): """Get the valid ratios of feature maps of all levels.""" _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): """Get the position embedding of proposal.""" scale = 2 * math.pi dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # N, L, 4 proposals = proposals.sigmoid() * scale # N, L, 4, 128 pos = proposals[:, :, :, None] / dim_t # N, L, 4, 64, 2 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) return pos def forward( self, multi_level_feats, multi_level_masks, multi_level_pos_embeds, query_embed, **kwargs, ): assert self.as_two_stage or query_embed is not None feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(multi_level_feats, multi_level_masks, multi_level_pos_embeds) ): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device ) level_start_index = torch.cat( (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) ) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in multi_level_masks], 1) reference_points = self.get_reference_points( spatial_shapes, valid_ratios, device=feat.device ) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs, ) bs, _, c = memory.shape if self.as_two_stage: output_memory, output_proposals, level_ids = self.gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes ) enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) enc_outputs_coord_unact = ( self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals ) topk = self.two_stage_num_proposals proposal_logit = enc_outputs_class[..., 0] if self.assign_first_stage: proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1) topk_proposals = [] for b in range(bs): prop_boxes_b = proposal_boxes[b] prop_logits_b = proposal_logit[b] # pre-nms per-level topk pre_nms_topk = 1000 pre_nms_inds = [] for lvl in range(len(spatial_shapes)): lvl_mask = level_ids == lvl pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1]) pre_nms_inds = torch.cat(pre_nms_inds) # nms on topk indices post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], prop_logits_b[pre_nms_inds], level_ids[pre_nms_inds], 0.9) keep_inds = pre_nms_inds[post_nms_inds] if len(keep_inds) < self.two_stage_num_proposals: print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}, running naive topk') keep_inds = torch.topk(proposal_logit[b], topk)[1] # keep top Q/L indices for L levels q_per_l = topk // len(spatial_shapes) is_level_ordered = level_ids[keep_inds][None] == torch.arange(len(spatial_shapes), device=level_ids.device)[:,None] # LS keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS keep_inds_mask = keep_inds_mask.any(0) # S # pad to Q indices (might let ones filtered from pre-nms sneak by... unlikely because we pick high conf anyways) if keep_inds_mask.sum() < topk: num_to_add = topk - keep_inds_mask.sum() pad_inds = (~keep_inds_mask).nonzero()[:num_to_add] keep_inds_mask[pad_inds] = True # index keep_inds_topk = keep_inds[keep_inds_mask] topk_proposals.append(keep_inds_topk) topk_proposals = torch.stack(topk_proposals) else: topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1] topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm( self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)) ) query_pos, query = torch.split(pos_trans_out, c, dim=2) topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach() query = query + self.pix_trans_norm(self.pix_trans(topk_feats)) else: query_pos, query = torch.split(query_embed, c, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos).sigmoid() init_reference_out = reference_points # decoder inter_states, inter_references = self.decoder( query=query, # bs, num_queries, embed_dims key=None, # bs, num_tokens, embed_dims value=memory, # bs, num_tokens, embed_dims query_pos=query_pos, key_padding_mask=mask_flatten, # bs, num_tokens reference_points=reference_points, # num_queries, 4 spatial_shapes=spatial_shapes, # nlvl, 2 level_start_index=level_start_index, # nlvl valid_ratios=valid_ratios, # bs, nlvl, 2 **kwargs, ) inter_references_out = inter_references if self.as_two_stage: return ( inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, output_proposals.sigmoid(), ) return inter_states, init_reference_out, inter_references_out, None, None, None