csmithxc's picture
Upload 146 files
1530901 verified
raw
history blame
No virus
13.7 kB
# Copyright (c) Tencent Inc. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Linear
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmengine.model import BaseModule
from mmyolo.registry import MODELS
from mmyolo.models.layers import CSPLayerWithTwoConv
@MODELS.register_module()
class MaxSigmoidAttnBlock(BaseModule):
"""Max Sigmoid attention block."""
def __init__(self,
in_channels: int,
out_channels: int,
guide_channels: int,
embed_channels: int,
kernel_size: int = 3,
padding: int = 1,
num_heads: int = 1,
use_depthwise: bool = False,
with_scale: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN',
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
assert (out_channels % num_heads == 0 and
embed_channels % num_heads == 0), \
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.use_einsum = use_einsum
self.embed_conv = ConvModule(
in_channels,
embed_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None) if embed_channels != in_channels else None
self.guide_fc = Linear(guide_channels, embed_channels)
self.bias = nn.Parameter(torch.zeros(num_heads))
if with_scale:
self.scale = nn.Parameter(torch.ones(1, num_heads, 1, 1))
else:
self.scale = 1.0
self.project_conv = conv(in_channels,
out_channels,
kernel_size,
stride=1,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
B, _, H, W = x.shape
guide = self.guide_fc(guide)
guide = guide.reshape(B, -1, self.num_heads, self.head_channels)
embed = self.embed_conv(x) if self.embed_conv is not None else x
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
if self.use_einsum:
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide)
else:
batch, m, channel, height, width = embed.shape
_, n, _, _ = guide.shape
embed = embed.permute(0, 1, 3, 4, 2)
embed = embed.reshape(batch, m, -1, channel)
guide = guide.permute(0, 2, 3, 1)
attn_weight = torch.matmul(embed, guide)
attn_weight = attn_weight.reshape(batch, m, height, width, n)
attn_weight = attn_weight.max(dim=-1)[0]
attn_weight = attn_weight / (self.head_channels**0.5)
attn_weight = attn_weight + self.bias[None, :, None, None]
attn_weight = attn_weight.sigmoid() * self.scale
x = self.project_conv(x)
x = x.reshape(B, self.num_heads, -1, H, W)
x = x * attn_weight.unsqueeze(2)
x = x.reshape(B, -1, H, W)
return x
@MODELS.register_module()
class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv):
"""Sigmoid-attention based CSP layer with two convolution layers."""
def __init__(
self,
in_channels: int,
out_channels: int,
guide_channels: int,
embed_channels: int,
num_heads: int = 1,
expand_ratio: float = 0.5,
num_blocks: int = 1,
with_scale: bool = False,
add_identity: bool = True, # shortcut
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
use_einsum: bool = True) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
num_blocks=num_blocks,
add_identity=add_identity,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.attn_block = MaxSigmoidAttnBlock(self.mid_channels,
self.mid_channels,
guide_channels=guide_channels,
embed_channels=embed_channels,
num_heads=num_heads,
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
use_einsum=use_einsum)
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
x_main = self.main_conv(x)
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
x_main.append(self.attn_block(x_main[-1], guide))
return self.final_conv(torch.cat(x_main, 1))
@MODELS.register_module()
class ImagePoolingAttentionModule(nn.Module):
def __init__(self,
image_channels: List[int],
text_channels: int,
embed_channels: int,
with_scale: bool = False,
num_feats: int = 3,
num_heads: int = 8,
pool_size: int = 3,
use_einsum: bool = True):
super().__init__()
self.text_channels = text_channels
self.embed_channels = embed_channels
self.num_heads = num_heads
self.num_feats = num_feats
self.head_channels = embed_channels // num_heads
self.pool_size = pool_size
self.use_einsum = use_einsum
if with_scale:
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
else:
self.scale = 1.0
self.projections = nn.ModuleList([
ConvModule(in_channels, embed_channels, 1, act_cfg=None)
for in_channels in image_channels
])
self.query = nn.Sequential(nn.LayerNorm(text_channels),
Linear(text_channels, embed_channels))
self.key = nn.Sequential(nn.LayerNorm(embed_channels),
Linear(embed_channels, embed_channels))
self.value = nn.Sequential(nn.LayerNorm(embed_channels),
Linear(embed_channels, embed_channels))
self.proj = Linear(embed_channels, text_channels)
self.image_pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((pool_size, pool_size))
for _ in range(num_feats)
])
def forward(self, text_features, image_features):
B = image_features[0].shape[0]
assert len(image_features) == self.num_feats
num_patches = self.pool_size**2
mlvl_image_features = [
pool(proj(x)).view(B, -1, num_patches)
for (x, proj, pool
) in zip(image_features, self.projections, self.image_pools)
]
mlvl_image_features = torch.cat(mlvl_image_features,
dim=-1).transpose(1, 2)
q = self.query(text_features)
k = self.key(mlvl_image_features)
v = self.value(mlvl_image_features)
q = q.reshape(B, -1, self.num_heads, self.head_channels)
k = k.reshape(B, -1, self.num_heads, self.head_channels)
v = v.reshape(B, -1, self.num_heads, self.head_channels)
if self.use_einsum:
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
else:
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 3, 1)
attn_weight = torch.matmul(q, k)
attn_weight = attn_weight / (self.head_channels**0.5)
attn_weight = F.softmax(attn_weight, dim=-1)
if self.use_einsum:
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
else:
v = v.permute(0, 2, 1, 3)
x = torch.matmul(attn_weight, v)
x = x.permute(0, 2, 1, 3)
x = self.proj(x.reshape(B, -1, self.embed_channels))
return x * self.scale + text_features
@MODELS.register_module()
class VanillaSigmoidBlock(BaseModule):
"""Sigmoid attention block."""
def __init__(self,
in_channels: int,
out_channels: int,
guide_channels: int,
embed_channels: int,
kernel_size: int = 3,
padding: int = 1,
num_heads: int = 1,
use_depthwise: bool = False,
with_scale: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN',
momentum=0.03,
eps=0.001),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
assert (out_channels % num_heads == 0 and
embed_channels % num_heads == 0), \
'out_channels and embed_channels should be divisible by num_heads.'
self.num_heads = num_heads
self.head_channels = out_channels // num_heads
self.project_conv = conv(in_channels,
out_channels,
kernel_size,
stride=1,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
x = self.project_conv(x)
x = x * x.sigmoid()
return x
@MODELS.register_module()
class EfficientCSPLayerWithTwoConv(CSPLayerWithTwoConv):
"""Sigmoid-attention based CSP layer with two convolution layers."""
def __init__(
self,
in_channels: int,
out_channels: int,
guide_channels: int,
embed_channels: int,
num_heads: int = 1,
expand_ratio: float = 0.5,
num_blocks: int = 1,
with_scale: bool = False,
add_identity: bool = True, # shortcut
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(in_channels=in_channels,
out_channels=out_channels,
expand_ratio=expand_ratio,
num_blocks=num_blocks,
add_identity=add_identity,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.attn_block = VanillaSigmoidBlock(self.mid_channels,
self.mid_channels,
guide_channels=guide_channels,
embed_channels=embed_channels,
num_heads=num_heads,
with_scale=with_scale,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
x_main = self.main_conv(x)
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
x_main.append(self.attn_block(x_main[-1], guide))
return self.final_conv(torch.cat(x_main, 1))