Upload 2 files
Browse files- UCTransNet.py +475 -0
- best_model.pth +3 -0
UCTransNet.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/2/17 11:06
|
3 |
+
# @Author : Haonan Wang
|
4 |
+
# @File : UCTransNet.py
|
5 |
+
# @Software: PyCharm
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
from torch.nn import Dropout, Softmax, LayerNorm
|
16 |
+
from torch.nn.modules.utils import _pair, _triple
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def get_activation(activation_type):
|
21 |
+
activation_type = activation_type.lower()
|
22 |
+
if hasattr(nn, activation_type):
|
23 |
+
return getattr(nn, activation_type)()
|
24 |
+
else:
|
25 |
+
return nn.ReLU()
|
26 |
+
|
27 |
+
def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
|
28 |
+
layers = []
|
29 |
+
layers.append(ConvBatchNorm(in_channels, out_channels, activation))
|
30 |
+
|
31 |
+
for _ in range(nb_Conv - 1):
|
32 |
+
layers.append(ConvBatchNorm(out_channels, out_channels, activation))
|
33 |
+
return nn.Sequential(*layers)
|
34 |
+
|
35 |
+
class ConvBatchNorm(nn.Module):
|
36 |
+
"""(convolution => [BN] => ReLU)"""
|
37 |
+
|
38 |
+
def __init__(self, in_channels, out_channels, activation='ReLU'):
|
39 |
+
super(ConvBatchNorm, self).__init__()
|
40 |
+
self.conv = nn.Conv3d(in_channels, out_channels,
|
41 |
+
kernel_size=3, padding=1)
|
42 |
+
self.norm = nn.BatchNorm3d(out_channels)
|
43 |
+
self.activation = get_activation(activation)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
out = self.conv(x)
|
47 |
+
out = self.norm(out)
|
48 |
+
return self.activation(out)
|
49 |
+
|
50 |
+
class DownBlock(nn.Module):
|
51 |
+
"""Downscaling with maxpool convolution"""
|
52 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
53 |
+
super(DownBlock, self).__init__()
|
54 |
+
self.maxpool = nn.MaxPool3d(2)
|
55 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
out = self.maxpool(x)
|
59 |
+
return self.nConvs(out)
|
60 |
+
|
61 |
+
class Flatten(nn.Module):
|
62 |
+
def forward(self, x):
|
63 |
+
return x.view(x.size(0), -1)
|
64 |
+
|
65 |
+
class CCA(nn.Module):
|
66 |
+
"""
|
67 |
+
CCA Block
|
68 |
+
"""
|
69 |
+
def __init__(self, F_g, F_x):
|
70 |
+
super().__init__()
|
71 |
+
self.mlp_x = nn.Sequential(
|
72 |
+
Flatten(),
|
73 |
+
nn.Linear(F_x, F_x))
|
74 |
+
self.mlp_g = nn.Sequential(
|
75 |
+
Flatten(),
|
76 |
+
nn.Linear(F_g, F_x))
|
77 |
+
self.relu = nn.ReLU(inplace=True)
|
78 |
+
|
79 |
+
def forward(self, g, x):
|
80 |
+
# channel-wise attention
|
81 |
+
avg_pool_x = F.avg_pool3d( x, (x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
|
82 |
+
channel_att_x = self.mlp_x(avg_pool_x)
|
83 |
+
avg_pool_g = F.avg_pool3d( g, (g.size(2), g.size(3), g.size(4)), stride=(g.size(2), g.size(3), g.size(4)))
|
84 |
+
channel_att_g = self.mlp_g(avg_pool_g)
|
85 |
+
channel_att_sum = (channel_att_x + channel_att_g)/2.0
|
86 |
+
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)
|
87 |
+
x_after_channel = x * scale
|
88 |
+
out = self.relu(x_after_channel)
|
89 |
+
return out
|
90 |
+
|
91 |
+
class UpBlock_attention(nn.Module):
|
92 |
+
def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
|
93 |
+
super().__init__()
|
94 |
+
self.up = nn.Upsample(scale_factor=2)
|
95 |
+
self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2)
|
96 |
+
self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
|
97 |
+
|
98 |
+
def forward(self, x, skip_x):
|
99 |
+
up = self.up(x)
|
100 |
+
skip_x_att = self.coatt(g=up, x=skip_x)
|
101 |
+
x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
|
102 |
+
return self.nConvs(x)
|
103 |
+
|
104 |
+
class UCTransNet(nn.Module):
|
105 |
+
def __init__(self, in_channels, out_channels, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, feature_size, img_size, patch_sizes):
|
106 |
+
super().__init__()
|
107 |
+
self.inc = ConvBatchNorm(in_channels, feature_size)
|
108 |
+
self.down1 = DownBlock(feature_size, feature_size*2, nb_Conv=2)
|
109 |
+
self.down2 = DownBlock(feature_size*2, feature_size*4, nb_Conv=2)
|
110 |
+
self.down3 = DownBlock(feature_size*4, feature_size*8, nb_Conv=2)
|
111 |
+
self.down4 = DownBlock(feature_size*8, feature_size*8, nb_Conv=2)
|
112 |
+
self.mtc = ChannelTransformer(img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate,
|
113 |
+
channel_num=[feature_size, feature_size*2, feature_size*4, feature_size*8],
|
114 |
+
patchSize=patch_sizes)
|
115 |
+
self.up4 = UpBlock_attention(feature_size*16, feature_size*4, nb_Conv=2)
|
116 |
+
self.up3 = UpBlock_attention(feature_size*8, feature_size*2, nb_Conv=2)
|
117 |
+
self.up2 = UpBlock_attention(feature_size*4, feature_size, nb_Conv=2)
|
118 |
+
self.up1 = UpBlock_attention(feature_size*2, feature_size, nb_Conv=2)
|
119 |
+
self.outc = nn.Conv3d(feature_size, out_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = x.float()
|
123 |
+
x1 = self.inc(x)
|
124 |
+
x2 = self.down1(x1)
|
125 |
+
x3 = self.down2(x2)
|
126 |
+
x4 = self.down3(x3)
|
127 |
+
x5 = self.down4(x4)
|
128 |
+
x1,x2,x3,x4 = self.mtc(x1,x2,x3,x4)
|
129 |
+
x = self.up4(x5, x4)
|
130 |
+
x = self.up3(x, x3)
|
131 |
+
x = self.up2(x, x2)
|
132 |
+
x = self.up1(x, x1)
|
133 |
+
|
134 |
+
logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
|
135 |
+
|
136 |
+
return logits
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
class Channel_Embeddings(nn.Module):
|
147 |
+
"""Construct the embeddings from patch, position embeddings.
|
148 |
+
"""
|
149 |
+
def __init__(self, patchsize, img_size, in_channels, reduce_scale):
|
150 |
+
super().__init__()
|
151 |
+
patch_size = _triple(patchsize)
|
152 |
+
n_patches = (img_size[0] // reduce_scale // patch_size[0]) * (img_size[1] // reduce_scale // patch_size[1]) * (img_size[2] // reduce_scale // patch_size[2])
|
153 |
+
|
154 |
+
self.patch_embeddings = nn.Conv3d(in_channels=in_channels,
|
155 |
+
out_channels=in_channels,
|
156 |
+
kernel_size=patch_size,
|
157 |
+
stride=patch_size)
|
158 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
|
159 |
+
self.dropout = Dropout(0.1)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
if x is None:
|
163 |
+
return None
|
164 |
+
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
|
165 |
+
h, w, d = x.shape[-3:]
|
166 |
+
x = x.flatten(2)
|
167 |
+
x = x.transpose(-1, -2) # (B, n_patches, hidden)
|
168 |
+
embeddings = x + self.position_embeddings
|
169 |
+
embeddings = self.dropout(embeddings)
|
170 |
+
return embeddings, (h, w, d)
|
171 |
+
|
172 |
+
class Reconstruct(nn.Module):
|
173 |
+
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
174 |
+
super(Reconstruct, self).__init__()
|
175 |
+
if kernel_size == 3:
|
176 |
+
padding = 1
|
177 |
+
else:
|
178 |
+
padding = 0
|
179 |
+
self.conv = nn.Conv3d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
|
180 |
+
self.norm = nn.BatchNorm3d(out_channels)
|
181 |
+
self.activation = nn.ReLU(inplace=True)
|
182 |
+
self.scale_factor = scale_factor
|
183 |
+
|
184 |
+
def forward(self, x, shp):
|
185 |
+
if x is None:
|
186 |
+
return None
|
187 |
+
|
188 |
+
B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
189 |
+
h, w, d = shp
|
190 |
+
x = x.permute(0, 2, 1)
|
191 |
+
x = x.contiguous().view(B, hidden, h, w, d)
|
192 |
+
x = nn.Upsample(scale_factor=self.scale_factor)(x)
|
193 |
+
|
194 |
+
out = self.conv(x)
|
195 |
+
out = self.norm(out)
|
196 |
+
out = self.activation(out)
|
197 |
+
return out
|
198 |
+
|
199 |
+
class Attention_org(nn.Module):
|
200 |
+
def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate):
|
201 |
+
super(Attention_org, self).__init__()
|
202 |
+
self.KV_size = KV_size
|
203 |
+
self.channel_num = channel_num
|
204 |
+
self.num_attention_heads = num_heads
|
205 |
+
|
206 |
+
self.query1 = nn.ModuleList()
|
207 |
+
self.query2 = nn.ModuleList()
|
208 |
+
self.query3 = nn.ModuleList()
|
209 |
+
self.query4 = nn.ModuleList()
|
210 |
+
self.key = nn.ModuleList()
|
211 |
+
self.value = nn.ModuleList()
|
212 |
+
|
213 |
+
for _ in range(num_heads):
|
214 |
+
query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
215 |
+
query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
216 |
+
query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
217 |
+
query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
218 |
+
key = nn.Linear( self.KV_size, self.KV_size, bias=False)
|
219 |
+
value = nn.Linear(self.KV_size, self.KV_size, bias=False)
|
220 |
+
self.query1.append(copy.deepcopy(query1))
|
221 |
+
self.query2.append(copy.deepcopy(query2))
|
222 |
+
self.query3.append(copy.deepcopy(query3))
|
223 |
+
self.query4.append(copy.deepcopy(query4))
|
224 |
+
self.key.append(copy.deepcopy(key))
|
225 |
+
self.value.append(copy.deepcopy(value))
|
226 |
+
self.psi = nn.InstanceNorm2d(self.num_attention_heads)
|
227 |
+
self.softmax = Softmax(dim=3)
|
228 |
+
self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
229 |
+
self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
230 |
+
self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
231 |
+
self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
232 |
+
self.attn_dropout = Dropout(attention_dropout_rate)
|
233 |
+
self.proj_dropout = Dropout(attention_dropout_rate)
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
def forward(self, emb1,emb2,emb3,emb4, emb_all):
|
238 |
+
multi_head_Q1_list = []
|
239 |
+
multi_head_Q2_list = []
|
240 |
+
multi_head_Q3_list = []
|
241 |
+
multi_head_Q4_list = []
|
242 |
+
multi_head_K_list = []
|
243 |
+
multi_head_V_list = []
|
244 |
+
if emb1 is not None:
|
245 |
+
for query1 in self.query1:
|
246 |
+
Q1 = query1(emb1)
|
247 |
+
multi_head_Q1_list.append(Q1)
|
248 |
+
if emb2 is not None:
|
249 |
+
for query2 in self.query2:
|
250 |
+
Q2 = query2(emb2)
|
251 |
+
multi_head_Q2_list.append(Q2)
|
252 |
+
if emb3 is not None:
|
253 |
+
for query3 in self.query3:
|
254 |
+
Q3 = query3(emb3)
|
255 |
+
multi_head_Q3_list.append(Q3)
|
256 |
+
if emb4 is not None:
|
257 |
+
for query4 in self.query4:
|
258 |
+
Q4 = query4(emb4)
|
259 |
+
multi_head_Q4_list.append(Q4)
|
260 |
+
for key in self.key:
|
261 |
+
K = key(emb_all)
|
262 |
+
multi_head_K_list.append(K)
|
263 |
+
for value in self.value:
|
264 |
+
V = value(emb_all)
|
265 |
+
multi_head_V_list.append(V)
|
266 |
+
# print(len(multi_head_Q4_list))
|
267 |
+
|
268 |
+
multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
|
269 |
+
multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
|
270 |
+
multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
|
271 |
+
multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
|
272 |
+
multi_head_K = torch.stack(multi_head_K_list, dim=1)
|
273 |
+
multi_head_V = torch.stack(multi_head_V_list, dim=1)
|
274 |
+
|
275 |
+
multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
|
276 |
+
multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
|
277 |
+
multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
|
278 |
+
multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
|
279 |
+
|
280 |
+
attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
|
281 |
+
attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
|
282 |
+
attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
|
283 |
+
attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
|
284 |
+
|
285 |
+
attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
|
286 |
+
attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
|
287 |
+
attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
|
288 |
+
attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
|
289 |
+
|
290 |
+
attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
|
291 |
+
attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
|
292 |
+
attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
|
293 |
+
attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
|
294 |
+
# print(attention_probs4.size())
|
295 |
+
|
296 |
+
attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
|
297 |
+
attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
|
298 |
+
attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
|
299 |
+
attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
|
300 |
+
|
301 |
+
multi_head_V = multi_head_V.transpose(-1, -2)
|
302 |
+
context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
|
303 |
+
context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
|
304 |
+
context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
|
305 |
+
context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
|
306 |
+
|
307 |
+
context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
|
308 |
+
context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
|
309 |
+
context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
|
310 |
+
context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
|
311 |
+
context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
|
312 |
+
context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
|
313 |
+
context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
|
314 |
+
context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
|
315 |
+
|
316 |
+
O1 = self.out1(context_layer1) if emb1 is not None else None
|
317 |
+
O2 = self.out2(context_layer2) if emb2 is not None else None
|
318 |
+
O3 = self.out3(context_layer3) if emb3 is not None else None
|
319 |
+
O4 = self.out4(context_layer4) if emb4 is not None else None
|
320 |
+
O1 = self.proj_dropout(O1) if emb1 is not None else None
|
321 |
+
O2 = self.proj_dropout(O2) if emb2 is not None else None
|
322 |
+
O3 = self.proj_dropout(O3) if emb3 is not None else None
|
323 |
+
O4 = self.proj_dropout(O4) if emb4 is not None else None
|
324 |
+
return O1,O2,O3,O4
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
class Mlp(nn.Module):
|
330 |
+
def __init__(self, in_channel, mlp_channel, dropout_rate):
|
331 |
+
super(Mlp, self).__init__()
|
332 |
+
self.fc1 = nn.Linear(in_channel, mlp_channel)
|
333 |
+
self.fc2 = nn.Linear(mlp_channel, in_channel)
|
334 |
+
self.act_fn = nn.GELU()
|
335 |
+
self.dropout = Dropout(dropout_rate)
|
336 |
+
self._init_weights()
|
337 |
+
|
338 |
+
def _init_weights(self):
|
339 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
340 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
341 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
342 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
x = self.fc1(x)
|
346 |
+
x = self.act_fn(x)
|
347 |
+
x = self.dropout(x)
|
348 |
+
x = self.fc2(x)
|
349 |
+
x = self.dropout(x)
|
350 |
+
return x
|
351 |
+
|
352 |
+
class Block_ViT(nn.Module):
|
353 |
+
def __init__(self, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate):
|
354 |
+
super(Block_ViT, self).__init__()
|
355 |
+
self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
356 |
+
self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
357 |
+
self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
358 |
+
self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
359 |
+
self.attn_norm = LayerNorm(KV_size,eps=1e-6)
|
360 |
+
self.channel_attn = Attention_org(KV_size, channel_num, num_heads, attention_dropout_rate)
|
361 |
+
|
362 |
+
self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
363 |
+
self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
364 |
+
self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
365 |
+
self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
366 |
+
self.ffn1 = Mlp(channel_num[0],channel_num[0]*4, mlp_dropout_rate)
|
367 |
+
self.ffn2 = Mlp(channel_num[1],channel_num[1]*4, mlp_dropout_rate)
|
368 |
+
self.ffn3 = Mlp(channel_num[2],channel_num[2]*4, mlp_dropout_rate)
|
369 |
+
self.ffn4 = Mlp(channel_num[3],channel_num[3]*4, mlp_dropout_rate)
|
370 |
+
|
371 |
+
|
372 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
373 |
+
embcat = []
|
374 |
+
org1 = emb1
|
375 |
+
org2 = emb2
|
376 |
+
org3 = emb3
|
377 |
+
org4 = emb4
|
378 |
+
for i in range(4):
|
379 |
+
var_name = "emb"+str(i+1)
|
380 |
+
tmp_var = locals()[var_name]
|
381 |
+
if tmp_var is not None:
|
382 |
+
embcat.append(tmp_var)
|
383 |
+
|
384 |
+
emb_all = torch.cat(embcat,dim=2)
|
385 |
+
cx1 = self.attn_norm1(emb1) if emb1 is not None else None
|
386 |
+
cx2 = self.attn_norm2(emb2) if emb2 is not None else None
|
387 |
+
cx3 = self.attn_norm3(emb3) if emb3 is not None else None
|
388 |
+
cx4 = self.attn_norm4(emb4) if emb4 is not None else None
|
389 |
+
emb_all = self.attn_norm(emb_all)
|
390 |
+
cx1,cx2,cx3,cx4 = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
|
391 |
+
cx1 = org1 + cx1 if emb1 is not None else None
|
392 |
+
cx2 = org2 + cx2 if emb2 is not None else None
|
393 |
+
cx3 = org3 + cx3 if emb3 is not None else None
|
394 |
+
cx4 = org4 + cx4 if emb4 is not None else None
|
395 |
+
|
396 |
+
org1 = cx1
|
397 |
+
org2 = cx2
|
398 |
+
org3 = cx3
|
399 |
+
org4 = cx4
|
400 |
+
x1 = self.ffn_norm1(cx1) if emb1 is not None else None
|
401 |
+
x2 = self.ffn_norm2(cx2) if emb2 is not None else None
|
402 |
+
x3 = self.ffn_norm3(cx3) if emb3 is not None else None
|
403 |
+
x4 = self.ffn_norm4(cx4) if emb4 is not None else None
|
404 |
+
x1 = self.ffn1(x1) if emb1 is not None else None
|
405 |
+
x2 = self.ffn2(x2) if emb2 is not None else None
|
406 |
+
x3 = self.ffn3(x3) if emb3 is not None else None
|
407 |
+
x4 = self.ffn4(x4) if emb4 is not None else None
|
408 |
+
x1 = x1 + org1 if emb1 is not None else None
|
409 |
+
x2 = x2 + org2 if emb2 is not None else None
|
410 |
+
x3 = x3 + org3 if emb3 is not None else None
|
411 |
+
x4 = x4 + org4 if emb4 is not None else None
|
412 |
+
|
413 |
+
return x1, x2, x3, x4
|
414 |
+
|
415 |
+
|
416 |
+
class Encoder(nn.Module):
|
417 |
+
def __init__(self, num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate):
|
418 |
+
super(Encoder, self).__init__()
|
419 |
+
self.layer = nn.ModuleList()
|
420 |
+
self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
421 |
+
self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
422 |
+
self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
423 |
+
self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
424 |
+
for _ in range(num_layers):
|
425 |
+
layer = Block_ViT(KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate)
|
426 |
+
self.layer.append(copy.deepcopy(layer))
|
427 |
+
|
428 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
429 |
+
for layer_block in self.layer:
|
430 |
+
emb1,emb2,emb3,emb4 = layer_block(emb1,emb2,emb3,emb4)
|
431 |
+
emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
|
432 |
+
emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
|
433 |
+
emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
|
434 |
+
emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
|
435 |
+
return emb1,emb2,emb3,emb4
|
436 |
+
|
437 |
+
|
438 |
+
class ChannelTransformer(nn.Module):
|
439 |
+
def __init__(self, img_size, num_layers, KV_size, num_heads, attention_dropout_rate, mlp_dropout_rate, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
|
440 |
+
super().__init__()
|
441 |
+
|
442 |
+
self.patchSize_1 = patchSize[0]
|
443 |
+
self.patchSize_2 = patchSize[1]
|
444 |
+
self.patchSize_3 = patchSize[2]
|
445 |
+
self.patchSize_4 = patchSize[3]
|
446 |
+
self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, reduce_scale=1, in_channels=channel_num[0])
|
447 |
+
self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size, reduce_scale=2, in_channels=channel_num[1])
|
448 |
+
self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size, reduce_scale=4, in_channels=channel_num[2])
|
449 |
+
self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size, reduce_scale=8, in_channels=channel_num[3])
|
450 |
+
self.encoder = Encoder(num_layers, KV_size, channel_num, num_heads, attention_dropout_rate, mlp_dropout_rate)
|
451 |
+
|
452 |
+
self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=_triple(self.patchSize_1))
|
453 |
+
self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=_triple(self.patchSize_2))
|
454 |
+
self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=_triple(self.patchSize_3))
|
455 |
+
self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=_triple(self.patchSize_4))
|
456 |
+
|
457 |
+
def forward(self, en1, en2, en3, en4):
|
458 |
+
|
459 |
+
emb1, shp1 = self.embeddings_1(en1)
|
460 |
+
emb2, shp2 = self.embeddings_2(en2)
|
461 |
+
emb3, shp3 = self.embeddings_3(en3)
|
462 |
+
emb4, shp4 = self.embeddings_4(en4)
|
463 |
+
|
464 |
+
encoded1, encoded2, encoded3, encoded4 = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
|
465 |
+
x1 = self.reconstruct_1(encoded1, shp1) if en1 is not None else None
|
466 |
+
x2 = self.reconstruct_2(encoded2, shp2) if en2 is not None else None
|
467 |
+
x3 = self.reconstruct_3(encoded3, shp3) if en3 is not None else None
|
468 |
+
x4 = self.reconstruct_4(encoded4, shp4) if en4 is not None else None
|
469 |
+
|
470 |
+
x1 = x1 + en1 if en1 is not None else None
|
471 |
+
x2 = x2 + en2 if en2 is not None else None
|
472 |
+
x3 = x3 + en3 if en3 is not None else None
|
473 |
+
x4 = x4 + en4 if en4 is not None else None
|
474 |
+
|
475 |
+
return x1, x2, x3, x4
|
best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae3a4051a40de52f51db628ff7737501ecf043bfc11a8931829a9885c559766a
|
3 |
+
size 816404132
|