|
|
|
import torch |
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.cnn.bricks.transformer import FFN |
|
from torch import Tensor |
|
from torch.nn import ModuleList |
|
|
|
from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer |
|
from .utils import MLP, ConditionalAttention, coordinate_to_encoding |
|
|
|
|
|
class ConditionalDetrTransformerDecoder(DetrTransformerDecoder): |
|
"""Decoder of Conditional DETR.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize decoder layers and other layers.""" |
|
self.layers = ModuleList([ |
|
ConditionalDetrTransformerDecoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
self.embed_dims = self.layers[0].embed_dims |
|
self.post_norm = build_norm_layer(self.post_norm_cfg, |
|
self.embed_dims)[1] |
|
|
|
self.query_scale = MLP(self.embed_dims, self.embed_dims, |
|
self.embed_dims, 2) |
|
self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2) |
|
|
|
|
|
|
|
for layer_id in range(self.num_layers - 1): |
|
self.layers[layer_id + 1].cross_attn.qpos_proj = None |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor = None, |
|
query_pos: Tensor = None, |
|
key_pos: Tensor = None, |
|
key_padding_mask: Tensor = None): |
|
"""Forward function of decoder. |
|
|
|
Args: |
|
query (Tensor): The input query with shape |
|
(bs, num_queries, dim). |
|
key (Tensor): The input key with shape (bs, num_keys, dim) If |
|
`None`, the `query` will be used. Defaults to `None`. |
|
query_pos (Tensor): The positional encoding for `query`, with the |
|
same shape as `query`. If not `None`, it will be added to |
|
`query` before forward function. Defaults to `None`. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. If not `None`, it will be added to |
|
`key` before forward function. If `None`, and `query_pos` |
|
has the same shape as `key`, then `query_pos` will be used |
|
as `key_pos`. Defaults to `None`. |
|
key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). |
|
Defaults to `None`. |
|
Returns: |
|
List[Tensor]: forwarded results with shape (num_decoder_layers, |
|
bs, num_queries, dim) if `return_intermediate` is True, otherwise |
|
with shape (1, bs, num_queries, dim). References with shape |
|
(bs, num_queries, 2). |
|
""" |
|
reference_unsigmoid = self.ref_point_head( |
|
query_pos) |
|
reference = reference_unsigmoid.sigmoid() |
|
reference_xy = reference[..., :2] |
|
intermediate = [] |
|
for layer_id, layer in enumerate(self.layers): |
|
if layer_id == 0: |
|
pos_transformation = 1 |
|
else: |
|
pos_transformation = self.query_scale(query) |
|
|
|
ref_sine_embed = coordinate_to_encoding(coord_tensor=reference_xy) |
|
|
|
ref_sine_embed = ref_sine_embed * pos_transformation |
|
query = layer( |
|
query, |
|
key=key, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
key_padding_mask=key_padding_mask, |
|
ref_sine_embed=ref_sine_embed, |
|
is_first=(layer_id == 0)) |
|
if self.return_intermediate: |
|
intermediate.append(self.post_norm(query)) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate), reference |
|
|
|
query = self.post_norm(query) |
|
return query.unsqueeze(0), reference |
|
|
|
|
|
class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): |
|
"""Implements decoder layer in Conditional DETR transformer.""" |
|
|
|
def _init_layers(self): |
|
"""Initialize self-attention, cross-attention, FFN, and |
|
normalization.""" |
|
self.self_attn = ConditionalAttention(**self.self_attn_cfg) |
|
self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) |
|
self.embed_dims = self.self_attn.embed_dims |
|
self.ffn = FFN(**self.ffn_cfg) |
|
norms_list = [ |
|
build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
|
for _ in range(3) |
|
] |
|
self.norms = ModuleList(norms_list) |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor = None, |
|
query_pos: Tensor = None, |
|
key_pos: Tensor = None, |
|
self_attn_masks: Tensor = None, |
|
cross_attn_masks: Tensor = None, |
|
key_padding_mask: Tensor = None, |
|
ref_sine_embed: Tensor = None, |
|
is_first: bool = False): |
|
""" |
|
Args: |
|
query (Tensor): The input query, has shape (bs, num_queries, dim) |
|
key (Tensor, optional): The input key, has shape (bs, num_keys, |
|
dim). If `None`, the `query` will be used. Defaults to `None`. |
|
query_pos (Tensor, optional): The positional encoding for `query`, |
|
has the same shape as `query`. If not `None`, it will be |
|
added to `query` before forward function. Defaults to `None`. |
|
ref_sine_embed (Tensor): The positional encoding for query in |
|
cross attention, with the same shape as `x`. Defaults to None. |
|
key_pos (Tensor, optional): The positional encoding for `key`, has |
|
the same shape as `key`. If not None, it will be added to |
|
`key` before forward function. If None, and `query_pos` has |
|
the same shape as `key`, then `query_pos` will be used for |
|
`key_pos`. Defaults to None. |
|
self_attn_masks (Tensor, optional): ByteTensor mask, has shape |
|
(num_queries, num_keys), Same in `nn.MultiheadAttention. |
|
forward`. Defaults to None. |
|
cross_attn_masks (Tensor, optional): ByteTensor mask, has shape |
|
(num_queries, num_keys), Same in `nn.MultiheadAttention. |
|
forward`. Defaults to None. |
|
key_padding_mask (Tensor, optional): ByteTensor, has shape |
|
(bs, num_keys). Defaults to None. |
|
is_first (bool): A indicator to tell whether the current layer |
|
is the first layer of the decoder. Defaults to False. |
|
|
|
Returns: |
|
Tensor: Forwarded results, has shape (bs, num_queries, dim). |
|
""" |
|
query = self.self_attn( |
|
query=query, |
|
key=query, |
|
query_pos=query_pos, |
|
key_pos=query_pos, |
|
attn_mask=self_attn_masks) |
|
query = self.norms[0](query) |
|
query = self.cross_attn( |
|
query=query, |
|
key=key, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
attn_mask=cross_attn_masks, |
|
key_padding_mask=key_padding_mask, |
|
ref_sine_embed=ref_sine_embed, |
|
is_first=is_first) |
|
query = self.norms[1](query) |
|
query = self.ffn(query) |
|
query = self.norms[2](query) |
|
|
|
return query |
|
|