# ------------------------------------------------------------------------ # Grounding DINO # url: https://github.com/IDEA-Research/GroundingDINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR Transformer class. # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ from typing import Optional import torch import torch.utils.checkpoint as checkpoint from torch import Tensor, nn from groundingdino.util.misc import inverse_sigmoid from .fuse_modules import BiAttentionBlock from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn from .transformer_vanilla import TransformerEncoderLayer from .utils import ( MLP, _get_activation_fn, _get_clones, gen_encoder_output_proposals, gen_sineembed_for_position, get_sine_pos_embed, ) class Transformer(nn.Module): def __init__( self, d_model=256, nhead=8, num_queries=300, num_encoder_layers=6, num_unicoder_layers=0, num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False, return_intermediate_dec=False, query_dim=4, num_patterns=0, # for deformable encoder num_feature_levels=1, enc_n_points=4, dec_n_points=4, # init query learnable_tgt_init=False, # two stage two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1'] embed_init_tgt=False, # for text use_text_enhancer=False, use_fusion_layer=False, use_checkpoint=False, use_transformer_ckpt=False, use_text_cross_attention=False, text_dropout=0.1, fusion_dropout=0.1, fusion_droppath=0.0, ): super().__init__() self.num_feature_levels = num_feature_levels self.num_encoder_layers = num_encoder_layers self.num_unicoder_layers = num_unicoder_layers self.num_decoder_layers = num_decoder_layers self.num_queries = num_queries assert query_dim == 4 # choose encoder layer type encoder_layer = DeformableTransformerEncoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points ) if use_text_enhancer: text_enhance_layer = TransformerEncoderLayer( d_model=d_model, nhead=nhead // 2, dim_feedforward=dim_feedforward // 2, dropout=text_dropout, ) else: text_enhance_layer = None if use_fusion_layer: feature_fusion_layer = BiAttentionBlock( v_dim=d_model, l_dim=d_model, embed_dim=dim_feedforward // 2, num_heads=nhead // 2, dropout=fusion_dropout, drop_path=fusion_droppath, ) else: feature_fusion_layer = None encoder_norm = nn.LayerNorm(d_model) if normalize_before else None assert encoder_norm is None self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, d_model=d_model, num_queries=num_queries, text_enhance_layer=text_enhance_layer, feature_fusion_layer=feature_fusion_layer, use_checkpoint=use_checkpoint, use_transformer_ckpt=use_transformer_ckpt, ) # choose decoder layer type decoder_layer = DeformableTransformerDecoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points, use_text_cross_attention=use_text_cross_attention, ) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=d_model, query_dim=query_dim, num_feature_levels=num_feature_levels, ) self.d_model = d_model self.nhead = nhead self.dec_layers = num_decoder_layers self.num_queries = num_queries # useful for single stage model only self.num_patterns = num_patterns if not isinstance(num_patterns, int): Warning("num_patterns should be int but {}".format(type(num_patterns))) self.num_patterns = 0 if num_feature_levels > 1: if self.num_encoder_layers > 0: self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) else: self.level_embed = None self.learnable_tgt_init = learnable_tgt_init assert learnable_tgt_init, "why not learnable_tgt_init" self.embed_init_tgt = embed_init_tgt if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"): self.tgt_embed = nn.Embedding(self.num_queries, d_model) nn.init.normal_(self.tgt_embed.weight.data) else: self.tgt_embed = None # for two stage self.two_stage_type = two_stage_type assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( two_stage_type ) if two_stage_type == "standard": # anchor selection at the output of encoder self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) self.two_stage_wh_embedding = None if two_stage_type == "no": self.init_ref_points(num_queries) # init self.refpoint_embed self.enc_out_class_embed = None self.enc_out_bbox_embed = None self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if self.num_feature_levels > 1 and self.level_embed is not None: nn.init.normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, 4) def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None): """ Input: - srcs: List of multi features [bs, ci, hi, wi] - masks: List of multi masks [bs, hi, wi] - refpoint_embed: [bs, num_dn, 4]. None in infer - pos_embeds: List of multi pos embeds [bs, ci, hi, wi] - tgt: [bs, num_dn, d_model]. None in infer """ # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c if self.num_feature_levels > 1 and self.level_embed is not None: lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=src_flatten.device ) level_start_index = torch.cat( (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) ) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # two stage enc_topk_proposals = enc_refpoint_embed = None ######################################################### # Begin Encoder ######################################################### memory, memory_text = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, memory_text=text_dict["encoded_text"], text_attention_mask=~text_dict["text_token_mask"], # we ~ the mask . False means use the token; True means pad the token position_ids=text_dict["position_ids"], text_self_attention_masks=text_dict["text_self_attention_masks"], ) ######################################################### # End Encoder # - memory: bs, \sum{hw}, c # - mask_flatten: bs, \sum{hw} # - lvl_pos_embed_flatten: bs, \sum{hw}, c # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) ######################################################### text_dict["encoded_text"] = memory_text # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if memory.isnan().any() | memory.isinf().any(): # import ipdb; ipdb.set_trace() if self.two_stage_type == "standard": output_memory, output_proposals = gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes ) output_memory = self.enc_output_norm(self.enc_output(output_memory)) if text_dict is not None: enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict) else: enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) topk_logits = enc_outputs_class_unselected.max(-1)[0] enc_outputs_coord_unselected = ( self.enc_out_bbox_embed(output_memory) + output_proposals ) # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq # gather boxes refpoint_embed_undetach = torch.gather( enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ) # unsigmoid refpoint_embed_ = refpoint_embed_undetach.detach() init_box_proposal = torch.gather( output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ).sigmoid() # sigmoid # gather tgt tgt_undetach = torch.gather( output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) ) if self.embed_init_tgt: tgt_ = ( self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) ) # nq, bs, d_model else: tgt_ = tgt_undetach.detach() if refpoint_embed is not None: refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ elif self.two_stage_type == "no": tgt_ = ( self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) ) # nq, bs, d_model refpoint_embed_ = ( self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) ) # nq, bs, 4 if refpoint_embed is not None: refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ if self.num_patterns > 0: tgt_embed = tgt.repeat(1, self.num_patterns, 1) refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( self.num_queries, 1 ) # 1, n_q*n_pat, d_model tgt = tgt_embed + tgt_pat init_box_proposal = refpoint_embed_.sigmoid() else: raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type)) ######################################################### # End preparing tgt # - tgt: bs, NQ, d_model # - refpoint_embed(unsigmoid): bs, NQ, d_model ######################################################### ######################################################### # Begin Decoder ######################################################### hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=memory.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten.transpose(0, 1), refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=attn_mask, memory_text=text_dict["encoded_text"], text_attention_mask=~text_dict["text_token_mask"], # we ~ the mask . False means use the token; True means pad the token ) ######################################################### # End Decoder # hs: n_dec, bs, nq, d_model # references: n_dec+1, bs, nq, query_dim ######################################################### ######################################################### # Begin postprocess ######################################################### if self.two_stage_type == "standard": hs_enc = tgt_undetach.unsqueeze(0) ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) else: hs_enc = ref_enc = None ######################################################### # End postprocess # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None ######################################################### return hs, references, hs_enc, ref_enc, init_box_proposal # hs: (n_dec, bs, nq, d_model) # references: sigmoid coordinates. (n_dec+1, bs, bq, 4) # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None # ref_enc: sigmoid coordinates. \ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None class TransformerEncoder(nn.Module): def __init__( self, encoder_layer, num_layers, d_model=256, num_queries=300, enc_layer_share=False, text_enhance_layer=None, feature_fusion_layer=None, use_checkpoint=False, use_transformer_ckpt=False, ): """_summary_ Args: encoder_layer (_type_): _description_ num_layers (_type_): _description_ norm (_type_, optional): _description_. Defaults to None. d_model (int, optional): _description_. Defaults to 256. num_queries (int, optional): _description_. Defaults to 300. enc_layer_share (bool, optional): _description_. Defaults to False. """ super().__init__() # prepare layers self.layers = [] self.text_layers = [] self.fusion_layers = [] if num_layers > 0: self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share) if text_enhance_layer is not None: self.text_layers = _get_clones( text_enhance_layer, num_layers, layer_share=enc_layer_share ) if feature_fusion_layer is not None: self.fusion_layers = _get_clones( feature_fusion_layer, num_layers, layer_share=enc_layer_share ) else: self.layers = [] del encoder_layer if text_enhance_layer is not None: self.text_layers = [] del text_enhance_layer if feature_fusion_layer is not None: self.fusion_layers = [] del feature_fusion_layer self.query_scale = None self.num_queries = num_queries self.num_layers = num_layers self.d_model = d_model self.use_checkpoint = use_checkpoint self.use_transformer_ckpt = use_transformer_ckpt @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward( self, # for images src: Tensor, pos: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, key_padding_mask: Tensor, # for texts memory_text: Tensor = None, text_attention_mask: Tensor = None, pos_text: Tensor = None, text_self_attention_masks: Tensor = None, position_ids: Tensor = None, ): """ Input: - src: [bs, sum(hi*wi), 256] - pos: pos embed for src. [bs, sum(hi*wi), 256] - spatial_shapes: h,w of each level [num_level, 2] - level_start_index: [num_level] start point of level in sum(hi*wi). - valid_ratios: [bs, num_level, 2] - key_padding_mask: [bs, sum(hi*wi)] - memory_text: bs, n_text, 256 - text_attention_mask: bs, n_text False for no padding; True for padding - pos_text: bs, n_text, 256 - position_ids: bs, n_text Intermedia: - reference_points: [bs, sum(hi*wi), num_level, 2] Outpus: - output: [bs, sum(hi*wi), 256] """ output = src # preparation and reshape if self.num_layers > 0: reference_points = self.get_reference_points( spatial_shapes, valid_ratios, device=src.device ) if self.text_layers: # generate pos_text bs, n_text, text_dim = memory_text.shape if pos_text is None and position_ids is None: pos_text = ( torch.arange(n_text, device=memory_text.device) .float() .unsqueeze(0) .unsqueeze(-1) .repeat(bs, 1, 1) ) pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False) if position_ids is not None: pos_text = get_sine_pos_embed( position_ids[..., None], num_pos_feats=256, exchange_xy=False ) # main process for layer_id, layer in enumerate(self.layers): # if output.isnan().any() or memory_text.isnan().any(): # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # import ipdb; ipdb.set_trace() if self.fusion_layers: if self.use_checkpoint: output, memory_text = checkpoint.checkpoint( self.fusion_layers[layer_id], output, memory_text, key_padding_mask, text_attention_mask, ) else: output, memory_text = self.fusion_layers[layer_id]( v=output, l=memory_text, attention_mask_v=key_padding_mask, attention_mask_l=text_attention_mask, ) if self.text_layers: memory_text = self.text_layers[layer_id]( src=memory_text.transpose(0, 1), src_mask=~text_self_attention_masks, # note we use ~ for mask here src_key_padding_mask=text_attention_mask, pos=(pos_text.transpose(0, 1) if pos_text is not None else None), ).transpose(0, 1) # main process if self.use_transformer_ckpt: output = checkpoint.checkpoint( layer, output, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask, ) else: output = layer( src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask, ) return output, memory_text class TransformerDecoder(nn.Module): def __init__( self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, num_feature_levels=1, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) self.query_pos_sine_scale = None self.query_scale = None self.bbox_embed = None self.class_embed = None self.d_model = d_model self.ref_anchor_head = None def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # for text memory_text: Optional[Tensor] = None, text_attention_mask: Optional[Tensor] = None, ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] for layer_id, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = ( reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] ) # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :] ) # nq, bs, 256*2 # conditional query raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if query_pos.isnan().any() | query_pos.isinf().any(): # import ipdb; ipdb.set_trace() # main process output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory_text=memory_text, text_attention_mask=text_attention_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, ) if output.isnan().any() | output.isinf().any(): print(f"output layer_id {layer_id} is nan") try: num_nan = output.isnan().sum().item() num_inf = output.isinf().sum().item() print(f"num_nan {num_nan}, num_inf {num_inf}") except Exception as e: print(e) # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # import ipdb; ipdb.set_trace() # iter update if self.bbox_embed is not None: # box_holder = self.bbox_embed(output) # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points) # new_reference_points = box_holder[..., :self.query_dim].sigmoid() reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], ] class DeformableTransformerEncoderLayer(nn.Module): def __init__( self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, ): super().__init__() # self attention self.self_attn = MSDeformAttn( embed_dim=d_model, num_levels=n_levels, num_heads=n_heads, num_points=n_points, batch_first=True, ) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward( self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None ): # self attention # import ipdb; ipdb.set_trace() src2 = self.self_attn( query=self.with_pos_embed(src, pos), reference_points=reference_points, value=src, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask, ) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class DeformableTransformerDecoderLayer(nn.Module): def __init__( self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_text_feat_guide=False, use_text_cross_attention=False, ): super().__init__() # cross attention self.cross_attn = MSDeformAttn( embed_dim=d_model, num_levels=n_levels, num_heads=n_heads, num_points=n_points, batch_first=True, ) self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm1 = nn.LayerNorm(d_model) # cross attention text if use_text_cross_attention: self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.catext_norm = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm3 = nn.LayerNorm(d_model) self.key_aware_proj = None self.use_text_feat_guide = use_text_feat_guide assert not use_text_feat_guide self.use_text_cross_attention = use_text_cross_attention def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): with torch.cuda.amp.autocast(enabled=False): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt def forward( self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 memory_text: Optional[Tensor] = None, # bs, num_token, d_model text_attention_mask: Optional[Tensor] = None, # bs, num_token # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ assert cross_attn_mask is None # self attention if self.self_attn is not None: # import ipdb; ipdb.set_trace() q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) if self.use_text_cross_attention: tgt2 = self.ca_text( self.with_pos_embed(tgt, tgt_query_pos), memory_text.transpose(0, 1), memory_text.transpose(0, 1), key_padding_mask=text_attention_mask, )[0] tgt = tgt + self.catext_dropout(tgt2) tgt = self.catext_norm(tgt) tgt2 = self.cross_attn( query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), reference_points=tgt_reference_points.transpose(0, 1).contiguous(), value=memory.transpose(0, 1), spatial_shapes=memory_spatial_shapes, level_start_index=memory_level_start_index, key_padding_mask=memory_key_padding_mask, ).transpose(0, 1) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # ffn tgt = self.forward_ffn(tgt) return tgt def build_transformer(args): return Transformer( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, num_queries=args.num_queries, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, return_intermediate_dec=True, query_dim=args.query_dim, activation=args.transformer_activation, num_patterns=args.num_patterns, num_feature_levels=args.num_feature_levels, enc_n_points=args.enc_n_points, dec_n_points=args.dec_n_points, learnable_tgt_init=True, # two stage two_stage_type=args.two_stage_type, # ['no', 'standard', 'early'] embed_init_tgt=args.embed_init_tgt, use_text_enhancer=args.use_text_enhancer, use_fusion_layer=args.use_fusion_layer, use_checkpoint=args.use_checkpoint, use_transformer_ckpt=args.use_transformer_ckpt, use_text_cross_attention=args.use_text_cross_attention, text_dropout=args.text_dropout, fusion_dropout=args.fusion_dropout, fusion_droppath=args.fusion_droppath, )