Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
########################################################################## | |
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): | |
layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride) | |
return layer | |
def conv3x3(in_chn, out_chn, bias=True): | |
layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) | |
return layer | |
def conv_down(in_chn, out_chn, bias=False): | |
layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) | |
return layer | |
########################################################################## | |
## Supervised Attention Module (RAM) | |
class SAM(nn.Module): | |
def __init__(self, n_feat, kernel_size, bias): | |
super(SAM, self).__init__() | |
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) | |
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) | |
self.conv3 = conv(3, n_feat, kernel_size, bias=bias) | |
def forward(self, x, x_img): | |
x1 = self.conv1(x) | |
img = self.conv2(x) + x_img | |
x2 = torch.sigmoid(self.conv3(img)) | |
x1 = x1 * x2 | |
x1 = x1 + x | |
return x1, img | |
########################################################################## | |
## Spatial Attention | |
class SALayer(nn.Module): | |
def __init__(self, kernel_size=7): | |
super(SALayer, self).__init__() | |
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
avg_out = torch.mean(x, dim=1, keepdim=True) | |
max_out, _ = torch.max(x, dim=1, keepdim=True) | |
y = torch.cat([avg_out, max_out], dim=1) | |
y = self.conv1(y) | |
y = self.sigmoid(y) | |
return x * y | |
# Spatial Attention Block (SAB) | |
class SAB(nn.Module): | |
def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
super(SAB, self).__init__() | |
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
self.body = nn.Sequential(*modules_body) | |
self.SA = SALayer(kernel_size=7) | |
def forward(self, x): | |
res = self.body(x) | |
res = self.SA(res) | |
res += x | |
return res | |
########################################################################## | |
## Pixel Attention | |
class PALayer(nn.Module): | |
def __init__(self, channel, reduction=16, bias=False): | |
super(PALayer, self).__init__() | |
self.pa = nn.Sequential( | |
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), # channel <-> 1 | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
y = self.pa(x) | |
return x * y | |
## Pixel Attention Block (PAB) | |
class PAB(nn.Module): | |
def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
super(PAB, self).__init__() | |
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
self.PA = PALayer(n_feat, reduction, bias=bias) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
res = self.body(x) | |
res = self.PA(res) | |
res += x | |
return res | |
########################################################################## | |
## Channel Attention Layer | |
class CALayer(nn.Module): | |
def __init__(self, channel, reduction=16, bias=False): | |
super(CALayer, self).__init__() | |
# global average pooling: feature --> point | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
# feature channel downscale and upscale --> channel weight | |
self.conv_du = nn.Sequential( | |
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
y = self.avg_pool(x) | |
y = self.conv_du(y) | |
return x * y | |
## Channel Attention Block (CAB) | |
class CAB(nn.Module): | |
def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
super(CAB, self).__init__() | |
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] | |
self.CA = CALayer(n_feat, reduction, bias=bias) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
res = self.body(x) | |
res = self.CA(res) | |
res += x | |
return res | |
if __name__ == "__main__": | |
import time | |
from thop import profile | |
# layer = CAB(64, 3, 4, False, nn.PReLU()) | |
layer = PAB(64, 3, 4, False, nn.PReLU()) | |
# layer = SAB(64, 3, 4, False, nn.PReLU()) | |
for idx, m in enumerate(layer.modules()): | |
print(idx, "-", m) | |
s = time.time() | |
rgb = torch.ones(1, 64, 256, 256, dtype=torch.float, requires_grad=False) | |
out = layer(rgb) | |
flops, params = profile(layer, inputs=(rgb,)) | |
print('parameters:', params) | |
print('flops', flops) | |
print('time: {:.4f}ms'.format((time.time()-s)*10)) |