AlexZou commited on
Commit
7eb6194
·
1 Parent(s): 676735b

Upload 17 files

Browse files
net/CMSFFT.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Lintao Peng
3
+ # @File : CMSFFT.py
4
+ # coding=utf-8
5
+ # Design based on the CTrans
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+ import copy
10
+ import logging
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
16
+ from torch.nn.modules.utils import _pair
17
+
18
+
19
+ #KV_size = 480
20
+ #transformer.num_heads = 4
21
+ #transformer.num_layers = 4
22
+ #expand_ratio = 4
23
+
24
+
25
+
26
+ #线性编码
27
+ class Channel_Embeddings(nn.Module):
28
+ """Construct the embeddings from patch, position embeddings.
29
+ """
30
+ def __init__(self, patchsize, img_size, in_channels):
31
+ super().__init__()
32
+ img_size = _pair(img_size)
33
+ patch_size = _pair(patchsize)
34
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
35
+
36
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
37
+ out_channels=in_channels,
38
+ kernel_size=patch_size,
39
+ stride=patch_size)
40
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
41
+ self.dropout = Dropout(0.1)
42
+
43
+ def forward(self, x):
44
+ if x is None:
45
+ return None
46
+ x = self.patch_embeddings(x) # (B, hidden,n_patches^(1/2), n_patches^(1/2))
47
+ x = x.flatten(2)
48
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
49
+ embeddings = x + self.position_embeddings
50
+ embeddings = self.dropout(embeddings)
51
+ return embeddings
52
+
53
+
54
+ #特征重组
55
+ class Reconstruct(nn.Module):
56
+ def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
57
+ super(Reconstruct, self).__init__()
58
+ if kernel_size == 3:
59
+ padding = 1
60
+ else:
61
+ padding = 0
62
+ self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
63
+ self.norm = nn.BatchNorm2d(out_channels)
64
+ self.activation = nn.ReLU(inplace=True)
65
+ self.scale_factor = scale_factor
66
+
67
+ def forward(self, x):
68
+ if x is None:
69
+ return None
70
+
71
+ # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
72
+ B, n_patch, hidden = x.size()
73
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
74
+ x = x.permute(0, 2, 1)
75
+ x = x.contiguous().view(B, hidden, h, w)
76
+ x = nn.Upsample(scale_factor=self.scale_factor)(x)
77
+
78
+ out = self.conv(x)
79
+ out = self.norm(out)
80
+ out = self.activation(out)
81
+ return out
82
+
83
+ class Attention_org(nn.Module):
84
+ def __init__(self, vis,channel_num, KV_size=480, num_heads=4):
85
+ super(Attention_org, self).__init__()
86
+ self.vis = vis
87
+ self.KV_size = KV_size
88
+ self.channel_num = channel_num
89
+ self.num_attention_heads = num_heads
90
+
91
+ self.query1 = nn.ModuleList()
92
+ self.query2 = nn.ModuleList()
93
+ self.query3 = nn.ModuleList()
94
+ self.query4 = nn.ModuleList()
95
+ self.key = nn.ModuleList()
96
+ self.value = nn.ModuleList()
97
+
98
+ for _ in range(num_heads):
99
+ query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
100
+ query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
101
+ query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
102
+ query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
103
+ key = nn.Linear( self.KV_size, self.KV_size, bias=False)
104
+ value = nn.Linear(self.KV_size, self.KV_size, bias=False)
105
+ #把所有的值都重新复制一遍,deepcopy为深复制,完全脱离原来的值,即将被复制对象完全再复制一遍作为独立的新个体单独存在
106
+ self.query1.append(copy.deepcopy(query1))
107
+ self.query2.append(copy.deepcopy(query2))
108
+ self.query3.append(copy.deepcopy(query3))
109
+ self.query4.append(copy.deepcopy(query4))
110
+ self.key.append(copy.deepcopy(key))
111
+ self.value.append(copy.deepcopy(value))
112
+ self.psi = nn.InstanceNorm2d(self.num_attention_heads)
113
+ self.softmax = Softmax(dim=3)
114
+ self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
115
+ self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
116
+ self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
117
+ self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
118
+ self.attn_dropout = Dropout(0.1)
119
+ self.proj_dropout = Dropout(0.1)
120
+
121
+
122
+
123
+ def forward(self, emb1,emb2,emb3,emb4, emb_all):
124
+ multi_head_Q1_list = []
125
+ multi_head_Q2_list = []
126
+ multi_head_Q3_list = []
127
+ multi_head_Q4_list = []
128
+ multi_head_K_list = []
129
+ multi_head_V_list = []
130
+ if emb1 is not None:
131
+ for query1 in self.query1:
132
+ Q1 = query1(emb1)
133
+ multi_head_Q1_list.append(Q1)
134
+ if emb2 is not None:
135
+ for query2 in self.query2:
136
+ Q2 = query2(emb2)
137
+ multi_head_Q2_list.append(Q2)
138
+ if emb3 is not None:
139
+ for query3 in self.query3:
140
+ Q3 = query3(emb3)
141
+ multi_head_Q3_list.append(Q3)
142
+ if emb4 is not None:
143
+ for query4 in self.query4:
144
+ Q4 = query4(emb4)
145
+ multi_head_Q4_list.append(Q4)
146
+ for key in self.key:
147
+ K = key(emb_all)
148
+ multi_head_K_list.append(K)
149
+ for value in self.value:
150
+ V = value(emb_all)
151
+ multi_head_V_list.append(V)
152
+ # print(len(multi_head_Q4_list))
153
+
154
+ multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
155
+ multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
156
+ multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
157
+ multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
158
+ multi_head_K = torch.stack(multi_head_K_list, dim=1)
159
+ multi_head_V = torch.stack(multi_head_V_list, dim=1)
160
+
161
+ multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
162
+ multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
163
+ multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
164
+ multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
165
+
166
+ attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
167
+ attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
168
+ attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
169
+ attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
170
+
171
+ attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
172
+ attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
173
+ attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
174
+ attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
175
+
176
+ attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
177
+ attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
178
+ attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
179
+ attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
180
+ # print(attention_probs4.size())
181
+
182
+ if self.vis:
183
+ weights = []
184
+ weights.append(attention_probs1.mean(1))
185
+ weights.append(attention_probs2.mean(1))
186
+ weights.append(attention_probs3.mean(1))
187
+ weights.append(attention_probs4.mean(1))
188
+ else: weights=None
189
+
190
+ attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
191
+ attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
192
+ attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
193
+ attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
194
+
195
+ multi_head_V = multi_head_V.transpose(-1, -2)
196
+ context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
197
+ context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
198
+ context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
199
+ context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
200
+
201
+ context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
202
+ context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
203
+ context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
204
+ context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
205
+ context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
206
+ context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
207
+ context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
208
+ context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
209
+
210
+ O1 = self.out1(context_layer1) if emb1 is not None else None
211
+ O2 = self.out2(context_layer2) if emb2 is not None else None
212
+ O3 = self.out3(context_layer3) if emb3 is not None else None
213
+ O4 = self.out4(context_layer4) if emb4 is not None else None
214
+ O1 = self.proj_dropout(O1) if emb1 is not None else None
215
+ O2 = self.proj_dropout(O2) if emb2 is not None else None
216
+ O3 = self.proj_dropout(O3) if emb3 is not None else None
217
+ O4 = self.proj_dropout(O4) if emb4 is not None else None
218
+ return O1,O2,O3,O4, weights
219
+
220
+
221
+
222
+
223
+ class Mlp(nn.Module):
224
+ def __init__(self, in_channel, mlp_channel):
225
+ super(Mlp, self).__init__()
226
+ self.fc1 = nn.Linear(in_channel, mlp_channel)
227
+ self.fc2 = nn.Linear(mlp_channel, in_channel)
228
+ self.act_fn = nn.GELU()
229
+ self.dropout = Dropout(0.0)
230
+ self._init_weights()
231
+
232
+ def _init_weights(self):
233
+ nn.init.xavier_uniform_(self.fc1.weight)
234
+ nn.init.xavier_uniform_(self.fc2.weight)
235
+ nn.init.normal_(self.fc1.bias, std=1e-6)
236
+ nn.init.normal_(self.fc2.bias, std=1e-6)
237
+
238
+ def forward(self, x):
239
+ x = self.fc1(x)
240
+ x = self.act_fn(x)
241
+ x = self.dropout(x)
242
+ x = self.fc2(x)
243
+ x = self.dropout(x)
244
+ return x
245
+
246
+ class Block_ViT(nn.Module):
247
+ def __init__(self, vis, channel_num, expand_ratio=4,KV_size=480):
248
+ super(Block_ViT, self).__init__()
249
+ expand_ratio = 4
250
+ self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
251
+ self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
252
+ self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
253
+ self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
254
+ self.attn_norm = LayerNorm(KV_size,eps=1e-6)
255
+ self.channel_attn = Attention_org(vis, channel_num)
256
+
257
+ self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
258
+ self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
259
+ self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
260
+ self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
261
+ self.ffn1 = Mlp(channel_num[0],channel_num[0]*expand_ratio)
262
+ self.ffn2 = Mlp(channel_num[1],channel_num[1]*expand_ratio)
263
+ self.ffn3 = Mlp(channel_num[2],channel_num[2]*expand_ratio)
264
+ self.ffn4 = Mlp(channel_num[3],channel_num[3]*expand_ratio)
265
+
266
+
267
+ def forward(self, emb1,emb2,emb3,emb4):
268
+ embcat = []
269
+ org1 = emb1
270
+ org2 = emb2
271
+ org3 = emb3
272
+ org4 = emb4
273
+ for i in range(4):
274
+ var_name = "emb"+str(i+1) #emb1,emb2,emb3,emb4
275
+ tmp_var = locals()[var_name]
276
+ if tmp_var is not None:
277
+ embcat.append(tmp_var)
278
+
279
+ emb_all = torch.cat(embcat,dim=2)
280
+ cx1 = self.attn_norm1(emb1) if emb1 is not None else None
281
+ cx2 = self.attn_norm2(emb2) if emb2 is not None else None
282
+ cx3 = self.attn_norm3(emb3) if emb3 is not None else None
283
+ cx4 = self.attn_norm4(emb4) if emb4 is not None else None
284
+ emb_all = self.attn_norm(emb_all)
285
+ cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
286
+ #残差
287
+ cx1 = org1 + cx1 if emb1 is not None else None
288
+ cx2 = org2 + cx2 if emb2 is not None else None
289
+ cx3 = org3 + cx3 if emb3 is not None else None
290
+ cx4 = org4 + cx4 if emb4 is not None else None
291
+
292
+ org1 = cx1
293
+ org2 = cx2
294
+ org3 = cx3
295
+ org4 = cx4
296
+ x1 = self.ffn_norm1(cx1) if emb1 is not None else None
297
+ x2 = self.ffn_norm2(cx2) if emb2 is not None else None
298
+ x3 = self.ffn_norm3(cx3) if emb3 is not None else None
299
+ x4 = self.ffn_norm4(cx4) if emb4 is not None else None
300
+ x1 = self.ffn1(x1) if emb1 is not None else None
301
+ x2 = self.ffn2(x2) if emb2 is not None else None
302
+ x3 = self.ffn3(x3) if emb3 is not None else None
303
+ x4 = self.ffn4(x4) if emb4 is not None else None
304
+ #残差
305
+ x1 = x1 + org1 if emb1 is not None else None
306
+ x2 = x2 + org2 if emb2 is not None else None
307
+ x3 = x3 + org3 if emb3 is not None else None
308
+ x4 = x4 + org4 if emb4 is not None else None
309
+
310
+ return x1, x2, x3, x4, weights
311
+
312
+
313
+ class Encoder(nn.Module):
314
+ def __init__(self, vis, channel_num, num_layers=4):
315
+ super(Encoder, self).__init__()
316
+ self.vis = vis
317
+ self.layer = nn.ModuleList()
318
+ self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
319
+ self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
320
+ self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
321
+ self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
322
+ for _ in range(num_layers):
323
+ layer = Block_ViT(vis, channel_num)
324
+ self.layer.append(copy.deepcopy(layer))
325
+
326
+ def forward(self, emb1,emb2,emb3,emb4):
327
+ attn_weights = []
328
+ for layer_block in self.layer:
329
+ emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4)
330
+ if self.vis:
331
+ attn_weights.append(weights)
332
+ emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
333
+ emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
334
+ emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
335
+ emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
336
+ return emb1,emb2,emb3,emb4, attn_weights
337
+
338
+
339
+ class ChannelTransformer(nn.Module):
340
+ def __init__(self, vis=False, img_size=256, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
341
+ super().__init__()
342
+
343
+ self.patchSize_1 = patchSize[0]
344
+ self.patchSize_2 = patchSize[1]
345
+ self.patchSize_3 = patchSize[2]
346
+ self.patchSize_4 = patchSize[3]
347
+ self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
348
+ self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1])
349
+ self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2])
350
+ self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3])
351
+ self.encoder = Encoder( vis, channel_num)
352
+
353
+ self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1))
354
+ self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2))
355
+ self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3))
356
+ self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4))
357
+
358
+ def forward(self,en1,en2,en3,en4):
359
+
360
+ emb1 = self.embeddings_1(en1)
361
+ emb2 = self.embeddings_2(en2)
362
+ emb3 = self.embeddings_3(en3)
363
+ emb4 = self.embeddings_4(en4)
364
+
365
+ encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
366
+ x1 = self.reconstruct_1(encoded1) if en1 is not None else None
367
+ x2 = self.reconstruct_2(encoded2) if en2 is not None else None
368
+ x3 = self.reconstruct_3(encoded3) if en3 is not None else None
369
+ x4 = self.reconstruct_4(encoded4) if en4 is not None else None
370
+
371
+ x1 = x1 + en1 if en1 is not None else None
372
+ x2 = x2 + en2 if en2 is not None else None
373
+ x3 = x3 + en3 if en3 is not None else None
374
+ x4 = x4 + en4 if en4 is not None else None
375
+
376
+ return x1, x2, x3, x4, attn_weights
377
+
net/IntmdSequential.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class IntermediateSequential(nn.Sequential):
5
+ def __init__(self, *args, return_intermediate=False):
6
+ super().__init__(*args)
7
+ self.return_intermediate = return_intermediate
8
+
9
+ def forward(self, input):
10
+ if not self.return_intermediate:
11
+ return super().forward(input)
12
+
13
+ intermediate_outputs = {}
14
+ output = input
15
+ for name, module in self.named_children():
16
+ output = intermediate_outputs[name] = module(output)
17
+
18
+ return output, intermediate_outputs
19
+
net/PositionalEncoding.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ #实现了位置编码
6
+ class FixedPositionalEncoding(nn.Module):
7
+ def __init__(self, embedding_dim, max_length=512):
8
+ super(FixedPositionalEncoding, self).__init__()
9
+
10
+ pe = torch.zeros(max_length, embedding_dim)
11
+ position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
12
+ div_term = torch.exp(
13
+ torch.arange(0, embedding_dim, 2).float()
14
+ * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
15
+ )
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0).transpose(0, 1)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[: x.size(0), :]
23
+ return x
24
+
25
+
26
+ class LearnedPositionalEncoding(nn.Module):
27
+ def __init__(self, max_position_embeddings, embedding_dim, seq_length):
28
+ super(LearnedPositionalEncoding, self).__init__()
29
+
30
+ self.position_embeddings = nn.Parameter(torch.zeros(1, 256, 512)) #8x
31
+
32
+ def forward(self, x, position_ids=None):
33
+
34
+ position_embeddings = self.position_embeddings
35
+ return x + position_embeddings
net/SGFMT.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Lintao Peng
3
+ # @File : SGFMT.py
4
+ # coding=utf-8
5
+ # Design based on the Vit
6
+
7
+ import torch.nn as nn
8
+ from net.IntmdSequential import IntermediateSequential
9
+
10
+
11
+ #实现了自注意力机制,相当于unet的bottleneck层
12
+ class SelfAttention(nn.Module):
13
+ def __init__(
14
+ self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
15
+ ):
16
+ super().__init__()
17
+ self.num_heads = heads
18
+ head_dim = dim // heads
19
+ self.scale = qk_scale or head_dim ** -0.5
20
+
21
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
22
+ self.attn_drop = nn.Dropout(dropout_rate)
23
+ self.proj = nn.Linear(dim, dim)
24
+ self.proj_drop = nn.Dropout(dropout_rate)
25
+
26
+ def forward(self, x):
27
+ B, N, C = x.shape
28
+ qkv = (
29
+ self.qkv(x)
30
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
31
+ .permute(2, 0, 3, 1, 4)
32
+ )
33
+ q, k, v = (
34
+ qkv[0],
35
+ qkv[1],
36
+ qkv[2],
37
+ ) # make torchscript happy (cannot use tensor as tuple)
38
+
39
+ attn = (q @ k.transpose(-2, -1)) * self.scale
40
+ attn = attn.softmax(dim=-1)
41
+ attn = self.attn_drop(attn)
42
+
43
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
44
+ x = self.proj(x)
45
+ x = self.proj_drop(x)
46
+ return x
47
+
48
+
49
+ class Residual(nn.Module):
50
+ def __init__(self, fn):
51
+ super().__init__()
52
+ self.fn = fn
53
+
54
+ def forward(self, x):
55
+ return self.fn(x) + x
56
+
57
+
58
+ class PreNorm(nn.Module):
59
+ def __init__(self, dim, fn):
60
+ super().__init__()
61
+ self.norm = nn.LayerNorm(dim)
62
+ self.fn = fn
63
+
64
+ def forward(self, x):
65
+ return self.fn(self.norm(x))
66
+
67
+
68
+ class PreNormDrop(nn.Module):
69
+ def __init__(self, dim, dropout_rate, fn):
70
+ super().__init__()
71
+ self.norm = nn.LayerNorm(dim)
72
+ self.dropout = nn.Dropout(p=dropout_rate)
73
+ self.fn = fn
74
+
75
+ def forward(self, x):
76
+ return self.dropout(self.fn(self.norm(x)))
77
+
78
+
79
+ class FeedForward(nn.Module):
80
+ def __init__(self, dim, hidden_dim, dropout_rate):
81
+ super().__init__()
82
+ self.net = nn.Sequential(
83
+ nn.Linear(dim, hidden_dim),
84
+ nn.GELU(),
85
+ nn.Dropout(p=dropout_rate),
86
+ nn.Linear(hidden_dim, dim),
87
+ nn.Dropout(p=dropout_rate),
88
+ )
89
+
90
+ def forward(self, x):
91
+ return self.net(x)
92
+
93
+
94
+ class TransformerModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ dim, #512
98
+ depth, #4
99
+ heads, #8
100
+ mlp_dim, #4096
101
+ dropout_rate=0.1,
102
+ attn_dropout_rate=0.1,
103
+ ):
104
+ super().__init__()
105
+ layers = []
106
+ for _ in range(depth):
107
+ layers.extend(
108
+ [
109
+ Residual(
110
+ PreNormDrop(
111
+ dim,
112
+ dropout_rate,
113
+ SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
114
+ )
115
+ ),
116
+ Residual(
117
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
118
+ ),
119
+ ]
120
+ )
121
+ # dim = dim / 2
122
+ self.net = IntermediateSequential(*layers)
123
+
124
+
125
+ def forward(self, x):
126
+ return self.net(x)
net/Transformer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Lintao Peng
3
+ # @File : SGFMT.py
4
+ # coding=utf-8
5
+ # Design based on the Vit
6
+
7
+ import torch.nn as nn
8
+ from net.IntmdSequential import IntermediateSequential
9
+
10
+
11
+ #实现了自注意力机制,相当于unet的bottleneck层
12
+ class SelfAttention(nn.Module):
13
+ def __init__(
14
+ self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
15
+ ):
16
+ super().__init__()
17
+ self.num_heads = heads
18
+ head_dim = dim // heads
19
+ self.scale = qk_scale or head_dim ** -0.5
20
+
21
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
22
+ self.attn_drop = nn.Dropout(dropout_rate)
23
+ self.proj = nn.Linear(dim, dim)
24
+ self.proj_drop = nn.Dropout(dropout_rate)
25
+
26
+ def forward(self, x):
27
+ B, N, C = x.shape
28
+ qkv = (
29
+ self.qkv(x)
30
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
31
+ .permute(2, 0, 3, 1, 4)
32
+ )
33
+ q, k, v = (
34
+ qkv[0],
35
+ qkv[1],
36
+ qkv[2],
37
+ ) # make torchscript happy (cannot use tensor as tuple)
38
+
39
+ attn = (q @ k.transpose(-2, -1)) * self.scale
40
+ attn = attn.softmax(dim=-1)
41
+ attn = self.attn_drop(attn)
42
+
43
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
44
+ x = self.proj(x)
45
+ x = self.proj_drop(x)
46
+ return x
47
+
48
+
49
+ class Residual(nn.Module):
50
+ def __init__(self, fn):
51
+ super().__init__()
52
+ self.fn = fn
53
+
54
+ def forward(self, x):
55
+ return self.fn(x) + x
56
+
57
+
58
+ class PreNorm(nn.Module):
59
+ def __init__(self, dim, fn):
60
+ super().__init__()
61
+ self.norm = nn.LayerNorm(dim)
62
+ self.fn = fn
63
+
64
+ def forward(self, x):
65
+ return self.fn(self.norm(x))
66
+
67
+
68
+ class PreNormDrop(nn.Module):
69
+ def __init__(self, dim, dropout_rate, fn):
70
+ super().__init__()
71
+ self.norm = nn.LayerNorm(dim)
72
+ self.dropout = nn.Dropout(p=dropout_rate)
73
+ self.fn = fn
74
+
75
+ def forward(self, x):
76
+ return self.dropout(self.fn(self.norm(x)))
77
+
78
+
79
+ class FeedForward(nn.Module):
80
+ def __init__(self, dim, hidden_dim, dropout_rate):
81
+ super().__init__()
82
+ self.net = nn.Sequential(
83
+ nn.Linear(dim, hidden_dim),
84
+ nn.GELU(),
85
+ nn.Dropout(p=dropout_rate),
86
+ nn.Linear(hidden_dim, dim),
87
+ nn.Dropout(p=dropout_rate),
88
+ )
89
+
90
+ def forward(self, x):
91
+ return self.net(x)
92
+
93
+
94
+ class TransformerModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ dim, #512
98
+ depth, #4
99
+ heads, #8
100
+ mlp_dim, #4096
101
+ dropout_rate=0.1,
102
+ attn_dropout_rate=0.1,
103
+ ):
104
+ super().__init__()
105
+ layers = []
106
+ for _ in range(depth):
107
+ layers.extend(
108
+ [
109
+ Residual(
110
+ PreNormDrop(
111
+ dim,
112
+ dropout_rate,
113
+ SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
114
+ )
115
+ ),
116
+ Residual(
117
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
118
+ ),
119
+ ]
120
+ )
121
+ # dim = dim / 2
122
+ self.net = IntermediateSequential(*layers)
123
+
124
+
125
+ def forward(self, x):
126
+ return self.net(x)
net/Ushape_Trans.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : Lintao Peng
3
+ # @File : Ushape_Trans.py
4
+ # coding=utf-8
5
+ # Design based on the pix2pix
6
+
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch
10
+ import datetime
11
+ import os
12
+ import time
13
+ import timeit
14
+ import copy
15
+ import numpy as np
16
+ from torch.nn import ModuleList
17
+ from torch.nn import Conv2d
18
+ from torch.nn import LeakyReLU
19
+ from net.block import *
20
+ from net.block import _equalized_conv2d
21
+ from net.SGFMT import TransformerModel
22
+ from net.PositionalEncoding import FixedPositionalEncoding,LearnedPositionalEncoding
23
+ from net.CMSFFT import ChannelTransformer
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+ ##权重初始化
32
+ def weights_init_normal(m):
33
+ classname = m.__class__.__name__
34
+ if classname.find("Conv") != -1:
35
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
36
+ elif classname.find("BatchNorm2d") != -1:
37
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
38
+ torch.nn.init.constant_(m.bias.data, 0.0)
39
+
40
+
41
+
42
+
43
+
44
+
45
+ class Generator(nn.Module):
46
+ """
47
+ MSG-Unet-GAN的生成器部分
48
+ """
49
+ def __init__(self,
50
+ img_dim=256,
51
+ patch_dim=16,
52
+ embedding_dim=512,
53
+ num_channels=3,
54
+ num_heads=8,
55
+ num_layers=4,
56
+ hidden_dim=256,
57
+ dropout_rate=0.0,
58
+ attn_dropout_rate=0.0,
59
+ in_ch=3,
60
+ out_ch=3,
61
+ conv_patch_representation=True,
62
+ positional_encoding_type="learned",
63
+ use_eql=True):
64
+ super(Generator, self).__init__()
65
+ assert embedding_dim % num_heads == 0
66
+ assert img_dim % patch_dim == 0
67
+
68
+ self.out_ch=out_ch #输出通道数
69
+ self.in_ch=in_ch #输入通道数
70
+ self.img_dim = img_dim #输入图片尺寸
71
+ self.embedding_dim = embedding_dim #512
72
+ self.num_heads = num_heads #多头注意力中头的数量
73
+ self.patch_dim = patch_dim #每个patch的尺寸
74
+ self.num_channels = num_channels #图片通道数?
75
+ self.dropout_rate = dropout_rate #drop-out比率
76
+ self.attn_dropout_rate = attn_dropout_rate #注意力模块的dropout比率
77
+ self.conv_patch_representation = conv_patch_representation #True
78
+
79
+ self.num_patches = int((img_dim // patch_dim) ** 2) #将三通道图片分成多少块
80
+ self.seq_length = self.num_patches #每个sequence的长度为patches的大小
81
+ self.flatten_dim = 128 * num_channels #128*3=384
82
+
83
+ #线性编码
84
+ self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
85
+ #位置编码
86
+ if positional_encoding_type == "learned":
87
+ self.position_encoding = LearnedPositionalEncoding(
88
+ self.seq_length, self.embedding_dim, self.seq_length
89
+ )
90
+ elif positional_encoding_type == "fixed":
91
+ self.position_encoding = FixedPositionalEncoding(
92
+ self.embedding_dim,
93
+ )
94
+
95
+ self.pe_dropout = nn.Dropout(p=self.dropout_rate)
96
+
97
+ self.transformer = TransformerModel(
98
+ embedding_dim, #512
99
+ num_layers, #4
100
+ num_heads, #8
101
+ hidden_dim, #4096
102
+
103
+ self.dropout_rate,
104
+ self.attn_dropout_rate,
105
+ )
106
+
107
+ #layer Norm
108
+ self.pre_head_ln = nn.LayerNorm(embedding_dim)
109
+
110
+ if self.conv_patch_representation:
111
+
112
+ self.Conv_x = nn.Conv2d(
113
+ 256,
114
+ self.embedding_dim, #512
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1
118
+ )
119
+
120
+ self.bn = nn.BatchNorm2d(256)
121
+ self.relu = nn.ReLU(inplace=True)
122
+
123
+
124
+
125
+ #modulelist
126
+ self.rgb_to_feature=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
127
+ self.feature_to_rgb=ModuleList([to_rgb(32),to_rgb(64),to_rgb(128),to_rgb(256)])
128
+
129
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
130
+ self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
131
+ self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
132
+ self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
133
+ self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
134
+
135
+ self.Conv1=conv_block(self.in_ch, 16)
136
+ self.Conv1_1 = conv_block(16, 32)
137
+ self.Conv2 = conv_block(32, 32)
138
+ self.Conv2_1 = conv_block(32, 64)
139
+ self.Conv3 = conv_block(64,64)
140
+ self.Conv3_1 = conv_block(64,128)
141
+ self.Conv4 = conv_block(128,128)
142
+ self.Conv4_1 = conv_block(128,256)
143
+
144
+ self.Conv5 = conv_block(512,256)
145
+
146
+ #self.Conv_x = conv_block(256,512)
147
+ self.mtc = ChannelTransformer(channel_num=[32,64,128,256],
148
+ patchSize=[32, 16, 8, 4])
149
+
150
+
151
+ self.Up5 = up_conv(256, 256)
152
+ self.coatt5 = CCA(F_g=256, F_x=256)
153
+ self.Up_conv5 = conv_block(512, 256)
154
+ self.Up_conv5_1 = conv_block(256, 256)
155
+
156
+ self.Up4 = up_conv(256, 128)
157
+ self.coatt4 = CCA(F_g=128, F_x=128)
158
+ self.Up_conv4 = conv_block(256, 128)
159
+ self.Up_conv4_1 = conv_block(128, 128)
160
+
161
+ self.Up3 = up_conv(128, 64)
162
+ self.coatt3 = CCA(F_g=64, F_x=64)
163
+ self.Up_conv3 = conv_block(128, 64)
164
+ self.Up_conv3_1 = conv_block(64, 64)
165
+
166
+ self.Up2 = up_conv(64, 32)
167
+ self.coatt2 = CCA(F_g=32, F_x=32)
168
+ self.Up_conv2 = conv_block(64, 32)
169
+ self.Up_conv2_1 = conv_block(32, 32)
170
+
171
+ self.Conv = nn.Conv2d(32, self.out_ch, kernel_size=1, stride=1, padding=0)
172
+
173
+ # self.active = torch.nn.Sigmoid()
174
+ #
175
+ def reshape_output(self,x): #将transformer的输出resize为原来的特征图尺寸
176
+ x = x.view(
177
+ x.size(0),
178
+ int(self.img_dim / self.patch_dim),
179
+ int(self.img_dim / self.patch_dim),
180
+ self.embedding_dim,
181
+ )#B,16,16,512
182
+ x = x.permute(0, 3, 1, 2).contiguous()
183
+
184
+ return x
185
+
186
+ def forward(self, x):
187
+ #print(x.shape)
188
+
189
+
190
+ output=[]
191
+
192
+ x_1=self.Maxpool(x)
193
+ x_2=self.Maxpool(x_1)
194
+ x_3=self.Maxpool(x_2)
195
+
196
+
197
+ e1 = self.Conv1(x)
198
+ #print(e1.shape)
199
+ e1 = self.Conv1_1(e1)
200
+ e2 = self.Maxpool1(e1)
201
+ #32*128*128
202
+
203
+ x_1=self.rgb_to_feature[0](x_1)
204
+ #e2=torch.cat((x_1,e2), dim=1)
205
+ e2=x_1+e2
206
+ e2 = self.Conv2(e2)
207
+ e2 = self.Conv2_1(e2)
208
+ e3 = self.Maxpool2(e2)
209
+ #64*64*64
210
+
211
+ x_2=self.rgb_to_feature[1](x_2)
212
+ #e3=torch.cat((x_2,e3), dim=1)
213
+ e3=x_2+e3
214
+ e3 = self.Conv3(e3)
215
+ e3 = self.Conv3_1(e3)
216
+ e4 = self.Maxpool3(e3)
217
+ #128*32*32
218
+
219
+ x_3=self.rgb_to_feature[2](x_3)
220
+ #e4=torch.cat((x_3,e4), dim=1)
221
+ e4=x_3+e4
222
+ e4 = self.Conv4(e4)
223
+ e4 = self.Conv4_1(e4)
224
+ e5 = self.Maxpool4(e4)
225
+ #256*16*16
226
+
227
+ #channel-wise transformer-based attention
228
+ e1,e2,e3,e4,att_weights = self.mtc(e1,e2,e3,e4)
229
+
230
+
231
+
232
+
233
+ #spatial-wise transformer-based attention
234
+ residual=e5
235
+ #中间的隐变量
236
+ #conv_x应该接受256通道,输出512通道的中间隐变量
237
+ e5= self.bn(e5)
238
+ e5=self.relu(e5)
239
+ e5= self.Conv_x(e5) #out->512*16*16 shape->B,512,16,16
240
+ e5= e5.permute(0, 2, 3, 1).contiguous() # B,512,16,16->B,16,16,512
241
+ e5= e5.view(e5.size(0), -1, self.embedding_dim) #B,16,16,512->B,16*16,512 线性映射层
242
+ e5= self.position_encoding(e5) #位置编码
243
+ e5= self.pe_dropout(e5) #预dropout层
244
+ # apply transformer
245
+ e5= self.transformer(e5)
246
+ e5= self.pre_head_ln(e5)
247
+ e5= self.reshape_output(e5)#out->512*16*16 shape->B,512,16,16
248
+ e5=self.Conv5(e5) #out->256,16,16 shape->B,256,16,16
249
+ #residual是否要加bn和relu?
250
+ e5=e5+residual
251
+
252
+
253
+
254
+ d5 = self.Up5(e5)
255
+ e4_att = self.coatt5(g=d5, x=e4)
256
+ d5 = torch.cat((e4_att, d5), dim=1)
257
+ d5 = self.Up_conv5(d5)
258
+ d5 = self.Up_conv5_1(d5)
259
+ #256
260
+ out3=self.feature_to_rgb[3](d5)
261
+ output.append(out3)#32*32orH/8,W/8
262
+
263
+ d4 = self.Up4(d5)
264
+ e3_att = self.coatt4(g=d4, x=e3)
265
+ d4 = torch.cat((e3_att, d4), dim=1)
266
+ d4 = self.Up_conv4(d4)
267
+ d4 = self.Up_conv4_1(d4)
268
+ #128
269
+ out2=self.feature_to_rgb[2](d4)
270
+ output.append(out2)#64*64orH/4,W/4
271
+
272
+ d3 = self.Up3(d4)
273
+ e2_att = self.coatt3(g=d3, x=e2)
274
+ d3 = torch.cat((e2_att, d3), dim=1)
275
+ d3 = self.Up_conv3(d3)
276
+ d3 = self.Up_conv3_1(d3)
277
+ #64
278
+ out1=self.feature_to_rgb[1](d3)
279
+ output.append(out1)#128#128orH/2,W/2
280
+
281
+ d2 = self.Up2(d3)
282
+ e1_att = self.coatt2(g=d2, x=e1)
283
+ d2 = torch.cat((e1_att, d2), dim=1)
284
+ d2 = self.Up_conv2(d2)
285
+ d2 = self.Up_conv2_1(d2)
286
+ #32
287
+ out0=self.feature_to_rgb[0](d2)
288
+ output.append(out0)#256*256
289
+
290
+ #out = self.Conv(d2)
291
+
292
+ #d1 = self.active(out)
293
+ #output=np.array(output)
294
+
295
+ return output[3]
296
+
297
+
298
+
299
+
300
+ class Discriminator(nn.Module):
301
+ def __init__(self, in_channels=3,use_eql=True):
302
+ super(Discriminator, self).__init__()
303
+
304
+ self.use_eql=use_eql
305
+ self.in_channels=in_channels
306
+
307
+
308
+ #modulelist
309
+ self.rgb_to_feature1=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
310
+ self.rgb_to_feature2=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
311
+
312
+
313
+ self.layer=_equalized_conv2d(self.in_channels*2, 64, (1, 1), bias=True)
314
+ # pixel_wise feature normalizer:
315
+ self.pixNorm = PixelwiseNorm()
316
+ # leaky_relu:
317
+ self.lrelu = LeakyReLU(0.2)
318
+
319
+
320
+ self.layer0=DisGeneralConvBlock(64,64,use_eql=self.use_eql)
321
+ #128*128*32
322
+
323
+ self.layer1=DisGeneralConvBlock(128,128,use_eql=self.use_eql)
324
+ #64*64*64
325
+
326
+ self.layer2=DisGeneralConvBlock(256,256,use_eql=self.use_eql)
327
+ #32*32*128
328
+
329
+ self.layer3=DisGeneralConvBlock(512,512,use_eql=self.use_eql)
330
+ #16*16*256
331
+
332
+ self.layer4=DisFinalBlock(512,use_eql=self.use_eql)
333
+ #8*8*512
334
+
335
+
336
+
337
+ def forward(self, img_A, inputs):
338
+ #inputs图片尺寸从小到大
339
+ # Concatenate image and condition image by channels to produce input
340
+ #img_input = torch.cat((img_A, img_B), 1)
341
+ #img_A_128= F.interpolate(img_A, size=[128, 128])
342
+ #img_A_64= F.interpolate(img_A, size=[64, 64])
343
+ #img_A_32= F.interpolate(img_A, size=[32, 32])
344
+
345
+
346
+ x=torch.cat((img_A[3], inputs[3]), 1)
347
+ y = self.pixNorm(self.lrelu(self.layer(x)))
348
+
349
+ y=self.layer0(y)
350
+ #128*128*64
351
+
352
+
353
+ x1=self.rgb_to_feature1[0](img_A[2])
354
+ x2=self.rgb_to_feature2[0](inputs[2])
355
+ x=torch.cat((x1,x2),1)
356
+ y=torch.cat((x,y),1)
357
+ y=self.layer1(y)
358
+ #64*64*128
359
+
360
+
361
+ x1=self.rgb_to_feature1[1](img_A[1])
362
+ x2=self.rgb_to_feature2[1](inputs[1])
363
+ x=torch.cat((x1,x2),1)
364
+ y=torch.cat((x,y),1)
365
+ y=self.layer2(y)
366
+ #32*32*256
367
+
368
+ x1=self.rgb_to_feature1[2](img_A[0])
369
+ x2=self.rgb_to_feature2[2](inputs[0])
370
+ x=torch.cat((x1,x2),1)
371
+ y=torch.cat((x,y),1)
372
+ y=self.layer3(y)
373
+ #16*16*512
374
+
375
+ y=self.layer4(y)
376
+ #8*8*512
377
+
378
+ return y
net/__pycache__/CMSFFT.cpython-37.pyc ADDED
Binary file (11.5 kB). View file
 
net/__pycache__/CTrans.cpython-37.pyc ADDED
Binary file (11.4 kB). View file
 
net/__pycache__/IntmdSequential.cpython-37.pyc ADDED
Binary file (919 Bytes). View file
 
net/__pycache__/PositionalEncoding.cpython-37.pyc ADDED
Binary file (1.78 kB). View file
 
net/__pycache__/SGFMT.cpython-37.pyc ADDED
Binary file (3.98 kB). View file
 
net/__pycache__/Transformer.cpython-37.pyc ADDED
Binary file (3.93 kB). View file
 
net/__pycache__/Ushape_Trans.cpython-37.pyc ADDED
Binary file (6.65 kB). View file
 
net/__pycache__/block.cpython-37.pyc ADDED
Binary file (13.1 kB). View file
 
net/__pycache__/utils.cpython-37.pyc ADDED
Binary file (2.8 kB). View file
 
net/block.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch as th
4
+ import datetime
5
+ import os
6
+ import time
7
+ import timeit
8
+ import copy
9
+ import numpy as np
10
+ from torch.nn import ModuleList
11
+ from torch.nn import Conv2d
12
+ from torch.nn import LeakyReLU
13
+
14
+
15
+
16
+
17
+ #PixelwiseNorm代替了BatchNorm
18
+ class PixelwiseNorm(th.nn.Module):
19
+ def __init__(self):
20
+ super(PixelwiseNorm, self).__init__()
21
+
22
+ def forward(self, x, alpha=1e-8):
23
+ """
24
+ forward pass of the module
25
+ :param x: input activations volume
26
+ :param alpha: small number for numerical stability
27
+ :return: y => pixel normalized activations
28
+ """
29
+ y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW]
30
+ y = x / y # normalize the input x volume
31
+ return y
32
+
33
+
34
+
35
+ class MinibatchStdDev(th.nn.Module):
36
+ """
37
+ Minibatch standard deviation layer for the discriminator
38
+ """
39
+
40
+ def __init__(self):
41
+ """
42
+ derived class constructor
43
+ """
44
+ super().__init__()
45
+
46
+ def forward(self, x, alpha=1e-8):
47
+ """
48
+ forward pass of the layer
49
+ :param x: input activation volume
50
+ :param alpha: small number for numerical stability
51
+ :return: y => x appended with standard deviation constant map
52
+ """
53
+ batch_size, _, height, width = x.shape
54
+
55
+ # [B x C x H x W] Subtract mean over batch.
56
+ y = x - x.mean(dim=0, keepdim=True)
57
+
58
+ # [1 x C x H x W] Calc standard deviation over batch
59
+ y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)
60
+
61
+ # [1] Take average over feature_maps and pixels.
62
+ y = y.mean().view(1, 1, 1, 1)
63
+
64
+ # [B x 1 x H x W] Replicate over group and pixels.
65
+ y = y.repeat(batch_size, 1, height, width)
66
+
67
+ # [B x C x H x W] Append as new feature_map.
68
+ y = th.cat([x, y], 1)
69
+
70
+ # return the computed values:
71
+ return y
72
+
73
+
74
+
75
+
76
+
77
+ # ==========================================================
78
+ # Equalized learning rate blocks:
79
+ # extending Conv2D and Deconv2D layers for equalized learning rate logic
80
+ # ==========================================================
81
+ class _equalized_conv2d(th.nn.Module):
82
+ """ conv2d with the concept of equalized learning rate
83
+ Args:
84
+ :param c_in: input channels
85
+ :param c_out: output channels
86
+ :param k_size: kernel size (h, w) should be a tuple or a single integer
87
+ :param stride: stride for conv
88
+ :param pad: padding
89
+ :param bias: whether to use bias or not
90
+ """
91
+
92
+ def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
93
+ """ constructor for the class """
94
+ from torch.nn.modules.utils import _pair
95
+ from numpy import sqrt, prod
96
+
97
+ super().__init__()
98
+
99
+ # define the weight and bias if to be used
100
+ self.weight = th.nn.Parameter(th.nn.init.normal_(
101
+ th.empty(c_out, c_in, *_pair(k_size))
102
+ ))
103
+
104
+ self.use_bias = bias
105
+ self.stride = stride
106
+ self.pad = pad
107
+
108
+ if self.use_bias:
109
+ self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
110
+
111
+ fan_in = prod(_pair(k_size)) * c_in # value of fan_in
112
+ self.scale = sqrt(2) / sqrt(fan_in)
113
+
114
+ def forward(self, x):
115
+ """
116
+ forward pass of the network
117
+ :param x: input
118
+ :return: y => output
119
+ """
120
+ from torch.nn.functional import conv2d
121
+
122
+ return conv2d(input=x,
123
+ weight=self.weight * self.scale, # scale the weight on runtime
124
+ bias=self.bias if self.use_bias else None,
125
+ stride=self.stride,
126
+ padding=self.pad)
127
+
128
+ def extra_repr(self):
129
+ return ", ".join(map(str, self.weight.shape))
130
+
131
+
132
+ class _equalized_deconv2d(th.nn.Module):
133
+ """ Transpose convolution using the equalized learning rate
134
+ Args:
135
+ :param c_in: input channels
136
+ :param c_out: output channels
137
+ :param k_size: kernel size
138
+ :param stride: stride for convolution transpose
139
+ :param pad: padding
140
+ :param bias: whether to use bias or not
141
+ """
142
+
143
+ def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
144
+ """ constructor for the class """
145
+ from torch.nn.modules.utils import _pair
146
+ from numpy import sqrt
147
+
148
+ super().__init__()
149
+
150
+ # define the weight and bias if to be used
151
+ self.weight = th.nn.Parameter(th.nn.init.normal_(
152
+ th.empty(c_in, c_out, *_pair(k_size))
153
+ ))
154
+
155
+ self.use_bias = bias
156
+ self.stride = stride
157
+ self.pad = pad
158
+
159
+ if self.use_bias:
160
+ self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
161
+
162
+ fan_in = c_in # value of fan_in for deconv
163
+ self.scale = sqrt(2) / sqrt(fan_in)
164
+
165
+ def forward(self, x):
166
+ """
167
+ forward pass of the layer
168
+ :param x: input
169
+ :return: y => output
170
+ """
171
+ from torch.nn.functional import conv_transpose2d
172
+
173
+ return conv_transpose2d(input=x,
174
+ weight=self.weight * self.scale, # scale the weight on runtime
175
+ bias=self.bias if self.use_bias else None,
176
+ stride=self.stride,
177
+ padding=self.pad)
178
+
179
+ def extra_repr(self):
180
+ return ", ".join(map(str, self.weight.shape))
181
+
182
+
183
+
184
+ #basic block of the encoding part of the genarater
185
+ #编码器的基本卷积块
186
+ class conv_block(nn.Module):
187
+ """
188
+ Convolution Block
189
+ with two convolution layers
190
+ """
191
+ def __init__(self, in_ch, out_ch,use_eql=True):
192
+ super(conv_block, self).__init__()
193
+
194
+ if use_eql:
195
+ self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
196
+ pad=0, bias=True)
197
+ self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
198
+ pad=1, bias=True)
199
+ self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
200
+ pad=1, bias=True)
201
+
202
+ else:
203
+ self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
204
+ padding=1, bias=True)
205
+ self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
206
+ padding=1, bias=True)
207
+
208
+ # pixel_wise feature normalizer:
209
+ self.pixNorm = PixelwiseNorm()
210
+
211
+ # leaky_relu:
212
+ self.lrelu = LeakyReLU(0.2)
213
+
214
+ def forward(self, x):
215
+ """
216
+ forward pass of the block
217
+ :param x: input
218
+ :return: y => output
219
+ """
220
+ from torch.nn.functional import interpolate
221
+
222
+ #y = interpolate(x, scale_factor=2)
223
+ y=self.conv_1(self.lrelu(self.pixNorm(x)))
224
+ residual=y
225
+ y=self.conv_2(self.lrelu(self.pixNorm(y)))
226
+ y=self.conv_3(self.lrelu(self.pixNorm(y)))
227
+ y=y+residual
228
+
229
+
230
+ return y
231
+
232
+
233
+
234
+
235
+ #basic up convolution block of the encoding part of the genarater
236
+ #编码器的基本卷积块
237
+ class up_conv(nn.Module):
238
+ """
239
+ Up Convolution Block
240
+ """
241
+ def __init__(self, in_ch, out_ch,use_eql=True):
242
+ super(up_conv, self).__init__()
243
+ if use_eql:
244
+ self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
245
+ pad=0, bias=True)
246
+ self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
247
+ pad=1, bias=True)
248
+ self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
249
+ pad=1, bias=True)
250
+
251
+ else:
252
+ self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
253
+ padding=1, bias=True)
254
+ self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
255
+ padding=1, bias=True)
256
+
257
+ # pixel_wise feature normalizer:
258
+ self.pixNorm = PixelwiseNorm()
259
+
260
+ # leaky_relu:
261
+ self.lrelu = LeakyReLU(0.2)
262
+
263
+ def forward(self, x):
264
+ """
265
+ forward pass of the block
266
+ :param x: input
267
+ :return: y => output
268
+ """
269
+ from torch.nn.functional import interpolate
270
+
271
+ x = interpolate(x, scale_factor=2, mode="bilinear")
272
+ y=self.conv_1(self.lrelu(self.pixNorm(x)))
273
+ residual=y
274
+ y=self.conv_2(self.lrelu(self.pixNorm(y)))
275
+ y=self.conv_3(self.lrelu(self.pixNorm(y)))
276
+ y=y+residual
277
+
278
+ return y
279
+
280
+
281
+
282
+
283
+ #判别器的最后一层
284
+ class DisFinalBlock(th.nn.Module):
285
+ """ Final block for the Discriminator """
286
+
287
+ def __init__(self, in_channels, use_eql=True):
288
+ """
289
+ constructor of the class
290
+ :param in_channels: number of input channels
291
+ :param use_eql: whether to use equalized learning rate
292
+ """
293
+ from torch.nn import LeakyReLU
294
+ from torch.nn import Conv2d
295
+
296
+ super().__init__()
297
+
298
+ # declare the required modules for forward pass
299
+ self.batch_discriminator = MinibatchStdDev()
300
+
301
+ if use_eql:
302
+ self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3),
303
+ pad=1, bias=True)
304
+ self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1,
305
+ bias=True)
306
+
307
+ # final layer emulates the fully connected layer
308
+ self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
309
+
310
+ else:
311
+ # modules required:
312
+ self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
313
+ self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
314
+
315
+ # final conv layer emulates a fully connected layer
316
+ self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)
317
+
318
+ # leaky_relu:
319
+ self.lrelu = LeakyReLU(0.2)
320
+
321
+ def forward(self, x):
322
+ """
323
+ forward pass of the FinalBlock
324
+ :param x: input
325
+ :return: y => output
326
+ """
327
+ # minibatch_std_dev layer
328
+ y = self.batch_discriminator(x)
329
+
330
+ # define the computations
331
+ y = self.lrelu(self.conv_1(y))
332
+ y = self.lrelu(self.conv_2(y))
333
+
334
+ # fully connected layer
335
+ y = self.conv_3(y) # This layer has linear activation
336
+
337
+ # flatten the output raw discriminator scores
338
+ return y
339
+
340
+
341
+
342
+ #判别器基本卷积块
343
+ class DisGeneralConvBlock(th.nn.Module):
344
+ """ General block in the discriminator """
345
+
346
+ def __init__(self, in_channels, out_channels, use_eql=True):
347
+ """
348
+ constructor of the class
349
+ :param in_channels: number of input channels
350
+ :param out_channels: number of output channels
351
+ :param use_eql: whether to use equalized learning rate
352
+ """
353
+ from torch.nn import AvgPool2d, LeakyReLU
354
+ from torch.nn import Conv2d
355
+
356
+ super().__init__()
357
+
358
+ if use_eql:
359
+ self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3),
360
+ pad=1, bias=True)
361
+ self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3),
362
+ pad=1, bias=True)
363
+ else:
364
+ # convolutional modules
365
+ self.conv_1 = Conv2d(in_channels, in_channels, (3, 3),
366
+ padding=1, bias=True)
367
+ self.conv_2 = Conv2d(in_channels, out_channels, (3, 3),
368
+ padding=1, bias=True)
369
+
370
+ self.downSampler = AvgPool2d(2) # downsampler
371
+
372
+ # leaky_relu:
373
+ self.lrelu = LeakyReLU(0.2)
374
+
375
+ def forward(self, x):
376
+ """
377
+ forward pass of the module
378
+ :param x: input
379
+ :return: y => output
380
+ """
381
+ # define the computations
382
+ y = self.lrelu(self.conv_1(x))
383
+ y = self.lrelu(self.conv_2(y))
384
+ y = self.downSampler(y)
385
+
386
+ return y
387
+
388
+
389
+
390
+
391
+
392
+ class from_rgb(nn.Module):
393
+ """
394
+ The RGB image is transformed into a multi-channel feature map to be concatenated with
395
+ the feature map with the same number of channels in the network
396
+ 把RGB图转换为多通道特征图,以便与网络中相同通道数的特征图拼接
397
+ """
398
+ def __init__(self, outchannels, use_eql=True):
399
+ super(from_rgb, self).__init__()
400
+ if use_eql:
401
+ self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True)
402
+ else:
403
+ self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True)
404
+ # pixel_wise feature normalizer:
405
+ self.pixNorm = PixelwiseNorm()
406
+
407
+ # leaky_relu:
408
+ self.lrelu = LeakyReLU(0.2)
409
+
410
+
411
+ def forward(self, x):
412
+ """
413
+ forward pass of the block
414
+ :param x: input
415
+ :return: y => output
416
+ """
417
+ y = self.pixNorm(self.lrelu(self.conv_1(x)))
418
+ return y
419
+
420
+ class to_rgb(nn.Module):
421
+ """
422
+ 把多通道特征图转换为RGB三通道图,以便输入判别器
423
+ The multi-channel feature map is converted into RGB image for input to the discriminator
424
+ """
425
+ def __init__(self, inchannels, use_eql=True):
426
+ super(to_rgb, self).__init__()
427
+ if use_eql:
428
+ self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True)
429
+ else:
430
+ self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True)
431
+
432
+
433
+
434
+
435
+
436
+ def forward(self, x):
437
+ """
438
+ forward pass of the block
439
+ :param x: input
440
+ :return: y => output
441
+ """
442
+
443
+ y = self.conv_1(x)
444
+
445
+ return y
446
+
447
+ class Flatten(nn.Module):
448
+ def forward(self, x):
449
+ return x.view(x.size(0), -1)
450
+
451
+
452
+
453
+ class CCA(nn.Module):
454
+ """
455
+ CCA Block
456
+ """
457
+ def __init__(self, F_g, F_x):
458
+ super().__init__()
459
+ self.mlp_x = nn.Sequential(
460
+ Flatten(),
461
+ nn.Linear(F_x, F_x))
462
+ self.mlp_g = nn.Sequential(
463
+ Flatten(),
464
+ nn.Linear(F_g, F_x))
465
+ self.relu = nn.ReLU(inplace=True)
466
+
467
+ def forward(self, g, x):
468
+ # channel-wise attention
469
+ avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
470
+ channel_att_x = self.mlp_x(avg_pool_x)
471
+ avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
472
+ channel_att_g = self.mlp_g(avg_pool_g)
473
+ channel_att_sum = (channel_att_x + channel_att_g)/2.0
474
+ scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
475
+ x_after_channel = x * scale
476
+ out = self.relu(x_after_channel)
477
+ return out
net/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from skimage.measure.simple_metrics import compare_psnr
6
+ from torchvision import models
7
+
8
+
9
+ def weights_init_kaiming(m):
10
+ classname = m.__class__.__name__
11
+ if classname.find('Conv') != -1:
12
+ nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
13
+ elif classname.find('Linear') != -1:
14
+ nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
15
+ elif classname.find('BatchNorm') != -1:
16
+ # nn.init.uniform(m.weight.data, 1.0, 0.02)
17
+ m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
18
+ nn.init.constant(m.bias.data, 0.0)
19
+
20
+ class VGG19_PercepLoss(nn.Module):
21
+ """ Calculates perceptual loss in vgg19 space
22
+ """
23
+ def __init__(self, _pretrained_=True):
24
+ super(VGG19_PercepLoss, self).__init__()
25
+ self.vgg = models.vgg19(pretrained=_pretrained_).features
26
+ for param in self.vgg.parameters():
27
+ param.requires_grad_(False)
28
+
29
+ def get_features(self, image, layers=None):
30
+ if layers is None:
31
+ layers = {'30': 'conv5_2'} # may add other layers
32
+ features = {}
33
+ x = image
34
+ for name, layer in self.vgg._modules.items():
35
+ x = layer(x)
36
+ if name in layers:
37
+ features[layers[name]] = x
38
+ return features
39
+
40
+ def forward(self, pred, true, layer='conv5_2'):
41
+ true_f = self.get_features(true)
42
+ pred_f = self.get_features(pred)
43
+ return torch.mean((true_f[layer]-pred_f[layer])**2)
44
+
45
+
46
+ def batch_PSNR(img, imclean, data_range):
47
+ Img = img.data.cpu().numpy().astype(np.float32)
48
+ Iclean = imclean.data.cpu().numpy().astype(np.float32)
49
+ PSNR = 0
50
+ for i in range(Img.shape[0]):
51
+ PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
52
+ return (PSNR/Img.shape[0])
53
+
54
+ def data_augmentation(image, mode):
55
+ out = np.transpose(image, (1,2,0))
56
+ #out = image
57
+ if mode == 0:
58
+ # original
59
+ out = out
60
+ elif mode == 1:
61
+ # flip up and down
62
+ out = np.flipud(out)
63
+ elif mode == 2:
64
+ # rotate counterwise 90 degree
65
+ out = np.rot90(out)
66
+ elif mode == 3:
67
+ # rotate 90 degree and flip up and down
68
+ out = np.rot90(out)
69
+ out = np.flipud(out)
70
+ elif mode == 4:
71
+ # rotate 180 degree
72
+ out = np.rot90(out, k=2)
73
+ elif mode == 5:
74
+ # rotate 180 degree and flip
75
+ out = np.rot90(out, k=2)
76
+ out = np.flipud(out)
77
+ elif mode == 6:
78
+ # rotate 270 degree
79
+ out = np.rot90(out, k=3)
80
+ elif mode == 7:
81
+ # rotate 270 degree and flip
82
+ out = np.rot90(out, k=3)
83
+ out = np.flipud(out)
84
+ return np.transpose(out, (2,0,1))
85
+ #return out
86
+