File size: 1,223 Bytes
9b9b1dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
"""
Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.
In: HxWxC
Out: HxWxC (original channels are restored by multiplying the output with the original input)
"""
def __init__(self, in_channels, reduction_ratio=8, bias=True):
super().__init__()
self.squeezing = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Conv2d(
in_channels,
in_channels // reduction_ratio,
kernel_size=1,
padding=0,
bias=bias,
),
nn.PReLU(),
nn.Conv2d(
in_channels // reduction_ratio,
in_channels,
kernel_size=1,
padding=0,
bias=bias,
),
nn.Sigmoid(),
)
def forward(self, x):
squeezed_x = self.squeezing(x) # 1x1xC
excitation = self.excitation(squeezed_x) # 1x1x(C/r)
return (
excitation * x
) # HxWxC restored through the mult. with the original input
|