thov's picture
Update app.py
8b1271d
raw
history blame
No virus
2.13 kB
import glob
import gradio as gr
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import imageio.v3 as iio
from UNET_perso import UNET
import matplotlib.pyplot as plt
from src.medicalDataLoader import MedicalImageDataset
def disp_prediction(num, alpha=0.4):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=True)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
fig = plt.figure()
im = im.detach().numpy()
im = im.reshape(256,256)
plt.imshow(alpha*pred_mask + (1-alpha)*im)
plt.axis('off')
return fig
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
print(patient)
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
with gr.Row() as row3:
gr.DataFrame()
demo.launch()