Spaces:
Running
Running
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BasicConv(nn.Module): | |
def __init__( | |
self, | |
in_planes, | |
out_planes, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
bias=True, | |
): | |
super(BasicConv, self).__init__() | |
self.out_channels = out_planes | |
self.conv = nn.Conv2d( | |
in_planes, | |
out_planes, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class ChannelGate(nn.Module): | |
def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]): | |
super(ChannelGate, self).__init__() | |
self.gate_channels = gate_channels | |
self.mlp = nn.Sequential( | |
Flatten(), | |
nn.Linear(gate_channels, gate_channels // reduction_ratio), | |
nn.ReLU(), | |
nn.Linear(gate_channels // reduction_ratio, gate_channels), | |
) | |
self.pool_types = pool_types | |
def forward(self, x): | |
channel_att_sum = None | |
for pool_type in self.pool_types: | |
if pool_type == "avg": | |
avg_pool = F.avg_pool2d( | |
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) | |
) | |
channel_att_raw = self.mlp(avg_pool) | |
elif pool_type == "max": | |
max_pool = F.max_pool2d( | |
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) | |
) | |
channel_att_raw = self.mlp(max_pool) | |
if channel_att_sum is None: | |
channel_att_sum = channel_att_raw | |
else: | |
channel_att_sum = channel_att_sum + channel_att_raw | |
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) | |
return x * scale | |
class ChannelPool(nn.Module): | |
def forward(self, x): | |
return torch.cat( | |
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 | |
) | |
class SpatialGate(nn.Module): | |
def __init__(self): | |
super(SpatialGate, self).__init__() | |
kernel_size = 7 | |
self.compress = ChannelPool() | |
self.spatial = BasicConv( | |
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2 | |
) | |
def forward(self, x): | |
x_compress = self.compress(x) | |
x_out = self.spatial(x_compress) | |
scale = torch.sigmoid(x_out) # broadcasting | |
return x * scale | |
class CBAM(nn.Module): | |
def __init__( | |
self, | |
gate_channels, | |
reduction_ratio=16, | |
pool_types=["avg", "max"], | |
no_spatial=False, | |
): | |
super(CBAM, self).__init__() | |
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) | |
self.no_spatial = no_spatial | |
if not no_spatial: | |
self.SpatialGate = SpatialGate() | |
def forward(self, x): | |
x_out = self.ChannelGate(x) | |
if not self.no_spatial: | |
x_out = self.SpatialGate(x_out) | |
return x_out | |