import glob import gradio as gr import numpy as np import pandas as pd import matplotlib.pyplot as plt import imageio.v3 as iio import einops import torch from torch.utils.data import DataLoader from torchvision import transforms from torchmetrics.functional import dice_score, jaccard_index from UNET_perso import UNET from src.medicalDataLoader import MedicalImageDataset from src.utils import getTargetSegmentation def disp_init(num): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] im = iio.imread(patient) fig = plt.figure() plt.imshow(im) plt.axis('off') return fig def disp_prediction(num, alpha=0.45): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] model = UNET(in_channels=1, out_channels=4).to('cpu') filepath = 'UNET_perso' model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.ToTensor()]) val_set = MedicalImageDataset('val', './Data', transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=1, shuffle=False) for _, (img, _, name) in enumerate(val_loader): if name[0] == patient: im = img pred = model(im) pred = pred.detach().numpy() pred = pred.reshape(4,256,256) pred_mask = np.argmax(pred, axis=0) im = im.detach().numpy() im = im.reshape(256,256) fig = plt.figure() try: plt.imshow((alpha*pred_mask + (1-alpha)*im)/2) except ValueError: print(type(pred_mask)) plt.axis('off') return fig def disp_label(num, alpha=0.4): lst_GT = glob.glob('./Data/val/GT/*.png') GT = lst_GT[num] label = iio.imread(GT) fig = plt.figure() try : plt.imshow(label) except ValueError: print(type(label)) plt.axis('off') return fig def compute_metrics(num): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] model = UNET(in_channels=1, out_channels=4).to('cpu') filepath = 'UNET_perso' model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.ToTensor()]) val_set = MedicalImageDataset('val', './Data', transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=1, shuffle=False) for _, (img, label, name) in enumerate(val_loader): if name[0] == patient: im = img lab = label.reshape(256,256) pred = model(im) pred = pred.detach().numpy() pred = pred.reshape(4,256,256) pred = np.argmax(pred, axis=0) pred = torch.from_numpy(pred) pred = torch.nn.functional.one_hot(pred, num_classes=4) pred = einops.rearrange(pred, 'h w class -> class h w') lab = getTargetSegmentation(lab) lab = torch.nn.functional.one_hot(lab, num_classes=4) lab = einops.rearrange(lab, 'h w class -> class h w') dsc1 = dice_score(pred[1], lab[1]) dsc2 = dice_score(pred[2], lab[2]) dsc3 = dice_score(pred[3], lab[3]) iou1 = jaccard_index(pred[1], lab[1], num_classes=2) iou2 = jaccard_index(pred[2], lab[2], num_classes=2) iou3 = jaccard_index(pred[3], lab[3], num_classes=2) df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3']) df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)] df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)] df = df.assign(metric=['Dice Score', 'IoU']) df = df[['metric','class 1','class 2','class 3']] return df with gr.Blocks() as demo: with gr.Row() as row1: slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)') with gr.Row() as row2: slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image')) slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction')) slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth')) with gr.Row() as row3: slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame()) demo.queue() demo.launch()