thov's picture
Update app.py
cb7c3ed
raw
history blame
2.06 kB
import glob
import gradio as gr
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import imageio.v3 as iio
from UNET_perso import UNET
import matplotlib.pyplot as plt
from src.medicalDataLoader import MedicalImageDataset
def disp_prediction(num, alpha=0.4):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=True)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
fig = plt.figure()
im = im.detach().numpy()
im = im.reshape(256,256)
plt.imshow(alpha*pred_mask + (1-alpha)*im)
plt.axis('off')
return fig
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
print(patient)
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(min=0, max=89, value=10, step=1)
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
with gr.Row() as row3:
gr.DataFrame()
demo.launch()