thov's picture
add IOU and Dice Score
279daf2
raw history blame
No virus
4.73 kB
import glob
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
import einops
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchmetrics.functional import dice_score, jaccard_index
from UNET_perso import UNET
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
def disp_prediction(num, alpha=0.45):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, _, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
im = im.detach().numpy()
im = im.reshape(256,256)
fig = plt.figure()
try:
plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
except ValueError:
print(type(pred_mask))
plt.axis('off')
return fig
def disp_label(num, alpha=0.4):
lst_GT = glob.glob('./Data/val/GT/*.png')
GT = lst_GT[num]
label = iio.imread(GT)
fig = plt.figure()
try :
plt.imshow(label)
except ValueError:
print(type(label))
plt.axis('off')
return fig
def compute_metrics(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
lab = label.reshape(256,256)
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred = np.argmax(pred, axis=0)
pred = torch.from_numpy(pred)
pred = torch.nn.functional.one_hot(pred, num_classes=4)
pred = einops.rearrange(pred, 'h w class -> class h w')
lab = getTargetSegmentation(lab)
lab = torch.nn.functional.one_hot(lab, num_classes=4)
lab = einops.rearrange(lab, 'h w class -> class h w')
dsc1 = dice_score(pred[1], lab[1])
dsc2 = dice_score(pred[2], lab[2])
dsc3 = dice_score(pred[3], lab[3])
iou1 = jaccard_index(pred[1], lab[1], num_classes=2)
iou2 = jaccard_index(pred[2], lab[2], num_classes=2)
iou3 = jaccard_index(pred[3], lab[3], num_classes=2)
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)]
df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)]
df = df.assign(metric=['Dice Score', 'IoU'])
df = df[['metric','class 1','class 2','class 3']]
return df
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
with gr.Row() as row3:
slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())
demo.queue()
demo.launch()