# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py import torch import torch.nn as nn import torch.nn.functional as F class BasicConv(nn.Module): def __init__( self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) def forward(self, x): x = self.conv(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels), ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type == "avg": avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) ) channel_att_raw = self.mlp(avg_pool) elif pool_type == "max": max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) ) channel_att_raw = self.mlp(max_pool) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv( 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2 ) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = torch.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__( self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"], no_spatial=False, ): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial = no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out