Robert001's picture
first commit
b334e29
raw
history blame
No virus
36.8 kB
import torch
import torch.nn as nn
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
xavier_init)
from .builder import TRANSFORMER
class MultiheadAttention(nn.Module):
"""A warpper for torch.nn.MultiheadAttention.
This module implements MultiheadAttention with residual connection,
and positional encoding used in DETR is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
dropout (float): A Dropout layer on attn_output_weights. Default 0.0.
"""
def __init__(self, embed_dims, num_heads, dropout=0.0):
super(MultiheadAttention, self).__init__()
assert embed_dims % num_heads == 0, 'embed_dims must be ' \
f'divisible by num_heads. got {embed_dims} and {num_heads}.'
self.embed_dims = embed_dims
self.num_heads = num_heads
self.dropout = dropout
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self,
x,
key=None,
value=None,
residual=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None):
"""Forward function for `MultiheadAttention`.
Args:
x (Tensor): The input query with shape [num_query, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
key (Tensor): The key tensor with shape [num_key, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
Default None. If None, the `query` will be used.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Default None.
If None, the `key` will be used.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. Default None. If not None, it will
be added to `x` before forward function.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Default None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (Tensor): ByteTensor mask with shape [num_query,
num_key]. Same in `nn.MultiheadAttention.forward`.
Default None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
Same in `nn.MultiheadAttention.forward`. Default None.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
query = x
if key is None:
key = query
if value is None:
value = key
if residual is None:
residual = x
if key_pos is None:
if query_pos is not None and key is not None:
if query_pos.shape == key.shape:
key_pos = query_pos
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query,
key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'dropout={self.dropout})'
return repr_str
class FFN(nn.Module):
"""Implements feed-forward networks (FFNs) with residual connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Defaults to 2.
act_cfg (dict, optional): The activation config for FFNs.
dropout (float, optional): Probability of an element to be
zeroed. Default 0.0.
add_residual (bool, optional): Add resudual connection.
Defaults to True.
"""
def __init__(self,
embed_dims,
feedforward_channels,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
dropout=0.0,
add_residual=True):
super(FFN, self).__init__()
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.dropout = dropout
self.activate = build_activation_layer(act_cfg)
layers = nn.ModuleList()
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(dropout)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
def forward(self, x, residual=None):
"""Forward function for `FFN`."""
out = self.layers(x)
if not self.add_residual:
return out
if residual is None:
residual = x
return residual + self.dropout(out)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'add_residual={self.add_residual})'
return repr_str
class TransformerEncoderLayer(nn.Module):
"""Implements one encoder layer in DETR transformer.
Args:
embed_dims (int): The feature dimension. Same as `FFN`.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
dropout (float): Probability of an element to be zeroed. Default 0.0.
order (tuple[str]): The order for encoder layer. Valid examples are
('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn',
'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm').
act_cfg (dict): The activation config for FFNs. Default ReLU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
num_fcs (int): The number of fully-connected layers for FFNs.
Default 2.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
dropout=0.0,
order=('selfattn', 'norm', 'ffn', 'norm'),
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2):
super(TransformerEncoderLayer, self).__init__()
assert isinstance(order, tuple) and len(order) == 4
assert set(order) == set(['selfattn', 'norm', 'ffn'])
self.embed_dims = embed_dims
self.num_heads = num_heads
self.feedforward_channels = feedforward_channels
self.dropout = dropout
self.order = order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.num_fcs = num_fcs
self.pre_norm = order[0] == 'norm'
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout)
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
dropout)
self.norms = nn.ModuleList()
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None):
"""Forward function for `TransformerEncoderLayer`.
Args:
x (Tensor): The input query with shape [num_key, bs,
embed_dims]. Same in `MultiheadAttention.forward`.
pos (Tensor): The positional encoding for query. Default None.
Same as `query_pos` in `MultiheadAttention.forward`.
attn_mask (Tensor): ByteTensor mask with shape [num_key,
num_key]. Same in `MultiheadAttention.forward`. Default None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
Same in `MultiheadAttention.forward`. Default None.
Returns:
Tensor: forwarded results with shape [num_key, bs, embed_dims].
"""
norm_cnt = 0
inp_residual = x
for layer in self.order:
if layer == 'selfattn':
# self attention
query = key = value = x
x = self.self_attn(
query,
key,
value,
inp_residual if self.pre_norm else None,
query_pos=pos,
key_pos=pos,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)
inp_residual = x
elif layer == 'norm':
x = self.norms[norm_cnt](x)
norm_cnt += 1
elif layer == 'ffn':
x = self.ffn(x, inp_residual if self.pre_norm else None)
return x
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'order={self.order}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg}, '
repr_str += f'num_fcs={self.num_fcs})'
return repr_str
class TransformerDecoderLayer(nn.Module):
"""Implements one decoder layer in DETR transformer.
Args:
embed_dims (int): The feature dimension. Same as
`TransformerEncoderLayer`.
num_heads (int): Parallel attention heads.
feedforward_channels (int): Same as `TransformerEncoderLayer`.
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0.
order (tuple[str]): The order for decoder layer. Valid examples are
('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and
('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn').
Default the former.
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
num_fcs (int): The number of fully-connected layers in FFNs.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
dropout=0.0,
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn',
'norm'),
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2):
super(TransformerDecoderLayer, self).__init__()
assert isinstance(order, tuple) and len(order) == 6
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn'])
self.embed_dims = embed_dims
self.num_heads = num_heads
self.feedforward_channels = feedforward_channels
self.dropout = dropout
self.order = order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.num_fcs = num_fcs
self.pre_norm = order[0] == 'norm'
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout)
self.multihead_attn = MultiheadAttention(embed_dims, num_heads,
dropout)
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
dropout)
self.norms = nn.ModuleList()
# 3 norm layers in official DETR's TransformerDecoderLayer
for _ in range(3):
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
def forward(self,
x,
memory,
memory_pos=None,
query_pos=None,
memory_attn_mask=None,
target_attn_mask=None,
memory_key_padding_mask=None,
target_key_padding_mask=None):
"""Forward function for `TransformerDecoderLayer`.
Args:
x (Tensor): Input query with shape [num_query, bs, embed_dims].
memory (Tensor): Tensor got from `TransformerEncoder`, with shape
[num_key, bs, embed_dims].
memory_pos (Tensor): The positional encoding for `memory`. Default
None. Same as `key_pos` in `MultiheadAttention.forward`.
query_pos (Tensor): The positional encoding for `query`. Default
None. Same as `query_pos` in `MultiheadAttention.forward`.
memory_attn_mask (Tensor): ByteTensor mask for `memory`, with
shape [num_key, num_key]. Same as `attn_mask` in
`MultiheadAttention.forward`. Default None.
target_attn_mask (Tensor): ByteTensor mask for `x`, with shape
[num_query, num_query]. Same as `attn_mask` in
`MultiheadAttention.forward`. Default None.
memory_key_padding_mask (Tensor): ByteTensor for `memory`, with
shape [bs, num_key]. Same as `key_padding_mask` in
`MultiheadAttention.forward`. Default None.
target_key_padding_mask (Tensor): ByteTensor for `x`, with shape
[bs, num_query]. Same as `key_padding_mask` in
`MultiheadAttention.forward`. Default None.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
norm_cnt = 0
inp_residual = x
for layer in self.order:
if layer == 'selfattn':
query = key = value = x
x = self.self_attn(
query,
key,
value,
inp_residual if self.pre_norm else None,
query_pos,
key_pos=query_pos,
attn_mask=target_attn_mask,
key_padding_mask=target_key_padding_mask)
inp_residual = x
elif layer == 'norm':
x = self.norms[norm_cnt](x)
norm_cnt += 1
elif layer == 'multiheadattn':
query = x
key = value = memory
x = self.multihead_attn(
query,
key,
value,
inp_residual if self.pre_norm else None,
query_pos,
key_pos=memory_pos,
attn_mask=memory_attn_mask,
key_padding_mask=memory_key_padding_mask)
inp_residual = x
elif layer == 'ffn':
x = self.ffn(x, inp_residual if self.pre_norm else None)
return x
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'order={self.order}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg}, '
repr_str += f'num_fcs={self.num_fcs})'
return repr_str
class TransformerEncoder(nn.Module):
"""Implements the encoder in DETR transformer.
Args:
num_layers (int): The number of `TransformerEncoderLayer`.
embed_dims (int): Same as `TransformerEncoderLayer`.
num_heads (int): Same as `TransformerEncoderLayer`.
feedforward_channels (int): Same as `TransformerEncoderLayer`.
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0.
order (tuple[str]): Same as `TransformerEncoderLayer`.
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU.
norm_cfg (dict): Same as `TransformerEncoderLayer`. Default
layer normalization.
num_fcs (int): Same as `TransformerEncoderLayer`. Default 2.
"""
def __init__(self,
num_layers,
embed_dims,
num_heads,
feedforward_channels,
dropout=0.0,
order=('selfattn', 'norm', 'ffn', 'norm'),
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2):
super(TransformerEncoder, self).__init__()
assert isinstance(order, tuple) and len(order) == 4
assert set(order) == set(['selfattn', 'norm', 'ffn'])
self.num_layers = num_layers
self.embed_dims = embed_dims
self.num_heads = num_heads
self.feedforward_channels = feedforward_channels
self.dropout = dropout
self.order = order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.num_fcs = num_fcs
self.pre_norm = order[0] == 'norm'
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerEncoderLayer(embed_dims, num_heads,
feedforward_channels, dropout, order,
act_cfg, norm_cfg, num_fcs))
self.norm = build_norm_layer(norm_cfg,
embed_dims)[1] if self.pre_norm else None
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None):
"""Forward function for `TransformerEncoder`.
Args:
x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`.
pos (Tensor): Positional encoding for query. Default None.
Same in `TransformerEncoderLayer.forward`.
attn_mask (Tensor): ByteTensor attention mask. Default None.
Same in `TransformerEncoderLayer.forward`.
key_padding_mask (Tensor): Same in
`TransformerEncoderLayer.forward`. Default None.
Returns:
Tensor: Results with shape [num_key, bs, embed_dims].
"""
for layer in self.layers:
x = layer(x, pos, attn_mask, key_padding_mask)
if self.norm is not None:
x = self.norm(x)
return x
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(num_layers={self.num_layers}, '
repr_str += f'embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'order={self.order}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg}, '
repr_str += f'num_fcs={self.num_fcs})'
return repr_str
class TransformerDecoder(nn.Module):
"""Implements the decoder in DETR transformer.
Args:
num_layers (int): The number of `TransformerDecoderLayer`.
embed_dims (int): Same as `TransformerDecoderLayer`.
num_heads (int): Same as `TransformerDecoderLayer`.
feedforward_channels (int): Same as `TransformerDecoderLayer`.
dropout (float): Same as `TransformerDecoderLayer`. Default 0.0.
order (tuple[str]): Same as `TransformerDecoderLayer`.
act_cfg (dict): Same as `TransformerDecoderLayer`. Default ReLU.
norm_cfg (dict): Same as `TransformerDecoderLayer`. Default
layer normalization.
num_fcs (int): Same as `TransformerDecoderLayer`. Default 2.
"""
def __init__(self,
num_layers,
embed_dims,
num_heads,
feedforward_channels,
dropout=0.0,
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn',
'norm'),
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2,
return_intermediate=False):
super(TransformerDecoder, self).__init__()
assert isinstance(order, tuple) and len(order) == 6
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn'])
self.num_layers = num_layers
self.embed_dims = embed_dims
self.num_heads = num_heads
self.feedforward_channels = feedforward_channels
self.dropout = dropout
self.order = order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.num_fcs = num_fcs
self.return_intermediate = return_intermediate
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerDecoderLayer(embed_dims, num_heads,
feedforward_channels, dropout, order,
act_cfg, norm_cfg, num_fcs))
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self,
x,
memory,
memory_pos=None,
query_pos=None,
memory_attn_mask=None,
target_attn_mask=None,
memory_key_padding_mask=None,
target_key_padding_mask=None):
"""Forward function for `TransformerDecoder`.
Args:
x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`.
memory (Tensor): Same in `TransformerDecoderLayer.forward`.
memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
Default None.
query_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
Default None.
memory_attn_mask (Tensor): Same in
`TransformerDecoderLayer.forward`. Default None.
target_attn_mask (Tensor): Same in
`TransformerDecoderLayer.forward`. Default None.
memory_key_padding_mask (Tensor): Same in
`TransformerDecoderLayer.forward`. Default None.
target_key_padding_mask (Tensor): Same in
`TransformerDecoderLayer.forward`. Default None.
Returns:
Tensor: Results with shape [num_query, bs, embed_dims].
"""
intermediate = []
for layer in self.layers:
x = layer(x, memory, memory_pos, query_pos, memory_attn_mask,
target_attn_mask, memory_key_padding_mask,
target_key_padding_mask)
if self.return_intermediate:
intermediate.append(self.norm(x))
if self.norm is not None:
x = self.norm(x)
if self.return_intermediate:
intermediate.pop()
intermediate.append(x)
if self.return_intermediate:
return torch.stack(intermediate)
return x.unsqueeze(0)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(num_layers={self.num_layers}, '
repr_str += f'embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'order={self.order}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'return_intermediate={self.return_intermediate})'
return repr_str
@TRANSFORMER.register_module()
class Transformer(nn.Module):
"""Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
num_encoder_layers (int): Number of `TransformerEncoderLayer`.
num_decoder_layers (int): Number of `TransformerDecoderLayer`.
feedforward_channels (int): The hidden dimension for FFNs used in both
encoder and decoder.
dropout (float): Probability of an element to be zeroed. Default 0.0.
act_cfg (dict): Activation config for FFNs used in both encoder
and decoder. Default ReLU.
norm_cfg (dict): Config dict for normalization used in both encoder
and decoder. Default layer normalization.
num_fcs (int): The number of fully-connected layers in FFNs, which is
used for both encoder and decoder.
pre_norm (bool): Whether the normalization layer is ordered
first in the encoder and decoder. Default False.
return_intermediate_dec (bool): Whether to return the intermediate
output from each TransformerDecoderLayer or only the last
TransformerDecoderLayer. Default False. If False, the returned
`hs` has shape [num_decoder_layers, bs, num_query, embed_dims].
If True, the returned `hs` will have shape [1, bs, num_query,
embed_dims].
"""
def __init__(self,
embed_dims=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
feedforward_channels=2048,
dropout=0.0,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2,
pre_norm=False,
return_intermediate_dec=False):
super(Transformer, self).__init__()
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.feedforward_channels = feedforward_channels
self.dropout = dropout
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.num_fcs = num_fcs
self.pre_norm = pre_norm
self.return_intermediate_dec = return_intermediate_dec
if self.pre_norm:
encoder_order = ('norm', 'selfattn', 'norm', 'ffn')
decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn',
'norm', 'ffn')
else:
encoder_order = ('selfattn', 'norm', 'ffn', 'norm')
decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm',
'ffn', 'norm')
self.encoder = TransformerEncoder(num_encoder_layers, embed_dims,
num_heads, feedforward_channels,
dropout, encoder_order, act_cfg,
norm_cfg, num_fcs)
self.decoder = TransformerDecoder(num_decoder_layers, embed_dims,
num_heads, feedforward_channels,
dropout, decoder_order, act_cfg,
norm_cfg, num_fcs,
return_intermediate_dec)
def init_weights(self, distribution='uniform'):
"""Initialize the transformer weights."""
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, 'weight') and m.weight.dim() > 1:
xavier_init(m, distribution=distribution)
def forward(self, x, mask, query_embed, pos_embed):
"""Forward function for `Transformer`.
Args:
x (Tensor): Input query with shape [bs, c, h, w] where
c = embed_dims.
mask (Tensor): The key_padding_mask used for encoder and decoder,
with shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder, with shape
[num_query, c].
pos_embed (Tensor): The positional encoding for encoder and
decoder, with the same shape as `x`.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- out_dec: Output from decoder. If return_intermediate_dec \
is True output has shape [num_dec_layers, bs,
num_query, embed_dims], else has shape [1, bs, \
num_query, embed_dims].
- memory: Output results from encoder, with shape \
[bs, embed_dims, h, w].
"""
bs, c, h, w = x.shape
x = x.flatten(2).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(
1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
mask = mask.flatten(1) # [bs, h, w] -> [bs, h*w]
memory = self.encoder(
x, pos=pos_embed, attn_mask=None, key_padding_mask=mask)
target = torch.zeros_like(query_embed)
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
target,
memory,
memory_pos=pos_embed,
query_pos=query_embed,
memory_attn_mask=None,
target_attn_mask=None,
memory_key_padding_mask=mask,
target_key_padding_mask=None)
out_dec = out_dec.transpose(1, 2)
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
return out_dec, memory
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'num_heads={self.num_heads}, '
repr_str += f'num_encoder_layers={self.num_encoder_layers}, '
repr_str += f'num_decoder_layers={self.num_decoder_layers}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'pre_norm={self.pre_norm}, '
repr_str += f'return_intermediate_dec={self.return_intermediate_dec})'
return repr_str
@TRANSFORMER.register_module()
class DynamicConv(nn.Module):
"""Implements Dynamic Convolution.
This module generate parameters for each sample and
use bmm to implement 1*1 convolution. Code is modified
from the `official github repo <https://github.com/PeizeSun/
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
Args:
in_channels (int): The input feature channel.
Defaults to 256.
feat_channels (int): The inner feature channel.
Defaults to 64.
out_channels (int, optional): The output feature channel.
When not specified, it will be set to `in_channels`
by default
input_feat_shape (int): The shape of input feature.
Defaults to 7.
act_cfg (dict): The activation config for DynamicConv.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
"""
def __init__(self,
in_channels=256,
feat_channels=64,
out_channels=None,
input_feat_shape=7,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')):
super(DynamicConv, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.out_channels_raw = out_channels
self.input_feat_shape = input_feat_shape
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.out_channels = out_channels if out_channels else in_channels
self.num_params_in = self.in_channels * self.feat_channels
self.num_params_out = self.out_channels * self.feat_channels
self.dynamic_layer = nn.Linear(
self.in_channels, self.num_params_in + self.num_params_out)
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
self.activation = build_activation_layer(act_cfg)
num_output = self.out_channels * input_feat_shape**2
self.fc_layer = nn.Linear(num_output, self.out_channels)
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
def forward(self, param_feature, input_feature):
"""Forward function for `DynamicConv`.
Args:
param_feature (Tensor): The feature can be used
to generate the parameter, has shape
(num_all_proposals, in_channels).
input_feature (Tensor): Feature that
interact with parameters, has shape
(num_all_proposals, in_channels, H, W).
Returns:
Tensor: The output feature has shape
(num_all_proposals, out_channels).
"""
num_proposals = param_feature.size(0)
input_feature = input_feature.view(num_proposals, self.in_channels,
-1).permute(2, 0, 1)
input_feature = input_feature.permute(1, 0, 2)
parameters = self.dynamic_layer(param_feature)
param_in = parameters[:, :self.num_params_in].view(
-1, self.in_channels, self.feat_channels)
param_out = parameters[:, -self.num_params_out:].view(
-1, self.feat_channels, self.out_channels)
# input_feature has shape (num_all_proposals, H*W, in_channels)
# param_in has shape (num_all_proposals, in_channels, feat_channels)
# feature has shape (num_all_proposals, H*W, feat_channels)
features = torch.bmm(input_feature, param_in)
features = self.norm_in(features)
features = self.activation(features)
# param_out has shape (batch_size, feat_channels, out_channels)
features = torch.bmm(features, param_out)
features = self.norm_out(features)
features = self.activation(features)
features = features.flatten(1)
features = self.fc_layer(features)
features = self.fc_norm(features)
features = self.activation(features)
return features
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(in_channels={self.in_channels}, '
repr_str += f'feat_channels={self.feat_channels}, '
repr_str += f'out_channels={self.out_channels_raw}, '
repr_str += f'input_feat_shape={self.input_feat_shape}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'norm_cfg={self.norm_cfg})'
return repr_str