Spaces:
Sleeping
Sleeping
import glob | |
import gradio as gr | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
import imageio.v3 as iio | |
from UNET_perso import UNET | |
import matplotlib.pyplot as plt | |
from src.medicalDataLoader import MedicalImageDataset | |
def disp_prediction(num, alpha=0.4): | |
lst_patients = glob.glob('./Data/val/Img/*.png') | |
patient = lst_patients[num] | |
model = UNET(in_channels=1, out_channels=4).to('cpu') | |
filepath = 'UNET_perso' | |
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) | |
model.eval() | |
transform = transforms.Compose([transforms.ToTensor()]) | |
val_set = MedicalImageDataset('val', | |
'./Data', | |
transform=transform, | |
mask_transform=transform, | |
equalize=False) | |
val_loader = DataLoader(val_set, | |
batch_size=1, | |
shuffle=True) | |
for _, (img, label, name) in enumerate(val_loader): | |
if name[0] == patient: | |
im = img | |
pred = model(im) | |
pred = pred.detach().numpy() | |
pred = pred.reshape(4,256,256) | |
pred_mask = np.argmax(pred, axis=0) | |
fig = plt.figure() | |
im = im.detach().numpy() | |
im = im.reshape(256,256) | |
plt.imshow(alpha*pred_mask + (1-alpha)*im) | |
plt.axis('off') | |
return fig | |
def disp_init(num): | |
lst_patients = glob.glob('./Data/val/Img/*.png') | |
patient = lst_patients[num] | |
print(patient) | |
im = iio.imread(patient) | |
fig = plt.figure() | |
plt.imshow(im) | |
plt.axis('off') | |
return fig | |
with gr.Blocks() as demo: | |
with gr.Row() as row1: | |
slide = gr.Slider(min=0, max=89, value=10, step=1) | |
with gr.Row() as row2: | |
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image')) | |
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction')) | |
with gr.Row() as row3: | |
gr.DataFrame() | |
demo.launch() |