amazinghaha commited on
Commit
5c4b9bd
·
verified ·
1 Parent(s): 90c722a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -37
app.py CHANGED
@@ -4,8 +4,11 @@ import os
4
  import numpy as np
5
  import SimpleITK as sitk
6
  from scipy.ndimage import zoom
7
- from resnet_gn import resnet50
8
  import pickle
 
 
 
 
9
  #import tempfile
10
 
11
  def load_from_pkl(load_path):
@@ -14,18 +17,190 @@ def load_from_pkl(load_path):
14
  data_input.close()
15
  return read_data
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  Image_3D = None
19
  Current_name = None
20
- ALL_message = load_from_pkl(r'./label0601.pkl')
21
 
22
- Model_Paht = r'./model_epoch152.pth.tar'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  checkpoint = torch.load(Model_Paht, map_location='cpu')
24
 
25
- classnet = resnet50(
26
- num_classes=1,
27
- sample_size=128,
28
- sample_duration=8)
29
  classnet.load_state_dict(checkpoint['model_dict'])
30
 
31
 
@@ -48,34 +223,54 @@ def resize3D(img, aimsize, order=3):
48
  order=order) # resample for cube_size
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def inference():
52
  global Image_small_3D
 
53
  model = classnet
54
- data = Image_small_3D
55
-
 
56
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
  model.eval()
58
- all_loss = 0
59
- length = 0
60
  try:
 
61
  with torch.no_grad():
62
- data = torch.from_numpy(data)
63
- image = torch.unsqueeze(data, 0)
64
- patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in})
65
-
66
- # Pre : Prediction Result
67
- pre_probs = model(patch_data)
68
-
69
- # pre_probs = F.sigmoid(pre_probs)#todo
70
- pre_flat = pre_probs.view(-1)
71
- np.round(pre_flat.numpy()[0], decimals=2)
72
- # (1-pre_flat.numpy()[0]).astype(np.float32)
73
- # pre_flat.numpy()[0].astype(np.float32)
74
- # p = float(np.round(pre_flat.numpy()[0], decimals=2))
75
- # n = float(np.round(1 - p, decimals=2))
76
- p = np.round(float(pre_flat.numpy()[0]), decimals=2)
77
- n = np.round(float(1 - p), decimals=2)
78
- return {'急性期': n, '亚急性期': p}
 
 
79
  except:
80
  return ' '
81
 
@@ -231,6 +426,7 @@ def arry_crop_3D(img,mask,ex_pix):
231
  def data_pretreatment():
232
  global Image_3D
233
  global ROI_3D
 
234
  global Image_small_3D
235
  global Current_name
236
  global Input_File
@@ -238,15 +434,17 @@ def data_pretreatment():
238
  return '没有数据'
239
  else:
240
  roi = ROI_3D
241
- waikuo = [4, 4, 4]
242
- fina_img, fina_mask = arry_crop_3D(Image_3D,roi,waikuo)
243
 
244
  cut_thre = np.percentile(fina_img, 99.9) # 直方图99.9%右侧值不要
245
  fina_img[fina_img >= cut_thre] = cut_thre
246
-
247
- fina_img = resize3D(fina_img, [128,256,128], order=3)
 
248
  fina_img = (np.max(fina_img)-fina_img)/(np.max(fina_img)-np.min(fina_img))
249
  Image_small_3D = fina_img
 
250
  return '预处理结束'
251
  class App:
252
  def __init__(self):
@@ -315,16 +513,16 @@ class App:
315
 
316
  gr.Markdown('''# Examples''')
317
  gr.Examples(
318
- examples=[["./P125539_A_L_MRI.nii.gz"],
319
- ["./P415121_A_R_MRI.nii.gz"]],
320
  inputs=inp,
321
  outputs=[out1, out2, out3, slider1, slider2, slider3,out8],
322
  fn=get_Image_reslice,
323
  cache_examples=True,
324
  )
