Spaces:
Sleeping
Sleeping
import glob | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import imageio.v3 as iio | |
import einops | |
import torch | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from UNET_perso import UNET | |
from src.medicalDataLoader import MedicalImageDataset | |
from src.utils import getTargetSegmentation, IOU, dice_score | |
def disp_init(num): | |
lst_patients = glob.glob('./Data/val/Img/*.png') | |
patient = lst_patients[num] | |
im = iio.imread(patient) | |
fig = plt.figure() | |
plt.imshow(im) | |
plt.axis('off') | |
return fig | |
def disp_prediction(num, alpha=0.45): | |
lst_patients = glob.glob('./Data/val/Img/*.png') | |
patient = lst_patients[num] | |
model = UNET(in_channels=1, out_channels=4).to('cpu') | |
filepath = 'UNET_perso' | |
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) | |
model.eval() | |
transform = transforms.Compose([transforms.ToTensor()]) | |
val_set = MedicalImageDataset('val', | |
'./Data', | |
transform=transform, | |
mask_transform=transform, | |
equalize=False) | |
val_loader = DataLoader(val_set, | |
batch_size=1, | |
shuffle=False) | |
for _, (img, _, name) in enumerate(val_loader): | |
if name[0] == patient: | |
im = img | |
pred = model(im) | |
pred = pred.detach().numpy() | |
pred = pred.reshape(4,256,256) | |
pred_mask = np.argmax(pred, axis=0) | |
im = im.detach().numpy() | |
im = im.reshape(256,256) | |
fig = plt.figure() | |
try: | |
plt.imshow((alpha*pred_mask + (1-alpha)*im)/2) | |
except ValueError: | |
pass | |
plt.axis('off') | |
return fig | |
def disp_label(num): | |
lst_GT = glob.glob('./Data/val/GT/*.png') | |
GT = lst_GT[num] | |
label = iio.imread(GT) | |
fig = plt.figure() | |
try: | |
plt.imshow(label) | |
except ValueError: | |
pass | |
plt.axis('off') | |
return fig | |
def compute_metrics(num): | |
lst_patients = glob.glob('./Data/val/Img/*.png') | |
patient = lst_patients[num] | |
model = UNET(in_channels=1, out_channels=4).to('cpu') | |
filepath = 'UNET_perso' | |
model.load_state_dict(torch.load(filepath, map_location=torch.device('cpu'))) | |
model.eval() | |
transform = transforms.Compose([transforms.ToTensor()]) | |
val_set = MedicalImageDataset('val', | |
'./Data', | |
transform=transform, | |
mask_transform=transform, | |
equalize=False) | |
val_loader = DataLoader(val_set, | |
batch_size=1, | |
shuffle=False) | |
for _, (img, label, name) in enumerate(val_loader): | |
if name[0] == patient: | |
im = img | |
lab = label.reshape(256,256) | |
pred = model(im) | |
pred = pred.detach().numpy() | |
pred = pred.reshape(4,256,256) | |
pred = np.argmax(pred, axis=0) | |
pred = torch.from_numpy(pred) | |
pred = torch.nn.functional.one_hot(pred, num_classes=4) | |
pred = einops.rearrange(pred, 'h w class -> class h w') | |
lab = getTargetSegmentation(lab) | |
lab = torch.nn.functional.one_hot(lab, num_classes=4) | |
lab = einops.rearrange(lab, 'h w class -> class h w') | |
dsc1 = dice_score(pred[1], lab[1]) | |
dsc2 = dice_score(pred[2], lab[2]) | |
dsc3 = dice_score(pred[3], lab[3]) | |
iou1 = IOU(pred[1], lab[1]) | |
iou2 = IOU(pred[2], lab[2]) | |
iou3 = IOU(pred[3], lab[3]) | |
df = pd.DataFrame(columns=['class 1', 'class 2', 'class 3']) | |
df.loc[len(df)] = [round(dsc1,2), round(dsc2,2), round(dsc3,2)] | |
df.loc[len(df)] = [round(iou1,2), round(iou2,2), round(iou3,2)] | |
df = df.assign(metric=['Dice Score', 'IoU']) | |
df = df[['metric','class 1','class 2','class 3']] | |
return df | |
with gr.Blocks() as demo: | |
with gr.Row() as row1: | |
slide = gr.Slider(minimum=0, maximum=89, value=0, step=1, label='inference with validation split data (choose patient)') | |
with gr.Row() as row2: | |
slide.release(fn=disp_init, inputs=[slide], outputs=gr.Plot(label='initial image')) | |
slide.release(fn=disp_prediction, inputs=[slide], outputs=gr.Plot(label='prediction')) | |
slide.release(fn=disp_label, inputs=[slide], outputs=gr.Plot(label='groundtruth')) | |
with gr.Row() as row3: | |
slide.release(fn=compute_metrics, inputs=[slide], outputs=gr.DataFrame()) | |
demo.queue() | |
demo.launch() |