File size: 2,030 Bytes
b3928af ac739cd 65c8302 b3928af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from unet import UNet
# Dataset Mean and STD
mean = (0.09189, 0.0833, 0.08749)
std = (0.13539, 0.1238, 0.12927)
model = UNet(in_channels=3, out_channels=1)
model.eval()
# Load Checkpoint
state_dict = torch.load("unet.pt")
model.load_state_dict(state_dict)
def outline(image, mask, color):
image = image.copy()
mask = np.round(mask)
max_val = mask.max()
yy, xx = np.nonzero(mask)
for y, x in zip(yy, xx):
if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val:
image[max(0, y) : y + 1, max(0, x) : x + 1] = color
return image
def segment(input_image):
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
pred_mask = torch.round(output[0, 0]).numpy()
red = (255, 0, 0)
output_image = outline(input_image, pred_mask, red)
return output_image
if __name__ == "__main__":
inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339)
outputs = gr.Image(height=300, width=300)
webapp = gr.interface.Interface(
fn=segment,
inputs=inputs,
outputs=outputs,
examples=[
os.path.join(os.path.dirname(__file__), "TCGA_DU_5855_19951217_17.png"),
os.path.join(os.path.dirname(__file__), "TCGA_DU_5872_19950223_11.png"),
],
allow_flagging="never",
theme="gradio/monochrome",
title="Brain MRI Segmentation Using U-Net",
description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n"
"Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."),
)
webapp.launch()
|