|
import os |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
|
|
from unet import UNet |
|
|
|
|
|
mean = (0.09189, 0.0833, 0.08749) |
|
std = (0.13539, 0.1238, 0.12927) |
|
|
|
model = UNet(in_channels=3, out_channels=1) |
|
model.eval() |
|
|
|
|
|
state_dict = torch.load("unet.pt") |
|
model.load_state_dict(state_dict) |
|
|
|
def outline(image, mask, color): |
|
image = image.copy() |
|
mask = np.round(mask) |
|
max_val = mask.max() |
|
yy, xx = np.nonzero(mask) |
|
|
|
for y, x in zip(yy, xx): |
|
if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val: |
|
image[max(0, y) : y + 1, max(0, x) : x + 1] = color |
|
return image |
|
|
|
def segment(input_image): |
|
preprocess = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=mean, std=std), |
|
]) |
|
|
|
input_tensor = preprocess(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
|
|
pred_mask = torch.round(output[0, 0]).numpy() |
|
|
|
red = (255, 0, 0) |
|
output_image = outline(input_image, pred_mask, red) |
|
|
|
return output_image |
|
|
|
if __name__ == "__main__": |
|
inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339) |
|
outputs = gr.Image(height=300, width=300) |
|
|
|
webapp = gr.interface.Interface( |
|
fn=segment, |
|
inputs=inputs, |
|
outputs=outputs, |
|
examples=[ |
|
os.path.join(os.path.dirname(__file__), "TCGA_CS_6667_20011105_1.png"), |
|
os.path.join(os.path.dirname(__file__), "TCGA_DU_5872_19950223_11.png"), |
|
], |
|
allow_flagging="never", |
|
theme="gradio/monochrome", |
|
title="Brain MRI Segmentation Using U-Net", |
|
description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n" |
|
"Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."), |
|
) |
|
webapp.launch() |
|
|