File size: 4,519 Bytes
cb7c3ed
2e2df88
cb7c3ed
5424c06
 
 
 
 
7adf3b2
cb7c3ed
 
 
7adf3b2
cb7c3ed
32278ee
2e2df88
5424c06
 
 
 
 
 
 
 
7adf3b2
279daf2
 
cb7c3ed
 
7adf3b2
cb7c3ed
7adf3b2
cb7c3ed
 
 
7adf3b2
cb7c3ed
 
 
 
 
 
 
 
279daf2
2e2df88
5424c06
cb7c3ed
 
 
 
 
 
 
 
 
 
 
5424c06
 
 
279daf2
5424c06
32278ee
cb7c3ed
 
 
 
32278ee
5424c06
 
 
 
32278ee
5424c06
 
32278ee
cb7c3ed
5424c06
cb7c3ed
 
5424c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279daf2
5424c06
 
 
 
279daf2
5424c06
 
 
 
 
 
 
279daf2
5424c06
 
279daf2
 
 
 
 
 
 
5424c06
32278ee
 
 
5424c06
 
32278ee
 
279daf2
 
 
5424c06
 
cb7c3ed
 
8b1271d
cb7c3ed
 
 
5424c06
cb7c3ed
5424c06
cb7c3ed
5424c06
2e2df88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import glob
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
import einops

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from UNET_perso import UNET
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation, IOU, dice_score

def disp_init(num):
    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]
    im = iio.imread(patient)
    fig =  plt.figure()
    plt.imshow(im)
    plt.axis('off')
    return fig

def disp_prediction(num, alpha=0.45):

    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]

    model = UNET(in_channels=1, out_channels=4).to('cpu')

    filepath = 'UNET_perso'
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose([transforms.ToTensor()])
    val_set = MedicalImageDataset('val',
                                './Data',
                                transform=transform,
                                mask_transform=transform,
                                equalize=False)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False)
    
    for _, (img, _, name) in enumerate(val_loader):
        if name[0] == patient:
            im = img

    pred = model(im)

    pred = pred.detach().numpy()
    pred = pred.reshape(4,256,256)
    pred_mask = np.argmax(pred, axis=0)

    im = im.detach().numpy()
    im = im.reshape(256,256)
    
    fig = plt.figure()
    try:
        plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
    except ValueError:
        pass
    plt.axis('off')

    return fig

def disp_label(num):
    lst_GT = glob.glob('./Data/val/GT/*.png')
    GT = lst_GT[num]
    label = iio.imread(GT)
    fig = plt.figure()
    try: 
        plt.imshow(label)
    except ValueError:
        pass
    plt.axis('off')

    return fig


def compute_metrics(num):

    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]

    model = UNET(in_channels=1, out_channels=4).to('cpu')

    filepath = 'UNET_perso'
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose([transforms.ToTensor()])
    val_set = MedicalImageDataset('val',
                                './Data',
                                transform=transform,
                                mask_transform=transform,
                                equalize=False)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False)
    
    for _, (img, label, name) in enumerate(val_loader):
        if name[0] == patient:
            im = img
            lab = label.reshape(256,256)

    pred = model(im)

    pred = pred.detach().numpy()
    pred = pred.reshape(4,256,256)
    pred = np.argmax(pred, axis=0)
    pred = torch.from_numpy(pred)
    pred = torch.nn.functional.one_hot(pred, num_classes=4)
    pred = einops.rearrange(pred, 'h w class -> class h w')

    lab = getTargetSegmentation(lab)
    lab = torch.nn.functional.one_hot(lab, num_classes=4)
    lab = einops.rearrange(lab, 'h w class -> class h w')
    
    dsc1 = dice_score(pred[1], lab[1])
    dsc2 = dice_score(pred[2], lab[2])
    dsc3 = dice_score(pred[3], lab[3])

    iou1 = IOU(pred[1], lab[1])
    iou2 = IOU(pred[2], lab[2])
    iou3 = IOU(pred[3], lab[3])

    df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
    df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)]
    df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)]
    df = df.assign(metric=['Dice Score', 'IoU'])
    df = df[['metric','class 1','class 2','class 3']]
    
    return df

with gr.Blocks() as demo:
    with gr.Row() as row1:
        slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
    with gr.Row() as row2:
        slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
        slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
        slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
    with gr.Row() as row3:
        slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())

demo.queue()
demo.launch()