Spaces:
Running
Running
add IOU and Dice Score
Browse files
app.py
CHANGED
@@ -9,13 +9,12 @@ import einops
|
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
from torchvision import transforms
|
12 |
-
from torchmetrics.functional import
|
13 |
|
14 |
from UNET_perso import UNET
|
15 |
from src.medicalDataLoader import MedicalImageDataset
|
16 |
from src.utils import getTargetSegmentation
|
17 |
|
18 |
-
|
19 |
def disp_init(num):
|
20 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
21 |
patient = lst_patients[num]
|
@@ -25,7 +24,8 @@ def disp_init(num):
|
|
25 |
plt.axis('off')
|
26 |
return fig
|
27 |
|
28 |
-
def disp_prediction(num, alpha=0.
|
|
|
29 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
30 |
patient = lst_patients[num]
|
31 |
|
@@ -43,7 +43,7 @@ def disp_prediction(num, alpha=0.4):
|
|
43 |
equalize=False)
|
44 |
val_loader = DataLoader(val_set,
|
45 |
batch_size=1,
|
46 |
-
shuffle=
|
47 |
|
48 |
for _, (img, _, name) in enumerate(val_loader):
|
49 |
if name[0] == patient:
|
@@ -60,7 +60,7 @@ def disp_prediction(num, alpha=0.4):
|
|
60 |
|
61 |
fig = plt.figure()
|
62 |
try:
|
63 |
-
plt.imshow(alpha*pred_mask + (1-alpha)*im)
|
64 |
except ValueError:
|
65 |
print(type(pred_mask))
|
66 |
plt.axis('off')
|
@@ -100,15 +100,14 @@ def compute_metrics(num):
|
|
100 |
transform=transform,
|
101 |
mask_transform=transform,
|
102 |
equalize=False)
|
103 |
-
|
104 |
val_loader = DataLoader(val_set,
|
105 |
batch_size=1,
|
106 |
-
shuffle=
|
107 |
|
108 |
for _, (img, label, name) in enumerate(val_loader):
|
109 |
if name[0] == patient:
|
110 |
im = img
|
111 |
-
|
112 |
|
113 |
pred = model(im)
|
114 |
|
@@ -116,30 +115,29 @@ def compute_metrics(num):
|
|
116 |
pred = pred.reshape(4,256,256)
|
117 |
pred = np.argmax(pred, axis=0)
|
118 |
pred = torch.from_numpy(pred)
|
119 |
-
pred = torch.nn.functional.one_hot(pred)
|
120 |
pred = einops.rearrange(pred, 'h w class -> class h w')
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
print(torch.mean(pred.float()), torch.max(pred.float()))
|
130 |
-
print(torch.mean(label.float()), torch.max(label.float()))
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
#dsc3 = dice_score(pred[3], label[3])
|
136 |
-
#print(dsc1, dsc2, dsc3)
|
137 |
|
138 |
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
return df
|
141 |
|
142 |
-
|
143 |
with gr.Blocks() as demo:
|
144 |
with gr.Row() as row1:
|
145 |
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
|
|
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
from torchvision import transforms
|
12 |
+
from torchmetrics.functional import dice_score, jaccard_index
|
13 |
|
14 |
from UNET_perso import UNET
|
15 |
from src.medicalDataLoader import MedicalImageDataset
|
16 |
from src.utils import getTargetSegmentation
|
17 |
|
|
|
18 |
def disp_init(num):
|
19 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
20 |
patient = lst_patients[num]
|
|
|
24 |
plt.axis('off')
|
25 |
return fig
|
26 |
|
27 |
+
def disp_prediction(num, alpha=0.45):
|
28 |
+
|
29 |
lst_patients = glob.glob('./Data/val/Img/*.png')
|
30 |
patient = lst_patients[num]
|
31 |
|
|
|
43 |
equalize=False)
|
44 |
val_loader = DataLoader(val_set,
|
45 |
batch_size=1,
|
46 |
+
shuffle=False)
|
47 |
|
48 |
for _, (img, _, name) in enumerate(val_loader):
|
49 |
if name[0] == patient:
|
|
|
60 |
|
61 |
fig = plt.figure()
|
62 |
try:
|
63 |
+
plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
|
64 |
except ValueError:
|
65 |
print(type(pred_mask))
|
66 |
plt.axis('off')
|
|
|
100 |
transform=transform,
|
101 |
mask_transform=transform,
|
102 |
equalize=False)
|
|
|
103 |
val_loader = DataLoader(val_set,
|
104 |
batch_size=1,
|
105 |
+
shuffle=False)
|
106 |
|
107 |
for _, (img, label, name) in enumerate(val_loader):
|
108 |
if name[0] == patient:
|
109 |
im = img
|
110 |
+
lab = label.reshape(256,256)
|
111 |
|
112 |
pred = model(im)
|
113 |
|
|
|
115 |
pred = pred.reshape(4,256,256)
|
116 |
pred = np.argmax(pred, axis=0)
|
117 |
pred = torch.from_numpy(pred)
|
118 |
+
pred = torch.nn.functional.one_hot(pred, num_classes=4)
|
119 |
pred = einops.rearrange(pred, 'h w class -> class h w')
|
120 |
|
121 |
+
lab = getTargetSegmentation(lab)
|
122 |
+
lab = torch.nn.functional.one_hot(lab, num_classes=4)
|
123 |
+
lab = einops.rearrange(lab, 'h w class -> class h w')
|
124 |
+
|
125 |
+
dsc1 = dice_score(pred[1], lab[1])
|
126 |
+
dsc2 = dice_score(pred[2], lab[2])
|
127 |
+
dsc3 = dice_score(pred[3], lab[3])
|
|
|
|
|
128 |
|
129 |
+
iou1 = jaccard_index(pred[1], lab[1], num_classes=2)
|
130 |
+
iou2 = jaccard_index(pred[2], lab[2], num_classes=2)
|
131 |
+
iou3 = jaccard_index(pred[3], lab[3], num_classes=2)
|
|
|
|
|
132 |
|
133 |
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
|
134 |
+
df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)]
|
135 |
+
df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)]
|
136 |
+
df = df.assign(metric=['Dice Score', 'IoU'])
|
137 |
+
df = df[['metric','class 1','class 2','class 3']]
|
138 |
+
|
139 |
return df
|
140 |
|
|
|
141 |
with gr.Blocks() as demo:
|
142 |
with gr.Row() as row1:
|
143 |
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
|