|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer, |
|
xavier_init) |
|
|
|
from .builder import TRANSFORMER |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
"""A warpper for torch.nn.MultiheadAttention. |
|
|
|
This module implements MultiheadAttention with residual connection, |
|
and positional encoding used in DETR is also passed as input. |
|
|
|
Args: |
|
embed_dims (int): The embedding dimension. |
|
num_heads (int): Parallel attention heads. Same as |
|
`nn.MultiheadAttention`. |
|
dropout (float): A Dropout layer on attn_output_weights. Default 0.0. |
|
""" |
|
|
|
def __init__(self, embed_dims, num_heads, dropout=0.0): |
|
super(MultiheadAttention, self).__init__() |
|
assert embed_dims % num_heads == 0, 'embed_dims must be ' \ |
|
f'divisible by num_heads. got {embed_dims} and {num_heads}.' |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, |
|
x, |
|
key=None, |
|
value=None, |
|
residual=None, |
|
query_pos=None, |
|
key_pos=None, |
|
attn_mask=None, |
|
key_padding_mask=None): |
|
"""Forward function for `MultiheadAttention`. |
|
|
|
Args: |
|
x (Tensor): The input query with shape [num_query, bs, |
|
embed_dims]. Same in `nn.MultiheadAttention.forward`. |
|
key (Tensor): The key tensor with shape [num_key, bs, |
|
embed_dims]. Same in `nn.MultiheadAttention.forward`. |
|
Default None. If None, the `query` will be used. |
|
value (Tensor): The value tensor with same shape as `key`. |
|
Same in `nn.MultiheadAttention.forward`. Default None. |
|
If None, the `key` will be used. |
|
residual (Tensor): The tensor used for addition, with the |
|
same shape as `x`. Default None. If None, `x` will be used. |
|
query_pos (Tensor): The positional encoding for query, with |
|
the same shape as `x`. Default None. If not None, it will |
|
be added to `x` before forward function. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. Default None. If not None, it will |
|
be added to `key` before forward function. If None, and |
|
`query_pos` has the same shape as `key`, then `query_pos` |
|
will be used for `key_pos`. |
|
attn_mask (Tensor): ByteTensor mask with shape [num_query, |
|
num_key]. Same in `nn.MultiheadAttention.forward`. |
|
Default None. |
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. |
|
Same in `nn.MultiheadAttention.forward`. Default None. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
query = x |
|
if key is None: |
|
key = query |
|
if value is None: |
|
value = key |
|
if residual is None: |
|
residual = x |
|
if key_pos is None: |
|
if query_pos is not None and key is not None: |
|
if query_pos.shape == key.shape: |
|
key_pos = query_pos |
|
if query_pos is not None: |
|
query = query + query_pos |
|
if key_pos is not None: |
|
key = key + key_pos |
|
out = self.attn( |
|
query, |
|
key, |
|
value=value, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask)[0] |
|
|
|
return residual + self.dropout(out) |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'dropout={self.dropout})' |
|
return repr_str |
|
|
|
|
|
class FFN(nn.Module): |
|
"""Implements feed-forward networks (FFNs) with residual connection. |
|
|
|
Args: |
|
embed_dims (int): The feature dimension. Same as |
|
`MultiheadAttention`. |
|
feedforward_channels (int): The hidden dimension of FFNs. |
|
num_fcs (int, optional): The number of fully-connected layers in |
|
FFNs. Defaults to 2. |
|
act_cfg (dict, optional): The activation config for FFNs. |
|
dropout (float, optional): Probability of an element to be |
|
zeroed. Default 0.0. |
|
add_residual (bool, optional): Add resudual connection. |
|
Defaults to True. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
feedforward_channels, |
|
num_fcs=2, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
dropout=0.0, |
|
add_residual=True): |
|
super(FFN, self).__init__() |
|
assert num_fcs >= 2, 'num_fcs should be no less ' \ |
|
f'than 2. got {num_fcs}.' |
|
self.embed_dims = embed_dims |
|
self.feedforward_channels = feedforward_channels |
|
self.num_fcs = num_fcs |
|
self.act_cfg = act_cfg |
|
self.dropout = dropout |
|
self.activate = build_activation_layer(act_cfg) |
|
|
|
layers = nn.ModuleList() |
|
in_channels = embed_dims |
|
for _ in range(num_fcs - 1): |
|
layers.append( |
|
nn.Sequential( |
|
Linear(in_channels, feedforward_channels), self.activate, |
|
nn.Dropout(dropout))) |
|
in_channels = feedforward_channels |
|
layers.append(Linear(feedforward_channels, embed_dims)) |
|
self.layers = nn.Sequential(*layers) |
|
self.dropout = nn.Dropout(dropout) |
|
self.add_residual = add_residual |
|
|
|
def forward(self, x, residual=None): |
|
"""Forward function for `FFN`.""" |
|
out = self.layers(x) |
|
if not self.add_residual: |
|
return out |
|
if residual is None: |
|
residual = x |
|
return residual + self.dropout(out) |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(embed_dims={self.embed_dims}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'num_fcs={self.num_fcs}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'add_residual={self.add_residual})' |
|
return repr_str |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
"""Implements one encoder layer in DETR transformer. |
|
|
|
Args: |
|
embed_dims (int): The feature dimension. Same as `FFN`. |
|
num_heads (int): Parallel attention heads. |
|
feedforward_channels (int): The hidden dimension for FFNs. |
|
dropout (float): Probability of an element to be zeroed. Default 0.0. |
|
order (tuple[str]): The order for encoder layer. Valid examples are |
|
('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn', |
|
'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm'). |
|
act_cfg (dict): The activation config for FFNs. Default ReLU. |
|
norm_cfg (dict): Config dict for normalization layer. Default |
|
layer normalization. |
|
num_fcs (int): The number of fully-connected layers for FFNs. |
|
Default 2. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
feedforward_channels, |
|
dropout=0.0, |
|
order=('selfattn', 'norm', 'ffn', 'norm'), |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
num_fcs=2): |
|
super(TransformerEncoderLayer, self).__init__() |
|
assert isinstance(order, tuple) and len(order) == 4 |
|
assert set(order) == set(['selfattn', 'norm', 'ffn']) |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.feedforward_channels = feedforward_channels |
|
self.dropout = dropout |
|
self.order = order |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.num_fcs = num_fcs |
|
self.pre_norm = order[0] == 'norm' |
|
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) |
|
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, |
|
dropout) |
|
self.norms = nn.ModuleList() |
|
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) |
|
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) |
|
|
|
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): |
|
"""Forward function for `TransformerEncoderLayer`. |
|
|
|
Args: |
|
x (Tensor): The input query with shape [num_key, bs, |
|
embed_dims]. Same in `MultiheadAttention.forward`. |
|
pos (Tensor): The positional encoding for query. Default None. |
|
Same as `query_pos` in `MultiheadAttention.forward`. |
|
attn_mask (Tensor): ByteTensor mask with shape [num_key, |
|
num_key]. Same in `MultiheadAttention.forward`. Default None. |
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. |
|
Same in `MultiheadAttention.forward`. Default None. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_key, bs, embed_dims]. |
|
""" |
|
norm_cnt = 0 |
|
inp_residual = x |
|
for layer in self.order: |
|
if layer == 'selfattn': |
|
|
|
query = key = value = x |
|
x = self.self_attn( |
|
query, |
|
key, |
|
value, |
|
inp_residual if self.pre_norm else None, |
|
query_pos=pos, |
|
key_pos=pos, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask) |
|
inp_residual = x |
|
elif layer == 'norm': |
|
x = self.norms[norm_cnt](x) |
|
norm_cnt += 1 |
|
elif layer == 'ffn': |
|
x = self.ffn(x, inp_residual if self.pre_norm else None) |
|
return x |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'order={self.order}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg}, ' |
|
repr_str += f'num_fcs={self.num_fcs})' |
|
return repr_str |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
"""Implements one decoder layer in DETR transformer. |
|
|
|
Args: |
|
embed_dims (int): The feature dimension. Same as |
|
`TransformerEncoderLayer`. |
|
num_heads (int): Parallel attention heads. |
|
feedforward_channels (int): Same as `TransformerEncoderLayer`. |
|
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. |
|
order (tuple[str]): The order for decoder layer. Valid examples are |
|
('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and |
|
('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn'). |
|
Default the former. |
|
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. |
|
norm_cfg (dict): Config dict for normalization layer. Default |
|
layer normalization. |
|
num_fcs (int): The number of fully-connected layers in FFNs. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
feedforward_channels, |
|
dropout=0.0, |
|
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', |
|
'norm'), |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
num_fcs=2): |
|
super(TransformerDecoderLayer, self).__init__() |
|
assert isinstance(order, tuple) and len(order) == 6 |
|
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.feedforward_channels = feedforward_channels |
|
self.dropout = dropout |
|
self.order = order |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.num_fcs = num_fcs |
|
self.pre_norm = order[0] == 'norm' |
|
self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) |
|
self.multihead_attn = MultiheadAttention(embed_dims, num_heads, |
|
dropout) |
|
self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, |
|
dropout) |
|
self.norms = nn.ModuleList() |
|
|
|
for _ in range(3): |
|
self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) |
|
|
|
def forward(self, |
|
x, |
|
memory, |
|
memory_pos=None, |
|
query_pos=None, |
|
memory_attn_mask=None, |
|
target_attn_mask=None, |
|
memory_key_padding_mask=None, |
|
target_key_padding_mask=None): |
|
"""Forward function for `TransformerDecoderLayer`. |
|
|
|
Args: |
|
x (Tensor): Input query with shape [num_query, bs, embed_dims]. |
|
memory (Tensor): Tensor got from `TransformerEncoder`, with shape |
|
[num_key, bs, embed_dims]. |
|
memory_pos (Tensor): The positional encoding for `memory`. Default |
|
None. Same as `key_pos` in `MultiheadAttention.forward`. |
|
query_pos (Tensor): The positional encoding for `query`. Default |
|
None. Same as `query_pos` in `MultiheadAttention.forward`. |
|
memory_attn_mask (Tensor): ByteTensor mask for `memory`, with |
|
shape [num_key, num_key]. Same as `attn_mask` in |
|
`MultiheadAttention.forward`. Default None. |
|
target_attn_mask (Tensor): ByteTensor mask for `x`, with shape |
|
[num_query, num_query]. Same as `attn_mask` in |
|
`MultiheadAttention.forward`. Default None. |
|
memory_key_padding_mask (Tensor): ByteTensor for `memory`, with |
|
shape [bs, num_key]. Same as `key_padding_mask` in |
|
`MultiheadAttention.forward`. Default None. |
|
target_key_padding_mask (Tensor): ByteTensor for `x`, with shape |
|
[bs, num_query]. Same as `key_padding_mask` in |
|
`MultiheadAttention.forward`. Default None. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
norm_cnt = 0 |
|
inp_residual = x |
|
for layer in self.order: |
|
if layer == 'selfattn': |
|
query = key = value = x |
|
x = self.self_attn( |
|
query, |
|
key, |
|
value, |
|
inp_residual if self.pre_norm else None, |
|
query_pos, |
|
key_pos=query_pos, |
|
attn_mask=target_attn_mask, |
|
key_padding_mask=target_key_padding_mask) |
|
inp_residual = x |
|
elif layer == 'norm': |
|
x = self.norms[norm_cnt](x) |
|
norm_cnt += 1 |
|
elif layer == 'multiheadattn': |
|
query = x |
|
key = value = memory |
|
x = self.multihead_attn( |
|
query, |
|
key, |
|
value, |
|
inp_residual if self.pre_norm else None, |
|
query_pos, |
|
key_pos=memory_pos, |
|
attn_mask=memory_attn_mask, |
|
key_padding_mask=memory_key_padding_mask) |
|
inp_residual = x |
|
elif layer == 'ffn': |
|
x = self.ffn(x, inp_residual if self.pre_norm else None) |
|
return x |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'order={self.order}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg}, ' |
|
repr_str += f'num_fcs={self.num_fcs})' |
|
return repr_str |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
"""Implements the encoder in DETR transformer. |
|
|
|
Args: |
|
num_layers (int): The number of `TransformerEncoderLayer`. |
|
embed_dims (int): Same as `TransformerEncoderLayer`. |
|
num_heads (int): Same as `TransformerEncoderLayer`. |
|
feedforward_channels (int): Same as `TransformerEncoderLayer`. |
|
dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. |
|
order (tuple[str]): Same as `TransformerEncoderLayer`. |
|
act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU. |
|
norm_cfg (dict): Same as `TransformerEncoderLayer`. Default |
|
layer normalization. |
|
num_fcs (int): Same as `TransformerEncoderLayer`. Default 2. |
|
""" |
|
|
|
def __init__(self, |
|
num_layers, |
|
embed_dims, |
|
num_heads, |
|
feedforward_channels, |
|
dropout=0.0, |
|
order=('selfattn', 'norm', 'ffn', 'norm'), |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
num_fcs=2): |
|
super(TransformerEncoder, self).__init__() |
|
assert isinstance(order, tuple) and len(order) == 4 |
|
assert set(order) == set(['selfattn', 'norm', 'ffn']) |
|
self.num_layers = num_layers |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.feedforward_channels = feedforward_channels |
|
self.dropout = dropout |
|
self.order = order |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.num_fcs = num_fcs |
|
self.pre_norm = order[0] == 'norm' |
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.layers.append( |
|
TransformerEncoderLayer(embed_dims, num_heads, |
|
feedforward_channels, dropout, order, |
|
act_cfg, norm_cfg, num_fcs)) |
|
self.norm = build_norm_layer(norm_cfg, |
|
embed_dims)[1] if self.pre_norm else None |
|
|
|
def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): |
|
"""Forward function for `TransformerEncoder`. |
|
|
|
Args: |
|
x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`. |
|
pos (Tensor): Positional encoding for query. Default None. |
|
Same in `TransformerEncoderLayer.forward`. |
|
attn_mask (Tensor): ByteTensor attention mask. Default None. |
|
Same in `TransformerEncoderLayer.forward`. |
|
key_padding_mask (Tensor): Same in |
|
`TransformerEncoderLayer.forward`. Default None. |
|
|
|
Returns: |
|
Tensor: Results with shape [num_key, bs, embed_dims]. |
|
""" |
|
for layer in self.layers: |
|
x = layer(x, pos, attn_mask, key_padding_mask) |
|
if self.norm is not None: |
|
x = self.norm(x) |
|
return x |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(num_layers={self.num_layers}, ' |
|
repr_str += f'embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'order={self.order}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg}, ' |
|
repr_str += f'num_fcs={self.num_fcs})' |
|
return repr_str |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
"""Implements the decoder in DETR transformer. |
|
|
|
Args: |
|
num_layers (int): The number of `TransformerDecoderLayer`. |
|
embed_dims (int): Same as `TransformerDecoderLayer`. |
|
num_heads (int): Same as `TransformerDecoderLayer`. |
|
feedforward_channels (int): Same as `TransformerDecoderLayer`. |
|
dropout (float): Same as `TransformerDecoderLayer`. Default 0.0. |
|
order (tuple[str]): Same as `TransformerDecoderLayer`. |
|
act_cfg (dict): Same as `TransformerDecoderLayer`. Default ReLU. |
|
norm_cfg (dict): Same as `TransformerDecoderLayer`. Default |
|
layer normalization. |
|
num_fcs (int): Same as `TransformerDecoderLayer`. Default 2. |
|
""" |
|
|
|
def __init__(self, |
|
num_layers, |
|
embed_dims, |
|
num_heads, |
|
feedforward_channels, |
|
dropout=0.0, |
|
order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', |
|
'norm'), |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
num_fcs=2, |
|
return_intermediate=False): |
|
super(TransformerDecoder, self).__init__() |
|
assert isinstance(order, tuple) and len(order) == 6 |
|
assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) |
|
self.num_layers = num_layers |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.feedforward_channels = feedforward_channels |
|
self.dropout = dropout |
|
self.order = order |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.num_fcs = num_fcs |
|
self.return_intermediate = return_intermediate |
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.layers.append( |
|
TransformerDecoderLayer(embed_dims, num_heads, |
|
feedforward_channels, dropout, order, |
|
act_cfg, norm_cfg, num_fcs)) |
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
|
def forward(self, |
|
x, |
|
memory, |
|
memory_pos=None, |
|
query_pos=None, |
|
memory_attn_mask=None, |
|
target_attn_mask=None, |
|
memory_key_padding_mask=None, |
|
target_key_padding_mask=None): |
|
"""Forward function for `TransformerDecoder`. |
|
|
|
Args: |
|
x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`. |
|
memory (Tensor): Same in `TransformerDecoderLayer.forward`. |
|
memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`. |
|
Default None. |
|
query_pos (Tensor): Same in `TransformerDecoderLayer.forward`. |
|
Default None. |
|
memory_attn_mask (Tensor): Same in |
|
`TransformerDecoderLayer.forward`. Default None. |
|
target_attn_mask (Tensor): Same in |
|
`TransformerDecoderLayer.forward`. Default None. |
|
memory_key_padding_mask (Tensor): Same in |
|
`TransformerDecoderLayer.forward`. Default None. |
|
target_key_padding_mask (Tensor): Same in |
|
`TransformerDecoderLayer.forward`. Default None. |
|
|
|
Returns: |
|
Tensor: Results with shape [num_query, bs, embed_dims]. |
|
""" |
|
intermediate = [] |
|
for layer in self.layers: |
|
x = layer(x, memory, memory_pos, query_pos, memory_attn_mask, |
|
target_attn_mask, memory_key_padding_mask, |
|
target_key_padding_mask) |
|
if self.return_intermediate: |
|
intermediate.append(self.norm(x)) |
|
if self.norm is not None: |
|
x = self.norm(x) |
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(x) |
|
if self.return_intermediate: |
|
return torch.stack(intermediate) |
|
return x.unsqueeze(0) |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(num_layers={self.num_layers}, ' |
|
repr_str += f'embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'order={self.order}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg}, ' |
|
repr_str += f'num_fcs={self.num_fcs}, ' |
|
repr_str += f'return_intermediate={self.return_intermediate})' |
|
return repr_str |
|
|
|
|
|
@TRANSFORMER.register_module() |
|
class Transformer(nn.Module): |
|
"""Implements the DETR transformer. |
|
|
|
Following the official DETR implementation, this module copy-paste |
|
from torch.nn.Transformer with modifications: |
|
|
|
* positional encodings are passed in MultiheadAttention |
|
* extra LN at the end of encoder is removed |
|
* decoder returns a stack of activations from all decoding layers |
|
|
|
See `paper: End-to-End Object Detection with Transformers |
|
<https://arxiv.org/pdf/2005.12872>`_ for details. |
|
|
|
Args: |
|
embed_dims (int): The feature dimension. |
|
num_heads (int): Parallel attention heads. Same as |
|
`nn.MultiheadAttention`. |
|
num_encoder_layers (int): Number of `TransformerEncoderLayer`. |
|
num_decoder_layers (int): Number of `TransformerDecoderLayer`. |
|
feedforward_channels (int): The hidden dimension for FFNs used in both |
|
encoder and decoder. |
|
dropout (float): Probability of an element to be zeroed. Default 0.0. |
|
act_cfg (dict): Activation config for FFNs used in both encoder |
|
and decoder. Default ReLU. |
|
norm_cfg (dict): Config dict for normalization used in both encoder |
|
and decoder. Default layer normalization. |
|
num_fcs (int): The number of fully-connected layers in FFNs, which is |
|
used for both encoder and decoder. |
|
pre_norm (bool): Whether the normalization layer is ordered |
|
first in the encoder and decoder. Default False. |
|
return_intermediate_dec (bool): Whether to return the intermediate |
|
output from each TransformerDecoderLayer or only the last |
|
TransformerDecoderLayer. Default False. If False, the returned |
|
`hs` has shape [num_decoder_layers, bs, num_query, embed_dims]. |
|
If True, the returned `hs` will have shape [1, bs, num_query, |
|
embed_dims]. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims=512, |
|
num_heads=8, |
|
num_encoder_layers=6, |
|
num_decoder_layers=6, |
|
feedforward_channels=2048, |
|
dropout=0.0, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
num_fcs=2, |
|
pre_norm=False, |
|
return_intermediate_dec=False): |
|
super(Transformer, self).__init__() |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.num_encoder_layers = num_encoder_layers |
|
self.num_decoder_layers = num_decoder_layers |
|
self.feedforward_channels = feedforward_channels |
|
self.dropout = dropout |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.num_fcs = num_fcs |
|
self.pre_norm = pre_norm |
|
self.return_intermediate_dec = return_intermediate_dec |
|
if self.pre_norm: |
|
encoder_order = ('norm', 'selfattn', 'norm', 'ffn') |
|
decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn', |
|
'norm', 'ffn') |
|
else: |
|
encoder_order = ('selfattn', 'norm', 'ffn', 'norm') |
|
decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm', |
|
'ffn', 'norm') |
|
self.encoder = TransformerEncoder(num_encoder_layers, embed_dims, |
|
num_heads, feedforward_channels, |
|
dropout, encoder_order, act_cfg, |
|
norm_cfg, num_fcs) |
|
self.decoder = TransformerDecoder(num_decoder_layers, embed_dims, |
|
num_heads, feedforward_channels, |
|
dropout, decoder_order, act_cfg, |
|
norm_cfg, num_fcs, |
|
return_intermediate_dec) |
|
|
|
def init_weights(self, distribution='uniform'): |
|
"""Initialize the transformer weights.""" |
|
|
|
for m in self.modules(): |
|
if hasattr(m, 'weight') and m.weight.dim() > 1: |
|
xavier_init(m, distribution=distribution) |
|
|
|
def forward(self, x, mask, query_embed, pos_embed): |
|
"""Forward function for `Transformer`. |
|
|
|
Args: |
|
x (Tensor): Input query with shape [bs, c, h, w] where |
|
c = embed_dims. |
|
mask (Tensor): The key_padding_mask used for encoder and decoder, |
|
with shape [bs, h, w]. |
|
query_embed (Tensor): The query embedding for decoder, with shape |
|
[num_query, c]. |
|
pos_embed (Tensor): The positional encoding for encoder and |
|
decoder, with the same shape as `x`. |
|
|
|
Returns: |
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
|
|
- out_dec: Output from decoder. If return_intermediate_dec \ |
|
is True output has shape [num_dec_layers, bs, |
|
num_query, embed_dims], else has shape [1, bs, \ |
|
num_query, embed_dims]. |
|
- memory: Output results from encoder, with shape \ |
|
[bs, embed_dims, h, w]. |
|
""" |
|
bs, c, h, w = x.shape |
|
x = x.flatten(2).permute(2, 0, 1) |
|
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) |
|
query_embed = query_embed.unsqueeze(1).repeat( |
|
1, bs, 1) |
|
mask = mask.flatten(1) |
|
memory = self.encoder( |
|
x, pos=pos_embed, attn_mask=None, key_padding_mask=mask) |
|
target = torch.zeros_like(query_embed) |
|
|
|
out_dec = self.decoder( |
|
target, |
|
memory, |
|
memory_pos=pos_embed, |
|
query_pos=query_embed, |
|
memory_attn_mask=None, |
|
target_attn_mask=None, |
|
memory_key_padding_mask=mask, |
|
target_key_padding_mask=None) |
|
out_dec = out_dec.transpose(1, 2) |
|
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) |
|
return out_dec, memory |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(embed_dims={self.embed_dims}, ' |
|
repr_str += f'num_heads={self.num_heads}, ' |
|
repr_str += f'num_encoder_layers={self.num_encoder_layers}, ' |
|
repr_str += f'num_decoder_layers={self.num_decoder_layers}, ' |
|
repr_str += f'feedforward_channels={self.feedforward_channels}, ' |
|
repr_str += f'dropout={self.dropout}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg}, ' |
|
repr_str += f'num_fcs={self.num_fcs}, ' |
|
repr_str += f'pre_norm={self.pre_norm}, ' |
|
repr_str += f'return_intermediate_dec={self.return_intermediate_dec})' |
|
return repr_str |
|
|
|
|
|
@TRANSFORMER.register_module() |
|
class DynamicConv(nn.Module): |
|
"""Implements Dynamic Convolution. |
|
|
|
This module generate parameters for each sample and |
|
use bmm to implement 1*1 convolution. Code is modified |
|
from the `official github repo <https://github.com/PeizeSun/ |
|
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ . |
|
|
|
Args: |
|
in_channels (int): The input feature channel. |
|
Defaults to 256. |
|
feat_channels (int): The inner feature channel. |
|
Defaults to 64. |
|
out_channels (int, optional): The output feature channel. |
|
When not specified, it will be set to `in_channels` |
|
by default |
|
input_feat_shape (int): The shape of input feature. |
|
Defaults to 7. |
|
act_cfg (dict): The activation config for DynamicConv. |
|
norm_cfg (dict): Config dict for normalization layer. Default |
|
layer normalization. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels=256, |
|
feat_channels=64, |
|
out_channels=None, |
|
input_feat_shape=7, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN')): |
|
super(DynamicConv, self).__init__() |
|
self.in_channels = in_channels |
|
self.feat_channels = feat_channels |
|
self.out_channels_raw = out_channels |
|
self.input_feat_shape = input_feat_shape |
|
self.act_cfg = act_cfg |
|
self.norm_cfg = norm_cfg |
|
self.out_channels = out_channels if out_channels else in_channels |
|
|
|
self.num_params_in = self.in_channels * self.feat_channels |
|
self.num_params_out = self.out_channels * self.feat_channels |
|
self.dynamic_layer = nn.Linear( |
|
self.in_channels, self.num_params_in + self.num_params_out) |
|
|
|
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] |
|
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] |
|
|
|
self.activation = build_activation_layer(act_cfg) |
|
|
|
num_output = self.out_channels * input_feat_shape**2 |
|
self.fc_layer = nn.Linear(num_output, self.out_channels) |
|
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] |
|
|
|
def forward(self, param_feature, input_feature): |
|
"""Forward function for `DynamicConv`. |
|
|
|
Args: |
|
param_feature (Tensor): The feature can be used |
|
to generate the parameter, has shape |
|
(num_all_proposals, in_channels). |
|
input_feature (Tensor): Feature that |
|
interact with parameters, has shape |
|
(num_all_proposals, in_channels, H, W). |
|
|
|
Returns: |
|
Tensor: The output feature has shape |
|
(num_all_proposals, out_channels). |
|
""" |
|
num_proposals = param_feature.size(0) |
|
input_feature = input_feature.view(num_proposals, self.in_channels, |
|
-1).permute(2, 0, 1) |
|
|
|
input_feature = input_feature.permute(1, 0, 2) |
|
parameters = self.dynamic_layer(param_feature) |
|
|
|
param_in = parameters[:, :self.num_params_in].view( |
|
-1, self.in_channels, self.feat_channels) |
|
param_out = parameters[:, -self.num_params_out:].view( |
|
-1, self.feat_channels, self.out_channels) |
|
|
|
|
|
|
|
|
|
features = torch.bmm(input_feature, param_in) |
|
features = self.norm_in(features) |
|
features = self.activation(features) |
|
|
|
|
|
features = torch.bmm(features, param_out) |
|
features = self.norm_out(features) |
|
features = self.activation(features) |
|
|
|
features = features.flatten(1) |
|
features = self.fc_layer(features) |
|
features = self.fc_norm(features) |
|
features = self.activation(features) |
|
|
|
return features |
|
|
|
def __repr__(self): |
|
"""str: a string that describes the module""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(in_channels={self.in_channels}, ' |
|
repr_str += f'feat_channels={self.feat_channels}, ' |
|
repr_str += f'out_channels={self.out_channels_raw}, ' |
|
repr_str += f'input_feat_shape={self.input_feat_shape}, ' |
|
repr_str += f'act_cfg={self.act_cfg}, ' |
|
repr_str += f'norm_cfg={self.norm_cfg})' |
|
return repr_str |
|
|