Spaces:
Sleeping
Sleeping
File size: 4,519 Bytes
cb7c3ed 2e2df88 cb7c3ed 5424c06 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 32278ee 2e2df88 5424c06 7adf3b2 279daf2 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 7adf3b2 cb7c3ed 279daf2 2e2df88 5424c06 cb7c3ed 5424c06 279daf2 5424c06 32278ee cb7c3ed 32278ee 5424c06 32278ee 5424c06 32278ee cb7c3ed 5424c06 cb7c3ed 5424c06 279daf2 5424c06 279daf2 5424c06 279daf2 5424c06 279daf2 5424c06 32278ee 5424c06 32278ee 279daf2 5424c06 cb7c3ed 8b1271d cb7c3ed 5424c06 cb7c3ed 5424c06 cb7c3ed 5424c06 2e2df88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import glob
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
import einops
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from UNET_perso import UNET
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation, IOU, dice_score
def disp_init(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
im = iio.imread(patient)
fig = plt.figure()
plt.imshow(im)
plt.axis('off')
return fig
def disp_prediction(num, alpha=0.45):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, _, name) in enumerate(val_loader):
if name[0] == patient:
im = img
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred_mask = np.argmax(pred, axis=0)
im = im.detach().numpy()
im = im.reshape(256,256)
fig = plt.figure()
try:
plt.imshow((alpha*pred_mask + (1-alpha)*im)/2)
except ValueError:
pass
plt.axis('off')
return fig
def disp_label(num):
lst_GT = glob.glob('./Data/val/GT/*.png')
GT = lst_GT[num]
label = iio.imread(GT)
fig = plt.figure()
try:
plt.imshow(label)
except ValueError:
pass
plt.axis('off')
return fig
def compute_metrics(num):
lst_patients = glob.glob('./Data/val/Img/*.png')
patient = lst_patients[num]
model = UNET(in_channels=1, out_channels=4).to('cpu')
filepath = 'UNET_perso'
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
val_set = MedicalImageDataset('val',
'./Data',
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=1,
shuffle=False)
for _, (img, label, name) in enumerate(val_loader):
if name[0] == patient:
im = img
lab = label.reshape(256,256)
pred = model(im)
pred = pred.detach().numpy()
pred = pred.reshape(4,256,256)
pred = np.argmax(pred, axis=0)
pred = torch.from_numpy(pred)
pred = torch.nn.functional.one_hot(pred, num_classes=4)
pred = einops.rearrange(pred, 'h w class -> class h w')
lab = getTargetSegmentation(lab)
lab = torch.nn.functional.one_hot(lab, num_classes=4)
lab = einops.rearrange(lab, 'h w class -> class h w')
dsc1 = dice_score(pred[1], lab[1])
dsc2 = dice_score(pred[2], lab[2])
dsc3 = dice_score(pred[3], lab[3])
iou1 = IOU(pred[1], lab[1])
iou2 = IOU(pred[2], lab[2])
iou3 = IOU(pred[3], lab[3])
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3'])
df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)]
df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)]
df = df.assign(metric=['Dice Score', 'IoU'])
df = df[['metric','class 1','class 2','class 3']]
return df
with gr.Blocks() as demo:
with gr.Row() as row1:
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)')
with gr.Row() as row2:
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image'))
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction'))
slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth'))
with gr.Row() as row3:
slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame())
demo.queue()
demo.launch() |