import torch import torch.nn as nn ########################################################################## def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride) return layer def conv3x3(in_chn, out_chn, bias=True): layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) return layer def conv_down(in_chn, out_chn, bias=False): layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) return layer ########################################################################## ## Supervised Attention Module (RAM) class SAM(nn.Module): def __init__(self, n_feat, kernel_size, bias): super(SAM, self).__init__() self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) self.conv3 = conv(3, n_feat, kernel_size, bias=bias) def forward(self, x, x_img): x1 = self.conv1(x) img = self.conv2(x) + x_img x2 = torch.sigmoid(self.conv3(img)) x1 = x1 * x2 x1 = x1 + x return x1, img ########################################################################## ## Spatial Attention class SALayer(nn.Module): def __init__(self, kernel_size=7): super(SALayer, self).__init__() self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) y = torch.cat([avg_out, max_out], dim=1) y = self.conv1(y) y = self.sigmoid(y) return x * y # Spatial Attention Block (SAB) class SAB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, bias, act): super(SAB, self).__init__() modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] self.body = nn.Sequential(*modules_body) self.SA = SALayer(kernel_size=7) def forward(self, x): res = self.body(x) res = self.SA(res) res += x return res ########################################################################## ## Pixel Attention class PALayer(nn.Module): def __init__(self, channel, reduction=16, bias=False): super(PALayer, self).__init__() self.pa = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), # channel <-> 1 nn.Sigmoid() ) def forward(self, x): y = self.pa(x) return x * y ## Pixel Attention Block (PAB) class PAB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, bias, act): super(PAB, self).__init__() modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] self.PA = PALayer(n_feat, reduction, bias=bias) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res = self.PA(res) res += x return res ########################################################################## ## Channel Attention Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16, bias=False): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Channel Attention Block (CAB) class CAB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, bias, act): super(CAB, self).__init__() modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)] self.CA = CALayer(n_feat, reduction, bias=bias) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res = self.CA(res) res += x return res if __name__ == "__main__": import time from thop import profile # layer = CAB(64, 3, 4, False, nn.PReLU()) layer = PAB(64, 3, 4, False, nn.PReLU()) # layer = SAB(64, 3, 4, False, nn.PReLU()) for idx, m in enumerate(layer.modules()): print(idx, "-", m) s = time.time() rgb = torch.ones(1, 64, 256, 256, dtype=torch.float, requires_grad=False) out = layer(rgb) flops, params = profile(layer, inputs=(rgb,)) print('parameters:', params) print('flops', flops) print('time: {:.4f}ms'.format((time.time()-s)*10))