Spaces:
Sleeping
Sleeping
amazinghaha
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,11 @@ import os
|
|
4 |
import numpy as np
|
5 |
import SimpleITK as sitk
|
6 |
from scipy.ndimage import zoom
|
7 |
-
from resnet_gn import resnet50
|
8 |
import pickle
|
|
|
|
|
|
|
|
|
9 |
#import tempfile
|
10 |
|
11 |
def load_from_pkl(load_path):
|
@@ -14,18 +17,190 @@ def load_from_pkl(load_path):
|
|
14 |
data_input.close()
|
15 |
return read_data
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
Image_3D = None
|
19 |
Current_name = None
|
20 |
-
ALL_message = load_from_pkl(r'./label0601.pkl')
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
checkpoint = torch.load(Model_Paht, map_location='cpu')
|
24 |
|
25 |
-
classnet =
|
26 |
-
num_classes=1,
|
27 |
-
sample_size=128,
|
28 |
-
sample_duration=8)
|
29 |
classnet.load_state_dict(checkpoint['model_dict'])
|
30 |
|
31 |
|
@@ -48,34 +223,54 @@ def resize3D(img, aimsize, order=3):
|
|
48 |
order=order) # resample for cube_size
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def inference():
|
52 |
global Image_small_3D
|
|
|
53 |
model = classnet
|
54 |
-
|
55 |
-
|
|
|
56 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
model.eval()
|
58 |
-
|
59 |
-
length = 0
|
60 |
try:
|
|
|
61 |
with torch.no_grad():
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
except:
|
80 |
return ' '
|
81 |
|
@@ -231,6 +426,7 @@ def arry_crop_3D(img,mask,ex_pix):
|
|
231 |
def data_pretreatment():
|
232 |
global Image_3D
|
233 |
global ROI_3D
|
|
|
234 |
global Image_small_3D
|
235 |
global Current_name
|
236 |
global Input_File
|
@@ -238,15 +434,17 @@ def data_pretreatment():
|
|
238 |
return '没有数据'
|
239 |
else:
|
240 |
roi = ROI_3D
|
241 |
-
waikuo = [4, 4, 4]
|
242 |
-
fina_img, fina_mask = arry_crop_3D(Image_3D,roi,waikuo)
|
243 |
|
244 |
cut_thre = np.percentile(fina_img, 99.9) # 直方图99.9%右侧值不要
|
245 |
fina_img[fina_img >= cut_thre] = cut_thre
|
246 |
-
|
247 |
-
fina_img = resize3D(fina_img, [
|
|
|
248 |
fina_img = (np.max(fina_img)-fina_img)/(np.max(fina_img)-np.min(fina_img))
|
249 |
Image_small_3D = fina_img
|
|
|
250 |
return '预处理结束'
|
251 |
class App:
|
252 |
def __init__(self):
|
@@ -315,16 +513,16 @@ class App:
|
|
315 |
|
316 |
gr.Markdown('''# Examples''')
|
317 |
gr.Examples(
|
318 |
-
examples=[["./
|
319 |
-
["./
|
320 |
inputs=inp,
|
321 |
outputs=[out1, out2, out3, slider1, slider2, slider3,out8],
|
322 |
fn=get_Image_reslice,
|
323 |
cache_examples=True,
|
324 |
)
|
325 |
gr.Examples(
|
326 |
-
examples=[["./
|
327 |
-
["./
|
328 |
inputs=inp2,
|
329 |
outputs=out9,
|
330 |
fn=get_ROI,
|
@@ -334,4 +532,4 @@ class App:
|
|
334 |
demo.launch(share=False)
|
335 |
|
336 |
|
337 |
-
app = App()
|
|
|
4 |
import numpy as np
|
5 |
import SimpleITK as sitk
|
6 |
from scipy.ndimage import zoom
|
|
|
7 |
import pickle
|
8 |
+
from model.Vision_Transformer_with_mask import vit_base_patch16_224,Attention,CrossAttention,Attention_ori
|
9 |
+
from model.CoordAttention import *
|
10 |
+
from typing import Tuple, Type
|
11 |
+
from torch import Tensor, nn
|
12 |
#import tempfile
|
13 |
|
14 |
def load_from_pkl(load_path):
|
|
|
17 |
data_input.close()
|
18 |
return read_data
|
19 |
|
20 |
+
class MLP_att_out(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, input_dim, inter_dim=None, output_dim=None, activation="relu", drop=0.0):
|
23 |
+
super().__init__()
|
24 |
+
self.input_dim = input_dim
|
25 |
+
self.inter_dim = inter_dim
|
26 |
+
self.output_dim = output_dim
|
27 |
+
if inter_dim is None: self.inter_dim=input_dim
|
28 |
+
if output_dim is None: self.output_dim=input_dim
|
29 |
+
|
30 |
+
self.linear1 = nn.Linear(self.input_dim, self.inter_dim)
|
31 |
+
self.activation = self._get_activation_fn(activation)
|
32 |
+
self.dropout3 = nn.Dropout(drop)
|
33 |
+
self.linear2 = nn.Linear(self.inter_dim, self.output_dim)
|
34 |
+
self.dropout4 = nn.Dropout(drop)
|
35 |
+
self.norm3 = nn.LayerNorm(self.output_dim)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.linear2(self.dropout3(self.activation(self.linear1(x))))
|
39 |
+
x = x + self.dropout4(x)
|
40 |
+
x = self.norm3(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
def _get_activation_fn(self, activation):
|
44 |
+
"""Return an activation function given a string"""
|
45 |
+
if activation == "relu":
|
46 |
+
return F.relu
|
47 |
+
if activation == "gelu":
|
48 |
+
return F.gelu
|
49 |
+
if activation == "glu":
|
50 |
+
return F.glu
|
51 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
52 |
+
|
53 |
+
class MLPBlock(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
embedding_dim: int,
|
57 |
+
mlp_dim: int,
|
58 |
+
act: Type[nn.Module] = nn.GELU,
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
62 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
63 |
+
self.act = act()
|
64 |
+
|
65 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
66 |
+
return self.lin2(self.act(self.lin1(x)))
|
67 |
+
class FusionAttentionBlock(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
embedding_dim: int,
|
71 |
+
num_heads: int,
|
72 |
+
mlp_dim: int = 2048,
|
73 |
+
activation: Type[nn.Module] = nn.ReLU,
|
74 |
+
) -> None:
|
75 |
+
"""
|
76 |
+
A transformer block with four layers: (1) self-attention of sparse
|
77 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
78 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
79 |
+
inputs.
|
80 |
+
|
81 |
+
Arguments:
|
82 |
+
embedding_dim (int): the channel dimension of the embeddings
|
83 |
+
num_heads (int): the number of heads in the attention layers
|
84 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
85 |
+
activation (nn.Module): the activation of the mlp block
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.self_attn = Attention_ori(embedding_dim, num_heads)
|
89 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
90 |
+
self.cross_attn_mask_to_image = CrossAttention(dim=embedding_dim, num_heads=num_heads)
|
91 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
92 |
+
|
93 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
94 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
95 |
+
|
96 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
97 |
+
self.cross_attn_image_to_mask = CrossAttention(dim=embedding_dim, num_heads=num_heads)
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, img_emb: Tensor, mask_emb: Tensor, atten_mask: Tensor) -> Tuple[ Tensor]:
|
101 |
+
# Self attention block #最开始的时候 queries=query_pe
|
102 |
+
#queries: Tensor, keys: Tensor
|
103 |
+
queries = mask_emb
|
104 |
+
attn_out = self.self_attn(queries) #小图
|
105 |
+
queries = attn_out
|
106 |
+
#queries = queries + attn_out
|
107 |
+
queries = self.norm1(queries)
|
108 |
+
|
109 |
+
# Cross attention block, mask attending to image embedding
|
110 |
+
q = queries #1,5,256
|
111 |
+
k = img_emb # v是值,因此用keys?
|
112 |
+
input_x = torch.cat((q, k), dim=1) # 2 50 768
|
113 |
+
attn_out = self.cross_attn_mask_to_image(input_x) #TODO 要不要mask呢 交叉的时候 先不用试试
|
114 |
+
queries = queries + attn_out
|
115 |
+
queries = self.norm2(queries)
|
116 |
+
|
117 |
+
# MLP block
|
118 |
+
mlp_out = self.mlp(queries)
|
119 |
+
queries = queries + mlp_out
|
120 |
+
queries = self.norm3(queries)
|
121 |
+
|
122 |
+
# Cross attention block, image embedding attending to tokens
|
123 |
+
q = img_emb
|
124 |
+
k = queries
|
125 |
+
input_x = torch.cat((q, k), dim=1)
|
126 |
+
attn_out = self.cross_attn_image_to_mask(input_x)
|
127 |
+
img_emb = img_emb + attn_out
|
128 |
+
img_emb = self.norm4(img_emb)
|
129 |
+
|
130 |
+
return img_emb
|
131 |
+
|
132 |
+
class my_model7(nn.Module):
|
133 |
+
'''不用mask的版本
|
134 |
+
concate 部分 加了nor 加 attention
|
135 |
+
attention 用不一样的方法
|
136 |
+
'''
|
137 |
+
def __init__(self, pretrained=False,num_classes=3,in_chans=1,img_size=224, **kwargs):
|
138 |
+
super().__init__()
|
139 |
+
self.backboon1 = vit_base_patch16_224(pretrained=False,in_chans=in_chans, as_backbone=True,img_size=img_size)
|
140 |
+
if pretrained:
|
141 |
+
pre_train_model = timm.create_model('vit_base_patch16_224', pretrained=True, in_chans=in_chans, num_classes=3)
|
142 |
+
self.backboon1 = load_weights(self.backboon1, pre_train_model.state_dict())
|
143 |
+
#self.backboon2 = vit_base_patch32_224(pretrained=False,as_backbone=True) #TODO 同一个网络共享参数/不共享参数/patch不同网络
|
144 |
+
self.self_atten_img = Attention_ori(dim= self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
|
145 |
+
#self.self_atten_mask = Attention(dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
|
146 |
+
self.self_atten_mask = Attention_ori(dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
|
147 |
+
self.cross_atten = FusionAttentionBlock(embedding_dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
|
148 |
+
#self.external_attention = ExternalAttention(d_model=2304,S=8)
|
149 |
+
self.mlp = MLP_att_out(input_dim=self.backboon1.embed_dim * 3, output_dim=self.backboon1.embed_dim)
|
150 |
+
self.attention = CoordAtt(1,1,1)
|
151 |
+
self.norm1 = nn.LayerNorm(self.backboon1.embed_dim)
|
152 |
+
self.norm2 = nn.LayerNorm(self.backboon1.embed_dim)
|
153 |
+
self.norm3 = nn.LayerNorm(self.backboon1.embed_dim)
|
154 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
155 |
+
self.head = nn.Linear(self.backboon1.embed_dim*3, num_classes) if num_classes > 0 else nn.Identity()
|
156 |
+
#self.head = nn.Linear(196, num_classes) if num_classes > 0 else nn.Identity()
|
157 |
+
def forward(self, img, mask):
|
158 |
+
|
159 |
+
x1 = self.backboon1(torch.cat((img, torch.zeros_like(img)), dim=1)) #TODO 是否用同一模型 还是不同 中间是否融合多尺度
|
160 |
+
x2 = self.backboon1(torch.cat((img*mask, torch.zeros_like(img)), dim=1)) #输出经过了归一化层 #小图
|
161 |
+
#自注意力+残差
|
162 |
+
x2_atten_mask = self.backboon1.atten_mask
|
163 |
+
x1_atten = self.self_atten_img(x1)
|
164 |
+
x2_atten = self.self_atten_mask(x2)
|
165 |
+
x1_out = self.norm1((x1 + x1_atten))
|
166 |
+
x2_out = self.norm2((x2 + x2_atten))
|
167 |
+
#交叉注意力
|
168 |
+
corss_out = self.norm3(self.cross_atten(x1, x2, x2_atten_mask))
|
169 |
+
#得到输出特征
|
170 |
+
out = torch.concat((x1_out, corss_out, x2_out), dim=2).permute(0, 2, 1)#12 2304 196
|
171 |
+
out = self.attention(out) #12 2304 196
|
172 |
+
#out_ = out.permute(0, 2, 1)
|
173 |
+
#out = self.mlp(out) # mlp #特征融合 2 196 768
|
174 |
+
# out = self.norm1(out) #这个好像不用 好像可以删掉
|
175 |
+
out = self.avgpool(out) # B C 1
|
176 |
+
out = torch.flatten(out, 1)
|
177 |
+
out = self.head(out)
|
178 |
+
|
179 |
+
return out
|
180 |
+
|
181 |
+
|
182 |
|
183 |
Image_3D = None
|
184 |
Current_name = None
|
|
|
185 |
|
186 |
+
ALL_message = load_from_pkl(r'.\label0601.pkl')
|
187 |
+
ALL_message2 = load_from_pkl(r'.\all_data_label.pkl')
|
188 |
+
a = ALL_message2['train']
|
189 |
+
a.update(ALL_message2['val'])
|
190 |
+
a.update(ALL_message2['test'])
|
191 |
+
ALL_message2 = a
|
192 |
+
|
193 |
+
LC_model_Paht = r'.\train_ADA_1.pkl'
|
194 |
+
LC_model = load_from_pkl(LC_model_Paht)['model'][0]
|
195 |
+
|
196 |
+
TF_model_Paht = r'.\tf_model.pkl'
|
197 |
+
TF_model = load_from_pkl(TF_model_Paht)['model']
|
198 |
+
DR_model = load_from_pkl(TF_model_Paht)['dr']
|
199 |
+
|
200 |
+
Model_Paht = r'./model_epoch120.pth.tar'
|
201 |
checkpoint = torch.load(Model_Paht, map_location='cpu')
|
202 |
|
203 |
+
classnet = my_model7(pretrained=False,num_classes=3,in_chans=1, img_size=224)
|
|
|
|
|
|
|
204 |
classnet.load_state_dict(checkpoint['model_dict'])
|
205 |
|
206 |
|
|
|
223 |
order=order) # resample for cube_size
|
224 |
|
225 |
|
226 |
+
def get_lc():
|
227 |
+
global Current_name
|
228 |
+
lc_min = np.array([17,1,0,1,1,1,1,1 , 1 , 1])
|
229 |
+
lc_max = np.array([96 ,2, 3 ,2, 2,2 , 2 ,2 ,2 ,4])
|
230 |
+
lc_key = ['age', 'sex', 'time', 'postpartum', 'traumatism', 'diabetes', 'high_blood_pressure', 'cerebral_infarction', 'postoperation']
|
231 |
+
|
232 |
+
lc_all = [ALL_message2[Current_name][ii] for ii in lc_key]
|
233 |
+
site_ = Current_name.split('_',1)[-1]
|
234 |
+
if site_ == 'A_L': lc_all.append(1)
|
235 |
+
elif site_ == 'A_R': lc_all.append(2)
|
236 |
+
elif site_ == 'B_L': lc_all.append(3)
|
237 |
+
elif site_ == 'B_R': lc_all.append(4)
|
238 |
+
else: pass
|
239 |
+
lc_all = (np.array(lc_all)-lc_min)/(lc_max-lc_min+ 1e-12)
|
240 |
+
a = 5
|
241 |
+
return lc_all
|
242 |
def inference():
|
243 |
global Image_small_3D
|
244 |
+
global ROI_small_3D
|
245 |
model = classnet
|
246 |
+
data_3d = Image_small_3D
|
247 |
+
lc_data = get_lc()
|
248 |
+
lc_data = np.expand_dims(lc_data, axis=0)
|
249 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
250 |
model.eval()
|
251 |
+
|
|
|
252 |
try:
|
253 |
+
#影像模型
|
254 |
with torch.no_grad():
|
255 |
+
all_probs = np.empty((0, 3))
|
256 |
+
for ii in tqdm(range(0, data_3d.shape[1]),total = data_3d.shape[1]):
|
257 |
+
data = torch.from_numpy(data_3d[:,ii,:])
|
258 |
+
roi = torch.from_numpy(ROI_small_3D[:,ii,:].astype(np.int8))
|
259 |
+
image = torch.unsqueeze(data, 0)
|
260 |
+
roi = torch.unsqueeze(torch.unsqueeze(roi, 0),0).to(device).float()
|
261 |
+
patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in})
|
262 |
+
|
263 |
+
# Pre : Prediction Result
|
264 |
+
pre_probs = model(patch_data,roi)
|
265 |
+
pre_probs = torch.nn.functional.softmax(pre_probs, dim=1)
|
266 |
+
all_probs = np.concatenate((all_probs, pre_probs.cpu().numpy()), axis=0)
|
267 |
+
dl_prob = np.mean(all_probs, axis=0)
|
268 |
+
dl_prob = np.expand_dims(dl_prob, axis=0)
|
269 |
+
lc_prob = LC_model.predict_proba(lc_data)
|
270 |
+
feature = DR_model.transform(np.concatenate([dl_prob, lc_prob], axis=1))
|
271 |
+
final_p = TF_model.predict_proba(feature)
|
272 |
+
final_p = np.round(final_p[0], decimals=2)
|
273 |
+
return {'急性期': final_p[0], '亚急性期': final_p[1], '慢性期': final_p[2]}
|
274 |
except:
|
275 |
return ' '
|
276 |
|
|
|
426 |
def data_pretreatment():
|
427 |
global Image_3D
|
428 |
global ROI_3D
|
429 |
+
global ROI_small_3D
|
430 |
global Image_small_3D
|
431 |
global Current_name
|
432 |
global Input_File
|
|
|
434 |
return '没有数据'
|
435 |
else:
|
436 |
roi = ROI_3D
|
437 |
+
# waikuo = [4, 4, 4]
|
438 |
+
# fina_img, fina_mask = arry_crop_3D(Image_3D,roi,waikuo)
|
439 |
|
440 |
cut_thre = np.percentile(fina_img, 99.9) # 直方图99.9%右侧值不要
|
441 |
fina_img[fina_img >= cut_thre] = cut_thre
|
442 |
+
z, y, x = fina_img.shape
|
443 |
+
fina_img = resize3D(fina_img, [224,y,224], order=3)
|
444 |
+
fina_roi = resize3D(roi, [224, y, 224], order=3)
|
445 |
fina_img = (np.max(fina_img)-fina_img)/(np.max(fina_img)-np.min(fina_img))
|
446 |
Image_small_3D = fina_img
|
447 |
+
ROI_small_3D = fina_roi
|
448 |
return '预处理结束'
|
449 |
class App:
|
450 |
def __init__(self):
|
|
|
513 |
|
514 |
gr.Markdown('''# Examples''')
|
515 |
gr.Examples(
|
516 |
+
examples=[["./2239561_B_R_MRI.nii.gz"],
|
517 |
+
["./2239561_B_R_MRI.nii.gz"]],
|
518 |
inputs=inp,
|
519 |
outputs=[out1, out2, out3, slider1, slider2, slider3,out8],
|
520 |
fn=get_Image_reslice,
|
521 |
cache_examples=True,
|
522 |
)
|
523 |
gr.Examples(
|
524 |
+
examples=[["./2239561_B_R_ROI.nii.gz"],
|
525 |
+
["./2239561_B_R_ROI.nii.gz"]],
|
526 |
inputs=inp2,
|
527 |
outputs=out9,
|
528 |
fn=get_ROI,
|
|
|
532 |
demo.launch(share=False)
|
533 |
|
534 |
|
535 |
+
app = App()
|