AlexZou commited on
Commit
fa0e517
1 Parent(s): c77984b

Upload 8 files

Browse files
Files changed (1) hide show
  1. net/CMSFFT.py +377 -0
net/CMSFFT.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Lintao Peng
3
+ # @File : CMSFFT.py
4
+ # coding=utf-8
5
+ # Design based on the CTrans
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+ import copy
10
+ import logging
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
16
+ from torch.nn.modules.utils import _pair
17
+
18
+
19
+ #KV_size = 480
20
+ #transformer.num_heads = 4
21
+ #transformer.num_layers = 4
22
+ #expand_ratio = 4
23
+
24
+
25
+
26
+ #线性编码
27
+ class Channel_Embeddings(nn.Module):
28
+ """Construct the embeddings from patch, position embeddings.
29
+ """
30
+ def __init__(self, patchsize, img_size, in_channels):
31
+ super().__init__()
32
+ img_size = _pair(img_size)
33
+ patch_size = _pair(patchsize)
34
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
35
+
36
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
37
+ out_channels=in_channels,
38
+ kernel_size=patch_size,
39
+ stride=patch_size)
40
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
41
+ self.dropout = Dropout(0.1)
42
+
43
+ def forward(self, x):
44
+ if x is None:
45
+ return None
46
+ x = self.patch_embeddings(x) # (B, hidden,n_patches^(1/2), n_patches^(1/2))
47
+ x = x.flatten(2)
48
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
49
+ embeddings = x + self.position_embeddings
50
+ embeddings = self.dropout(embeddings)
51
+ return embeddings
52
+
53
+
54
+ #特征重组
55
+ class Reconstruct(nn.Module):
56
+ def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
57
+ super(Reconstruct, self).__init__()
58
+ if kernel_size == 3:
59
+ padding = 1
60
+ else:
61
+ padding = 0
62
+ self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
63
+ self.norm = nn.BatchNorm2d(out_channels)
64
+ self.activation = nn.ReLU(inplace=True)
65
+ self.scale_factor = scale_factor
66
+
67
+ def forward(self, x):
68
+ if x is None:
69
+ return None
70
+
71
+ # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
72
+ B, n_patch, hidden = x.size()
73
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
74
+ x = x.permute(0, 2, 1)
75
+ x = x.contiguous().view(B, hidden, h, w)
76
+ x = nn.Upsample(scale_factor=self.scale_factor)(x)
77
+
78
+ out = self.conv(x)
79
+ out = self.norm(out)
80
+ out = self.activation(out)
81
+ return out
82
+
83
+ class Attention_org(nn.Module):
84
+ def __init__(self, vis,channel_num, KV_size=480, num_heads=4):
85
+ super(Attention_org, self).__init__()
86
+ self.vis = vis
87
+ self.KV_size = KV_size
88
+ self.channel_num = channel_num
89
+ self.num_attention_heads = num_heads
90
+
91
+ self.query1 = nn.ModuleList()
92
+ self.query2 = nn.ModuleList()
93
+ self.query3 = nn.ModuleList()
94
+ self.query4 = nn.ModuleList()
95
+ self.key = nn.ModuleList()
96
+ self.value = nn.ModuleList()
97
+
98
+ for _ in range(num_heads):
99
+ query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
100
+ query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
101
+ query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
102
+ query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
103
+ key = nn.Linear( self.KV_size, self.KV_size, bias=False)
104
+ value = nn.Linear(self.KV_size, self.KV_size, bias=False)
105
+ #把所有的值都重新复制一遍,deepcopy为深复制,完全脱离原来的值,即将被复制对象完全再复制一遍作为独立的新个体单独存在
106
+ self.query1.append(copy.deepcopy(query1))
107
+ self.query2.append(copy.deepcopy(query2))
108
+ self.query3.append(copy.deepcopy(query3))
109
+ self.query4.append(copy.deepcopy(query4))
110
+ self.key.append(copy.deepcopy(key))
111
+ self.value.append(copy.deepcopy(value))
112
+ self.psi = nn.InstanceNorm2d(self.num_attention_heads)
113
+ self.softmax = Softmax(dim=3)
114
+ self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
115
+ self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
116
+ self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
117
+ self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
118
+ self.attn_dropout = Dropout(0.1)
119
+ self.proj_dropout = Dropout(0.1)
120
+
121
+
122
+
123
+ def forward(self, emb1,emb2,emb3,emb4, emb_all):
124
+ multi_head_Q1_list = []
125
+ multi_head_Q2_list = []
126
+ multi_head_Q3_list = []
127
+ multi_head_Q4_list = []
128
+ multi_head_K_list = []
129
+ multi_head_V_list = []
130
+ if emb1 is not None:
131
+ for query1 in self.query1:
132
+ Q1 = query1(emb1)
133
+ multi_head_Q1_list.append(Q1)
134
+ if emb2 is not None:
135
+ for query2 in self.query2:
136
+ Q2 = query2(emb2)
137
+ multi_head_Q2_list.append(Q2)
138
+ if emb3 is not None:
139
+ for query3 in self.query3:
140
+ Q3 = query3(emb3)
141
+ multi_head_Q3_list.append(Q3)
142
+ if emb4 is not None:
143
+ for query4 in self.query4:
144
+ Q4 = query4(emb4)
145
+ multi_head_Q4_list.append(Q4)
146
+ for key in self.key:
147
+ K = key(emb_all)
148
+ multi_head_K_list.append(K)
149
+ for value in self.value:
150
+ V = value(emb_all)
151
+ multi_head_V_list.append(V)
152
+ # print(len(multi_head_Q4_list))
153
+
154
+ multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
155
+ multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
156
+ multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
157
+ multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
158
+ multi_head_K = torch.stack(multi_head_K_list, dim=1)
159
+ multi_head_V = torch.stack(multi_head_V_list, dim=1)
160
+
161
+ multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
162
+ multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
163
+ multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
164
+ multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
165
+
166
+ attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
167
+ attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
168
+ attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
169
+ attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
170
+
171
+ attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
172
+ attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
173
+ attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
174
+ attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
175
+
176
+ attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
177
+ attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
178
+ attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
179
+ attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
180
+ # print(attention_probs4.size())
181
+
182
+ if self.vis:
183
+ weights = []
184
+ weights.append(attention_probs1.mean(1))
185
+ weights.append(attention_probs2.mean(1))
186
+ weights.append(attention_probs3.mean(1))
187
+ weights.append(attention_probs4.mean(1))
188
+ else: weights=None
189
+
190
+ attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
191
+ attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
192
+ attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
193
+ attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
194
+
195
+ multi_head_V = multi_head_V.transpose(-1, -2)
196
+ context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
197
+ context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
198
+ context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
199
+ context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
200
+
201
+ context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
202
+ context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
203
+ context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
204
+ context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
205
+ context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
206
+ context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
207
+ context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
208
+ context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
209
+
210
+ O1 = self.out1(context_layer1) if emb1 is not None else None
211
+ O2 = self.out2(context_layer2) if emb2 is not None else None
212
+ O3 = self.out3(context_layer3) if emb3 is not None else None
213
+ O4 = self.out4(context_layer4) if emb4 is not None else None
214
+ O1 = self.proj_dropout(O1) if emb1 is not None else None
215
+ O2 = self.proj_dropout(O2) if emb2 is not None else None
216
+ O3 = self.proj_dropout(O3) if emb3 is not None else None
217
+ O4 = self.proj_dropout(O4) if emb4 is not None else None
218
+ return O1,O2,O3,O4, weights
219
+
220
+
221
+
222
+
223
+ class Mlp(nn.Module):
224
+ def __init__(self, in_channel, mlp_channel):
225
+ super(Mlp, self).__init__()
226
+ self.fc1 = nn.Linear(in_channel, mlp_channel)
227
+ self.fc2 = nn.Linear(mlp_channel, in_channel)
228
+ self.act_fn = nn.GELU()
229
+ self.dropout = Dropout(0.0)
230
+ self._init_weights()
231
+
232
+ def _init_weights(self):
233
+ nn.init.xavier_uniform_(self.fc1.weight)
234
+ nn.init.xavier_uniform_(self.fc2.weight)
235
+ nn.init.normal_(self.fc1.bias, std=1e-6)
236
+ nn.init.normal_(self.fc2.bias, std=1e-6)
237
+
238
+ def forward(self, x):
239
+ x = self.fc1(x)
240
+ x = self.act_fn(x)
241
+ x = self.dropout(x)
242
+ x = self.fc2(x)
243
+ x = self.dropout(x)
244
+ return x
245
+
246
+ class Block_ViT(nn.Module):
247
+ def __init__(self, vis, channel_num, expand_ratio=4,KV_size=480):
248
+ super(Block_ViT, self).__init__()
249
+ expand_ratio = 4
250
+ self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
251
+ self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
252
+ self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
253
+ self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
254
+ self.attn_norm = LayerNorm(KV_size,eps=1e-6)
255
+ self.channel_attn = Attention_org(vis, channel_num)
256
+
257
+ self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
258
+ self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
259
+ self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
260
+ self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
261
+ self.ffn1 = Mlp(channel_num[0],channel_num[0]*expand_ratio)
262
+ self.ffn2 = Mlp(channel_num[1],channel_num[1]*expand_ratio)
263
+ self.ffn3 = Mlp(channel_num[2],channel_num[2]*expand_ratio)
264
+ self.ffn4 = Mlp(channel_num[3],channel_num[3]*expand_ratio)
265
+
266
+
267
+ def forward(self, emb1,emb2,emb3,emb4):
268
+ embcat = []
269
+ org1 = emb1
270
+ org2 = emb2
271
+ org3 = emb3
272
+ org4 = emb4
273
+ for i in range(4):
274
+ var_name = "emb"+str(i+1) #emb1,emb2,emb3,emb4
275
+ tmp_var = locals()[var_name]
276
+ if tmp_var is not None:
277
+ embcat.append(tmp_var)
278
+
279
+ emb_all = torch.cat(embcat,dim=2)
280
+ cx1 = self.attn_norm1(emb1) if emb1 is not None else None
281
+ cx2 = self.attn_norm2(emb2) if emb2 is not None else None
282
+ cx3 = self.attn_norm3(emb3) if emb3 is not None else None
283
+ cx4 = self.attn_norm4(emb4) if emb4 is not None else None
284
+ emb_all = self.attn_norm(emb_all)
285
+ cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
286
+ #残差
287
+ cx1 = org1 + cx1 if emb1 is not None else None
288
+ cx2 = org2 + cx2 if emb2 is not None else None
289
+ cx3 = org3 + cx3 if emb3 is not None else None
290
+ cx4 = org4 + cx4 if emb4 is not None else None
291
+
292
+ org1 = cx1
293
+ org2 = cx2
294
+ org3 = cx3
295
+ org4 = cx4
296
+ x1 = self.ffn_norm1(cx1) if emb1 is not None else None
297
+ x2 = self.ffn_norm2(cx2) if emb2 is not None else None
298
+ x3 = self.ffn_norm3(cx3) if emb3 is not None else None
299
+ x4 = self.ffn_norm4(cx4) if emb4 is not None else None
300
+ x1 = self.ffn1(x1) if emb1 is not None else None
301
+ x2 = self.ffn2(x2) if emb2 is not None else None
302
+ x3 = self.ffn3(x3) if emb3 is not None else None
303
+ x4 = self.ffn4(x4) if emb4 is not None else None
304
+ #残差
305
+ x1 = x1 + org1 if emb1 is not None else None
306
+ x2 = x2 + org2 if emb2 is not None else None
307
+ x3 = x3 + org3 if emb3 is not None else None
308
+ x4 = x4 + org4 if emb4 is not None else None
309
+
310
+ return x1, x2, x3, x4, weights
311
+
312
+
313
+ class Encoder(nn.Module):
314
+ def __init__(self, vis, channel_num, num_layers=4):
315
+ super(Encoder, self).__init__()
316
+ self.vis = vis
317
+ self.layer = nn.ModuleList()
318
+ self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
319
+ self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
320
+ self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
321
+ self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
322
+ for _ in range(num_layers):
323
+ layer = Block_ViT(vis, channel_num)
324
+ self.layer.append(copy.deepcopy(layer))
325
+
326
+ def forward(self, emb1,emb2,emb3,emb4):
327
+ attn_weights = []
328
+ for layer_block in self.layer:
329
+ emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4)
330
+ if self.vis:
331
+ attn_weights.append(weights)
332
+ emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
333
+ emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
334
+ emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
335
+ emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
336
+ return emb1,emb2,emb3,emb4, attn_weights
337
+
338
+
339
+ class ChannelTransformer(nn.Module):
340
+ def __init__(self, vis=False, img_size=256, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
341
+ super().__init__()
342
+
343
+ self.patchSize_1 = patchSize[0]
344
+ self.patchSize_2 = patchSize[1]
345
+ self.patchSize_3 = patchSize[2]
346
+ self.patchSize_4 = patchSize[3]
347
+ self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
348
+ self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1])
349
+ self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2])
350
+ self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3])
351
+ self.encoder = Encoder( vis, channel_num)
352
+
353
+ self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1))
354
+ self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2))
355
+ self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3))
356
+ self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4))
357
+
358
+ def forward(self,en1,en2,en3,en4):
359
+
360
+ emb1 = self.embeddings_1(en1)
361
+ emb2 = self.embeddings_2(en2)
362
+ emb3 = self.embeddings_3(en3)
363
+ emb4 = self.embeddings_4(en4)
364
+
365
+ encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
366
+ x1 = self.reconstruct_1(encoded1) if en1 is not None else None
367
+ x2 = self.reconstruct_2(encoded2) if en2 is not None else None
368
+ x3 = self.reconstruct_3(encoded3) if en3 is not None else None
369
+ x4 = self.reconstruct_4(encoded4) if en4 is not None else None
370
+
371
+ x1 = x1 + en1 if en1 is not None else None
372
+ x2 = x2 + en2 if en2 is not None else None
373
+ x3 = x3 + en3 if en3 is not None else None
374
+ x4 = x4 + en4 if en4 is not None else None
375
+
376
+ return x1, x2, x3, x4, attn_weights
377
+