File size: 4,731 Bytes
cb7c3ed
2e2df88
cb7c3ed
5424c06
 
 
 
 
7adf3b2
cb7c3ed
 
279daf2
cb7c3ed
7adf3b2
cb7c3ed
5424c06
2e2df88
5424c06
 
 
 
 
 
 
 
7adf3b2
279daf2
 
cb7c3ed
 
7adf3b2
cb7c3ed
7adf3b2
cb7c3ed
 
 
7adf3b2
cb7c3ed
 
 
 
 
 
 
 
279daf2
2e2df88
5424c06
cb7c3ed
 
 
 
 
 
 
 
 
 
 
5424c06
 
 
279daf2
5424c06
 
cb7c3ed
 
 
 
5424c06
 
 
 
 
 
 
 
 
 
 
cb7c3ed
5424c06
cb7c3ed
 
5424c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279daf2
5424c06
 
 
 
279daf2
5424c06
 
 
 
 
 
 
279daf2
5424c06
 
279daf2
 
 
 
 
 
 
5424c06
279daf2
 
 
5424c06
 
279daf2
 
 
 
 
5424c06
 
cb7c3ed
 
8b1271d
cb7c3ed
 
 
5424c06
cb7c3ed
5424c06
cb7c3ed
5424c06
2e2df88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import glob
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
import einops

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchmetrics.functional import dice_score, jaccard_index

from UNET_perso import UNET
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation

def disp_init(num):
    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]
    im = iio.imread(patient)
    fig =  plt.figure()
    plt.imshow(im)
    plt.axis('off')
    return fig

def disp_prediction(num, alpha=0.45):

    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]

    model = UNET(in_channels=1, out_channels=4).to('cpu')

    filepath = 'UNET_perso'
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose([transforms.ToTensor()])
    val_set = MedicalImageDataset('val',
                                './Data',
                                transform=transform,
                                mask_transform=transform,
                                equalize=False)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False)
    
    for _, (img, _, name) in enumerate(val_loader):
        if name[0] == patient:
            im = img

    pred = model(im)

    pred = pred.detach().numpy()
    pred = pred.reshape(4,256,256)
    pred_mask = np.argmax(pred, axis=0)

    im = im.detach().numpy()
    im = im.reshape(256,256)
    
    fig = plt.figure()
    try:
        plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
    except ValueError:
        print(type(pred_mask))
    plt.axis('off')

    return fig

def disp_label(num, alpha=0.4):
    lst_GT = glob.glob('./Data/val/GT/*.png')
    GT = lst_GT[num]

    label = iio.imread(GT)
    
    fig = plt.figure()
    try : 
        plt.imshow(label)
    except ValueError:
        print(type(label))
    plt.axis('off')

    return fig


def compute_metrics(num):

    lst_patients = glob.glob('./Data/val/Img/*.png')
    patient = lst_patients[num]

    model = UNET(in_channels=1, out_channels=4).to('cpu')

    filepath = 'UNET_perso'
    model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose([transforms.ToTensor()])
    val_set = MedicalImageDataset('val',
                                './Data',
                                transform=transform,
                                mask_transform=transform,
                                equalize=False)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False)
    
    for _, (img, label, name) in enumerate(val_loader):
        if name[0] == patient:
            im = img
            lab = label.reshape(256,256)

    pred = model(im)

    pred = pred.detach().numpy()
    pred = pred.reshape(4,256,256)
    pred = np.argmax(pred, axis=0)
    pred = torch.from_numpy(pred)
    pred = torch.nn.functional.one_hot(pred, num_classes=4)
    pred = einops.rearrange(pred, 'h w class -> class h w')

    lab = getTargetSegmentation(lab)
    lab = torch.nn.functional.one_hot(lab, num_classes=4)
    lab = einops.rearrange(lab, 'h w class -> class h w')
    
    dsc1 = dice_score(pred[1], lab[1])
    dsc2 = dice_score(pred[2], lab[2])
    dsc3 = dice_score(pred[3], lab[3])

    iou1 = jaccard_index(pred[1], lab[1], num_classes=2)
    iou2 = jaccard_index(pred[2], lab[2], num_classes=2)
    iou3 = jaccard_index(pred[3], lab[3], num_classes=2)

    df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
    df.loc[len(df)] = [round(dsc1.item(),2), round(dsc2.item(),2), round(dsc3.item(),2)]
    df.loc[len(df)] = [round(iou1.item(),2), round(iou2.item(),2), round(iou3.item(),2)]
    df = df.assign(metric=['Dice Score', 'IoU'])
    df = df[['metric','class 1','class 2','class 3']]
    
    return df

with gr.Blocks() as demo:
    with gr.Row() as row1:
        slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
    with gr.Row() as row2:
        slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
        slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
        slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
    with gr.Row() as row3:
        slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())

demo.queue()
demo.launch()