QJerry commited on
Commit
951b7de
1 Parent(s): 5f0497f

Upload 2 files

Browse files
Files changed (2) hide show
  1. UCTransNet.py +475 -0
  2. best_model.pth +3 -0
UCTransNet.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/2/17 11:06
3
+ # @Author : Haonan Wang
4
+ # @File : UCTransNet.py
5
+ # @Software: PyCharm
6
+
7
+
8
+
9
+ import torch.nn.functional as F
10
+ import copy
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from torch.nn import Dropout, Softmax, LayerNorm
16
+ from torch.nn.modules.utils import _pair, _triple
17
+
18
+
19
+
20
+ def get_activation(activation_type):
21
+ activation_type = activation_type.lower()
22
+ if hasattr(nn, activation_type):
23
+ return getattr(nn, activation_type)()
24
+ else:
25
+ return nn.ReLU()
26
+
27
+ def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
28
+ layers = []
29
+ layers.append(ConvBatchNorm(in_channels, out_channels, activation))
30
+
31
+ for _ in range(nb_Conv - 1):
32
+ layers.append(ConvBatchNorm(out_channels, out_channels, activation))
33
+ return nn.Sequential(*layers)
34
+
35
+ class ConvBatchNorm(nn.Module):
36
+ """(convolution => [BN] => ReLU)"""
37
+
38
+ def __init__(self, in_channels, out_channels, activation='ReLU'):
39
+ super(ConvBatchNorm, self).__init__()
40
+ self.conv = nn.Conv3d(in_channels, out_channels,
41
+ kernel_size=3, padding=1)
42
+ self.norm = nn.BatchNorm3d(out_channels)
43
+ self.activation = get_activation(activation)
44
+
45
+ def forward(self, x):
46
+ out = self.conv(x)
47
+ out = self.norm(out)
48
+ return self.activation(out)
49
+
50
+ class DownBlock(nn.Module):
51
+ """Downscaling with maxpool convolution"""
52
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
53
+ super(DownBlock, self).__init__()
54
+ self.maxpool = nn.MaxPool3d(2)
55
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
56
+
57
+ def forward(self, x):
58
+ out = self.maxpool(x)
59
+ return self.nConvs(out)
60
+
61
+ class Flatten(nn.Module):
62
+ def forward(self, x):
63
+ return x.view(x.size(0), -1)
64
+
65
+ class CCA(nn.Module):
66
+ """
67
+ CCA Block
68
+ """
69
+ def __init__(self, F_g, F_x):
70
+ super().__init__()
71
+ self.mlp_x = nn.Sequential(
72
+ Flatten(),
73
+ nn.Linear(F_x, F_x))
74
+ self.mlp_g = nn.Sequential(
75
+ Flatten(),
76
+ nn.Linear(F_g, F_x))
77
+ self.relu = nn.ReLU(inplace=True)
78
+
79
+ def forward(self, g, x):
80
+ # channel-wise attention
81
+ avg_pool_x = F.avg_pool3d( x, (x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
82
+ channel_att_x = self.mlp_x(avg_pool_x)
83
+ avg_pool_g = F.avg_pool3d( g, (g.size(2), g.size(3), g.size(4)), stride=(g.size(2), g.size(3), g.size(4)))
84
+ channel_att_g = self.mlp_g(avg_pool_g)
85
+ channel_att_sum = (channel_att_x + channel_att_g)/2.0
86
+ scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)
87
+ x_after_channel = x * scale
88
+ out = self.relu(x_after_channel)
89
+ return out
90
+
91
+ class UpBlock_attention(nn.Module):
92
+ def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
93
+ super().__init__()
94
+ self.up = nn.Upsample(scale_factor=2)
95
+ self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2)
96
+ self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
97
+
98
+ def forward(self, x, skip_x):
99
+ up = self.up(x)
100
+ skip_x_att = self.coatt(g=up, x=skip_x)
101
+ x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
102
+ return self.nConvs(x)
103
+
104
+ class UCTransNet(nn.Module):
105
+ def __init__(self, in_channels, out_channels, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, feature_size, img_size, patch_sizes):
106
+ super().__init__()
107
+ self.inc = ConvBatchNorm(in_channels, feature_size)
108
+ self.down1 = DownBlock(feature_size, feature_size*2, nb_Conv=2)
109
+ self.down2 = DownBlock(feature_size*2, feature_size*4, nb_Conv=2)
110
+ self.down3 = DownBlock(feature_size*4, feature_size*8, nb_Conv=2)
111
+ self.down4 = DownBlock(feature_size*8, feature_size*8, nb_Conv=2)
112
+ self.mtc = ChannelTransformer(img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate,
113
+ channel_num=[feature_size, feature_size*2, feature_size*4, feature_size*8],
114
+ patchSize=patch_sizes)
115
+ self.up4 = UpBlock_attention(feature_size*16, feature_size*4, nb_Conv=2)
116
+ self.up3 = UpBlock_attention(feature_size*8, feature_size*2, nb_Conv=2)
117
+ self.up2 = UpBlock_attention(feature_size*4, feature_size, nb_Conv=2)
118
+ self.up1 = UpBlock_attention(feature_size*2, feature_size, nb_Conv=2)
119
+ self.outc = nn.Conv3d(feature_size, out_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1))
120
+
121
+ def forward(self, x):
122
+ x = x.float()
123
+ x1 = self.inc(x)
124
+ x2 = self.down1(x1)
125
+ x3 = self.down2(x2)
126
+ x4 = self.down3(x3)
127
+ x5 = self.down4(x4)
128
+ x1,x2,x3,x4 = self.mtc(x1,x2,x3,x4)
129
+ x = self.up4(x5, x4)
130
+ x = self.up3(x, x3)
131
+ x = self.up2(x, x2)
132
+ x = self.up1(x, x1)
133
+
134
+ logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
135
+
136
+ return logits
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+ class Channel_Embeddings(nn.Module):
147
+ """Construct the embeddings from patch, position embeddings.
148
+ """
149
+ def __init__(self, patchsize, img_size, in_channels, reduce_scale):
150
+ super().__init__()
151
+ patch_size = _triple(patchsize)
152
+ n_patches = (img_size[0] // reduce_scale // patch_size[0]) * (img_size[1] // reduce_scale // patch_size[1]) * (img_size[2] // reduce_scale // patch_size[2])
153
+
154
+ self.patch_embeddings = nn.Conv3d(in_channels=in_channels,
155
+ out_channels=in_channels,
156
+ kernel_size=patch_size,
157
+ stride=patch_size)
158
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
159
+ self.dropout = Dropout(0.1)
160
+
161
+ def forward(self, x):
162
+ if x is None:
163
+ return None
164
+ x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
165
+ h, w, d = x.shape[-3:]
166
+ x = x.flatten(2)
167
+ x = x.transpose(-1, -2) # (B, n_patches, hidden)
168
+ embeddings = x + self.position_embeddings
169
+ embeddings = self.dropout(embeddings)
170
+ return embeddings, (h, w, d)
171
+
172
+ class Reconstruct(nn.Module):
173
+ def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
174
+ super(Reconstruct, self).__init__()
175
+ if kernel_size == 3:
176
+ padding = 1
177
+ else:
178
+ padding = 0
179
+ self.conv = nn.Conv3d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
180
+ self.norm = nn.BatchNorm3d(out_channels)
181
+ self.activation = nn.ReLU(inplace=True)
182
+ self.scale_factor = scale_factor
183
+
184
+ def forward(self, x, shp):
185
+ if x is None:
186
+ return None
187
+
188
+ B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
189
+ h, w, d = shp
190
+ x = x.permute(0, 2, 1)
191
+ x = x.contiguous().view(B, hidden, h, w, d)
192
+ x = nn.Upsample(scale_factor=self.scale_factor)(x)
193
+
194
+ out = self.conv(x)
195
+ out = self.norm(out)
196
+ out = self.activation(out)
197
+ return out
198
+
199
+ class Attention_org(nn.Module):
200
+ def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate):
201
+ super(Attention_org, self).__init__()
202
+ self.KV_size = KV_size
203
+ self.channel_num = channel_num
204
+ self.num_attention_heads = num_heads
205
+
206
+ self.query1 = nn.ModuleList()
207
+ self.query2 = nn.ModuleList()
208
+ self.query3 = nn.ModuleList()
209
+ self.query4 = nn.ModuleList()
210
+ self.key = nn.ModuleList()
211
+ self.value = nn.ModuleList()
212
+
213
+ for _ in range(num_heads):
214
+ query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
215
+ query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
216
+ query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
217
+ query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
218
+ key = nn.Linear( self.KV_size, self.KV_size, bias=False)
219
+ value = nn.Linear(self.KV_size, self.KV_size, bias=False)
220
+ self.query1.append(copy.deepcopy(query1))
221
+ self.query2.append(copy.deepcopy(query2))
222
+ self.query3.append(copy.deepcopy(query3))
223
+ self.query4.append(copy.deepcopy(query4))
224
+ self.key.append(copy.deepcopy(key))
225
+ self.value.append(copy.deepcopy(value))
226
+ self.psi = nn.InstanceNorm2d(self.num_attention_heads)
227
+ self.softmax = Softmax(dim=3)
228
+ self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
229
+ self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
230
+ self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
231
+ self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
232
+ self.attn_dropout = Dropout(attention_dropout_rate)
233
+ self.proj_dropout = Dropout(attention_dropout_rate)
234
+
235
+
236
+
237
+ def forward(self, emb1,emb2,emb3,emb4, emb_all):
238
+ multi_head_Q1_list = []
239
+ multi_head_Q2_list = []
240
+ multi_head_Q3_list = []
241
+ multi_head_Q4_list = []
242
+ multi_head_K_list = []
243
+ multi_head_V_list = []
244
+ if emb1 is not None:
245
+ for query1 in self.query1:
246
+ Q1 = query1(emb1)
247
+ multi_head_Q1_list.append(Q1)
248
+ if emb2 is not None:
249
+ for query2 in self.query2:
250
+ Q2 = query2(emb2)
251
+ multi_head_Q2_list.append(Q2)
252
+ if emb3 is not None:
253
+ for query3 in self.query3:
254
+ Q3 = query3(emb3)
255
+ multi_head_Q3_list.append(Q3)
256
+ if emb4 is not None:
257
+ for query4 in self.query4:
258
+ Q4 = query4(emb4)
259
+ multi_head_Q4_list.append(Q4)
260
+ for key in self.key:
261
+ K = key(emb_all)
262
+ multi_head_K_list.append(K)
263
+ for value in self.value:
264
+ V = value(emb_all)
265
+ multi_head_V_list.append(V)
266
+ # print(len(multi_head_Q4_list))
267
+
268
+ multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
269
+ multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
270
+ multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
271
+ multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
272
+ multi_head_K = torch.stack(multi_head_K_list, dim=1)
273
+ multi_head_V = torch.stack(multi_head_V_list, dim=1)
274
+
275
+ multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
276
+ multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
277
+ multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
278
+ multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
279
+
280
+ attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
281
+ attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
282
+ attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
283
+ attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
284
+
285
+ attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
286
+ attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
287
+ attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
288
+ attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
289
+
290
+ attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
291
+ attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
292
+ attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
293
+ attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
294
+ # print(attention_probs4.size())
295
+
296
+ attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
297
+ attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
298
+ attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
299
+ attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
300
+
301
+ multi_head_V = multi_head_V.transpose(-1, -2)
302
+ context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
303
+ context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
304
+ context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
305
+ context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
306
+
307
+ context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
308
+ context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
309
+ context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
310
+ context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
311
+ context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
312
+ context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
313
+ context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
314
+ context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
315
+
316
+ O1 = self.out1(context_layer1) if emb1 is not None else None
317
+ O2 = self.out2(context_layer2) if emb2 is not None else None
318
+ O3 = self.out3(context_layer3) if emb3 is not None else None
319
+ O4 = self.out4(context_layer4) if emb4 is not None else None
320
+ O1 = self.proj_dropout(O1) if emb1 is not None else None
321
+ O2 = self.proj_dropout(O2) if emb2 is not None else None
322
+ O3 = self.proj_dropout(O3) if emb3 is not None else None
323
+ O4 = self.proj_dropout(O4) if emb4 is not None else None
324
+ return O1,O2,O3,O4
325
+
326
+
327
+
328
+
329
+ class Mlp(nn.Module):
330
+ def __init__(self, in_channel, mlp_channel, dropout_rate):
331
+ super(Mlp, self).__init__()
332
+ self.fc1 = nn.Linear(in_channel, mlp_channel)
333
+ self.fc2 = nn.Linear(mlp_channel, in_channel)
334
+ self.act_fn = nn.GELU()
335
+ self.dropout = Dropout(dropout_rate)
336
+ self._init_weights()
337
+
338
+ def _init_weights(self):
339
+ nn.init.xavier_uniform_(self.fc1.weight)
340
+ nn.init.xavier_uniform_(self.fc2.weight)
341
+ nn.init.normal_(self.fc1.bias, std=1e-6)
342
+ nn.init.normal_(self.fc2.bias, std=1e-6)
343
+
344
+ def forward(self, x):
345
+ x = self.fc1(x)
346
+ x = self.act_fn(x)
347
+ x = self.dropout(x)
348
+ x = self.fc2(x)
349
+ x = self.dropout(x)
350
+ return x
351
+
352
+ class Block_ViT(nn.Module):
353
+ def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate):
354
+ super(Block_ViT, self).__init__()
355
+ self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
356
+ self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
357
+ self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
358
+ self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
359
+ self.attn_norm = LayerNorm(KV_size,eps=1e-6)
360
+ self.channel_attn = Attention_org(KV_size, channel_num, num_heads, attention_dropout_rate)
361
+
362
+ self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
363
+ self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
364
+ self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
365
+ self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
366
+ self.ffn1 = Mlp(channel_num[0],channel_num[0]*4, mlp_dropout_rate)
367
+ self.ffn2 = Mlp(channel_num[1],channel_num[1]*4, mlp_dropout_rate)
368
+ self.ffn3 = Mlp(channel_num[2],channel_num[2]*4, mlp_dropout_rate)
369
+ self.ffn4 = Mlp(channel_num[3],channel_num[3]*4, mlp_dropout_rate)
370
+
371
+
372
+ def forward(self, emb1,emb2,emb3,emb4):
373
+ embcat = []
374
+ org1 = emb1
375
+ org2 = emb2
376
+ org3 = emb3
377
+ org4 = emb4
378
+ for i in range(4):
379
+ var_name = "emb"+str(i+1)
380
+ tmp_var = locals()[var_name]
381
+ if tmp_var is not None:
382
+ embcat.append(tmp_var)
383
+
384
+ emb_all = torch.cat(embcat,dim=2)
385
+ cx1 = self.attn_norm1(emb1) if emb1 is not None else None
386
+ cx2 = self.attn_norm2(emb2) if emb2 is not None else None
387
+ cx3 = self.attn_norm3(emb3) if emb3 is not None else None
388
+ cx4 = self.attn_norm4(emb4) if emb4 is not None else None
389
+ emb_all = self.attn_norm(emb_all)
390
+ cx1,cx2,cx3,cx4 = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
391
+ cx1 = org1 + cx1 if emb1 is not None else None
392
+ cx2 = org2 + cx2 if emb2 is not None else None
393
+ cx3 = org3 + cx3 if emb3 is not None else None
394
+ cx4 = org4 + cx4 if emb4 is not None else None
395
+
396
+ org1 = cx1
397
+ org2 = cx2
398
+ org3 = cx3
399
+ org4 = cx4
400
+ x1 = self.ffn_norm1(cx1) if emb1 is not None else None
401
+ x2 = self.ffn_norm2(cx2) if emb2 is not None else None
402
+ x3 = self.ffn_norm3(cx3) if emb3 is not None else None
403
+ x4 = self.ffn_norm4(cx4) if emb4 is not None else None
404
+ x1 = self.ffn1(x1) if emb1 is not None else None
405
+ x2 = self.ffn2(x2) if emb2 is not None else None
406
+ x3 = self.ffn3(x3) if emb3 is not None else None
407
+ x4 = self.ffn4(x4) if emb4 is not None else None
408
+ x1 = x1 + org1 if emb1 is not None else None
409
+ x2 = x2 + org2 if emb2 is not None else None
410
+ x3 = x3 + org3 if emb3 is not None else None
411
+ x4 = x4 + org4 if emb4 is not None else None
412
+
413
+ return x1, x2, x3, x4
414
+
415
+
416
+ class Encoder(nn.Module):
417
+ def __init__(self, num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate):
418
+ super(Encoder, self).__init__()
419
+ self.layer = nn.ModuleList()
420
+ self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
421
+ self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
422
+ self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
423
+ self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
424
+ for _ in range(num_layers):
425
+ layer = Block_ViT(KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate)
426
+ self.layer.append(copy.deepcopy(layer))
427
+
428
+ def forward(self, emb1,emb2,emb3,emb4):
429
+ for layer_block in self.layer:
430
+ emb1,emb2,emb3,emb4 = layer_block(emb1,emb2,emb3,emb4)
431
+ emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
432
+ emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
433
+ emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
434
+ emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
435
+ return emb1,emb2,emb3,emb4
436
+
437
+
438
+ class ChannelTransformer(nn.Module):
439
+ def __init__(self, img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
440
+ super().__init__()
441
+
442
+ self.patchSize_1 = patchSize[0]
443
+ self.patchSize_2 = patchSize[1]
444
+ self.patchSize_3 = patchSize[2]
445
+ self.patchSize_4 = patchSize[3]
446
+ self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, reduce_scale=1, in_channels=channel_num[0])
447
+ self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size, reduce_scale=2, in_channels=channel_num[1])
448
+ self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size, reduce_scale=4, in_channels=channel_num[2])
449
+ self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size, reduce_scale=8, in_channels=channel_num[3])
450
+ self.encoder = Encoder(num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate)
451
+
452
+ self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=_triple(self.patchSize_1))
453
+ self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=_triple(self.patchSize_2))
454
+ self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=_triple(self.patchSize_3))
455
+ self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=_triple(self.patchSize_4))
456
+
457
+ def forward(self, en1, en2, en3, en4):
458
+
459
+ emb1, shp1 = self.embeddings_1(en1)
460
+ emb2, shp2 = self.embeddings_2(en2)
461
+ emb3, shp3 = self.embeddings_3(en3)
462
+ emb4, shp4 = self.embeddings_4(en4)
463
+
464
+ encoded1, encoded2, encoded3, encoded4 = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
465
+ x1 = self.reconstruct_1(encoded1, shp1) if en1 is not None else None
466
+ x2 = self.reconstruct_2(encoded2, shp2) if en2 is not None else None
467
+ x3 = self.reconstruct_3(encoded3, shp3) if en3 is not None else None
468
+ x4 = self.reconstruct_4(encoded4, shp4) if en4 is not None else None
469
+
470
+ x1 = x1 + en1 if en1 is not None else None
471
+ x2 = x2 + en2 if en2 is not None else None
472
+ x3 = x3 + en3 if en3 is not None else None
473
+ x4 = x4 + en4 if en4 is not None else None
474
+
475
+ return x1, x2, x3, x4
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae3a4051a40de52f51db628ff7737501ecf043bfc11a8931829a9885c559766a
3
+ size 816404132