| | import torch.nn as nn
|
| |
|
| | """
|
| | This code refers to "Pyramid attention network for semantic segmentation", that is
|
| | "https://github.com/JaveyWang/Pyramid-Attention-Networks-pytorch/blob/f719365c1780f062058dd0c94550c6c4766cd937/networks.py#L41"
|
| | """
|
| |
|
| | class FPM(nn.Module):
|
| | def __init__(self, channels=1024):
|
| | """
|
| | Feature Pyramid Attention
|
| | :type channels: int
|
| | """
|
| | super(FPM, self).__init__()
|
| | channels_mid = int(channels/4)
|
| |
|
| | self.channels_cond = channels
|
| |
|
| | self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
|
| | self.bn_master = nn.BatchNorm2d(channels)
|
| |
|
| |
|
| | self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
|
| | self.bn1_1 = nn.BatchNorm2d(channels_mid)
|
| | self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
|
| | self.bn2_1 = nn.BatchNorm2d(channels_mid)
|
| | self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
|
| | self.bn3_1 = nn.BatchNorm2d(channels_mid)
|
| |
|
| | self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
|
| | self.bn1_2 = nn.BatchNorm2d(channels_mid)
|
| | self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
|
| | self.bn2_2 = nn.BatchNorm2d(channels_mid)
|
| | self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
|
| | self.bn3_2 = nn.BatchNorm2d(channels_mid)
|
| |
|
| |
|
| | self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
|
| | self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)
|
| |
|
| | self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
|
| | self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)
|
| |
|
| | self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
|
| | self.bn_upsample_1 = nn.BatchNorm2d(channels)
|
| |
|
| | self.relu = nn.ReLU(inplace=False)
|
| |
|
| | def forward(self, x):
|
| | """
|
| | :param x: Shape: [b, 2048, h, w]
|
| | :return: out: Feature maps. Shape: [b, 2048, h, w]
|
| | """
|
| |
|
| | x_master = self.conv_master(x)
|
| | x_master = self.bn_master(x_master)
|
| |
|
| |
|
| | x1_1 = self.conv7x7_1(x)
|
| | x1_1 = self.bn1_1(x1_1)
|
| | x1_1 = self.relu(x1_1)
|
| | x1_2 = self.conv7x7_2(x1_1)
|
| | x1_2 = self.bn1_2(x1_2)
|
| |
|
| |
|
| | x2_1 = self.conv5x5_1(x1_1)
|
| | x2_1 = self.bn2_1(x2_1)
|
| | x2_1 = self.relu(x2_1)
|
| | x2_2 = self.conv5x5_2(x2_1)
|
| | x2_2 = self.bn2_2(x2_2)
|
| |
|
| |
|
| | x3_1 = self.conv3x3_1(x2_1)
|
| | x3_1 = self.bn3_1(x3_1)
|
| | x3_1 = self.relu(x3_1)
|
| | x3_2 = self.conv3x3_2(x3_1)
|
| | x3_2 = self.bn3_2(x3_2)
|
| |
|
| |
|
| | x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
|
| | x2_merge = self.relu(x2_2 + x3_upsample)
|
| | x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
|
| | x1_merge = self.relu(x1_2 + x2_upsample)
|
| |
|
| | x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))
|
| |
|
| | out = self.relu(x_master)
|
| |
|
| | return out |