Robert001's picture
first commit
b334e29
raw
history blame
No virus
6.15 kB
import torch
from annotator.uniformer.mmcv.cnn import ConvModule, constant_init
from torch import nn as nn
from torch.nn import functional as F
class SelfAttentionBlock(nn.Module):
"""General self-attention block/non-local block.
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
query and value.
Args:
key_in_channels (int): Input channels of key feature.
query_in_channels (int): Input channels of query feature.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_downsample (nn.Module): Query downsample module.
key_downsample (nn.Module): Key downsample module.
key_query_num_convs (int): Number of convs for key/query projection.
value_num_convs (int): Number of convs for value projection.
matmul_norm (bool): Whether normalize attention map with sqrt of
channels
with_out (bool): Whether use out projection.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, key_in_channels, query_in_channels, channels,
out_channels, share_key_query, query_downsample,
key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out,
conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__()
if share_key_query:
assert key_in_channels == query_in_channels
self.key_in_channels = key_in_channels
self.query_in_channels = query_in_channels
self.out_channels = out_channels
self.channels = channels
self.share_key_query = share_key_query
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.key_project = self.build_project(
key_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if share_key_query:
self.query_project = self.key_project
else:
self.query_project = self.build_project(
query_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.value_project = self.build_project(
key_in_channels,
channels if with_out else out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if with_out:
self.out_project = self.build_project(
channels,
out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.out_project = None
self.query_downsample = query_downsample
self.key_downsample = key_downsample
self.matmul_norm = matmul_norm
self.init_weights()
def init_weights(self):
"""Initialize weight of later layer."""
if self.out_project is not None:
if not isinstance(self.out_project, ConvModule):
constant_init(self.out_project, 0)
def build_project(self, in_channels, channels, num_convs, use_conv_module,
conv_cfg, norm_cfg, act_cfg):
"""Build projection layer for key/query/value/out."""
if use_conv_module:
convs = [
ConvModule(
in_channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
]
for _ in range(num_convs - 1):
convs.append(
ConvModule(
channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
convs = [nn.Conv2d(in_channels, channels, 1)]
for _ in range(num_convs - 1):
convs.append(nn.Conv2d(channels, channels, 1))
if len(convs) > 1:
convs = nn.Sequential(*convs)
else:
convs = convs[0]
return convs
def forward(self, query_feats, key_feats):
"""Forward function."""
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous()
key = self.key_project(key_feats)
value = self.value_project(key_feats)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.reshape(*key.shape[:2], -1)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous()
sim_map = torch.matmul(query, key)
if self.matmul_norm:
sim_map = (self.channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
if self.out_project is not None:
context = self.out_project(context)
return context