File size: 4,050 Bytes
82b70d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F


class Avg_ChannelAttention_n(nn.Module):
    def __init__(self, channels, r=4):
        super(Avg_ChannelAttention_n, self).__init__()
        self.avg_channel = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
            nn.Conv2d(channels, channels // r, 1, 1, 0),  # bz,C_out,1,1 -> bz,C_out/r,1,1
            nn.BatchNorm2d(channels // r),
            nn.ReLU(True),
            nn.Conv2d(channels // r, channels, 1, 1, 0),  # bz,C_out/r,1,1 -> bz,C_out,1,1
            nn.BatchNorm2d(channels),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.avg_channel(x)



class Avg_ChannelAttention(nn.Module):
    def __init__(self, channels, r=4):
        super(Avg_ChannelAttention, self).__init__()
        self.avg_channel = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
            nn.Conv2d(channels, channels // r, 1, 1, 0),  # bz,C_out,1,1 -> bz,C_out/r,1,1
            nn.ReLU(True),
            nn.Conv2d(channels // r, channels, 1, 1, 0),  # bz,C_out/r,1,1 -> bz,C_out,1,1
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.avg_channel(x)


class AttnContrastLayer(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
        super(AttnContrastLayer, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):

        out_normal = self.conv(x)

        theta = self.attn(x)

        kernel_w1 = self.conv.weight.sum(2).sum(2)

        kernel_w2 = kernel_w1[:, :, None, None]

        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return theta * out_center - out_normal



class AttnContrastLayer_n(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
        super(AttnContrastLayer_n, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention_n(channels)

    def forward(self, x):
        out_normal = self.conv(x)
        theta = self.attn(x)


        kernel_w1 = self.conv.weight.sum(2).sum(2)
        kernel_w2 = kernel_w1[:, :, None, None]

        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return theta * out_center - out_normal

class AttnContrastLayer_d(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=2, groups=1, bias=False):
        super(AttnContrastLayer_d, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):

        out_normal = self.conv(x)

        theta = self.attn(x)

        kernel_w1 = self.conv.weight.sum(2).sum(2)
        kernel_w2 = kernel_w1[:, :, None, None]
        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return out_center - theta * out_normal

class AtrousAttnWeight(nn.Module):
    def __init__(self, channels):
        super(AtrousAttnWeight, self).__init__()
        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):
        return self.attn(x)