thov's picture
implementation perso of dice and iou
32278ee
import glob
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
import einops
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from UNET_perso import UNET
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation, IOU, dice_score
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
def disp_prediction(num, alpha=0.45):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, _, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
im = im.detach().numpy()
im = im.reshape(256,256)
fig = plt.figure()
try:
plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
except ValueError:
pass
plt.axis('off')
return fig
def disp_label(num):
lst_GT = glob.glob('./Data/val/GT/*.png')
GT = lst_GT[num]
label = iio.imread(GT)
fig = plt.figure()
try:
plt.imshow(label)
except ValueError:
pass
plt.axis('off')
return fig
def compute_metrics(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
lab = label.reshape(256,256)
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred = np.argmax(pred, axis=0)
pred = torch.from_numpy(pred)
pred = torch.nn.functional.one_hot(pred, num_classes=4)
pred = einops.rearrange(pred, 'h w class -> class h w')
lab = getTargetSegmentation(lab)
lab = torch.nn.functional.one_hot(lab, num_classes=4)
lab = einops.rearrange(lab, 'h w class -> class h w')
dsc1 = dice_score(pred[1], lab[1])
dsc2 = dice_score(pred[2], lab[2])
dsc3 = dice_score(pred[3], lab[3])
iou1 = IOU(pred[1], lab[1])
iou2 = IOU(pred[2], lab[2])
iou3 = IOU(pred[3], lab[3])
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)]
df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)]
df = df.assign(metric=['Dice Score', 'IoU'])
df = df[['metric','class 1','class 2','class 3']]
return df
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
with gr.Row() as row3:
slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())
demo.queue()
demo.launch()