asdasdasdasd commited on
Commit
2f0673d
1 Parent(s): b0df336

Upload attention.py

Browse files
Files changed (1) hide show
  1. attention.py +252 -0
attention.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ """
7
+ Channel Attention and Spaitial Attention from
8
+ Woo, S., Park, J., Lee, J.Y., & Kweon, I. CBAM: Convolutional Block Attention Module. ECCV2018.
9
+ """
10
+
11
+
12
+ class ChannelAttention(nn.Module):
13
+ def __init__(self, in_planes, ratio=8):
14
+ super(ChannelAttention, self).__init__()
15
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
16
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
17
+
18
+ self.sharedMLP = nn.Sequential(
19
+ nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
20
+ nn.ReLU(),
21
+ nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
22
+ self.sigmoid = nn.Sigmoid()
23
+
24
+ for m in self.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ nn.init.xavier_normal_(m.weight.data, gain=0.02)
27
+
28
+ def forward(self, x):
29
+ avgout = self.sharedMLP(self.avg_pool(x))
30
+ maxout = self.sharedMLP(self.max_pool(x))
31
+ return self.sigmoid(avgout + maxout)
32
+
33
+
34
+ class SpatialAttention(nn.Module):
35
+ def __init__(self, kernel_size=7):
36
+ super(SpatialAttention, self).__init__()
37
+ assert kernel_size in (3, 7), "kernel size must be 3 or 7"
38
+ padding = 3 if kernel_size == 7 else 1
39
+
40
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
41
+ self.sigmoid = nn.Sigmoid()
42
+
43
+ for m in self.modules():
44
+ if isinstance(m, nn.Conv2d):
45
+ nn.init.xavier_normal_(m.weight.data, gain=0.02)
46
+
47
+ def forward(self, x):
48
+ avgout = torch.mean(x, dim=1, keepdim=True)
49
+ maxout, _ = torch.max(x, dim=1, keepdim=True)
50
+ x = torch.cat([avgout, maxout], dim=1)
51
+ x = self.conv(x)
52
+ return self.sigmoid(x)
53
+
54
+
55
+ """
56
+ The following modules are modified based on https://github.com/heykeetae/Self-Attention-GAN
57
+ """
58
+
59
+
60
+ class Self_Attn(nn.Module):
61
+ """ Self attention Layer"""
62
+
63
+ def __init__(self, in_dim, out_dim=None, add=False, ratio=8):
64
+ super(Self_Attn, self).__init__()
65
+ self.chanel_in = in_dim
66
+ self.add = add
67
+ if out_dim is None:
68
+ out_dim = in_dim
69
+ self.out_dim = out_dim
70
+ # self.activation = activation
71
+
72
+ self.query_conv = nn.Conv2d(
73
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
74
+ self.key_conv = nn.Conv2d(
75
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
76
+ self.value_conv = nn.Conv2d(
77
+ in_channels=in_dim, out_channels=out_dim, kernel_size=1)
78
+ self.gamma = nn.Parameter(torch.zeros(1))
79
+
80
+ self.softmax = nn.Softmax(dim=-1)
81
+
82
+ def forward(self, x):
83
+ """
84
+ inputs :
85
+ x : input feature maps( B X C X W X H)
86
+ returns :
87
+ out : self attention value + input feature
88
+ attention: B X N X N (N is Width*Height)
89
+ """
90
+ m_batchsize, C, width, height = x.size()
91
+ proj_query = self.query_conv(x).view(
92
+ m_batchsize, -1, width*height).permute(0, 2, 1) # B X C X(N)
93
+ proj_key = self.key_conv(x).view(
94
+ m_batchsize, -1, width*height) # B X C x (*W*H)
95
+ energy = torch.bmm(proj_query, proj_key) # transpose check
96
+ attention = self.softmax(energy) # BX (N) X (N)
97
+ proj_value = self.value_conv(x).view(
98
+ m_batchsize, -1, width*height) # B X C X N
99
+
100
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
101
+ out = out.view(m_batchsize, self.out_dim, width, height)
102
+
103
+ if self.add:
104
+ out = self.gamma*out + x
105
+ else:
106
+ out = self.gamma*out
107
+ return out # , attention
108
+
109
+
110
+ class CrossModalAttention(nn.Module):
111
+ """ CMA attention Layer"""
112
+
113
+ def __init__(self, in_dim, activation=None, ratio=8, cross_value=True):
114
+ super(CrossModalAttention, self).__init__()
115
+ self.chanel_in = in_dim
116
+ self.activation = activation
117
+ self.cross_value = cross_value
118
+
119
+ self.query_conv = nn.Conv2d(
120
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
121
+ self.key_conv = nn.Conv2d(
122
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
123
+ self.value_conv = nn.Conv2d(
124
+ in_channels=in_dim, out_channels=in_dim, kernel_size=1)
125
+ self.gamma = nn.Parameter(torch.zeros(1))
126
+
127
+ self.softmax = nn.Softmax(dim=-1)
128
+
129
+ for m in self.modules():
130
+ if isinstance(m, nn.Conv2d):
131
+ nn.init.xavier_normal_(m.weight.data, gain=0.02)
132
+
133
+ def forward(self, x, y):
134
+ """
135
+ inputs :
136
+ x : input feature maps( B X C X W X H)
137
+ returns :
138
+ out : self attention value + input feature
139
+ attention: B X N X N (N is Width*Height)
140
+ """
141
+ B, C, H, W = x.size()
142
+
143
+ proj_query = self.query_conv(x).view(
144
+ B, -1, H*W).permute(0, 2, 1) # B , HW, C
145
+ proj_key = self.key_conv(y).view(
146
+ B, -1, H*W) # B X C x (*W*H)
147
+ energy = torch.bmm(proj_query, proj_key) # B, HW, HW
148
+ attention = self.softmax(energy) # BX (N) X (N)
149
+ if self.cross_value:
150
+ proj_value = self.value_conv(y).view(
151
+ B, -1, H*W) # B , C , HW
152
+ else:
153
+ proj_value = self.value_conv(x).view(
154
+ B, -1, H*W) # B , C , HW
155
+
156
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
157
+ out = out.view(B, C, H, W)
158
+
159
+ out = self.gamma*out + x
160
+
161
+ if self.activation is not None:
162
+ out = self.activation(out)
163
+
164
+ return out # , attention
165
+
166
+
167
+ class DualCrossModalAttention(nn.Module):
168
+ """ Dual CMA attention Layer"""
169
+
170
+ def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False):
171
+ super(DualCrossModalAttention, self).__init__()
172
+ self.chanel_in = in_dim
173
+ self.activation = activation
174
+ self.ret_att = ret_att
175
+
176
+ # query conv
177
+ self.key_conv1 = nn.Conv2d(
178
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
179
+ self.key_conv2 = nn.Conv2d(
180
+ in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
181
+ self.key_conv_share = nn.Conv2d(
182
+ in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1)
183
+
184
+ self.linear1 = nn.Linear(size*size, size*size)
185
+ self.linear2 = nn.Linear(size*size, size*size)
186
+
187
+ # separated value conv
188
+ self.value_conv1 = nn.Conv2d(
189
+ in_channels=in_dim, out_channels=in_dim, kernel_size=1)
190
+ self.gamma1 = nn.Parameter(torch.zeros(1))
191
+
192
+ self.value_conv2 = nn.Conv2d(
193
+ in_channels=in_dim, out_channels=in_dim, kernel_size=1)
194
+ self.gamma2 = nn.Parameter(torch.zeros(1))
195
+
196
+ self.softmax = nn.Softmax(dim=-1)
197
+
198
+ for m in self.modules():
199
+ if isinstance(m, nn.Conv2d):
200
+ nn.init.xavier_normal_(m.weight.data, gain=0.02)
201
+ if isinstance(m, nn.Linear):
202
+ nn.init.xavier_normal_(m.weight.data, gain=0.02)
203
+
204
+ def forward(self, x, y):
205
+ """
206
+ inputs :
207
+ x : input feature maps( B X C X W X H)
208
+ returns :
209
+ out : self attention value + input feature
210
+ attention: B X N X N (N is Width*Height)
211
+ """
212
+ B, C, H, W = x.size()
213
+
214
+ def _get_att(a, b):
215
+ proj_key1 = self.key_conv_share(self.key_conv1(a)).view(
216
+ B, -1, H*W).permute(0, 2, 1) # B, HW, C
217
+ proj_key2 = self.key_conv_share(self.key_conv2(b)).view(
218
+ B, -1, H*W) # B X C x (*W*H)
219
+ energy = torch.bmm(proj_key1, proj_key2) # B, HW, HW
220
+
221
+ attention1 = self.softmax(self.linear1(energy))
222
+ attention2 = self.softmax(self.linear2(
223
+ energy.permute(0, 2, 1))) # BX (N) X (N)
224
+
225
+ return attention1, attention2
226
+
227
+ att_y_on_x, att_x_on_y = _get_att(x, y)
228
+ proj_value_y_on_x = self.value_conv2(y).view(
229
+ B, -1, H*W) # B, C, HW
230
+ out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1))
231
+ out_y_on_x = out_y_on_x.view(B, C, H, W)
232
+ out_x = self.gamma1*out_y_on_x + x
233
+
234
+ proj_value_x_on_y = self.value_conv1(x).view(
235
+ B, -1, H*W) # B , C , HW
236
+ out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1))
237
+ out_x_on_y = out_x_on_y.view(B, C, H, W)
238
+ out_y = self.gamma2*out_x_on_y + y
239
+
240
+ if self.ret_att:
241
+ return out_x, out_y, att_y_on_x, att_x_on_y
242
+
243
+ return out_x, out_y # , attention
244
+
245
+
246
+ if __name__ == "__main__":
247
+ x = torch.rand(10, 768, 16, 16)
248
+ y = torch.rand(10, 768, 16, 16)
249
+ dcma = DualCrossModalAttention(768, ret_att=True)
250
+ out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y)
251
+ print(out_y.size())
252
+ print(att_x_on_y.size())