File size: 4,050 Bytes
82b70d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class Avg_ChannelAttention_n(nn.Module):
def __init__(self, channels, r=4):
super(Avg_ChannelAttention_n, self).__init__()
self.avg_channel = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
nn.Conv2d(channels, channels // r, 1, 1, 0), # bz,C_out,1,1 -> bz,C_out/r,1,1
nn.BatchNorm2d(channels // r),
nn.ReLU(True),
nn.Conv2d(channels // r, channels, 1, 1, 0), # bz,C_out/r,1,1 -> bz,C_out,1,1
nn.BatchNorm2d(channels),
nn.Sigmoid(),
)
def forward(self, x):
return self.avg_channel(x)
class Avg_ChannelAttention(nn.Module):
def __init__(self, channels, r=4):
super(Avg_ChannelAttention, self).__init__()
self.avg_channel = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
nn.Conv2d(channels, channels // r, 1, 1, 0), # bz,C_out,1,1 -> bz,C_out/r,1,1
nn.ReLU(True),
nn.Conv2d(channels // r, channels, 1, 1, 0), # bz,C_out/r,1,1 -> bz,C_out,1,1
nn.Sigmoid(),
)
def forward(self, x):
return self.avg_channel(x)
class AttnContrastLayer(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
super(AttnContrastLayer, self).__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.attn = Avg_ChannelAttention(channels)
def forward(self, x):
out_normal = self.conv(x)
theta = self.attn(x)
kernel_w1 = self.conv.weight.sum(2).sum(2)
kernel_w2 = kernel_w1[:, :, None, None]
out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
padding=0, groups=self.conv.groups)
return theta * out_center - out_normal
class AttnContrastLayer_n(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
super(AttnContrastLayer_n, self).__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.attn = Avg_ChannelAttention_n(channels)
def forward(self, x):
out_normal = self.conv(x)
theta = self.attn(x)
kernel_w1 = self.conv.weight.sum(2).sum(2)
kernel_w2 = kernel_w1[:, :, None, None]
out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
padding=0, groups=self.conv.groups)
return theta * out_center - out_normal
class AttnContrastLayer_d(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=2, groups=1, bias=False):
super(AttnContrastLayer_d, self).__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.attn = Avg_ChannelAttention(channels)
def forward(self, x):
out_normal = self.conv(x)
theta = self.attn(x)
kernel_w1 = self.conv.weight.sum(2).sum(2)
kernel_w2 = kernel_w1[:, :, None, None]
out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
padding=0, groups=self.conv.groups)
return out_center - theta * out_normal
class AtrousAttnWeight(nn.Module):
def __init__(self, channels):
super(AtrousAttnWeight, self).__init__()
self.attn = Avg_ChannelAttention(channels)
def forward(self, x):
return self.attn(x)
|