325
  gr.Examples(
326
- examples=[["./P125539_A_L_ROI.nii.gz"],
327
- ["./P415121_A_R_ROI.nii.gz"]],
328
  inputs=inp2,
329
  outputs=out9,
330
  fn=get_ROI,
@@ -334,4 +532,4 @@ class App:
334
  demo.launch(share=False)
335
 
336
 
337
- app = App()
 
4
  import numpy as np
5
  import SimpleITK as sitk
6
  from scipy.ndimage import zoom
 
7
  import pickle
8
+ from model.Vision_Transformer_with_mask import vit_base_patch16_224,Attention,CrossAttention,Attention_ori
9
+ from model.CoordAttention import *
10
+ from typing import Tuple, Type
11
+ from torch import Tensor, nn
12
  #import tempfile
13
 
14
  def load_from_pkl(load_path):
 
17
  data_input.close()
18
  return read_data
19
 
20
+ class MLP_att_out(nn.Module):
21
+
22
+ def __init__(self, input_dim, inter_dim=None, output_dim=None, activation="relu", drop=0.0):
23
+ super().__init__()
24
+ self.input_dim = input_dim
25
+ self.inter_dim = inter_dim
26
+ self.output_dim = output_dim
27
+ if inter_dim is None: self.inter_dim=input_dim
28
+ if output_dim is None: self.output_dim=input_dim
29
+
30
+ self.linear1 = nn.Linear(self.input_dim, self.inter_dim)
31
+ self.activation = self._get_activation_fn(activation)
32
+ self.dropout3 = nn.Dropout(drop)
33
+ self.linear2 = nn.Linear(self.inter_dim, self.output_dim)
34
+ self.dropout4 = nn.Dropout(drop)
35
+ self.norm3 = nn.LayerNorm(self.output_dim)
36
+
37
+ def forward(self, x):
38
+ x = self.linear2(self.dropout3(self.activation(self.linear1(x))))
39
+ x = x + self.dropout4(x)
40
+ x = self.norm3(x)
41
+ return x
42
+
43
+ def _get_activation_fn(self, activation):
44
+ """Return an activation function given a string"""
45
+ if activation == "relu":
46
+ return F.relu
47
+ if activation == "gelu":
48
+ return F.gelu
49
+ if activation == "glu":
50
+ return F.glu
51
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
52
+
53
+ class MLPBlock(nn.Module):
54
+ def __init__(
55
+ self,
56
+ embedding_dim: int,
57
+ mlp_dim: int,
58
+ act: Type[nn.Module] = nn.GELU,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
62
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
63
+ self.act = act()
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ return self.lin2(self.act(self.lin1(x)))
67
+ class FusionAttentionBlock(nn.Module):
68
+ def __init__(
69
+ self,
70
+ embedding_dim: int,
71
+ num_heads: int,
72
+ mlp_dim: int = 2048,
73
+ activation: Type[nn.Module] = nn.ReLU,
74
+ ) -> None:
75
+ """
76
+ A transformer block with four layers: (1) self-attention of sparse
77
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
78
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
79
+ inputs.
80
+
81
+ Arguments:
82
+ embedding_dim (int): the channel dimension of the embeddings
83
+ num_heads (int): the number of heads in the attention layers
84
+ mlp_dim (int): the hidden dimension of the mlp block
85
+ activation (nn.Module): the activation of the mlp block
86
+ """
87
+ super().__init__()
88
+ self.self_attn = Attention_ori(embedding_dim, num_heads)
89
+ self.norm1 = nn.LayerNorm(embedding_dim)
90
+ self.cross_attn_mask_to_image = CrossAttention(dim=embedding_dim, num_heads=num_heads)
91
+ self.norm2 = nn.LayerNorm(embedding_dim)
92
+
93
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
94
+ self.norm3 = nn.LayerNorm(embedding_dim)
95
+
96
+ self.norm4 = nn.LayerNorm(embedding_dim)
97
+ self.cross_attn_image_to_mask = CrossAttention(dim=embedding_dim, num_heads=num_heads)
98
+
99
+
100
+ def forward(self, img_emb: Tensor, mask_emb: Tensor, atten_mask: Tensor) -> Tuple[ Tensor]:
101
+ # Self attention block #最开始的时候 queries=query_pe
102
+ #queries: Tensor, keys: Tensor
103
+ queries = mask_emb
104
+ attn_out = self.self_attn(queries) #小图
105
+ queries = attn_out
106
+ #queries = queries + attn_out
107
+ queries = self.norm1(queries)
108
+
109
+ # Cross attention block, mask attending to image embedding
110
+ q = queries #1,5,256
111
+ k = img_emb # v是值,因此用keys?
112
+ input_x = torch.cat((q, k), dim=1) # 2 50 768
113
+ attn_out = self.cross_attn_mask_to_image(input_x) #TODO 要不要mask呢 交叉的时候 先不用试试
114
+ queries = queries + attn_out
115
+ queries = self.norm2(queries)
116
+
117
+ # MLP block
118
+ mlp_out = self.mlp(queries)
119
+ queries = queries + mlp_out
120
+ queries = self.norm3(queries)
121
+
122
+ # Cross attention block, image embedding attending to tokens
123
+ q = img_emb
124
+ k = queries
125
+ input_x = torch.cat((q, k), dim=1)
126
+ attn_out = self.cross_attn_image_to_mask(input_x)
127
+ img_emb = img_emb + attn_out
128
+ img_emb = self.norm4(img_emb)
129
+
130
+ return img_emb
131
+
132
+ class my_model7(nn.Module):
133
+ '''不用mask的版本
134
+ concate 部分 加了nor 加 attention
135
+ attention 用不一样的方法
136
+ '''
137
+ def __init__(self, pretrained=False,num_classes=3,in_chans=1,img_size=224, **kwargs):
138
+ super().__init__()
139
+ self.backboon1 = vit_base_patch16_224(pretrained=False,in_chans=in_chans, as_backbone=True,img_size=img_size)
140
+ if pretrained:
141
+ pre_train_model = timm.create_model('vit_base_patch16_224', pretrained=True, in_chans=in_chans, num_classes=3)
142
+ self.backboon1 = load_weights(self.backboon1, pre_train_model.state_dict())
143
+ #self.backboon2 = vit_base_patch32_224(pretrained=False,as_backbone=True) #TODO 同一个网络共享参数/不共享参数/patch不同网络
144
+ self.self_atten_img = Attention_ori(dim= self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
145
+ #self.self_atten_mask = Attention(dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
146
+ self.self_atten_mask = Attention_ori(dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
147
+ self.cross_atten = FusionAttentionBlock(embedding_dim=self.backboon1.embed_dim, num_heads=self.backboon1.num_heads)
148
+ #self.external_attention = ExternalAttention(d_model=2304,S=8)
149
+ self.mlp = MLP_att_out(input_dim=self.backboon1.embed_dim * 3, output_dim=self.backboon1.embed_dim)
150
+ self.attention = CoordAtt(1,1,1)
151
+ self.norm1 = nn.LayerNorm(self.backboon1.embed_dim)
152
+ self.norm2 = nn.LayerNorm(self.backboon1.embed_dim)
153
+ self.norm3 = nn.LayerNorm(self.backboon1.embed_dim)
154
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
155
+ self.head = nn.Linear(self.backboon1.embed_dim*3, num_classes) if num_classes > 0 else nn.Identity()
156
+ #self.head = nn.Linear(196, num_classes) if num_classes > 0 else nn.Identity()
157
+ def forward(self, img, mask):
158
+
159
+ x1 = self.backboon1(torch.cat((img, torch.zeros_like(img)), dim=1)) #TODO 是否用同一模型 还是不同 中间是否融合多尺度
160
+ x2 = self.backboon1(torch.cat((img*mask, torch.zeros_like(img)), dim=1)) #输出经过了归一化层 #小图
161
+ #自注意力+残差
162
+ x2_atten_mask = self.backboon1.atten_mask
163
+ x1_atten = self.self_atten_img(x1)
164
+ x2_atten = self.self_atten_mask(x2)
165
+ x1_out = self.norm1((x1 + x1_atten))
166
+ x2_out = self.norm2((x2 + x2_atten))
167
+ #交叉注意力
168
+ corss_out = self.norm3(self.cross_atten(x1, x2, x2_atten_mask))
169
+ #得到输出特征
170
+ out = torch.concat((x1_out, corss_out, x2_out), dim=2).permute(0, 2, 1)#12 2304 196
171
+ out = self.attention(out) #12 2304 196
172
+ #out_ = out.permute(0, 2, 1)
173
+ #out = self.mlp(out) # mlp #特征融合 2 196 768
174
+ # out = self.norm1(out) #这个好像不用 好像可以删掉
175
+ out = self.avgpool(out) # B C 1
176
+ out = torch.flatten(out, 1)
177
+ out = self.head(out)
178
+
179
+ return out
180
+
181
+
182
 
183
  Image_3D = None
184
  Current_name = None
 
185
 
186
+ ALL_message = load_from_pkl(r'.\label0601.pkl')
187
+ ALL_message2 = load_from_pkl(r'.\all_data_label.pkl')
188
+ a = ALL_message2['train']
189
+ a.update(ALL_message2['val'])
190
+ a.update(ALL_message2['test'])
191
+ ALL_message2 = a
192
+
193
+ LC_model_Paht = r'.\train_ADA_1.pkl'
194
+ LC_model = load_from_pkl(LC_model_Paht)['model'][0]
195
+
196
+ TF_model_Paht = r'.\tf_model.pkl'
197
+ TF_model = load_from_pkl(TF_model_Paht)['model']
198
+ DR_model = load_from_pkl(TF_model_Paht)['dr']
199
+
200
+ Model_Paht = r'./model_epoch120.pth.tar'
201
  checkpoint = torch.load(Model_Paht, map_location='cpu')
202
 
203
+ classnet = my_model7(pretrained=False,num_classes=3,in_chans=1, img_size=224)
 
 
 
204
  classnet.load_state_dict(checkpoint['model_dict'])
205
 
206
 
 
223
  order=order) # resample for cube_size
224
 
225
 
226
+ def get_lc():
227
+ global Current_name
228
+ lc_min = np.array([17,1,0,1,1,1,1,1 , 1 , 1])
229
+ lc_max = np.array([96 ,2, 3 ,2, 2,2 , 2 ,2 ,2 ,4])
230
+ lc_key = ['age', 'sex', 'time', 'postpartum', 'traumatism', 'diabetes', 'high_blood_pressure', 'cerebral_infarction', 'postoperation']
231
+
232
+ lc_all = [ALL_message2[Current_name][ii] for ii in lc_key]
233
+ site_ = Current_name.split('_',1)[-1]
234
+ if site_ == 'A_L': lc_all.append(1)
235
+ elif site_ == 'A_R': lc_all.append(2)
236
+ elif site_ == 'B_L': lc_all.append(3)
237
+ elif site_ == 'B_R': lc_all.append(4)
238
+ else: pass
239
+ lc_all = (np.array(lc_all)-lc_min)/(lc_max-lc_min+ 1e-12)
240
+ a = 5
241
+ return lc_all
242
  def inference():
243
  global Image_small_3D
244
+ global ROI_small_3D
245
  model = classnet
246
+ data_3d = Image_small_3D
247
+ lc_data = get_lc()
248
+ lc_data = np.expand_dims(lc_data, axis=0)
249
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
250
  model.eval()
251
+
 
252
  try:
253
+ #影像模型
254
  with torch.no_grad():
255
+ all_probs = np.empty((0, 3))
256
+ for ii in tqdm(range(0, data_3d.shape[1]),total = data_3d.shape[1]):
257
+ data = torch.from_numpy(data_3d[:,ii,:])
258
+ roi = torch.from_numpy(ROI_small_3D[:,ii,:].astype(np.int8))
259
+ image = torch.unsqueeze(data, 0)
260
+ roi = torch.unsqueeze(torch.unsqueeze(roi, 0),0).to(device).float()
261
+ patch_data = torch.unsqueeze(image, 0).to(device).float() # (N, C_{in}, D_{in}, H_{in}, W_{in})
262
+
263
+ # Pre : Prediction Result
264
+ pre_probs = model(patch_data,roi)
265
+ pre_probs = torch.nn.functional.softmax(pre_probs, dim=1)
266
+ all_probs = np.concatenate((all_probs, pre_probs.cpu().numpy()), axis=0)
267
+ dl_prob = np.mean(all_probs, axis=0)
268
+ dl_prob = np.expand_dims(dl_prob, axis=0)
269
+ lc_prob = LC_model.predict_proba(lc_data)
270
+ feature = DR_model.transform(np.concatenate([dl_prob, lc_prob], axis=1))
271
+ final_p = TF_model.predict_proba(feature)
272
+ final_p = np.round(final_p[0], decimals=2)
273
+ return {'急性期': final_p[0], '亚急性期': final_p[1], '慢性期': final_p[2]}
274
  except:
275
  return ' '
276
 
 
426
  def data_pretreatment():
427
  global Image_3D
428
  global ROI_3D
429
+ global ROI_small_3D
430
  global Image_small_3D
431
  global Current_name
432
  global Input_File
 
434
  return '没有数据'
435
  else:
436
  roi = ROI_3D
437
+ # waikuo = [4, 4, 4]
438
+ # fina_img, fina_mask = arry_crop_3D(Image_3D,roi,waikuo)
439
 
440
  cut_thre = np.percentile(fina_img, 99.9) # 直方图99.9%右侧值不要
441
  fina_img[fina_img >= cut_thre] = cut_thre
442
+ z, y, x = fina_img.shape
443
+ fina_img = resize3D(fina_img, [224,y,224], order=3)
444
+ fina_roi = resize3D(roi, [224, y, 224], order=3)
445
  fina_img = (np.max(fina_img)-fina_img)/(np.max(fina_img)-np.min(fina_img))
446
  Image_small_3D = fina_img
447
+ ROI_small_3D = fina_roi
448
  return '预处理结束'
449
  class App:
450
  def __init__(self):
 
513
 
514
  gr.Markdown('''# Examples''')
515
  gr.Examples(
516
+ examples=[["./2239561_B_R_MRI.nii.gz"],
517
+ ["./2239561_B_R_MRI.nii.gz"]],
518
  inputs=inp,
519
  outputs=[out1, out2, out3, slider1, slider2, slider3,out8],
520
  fn=get_Image_reslice,
521
  cache_examples=True,
522
  )
523
  gr.Examples(
524
+ examples=[["./2239561_B_R_ROI.nii.gz"],
525
+ ["./2239561_B_R_ROI.nii.gz"]],
526
  inputs=inp2,
527
  outputs=out9,
528
  fn=get_ROI,
 
532
  demo.launch(share=False)
533
 
534
 
535
+ app = App()