thov commited on
Commit
5424c06
1 Parent(s): 9c477fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -15
app.py CHANGED
@@ -1,18 +1,35 @@
1
  import glob
2
  import gradio as gr
3
  import numpy as np
 
 
 
 
 
4
  import torch
5
  from torch.utils.data import DataLoader
6
  from torchvision import transforms
7
- import imageio.v3 as iio
8
 
9
  from UNET_perso import UNET
10
- import matplotlib.pyplot as plt
11
  from src.medicalDataLoader import MedicalImageDataset
 
12
 
 
 
13
 
14
- def disp_prediction(num, alpha=0.4):
 
 
 
 
 
 
 
 
 
15
 
 
16
  lst_patients = glob.glob('./Data/val/Img/*.png')
17
  patient = lst_patients[num]
18
 
@@ -28,12 +45,11 @@ def disp_prediction(num, alpha=0.4):
28
  transform=transform,
29
  mask_transform=transform,
30
  equalize=False)
31
-
32
  val_loader = DataLoader(val_set,
33
  batch_size=1,
34
  shuffle=True)
35
 
36
- for _, (img, label, name) in enumerate(val_loader):
37
  if name[0] == patient:
38
  im = img
39
 
@@ -43,31 +59,100 @@ def disp_prediction(num, alpha=0.4):
43
  pred = pred.reshape(4,256,256)
44
  pred_mask = np.argmax(pred, axis=0)
45
 
46
- fig = plt.figure()
47
  im = im.detach().numpy()
48
  im = im.reshape(256,256)
49
- plt.imshow(alpha*pred_mask + (1-alpha)*im)
 
 
 
 
 
50
  plt.axis('off')
51
 
52
  return fig
53
 
54
- def disp_init(num):
55
- lst_patients = glob.glob('./Data/val/Img/*.png')
56
- patient = lst_patients[num]
57
- print(patient)
58
- im = iio.imread(patient)
59
- fig = plt.figure()
60
- plt.imshow(im)
 
 
 
 
61
  plt.axis('off')
 
62
  return fig
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with gr.Blocks() as demo:
65
  with gr.Row() as row1:
66
  slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
67
  with gr.Row() as row2:
68
  slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
69
  slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
 
70
  with gr.Row() as row3:
71
- gr.DataFrame()
72
 
 
73
  demo.launch()
 
1
  import glob
2
  import gradio as gr
3
  import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import imageio.v3 as iio
7
+ import einops
8
+
9
  import torch
10
  from torch.utils.data import DataLoader
11
  from torchvision import transforms
12
+ from torchmetrics.functional import dice, dice_score
13
 
14
  from UNET_perso import UNET
 
15
  from src.medicalDataLoader import MedicalImageDataset
16
+ from src.utils import getTargetSegmentation
17
 
18
+ import sys
19
+ from IPython.core.ultratb import ColorTB
20
 
21
+ sys.excepthook = ColorTB()
22
+
23
+ def disp_init(num):
24
+ lst_patients = glob.glob('./Data/val/Img/*.png')
25
+ patient = lst_patients[num]
26
+ im = iio.imread(patient)
27
+ fig = plt.figure()
28
+ plt.imshow(im)
29
+ plt.axis('off')
30
+ return fig
31
 
32
+ def disp_prediction(num, alpha=0.4):
33
  lst_patients = glob.glob('./Data/val/Img/*.png')
34
  patient = lst_patients[num]
35
 
 
45
  transform=transform,
46
  mask_transform=transform,
47
  equalize=False)
 
48
  val_loader = DataLoader(val_set,
49
  batch_size=1,
50
  shuffle=True)
51
 
52
+ for _, (img, _, name) in enumerate(val_loader):
53
  if name[0] == patient:
54
  im = img
55
 
 
59
  pred = pred.reshape(4,256,256)
60
  pred_mask = np.argmax(pred, axis=0)
61
 
 
62
  im = im.detach().numpy()
63
  im = im.reshape(256,256)
64
+
65
+ fig = plt.figure()
66
+ try:
67
+ plt.imshow(alpha*pred_mask + (1-alpha)*im)
68
+ except ValueError:
69
+ print(type(pred_mask))
70
  plt.axis('off')
71
 
72
  return fig
73
 
74
+ def disp_label(num, alpha=0.4):
75
+ lst_GT = glob.glob('./Data/val/GT/*.png')
76
+ GT = lst_GT[num]
77
+
78
+ label = iio.imread(GT)
79
+
80
+ fig = plt.figure()
81
+ try :
82
+ plt.imshow(label)
83
+ except ValueError:
84
+ print(type(label))
85
  plt.axis('off')
86
+
87
  return fig
88
 
89
+
90
+ def compute_metrics(num):
91
+
92
+ lst_patients = glob.glob('./Data/val/Img/*.png')
93
+ patient = lst_patients[num]
94
+
95
+ model = UNET(in_channels=1, out_channels=4).to('cpu')
96
+
97
+ filepath = 'UNET_perso'
98
+ model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
99
+ model.eval()
100
+
101
+ transform = transforms.Compose([transforms.ToTensor()])
102
+ val_set = MedicalImageDataset('val',
103
+ './Data',
104
+ transform=transform,
105
+ mask_transform=transform,
106
+ equalize=False)
107
+
108
+ val_loader = DataLoader(val_set,
109
+ batch_size=1,
110
+ shuffle=True)
111
+
112
+ for _, (img, label, name) in enumerate(val_loader):
113
+ if name[0] == patient:
114
+ im = img
115
+ label = label.reshape(256,256)
116
+
117
+ pred = model(im)
118
+
119
+ pred = pred.detach().numpy()
120
+ pred = pred.reshape(4,256,256)
121
+ pred = np.argmax(pred, axis=0)
122
+ pred = torch.from_numpy(pred)
123
+ pred = torch.nn.functional.one_hot(pred)
124
+ pred = einops.rearrange(pred, 'h w class -> class h w')
125
+
126
+ label = getTargetSegmentation(label)
127
+ label = torch.nn.functional.one_hot(label, num_classes=4)
128
+ label = einops.rearrange(label, 'h w class -> class h w')
129
+
130
+ print(pred.shape, label.shape)
131
+ #dsc = dice(pred, label, average='none', num_classes=4)
132
+
133
+ print(torch.mean(pred.float()), torch.max(pred.float()))
134
+ print(torch.mean(label.float()), torch.max(label.float()))
135
+
136
+ #print(dsc)
137
+ dsc1 = dice_score(pred[1], label[1])
138
+ dsc2 = dice_score(pred[2], label[2])
139
+ dsc3 = dice_score(pred[3], label[3])
140
+ print(dsc1, dsc2, dsc3)
141
+
142
+ df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
143
+
144
+ return df
145
+
146
+
147
  with gr.Blocks() as demo:
148
  with gr.Row() as row1:
149
  slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
150
  with gr.Row() as row2:
151
  slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
152
  slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
153
+ slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
154
  with gr.Row() as row3:
155
+ slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())
156
 
157
+ demo.queue()
158
  demo.launch()