|
""" |
|
3D Squeeze and Excitation Modules |
|
***************************** |
|
3D Extensions of the following 2D squeeze and excitation blocks: |
|
1. `Channel Squeeze and Excitation <https://arxiv.org/abs/1709.01507>`_ |
|
2. `Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_ |
|
3. `Channel and Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_ |
|
New Project & Excite block, designed specifically for 3D inputs |
|
'quote' |
|
Coded by -- Anne-Marie Rickmann (https://github.com/arickm) |
|
""" |
|
|
|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class ChannelSELayer3D(nn.Module): |
|
""" |
|
3D extension of Squeeze-and-Excitation (SE) block described in: |
|
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* |
|
*Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238* |
|
""" |
|
|
|
def __init__(self, num_channels, reduction_ratio=2): |
|
""" |
|
Args: |
|
num_channels (int): No of input channels |
|
reduction_ratio (int): By how much should the num_channels should be reduced |
|
""" |
|
super(ChannelSELayer3D, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool3d(1) |
|
num_channels_reduced = num_channels // reduction_ratio |
|
self.reduction_ratio = reduction_ratio |
|
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) |
|
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) |
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
batch_size, num_channels, D, H, W = x.size() |
|
|
|
squeeze_tensor = self.avg_pool(x) |
|
|
|
|
|
fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) |
|
fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) |
|
|
|
output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) |
|
|
|
return output_tensor |
|
|
|
|
|
class SpatialSELayer3D(nn.Module): |
|
""" |
|
3D extension of SE block -- squeezing spatially and exciting channel-wise described in: |
|
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* |
|
""" |
|
|
|
def __init__(self, num_channels): |
|
""" |
|
Args: |
|
num_channels (int): No of input channels |
|
""" |
|
super(SpatialSELayer3D, self).__init__() |
|
self.conv = nn.Conv3d(num_channels, 1, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x, weights=None): |
|
""" |
|
Args: |
|
weights (torch.Tensor): weights for few shot learning |
|
x: X, shape = (batch_size, num_channels, D, H, W) |
|
|
|
Returns: |
|
(torch.Tensor): output_tensor |
|
""" |
|
|
|
batch_size, channel, D, H, W = x.size() |
|
|
|
if weights: |
|
weights = weights.view(1, channel, 1, 1) |
|
out = F.conv2d(x, weights) |
|
else: |
|
out = self.conv(x) |
|
|
|
squeeze_tensor = self.sigmoid(out) |
|
|
|
|
|
output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W)) |
|
|
|
return output_tensor |
|
|
|
|
|
class ChannelSpatialSELayer3D(nn.Module): |
|
""" |
|
3D extension of concurrent spatial and channel squeeze & excitation: |
|
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* |
|
""" |
|
|
|
def __init__(self, num_channels, reduction_ratio=2): |
|
""" |
|
Args: |
|
num_channels (int): No of input channels |
|
reduction_ratio (int): By how much should the num_channels should be reduced |
|
""" |
|
super(ChannelSpatialSELayer3D, self).__init__() |
|
self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) |
|
self.sSE = SpatialSELayer3D(num_channels) |
|
|
|
def forward(self, input_tensor): |
|
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) |
|
return output_tensor |
|
|