Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,35 @@
|
|
1 |
import glob
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
from torch.utils.data import DataLoader
|
6 |
from torchvision import transforms
|
7 |
-
|
8 |
|
9 |
from UNET_perso import UNET
|
10 |
-
import matplotlib.pyplot as plt
|
11 |
from src.medicalDataLoader import MedicalImageDataset
|
|
|
12 |
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
17 |
patient = lst_patients[num]
|
18 |
|
@@ -28,12 +45,11 @@ def disp_prediction(num, alpha=0.4):
|
|
28 |
transform=transform,
|
29 |
mask_transform=transform,
|
30 |
equalize=False)
|
31 |
-
|
32 |
val_loader = DataLoader(val_set,
|
33 |
batch_size=1,
|
34 |
shuffle=True)
|
35 |
|
36 |
-
for _, (img,
|
37 |
if name[0] == patient:
|
38 |
im = img
|
39 |
|
@@ -43,31 +59,100 @@ def disp_prediction(num, alpha=0.4):
|
|
43 |
pred = pred.reshape(4,256,256)
|
44 |
pred_mask = np.argmax(pred, axis=0)
|
45 |
|
46 |
-
fig = plt.figure()
|
47 |
im = im.detach().numpy()
|
48 |
im = im.reshape(256,256)
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
plt.axis('off')
|
51 |
|
52 |
return fig
|
53 |
|
54 |
-
def
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
plt.
|
|
|
|
|
|
|
|
|
61 |
plt.axis('off')
|
|
|
62 |
return fig
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
with gr.Blocks() as demo:
|
65 |
with gr.Row() as row1:
|
66 |
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
|
67 |
with gr.Row() as row2:
|
68 |
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
|
69 |
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
|
|
|
70 |
with gr.Row() as row3:
|
71 |
-
gr.DataFrame()
|
72 |
|
|
|
73 |
demo.launch()
|
|
|
1 |
import glob
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import imageio.v3 as iio
|
7 |
+
import einops
|
8 |
+
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
from torchvision import transforms
|
12 |
+
from torchmetrics.functional import dice, dice_score
|
13 |
|
14 |
from UNET_perso import UNET
|
|
|
15 |
from src.medicalDataLoader import MedicalImageDataset
|
16 |
+
from src.utils import getTargetSegmentation
|
17 |
|
18 |
+
import sys
|
19 |
+
from IPython.core.ultratb import ColorTB
|
20 |
|
21 |
+
sys.excepthook = ColorTB()
|
22 |
+
|
23 |
+
def disp_init(num):
|
24 |
+
lst_patients = glob.glob('./Data/val/Img/*.png')
|
25 |
+
patient = lst_patients[num]
|
26 |
+
im = iio.imread(patient)
|
27 |
+
fig = plt.figure()
|
28 |
+
plt.imshow(im)
|
29 |
+
plt.axis('off')
|
30 |
+
return fig
|
31 |
|
32 |
+
def disp_prediction(num, alpha=0.4):
|
33 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
34 |
patient = lst_patients[num]
|
35 |
|
|
|
45 |
transform=transform,
|
46 |
mask_transform=transform,
|
47 |
equalize=False)
|
|
|
48 |
val_loader = DataLoader(val_set,
|
49 |
batch_size=1,
|
50 |
shuffle=True)
|
51 |
|
52 |
+
for _, (img, _, name) in enumerate(val_loader):
|
53 |
if name[0] == patient:
|
54 |
im = img
|
55 |
|
|
|
59 |
pred = pred.reshape(4,256,256)
|
60 |
pred_mask = np.argmax(pred, axis=0)
|
61 |
|
|
|
62 |
im = im.detach().numpy()
|
63 |
im = im.reshape(256,256)
|
64 |
+
|
65 |
+
fig = plt.figure()
|
66 |
+
try:
|
67 |
+
plt.imshow(alpha*pred_mask + (1-alpha)*im)
|
68 |
+
except ValueError:
|
69 |
+
print(type(pred_mask))
|
70 |
plt.axis('off')
|
71 |
|
72 |
return fig
|
73 |
|
74 |
+
def disp_label(num, alpha=0.4):
|
75 |
+
lst_GT = glob.glob('./Data/val/GT/*.png')
|
76 |
+
GT = lst_GT[num]
|
77 |
+
|
78 |
+
label = iio.imread(GT)
|
79 |
+
|
80 |
+
fig = plt.figure()
|
81 |
+
try :
|
82 |
+
plt.imshow(label)
|
83 |
+
except ValueError:
|
84 |
+
print(type(label))
|
85 |
plt.axis('off')
|
86 |
+
|
87 |
return fig
|
88 |
|
89 |
+
|
90 |
+
def compute_metrics(num):
|
91 |
+
|
92 |
+
lst_patients = glob.glob('./Data/val/Img/*.png')
|
93 |
+
patient = lst_patients[num]
|
94 |
+
|
95 |
+
model = UNET(in_channels=1, out_channels=4).to('cpu')
|
96 |
+
|
97 |
+
filepath = 'UNET_perso'
|
98 |
+
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
|
99 |
+
model.eval()
|
100 |
+
|
101 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
102 |
+
val_set = MedicalImageDataset('val',
|
103 |
+
'./Data',
|
104 |
+
transform=transform,
|
105 |
+
mask_transform=transform,
|
106 |
+
equalize=False)
|
107 |
+
|
108 |
+
val_loader = DataLoader(val_set,
|
109 |
+
batch_size=1,
|
110 |
+
shuffle=True)
|
111 |
+
|
112 |
+
for _, (img, label, name) in enumerate(val_loader):
|
113 |
+
if name[0] == patient:
|
114 |
+
im = img
|
115 |
+
label = label.reshape(256,256)
|
116 |
+
|
117 |
+
pred = model(im)
|
118 |
+
|
119 |
+
pred = pred.detach().numpy()
|
120 |
+
pred = pred.reshape(4,256,256)
|
121 |
+
pred = np.argmax(pred, axis=0)
|
122 |
+
pred = torch.from_numpy(pred)
|
123 |
+
pred = torch.nn.functional.one_hot(pred)
|
124 |
+
pred = einops.rearrange(pred, 'h w class -> class h w')
|
125 |
+
|
126 |
+
label = getTargetSegmentation(label)
|
127 |
+
label = torch.nn.functional.one_hot(label, num_classes=4)
|
128 |
+
label = einops.rearrange(label, 'h w class -> class h w')
|
129 |
+
|
130 |
+
print(pred.shape, label.shape)
|
131 |
+
#dsc = dice(pred, label, average='none', num_classes=4)
|
132 |
+
|
133 |
+
print(torch.mean(pred.float()), torch.max(pred.float()))
|
134 |
+
print(torch.mean(label.float()), torch.max(label.float()))
|
135 |
+
|
136 |
+
#print(dsc)
|
137 |
+
dsc1 = dice_score(pred[1], label[1])
|
138 |
+
dsc2 = dice_score(pred[2], label[2])
|
139 |
+
dsc3 = dice_score(pred[3], label[3])
|
140 |
+
print(dsc1, dsc2, dsc3)
|
141 |
+
|
142 |
+
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
|
143 |
+
|
144 |
+
return df
|
145 |
+
|
146 |
+
|
147 |
with gr.Blocks() as demo:
|
148 |
with gr.Row() as row1:
|
149 |
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
|
150 |
with gr.Row() as row2:
|
151 |
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
|
152 |
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
|
153 |
+
slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
|
154 |
with gr.Row() as row3:
|
155 |
+
slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())
|
156 |
|
157 |
+
demo.queue()
|
158 |
demo.launch()
|