test_unet / app.py
wong132's picture
Update app.py
65c8302 verified
raw
history blame
No virus
2.03 kB
import os
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from unet import UNet
# Dataset Mean and STD
mean = (0.09189, 0.0833, 0.08749)
std = (0.13539, 0.1238, 0.12927)
model = UNet(in_channels=3, out_channels=1)
model.eval()
# Load Checkpoint
state_dict = torch.load("unet.pt")
model.load_state_dict(state_dict)
def outline(image, mask, color):
image = image.copy()
mask = np.round(mask)
max_val = mask.max()
yy, xx = np.nonzero(mask)
for y, x in zip(yy, xx):
if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val:
image[max(0, y) : y + 1, max(0, x) : x + 1] = color
return image
def segment(input_image):
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
pred_mask = torch.round(output[0, 0]).numpy()
red = (255, 0, 0)
output_image = outline(input_image, pred_mask, red)
return output_image
if __name__ == "__main__":
inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339)
outputs = gr.Image(height=300, width=300)
webapp = gr.interface.Interface(
fn=segment,
inputs=inputs,
outputs=outputs,
examples=[
os.path.join(os.path.dirname(__file__), "TCGA_CS_6667_20011105_1.png"),
os.path.join(os.path.dirname(__file__), "TCGA_DU_5872_19950223_11.png"),
],
allow_flagging="never",
theme="gradio/monochrome",
title="Brain MRI Segmentation Using U-Net",
description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n"
"Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."),
)
webapp.launch()