Spaces:
Sleeping
Sleeping
File size: 2,055 Bytes
cb7c3ed 2e2df88 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 2e2df88 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 2e2df88 cb7c3ed 2e2df88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import glob
import gradio as gr
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import imageio.v3 as iio
from UNET_perso import UNET
import matplotlib.pyplot as plt
from src.medicalDataLoader import MedicalImageDataset
def disp_prediction(num, alpha=0.4):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=True)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
fig = plt.figure()
im = im.detach().numpy()
im = im.reshape(256,256)
plt.imshow(alpha*pred_mask + (1-alpha)*im)
plt.axis('off')
return fig
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
print(patient)
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(min=0, max=89, value=10, step=1)
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
with gr.Row() as row3:
gr.DataFrame()
demo.launch() |