# ------------------------------------------------------------------------ # Copyright (c) IDEA, Inc. and its affiliates. # Modified from DINO https://github.com/IDEA-Research/DINO by Feng Li and Hao Zhang. # ------------------------------------------------------------------------ from typing import Optional, List, Union import torch from torch import nn, Tensor from torch.cuda.amp import autocast from ...utils.utils import MLP, _get_clones, _get_activation_fn, gen_sineembed_for_position, inverse_sigmoid from detrex.layers import MultiScaleDeformableAttention class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=True, num_feature_levels=1, deformable_decoder=True, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=True, dec_layer_share=False, dec_layer_dropout_prob=None, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] for layer_id, layer in enumerate(self.layers): # preprocess ref points if self.training and self.decoder_query_perturber is not None and layer_id != 0: reference_points = self.decoder_query_perturber(reference_points) reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2 raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask ) # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: ref_points.append(new_reference_points) intermediate.append(self.norm(output)) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, key_aware_type=None, ): super().__init__() # cross attention if use_deformable_box_attn: raise NotImplementedError else: self.cross_attn = MultiScaleDeformableAttention( embed_dim=d_model, num_levels=n_levels, num_heads=n_heads, num_points=n_points, batch_first=True,dropout=dropout) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt @autocast(enabled=False) def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # cross attention if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) tgt2 = self.cross_attn(query=tgt.transpose(0, 1), query_pos=tgt_query_pos.transpose(0, 1), reference_points=tgt_reference_points.transpose(0, 1).contiguous(), value=memory.transpose(0, 1), spatial_shapes=memory_spatial_shapes, level_start_index=memory_level_start_index, key_padding_mask=memory_key_padding_mask).transpose(0, 1) tgt = tgt2 tgt = self.norm1(tgt) # ffn tgt = self.forward_ffn(tgt) return tgt # def _get_clones(module, N, layer_share=False): # # import ipdb; ipdb.set_trace() # if layer_share: # return nn.ModuleList([module for i in range(N)]) # else: # return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) # # # def _get_activation_fn_(activation): # """Return an activation function given a string""" # if activation == "relu": # return F.relu # if activation == "gelu": # return F.gelu # if activation == "glu": # return F.glu # raise RuntimeError(f"activation should be relu/gelu, not {activation}.")