Spaces:
Sleeping
Sleeping
File size: 6,147 Bytes
b4d6f1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
""" Bilinear-Attention-Transform and Non-Local Attention
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
"""
import torch
from torch import nn
from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
class NonLocalAttn(nn.Module):
"""Spatial NL block for image classification.
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
"""
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
super(NonLocalAttn, self).__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.scale = in_channels ** -0.5 if use_scale else 1.0
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
self.norm = nn.BatchNorm2d(in_channels)
self.reset_parameters()
def forward(self, x):
shortcut = x
t = self.t(x)
p = self.p(x)
g = self.g(x)
B, C, H, W = t.size()
t = t.view(B, C, -1).permute(0, 2, 1)
p = p.view(B, C, -1)
g = g.view(B, C, -1).permute(0, 2, 1)
att = torch.bmm(t, p) * self.scale
att = F.softmax(att, dim=2)
x = torch.bmm(att, g)
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.z(x)
x = self.norm(x) + shortcut
return x
def reset_parameters(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if len(list(m.parameters())) > 1:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size
self.groups = groups
self.in_channels = in_channels
def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
x = x.view(B * C, block_size, block_size, t, t)
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
x = x.view(B, C, block_size * t, block_size * t)
return x
def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
p = p / p.sum(dim=3, keepdim=True)
q = q / q.sum(dim=2, keepdim=True)
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
p = p.view(B, C, self.block_size, self.block_size)
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
q = q.view(B, C, self.block_size, self.block_size)
p = self.resize_mat(p, H // self.block_size)
q = self.resize_mat(q, W // self.block_size)
y = p.matmul(x)
y = y.matmul(q)
y = self.conv2(y)
return y
class BatNonLocalAttn(nn.Module):
""" BAT
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
"""
def __init__(
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
super().__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x):
xl = self.conv1(x)
y = self.ba(xl)
y = self.conv2(y)
y = self.dropout(y)
return y + x
|