import glob import gradio as gr import numpy as np import torch from torch.utils.data import DataLoader from torchvision import transforms import imageio.v3 as iio from UNET_perso import UNET import matplotlib.pyplot as plt from src.medicalDataLoader import MedicalImageDataset def disp_prediction(num, alpha=0.4): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] model = UNET(in_channels=1, out_channels=4).to('cpu') filepath = 'UNET_perso' model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.ToTensor()]) val_set = MedicalImageDataset('val', './Data', transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=1, shuffle=True) for _, (img, label, name) in enumerate(val_loader): if name[0] == patient: im = img pred = model(im) pred = pred.detach().numpy() pred = pred.reshape(4,256,256) pred_mask = np.argmax(pred, axis=0) fig = plt.figure() im = im.detach().numpy() im = im.reshape(256,256) plt.imshow(alpha*pred_mask + (1-alpha)*im) plt.axis('off') return fig def disp_init(num): lst_patients = glob.glob('./Data/val/Img/*.png') patient = lst_patients[num] print(patient) im = iio.imread(patient) fig = plt.figure() plt.imshow(im) plt.axis('off') return fig with gr.Blocks() as demo: with gr.Row() as row1: slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)') with gr.Row() as row2: slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image')) slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction')) with gr.Row() as row3: gr.DataFrame() demo.launch()