import glob import gradio as gr import numpy as np import pandas as pd import matplotlib.pyplot as plt import imageio.v3 as iio import einops import torch from torch.utils.data import DataLoader from torchvision import transforms from UNET_perso import UNET from src.medicalDataLoader import MedicalImageDataset from src.utils import getTargetSegmentation, IOU, dice_score def disp_init(num): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] im = iio.imread(patient) fig = plt.figure() plt.imshow(im) plt.axis('off') return fig def disp_prediction(num, alpha=0.45): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] model = UNET(in_channels=1, out_channels=4).to('cpu') filepath = 'UNET_perso' model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.ToTensor()]) val_set = MedicalImageDataset('val', './Data', transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=1, shuffle=False) for _, (img, _, name) in enumerate(val_loader): if name[0] == patient: im = img pred = model(im) pred = pred.detach().numpy() pred = pred.reshape(4,256,256) pred_mask = np.argmax(pred, axis=0) im = im.detach().numpy() im = im.reshape(256,256) fig = plt.figure() try: plt.imshow((alpha*pred_mask + (1-alpha)*im)/2) except ValueError: pass plt.axis('off') return fig def disp_label(num): lst_GT = glob.glob('./Data/val/GT/*.png') GT = lst_GT[num] label = iio.imread(GT) fig = plt.figure() try: plt.imshow(label) except ValueError: pass plt.axis('off') return fig def compute_metrics(num): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] model = UNET(in_channels=1, out_channels=4).to('cpu') filepath = 'UNET_perso' model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.ToTensor()]) val_set = MedicalImageDataset('val', './Data', transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=1, shuffle=False) for _, (img, label, name) in enumerate(val_loader): if name[0] == patient: im = img lab = label.reshape(256,256) pred = model(im) pred = pred.detach().numpy() pred = pred.reshape(4,256,256) pred = np.argmax(pred, axis=0) pred = torch.from_numpy(pred) pred = torch.nn.functional.one_hot(pred, num_classes=4) pred = einops.rearrange(pred, 'h w class -> class h w') lab = getTargetSegmentation(lab) lab = torch.nn.functional.one_hot(lab, num_classes=4) lab = einops.rearrange(lab, 'h w class -> class h w') dsc1 = dice_score(pred[1], lab[1]) dsc2 = dice_score(pred[2], lab[2]) dsc3 = dice_score(pred[3], lab[3]) iou1 = IOU(pred[1], lab[1]) iou2 = IOU(pred[2], lab[2]) iou3 = IOU(pred[3], lab[3]) df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3']) df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)] df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)] df = df.assign(metric=['Dice Score', 'IoU']) df = df[['metric','class 1','class 2','class 3']] return df with gr.Blocks() as demo: with gr.Row() as row1: slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)') with gr.Row() as row2: slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image')) slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction')) slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth')) with gr.Row() as row3: slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame()) demo.queue() demo.launch()