File size: 2,125 Bytes
cb7c3ed
2e2df88
cb7c3ed
7adf3b2
cb7c3ed
 
 
 
7adf3b2
 
cb7c3ed
2e2df88
 
cb7c3ed
7adf3b2
cb7c3ed
 
7adf3b2
cb7c3ed
7adf3b2
cb7c3ed
 
 
7adf3b2
cb7c3ed
 
 
 
 
 
 
 
 
 
2e2df88
cb7c3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b1271d
cb7c3ed
 
 
 
 
 
2e2df88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import glob
import gradio as gr
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import imageio.v3 as iio

from UNET_perso import UNET
import matplotlib.pyplot as plt
from src.medicalDataLoader import MedicalImageDataset


def disp_prediction(num, alpha=0.4):

    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]

    model = UNET(in_channels=1, out_channels=4).to('cpu')

    filepath = 'UNET_perso'
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose([transforms.ToTensor()])
    val_set = MedicalImageDataset('val',
                                './Data',
                                transform=transform,
                                mask_transform=transform,
                                equalize=False)
    
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=True)
    
    for _, (img, label, name) in enumerate(val_loader):
        if name[0] == patient:
            im = img

    pred = model(im)

    pred = pred.detach().numpy()
    pred = pred.reshape(4,256,256)
    pred_mask = np.argmax(pred, axis=0)

    fig = plt.figure()
    im = im.detach().numpy()
    im = im.reshape(256,256)
    plt.imshow(alpha*pred_mask + (1-alpha)*im)
    plt.axis('off')

    return fig

def disp_init(num):
    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]
    print(patient)
    im = iio.imread(patient)
    fig =  plt.figure()
    plt.imshow(im)
    plt.axis('off')
    return fig

with gr.Blocks() as demo:
    with gr.Row() as row1:
        slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
    with gr.Row() as row2:
        slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
        slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
    with gr.Row() as row3:
        gr.DataFrame()

demo.launch()