Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from unet import UNet | |
# Dataset Mean and STD | |
mean = (0.09189, 0.0833, 0.08749) | |
std = (0.13539, 0.1238, 0.12927) | |
model = UNet(in_channels=3, out_channels=1) | |
model.eval() | |
# Load Checkpoint | |
state_dict = torch.load("unet.pt") | |
model.load_state_dict(state_dict) | |
def outline(image, mask, color): | |
image = image.copy() | |
mask = np.round(mask) | |
max_val = mask.max() | |
yy, xx = np.nonzero(mask) | |
for y, x in zip(yy, xx): | |
if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val: | |
image[max(0, y) : y + 1, max(0, x) : x + 1] = color | |
return image | |
def segment(input_image): | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
]) | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
with torch.no_grad(): | |
output = model(input_batch) | |
pred_mask = torch.round(output[0, 0]).numpy() | |
red = (255, 0, 0) | |
output_image = outline(input_image, pred_mask, red) | |
return output_image | |
if __name__ == "__main__": | |
inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339) | |
outputs = gr.Image(height=300, width=300) | |
webapp = gr.interface.Interface( | |
fn=segment, | |
inputs=inputs, | |
outputs=outputs, | |
examples=[ | |
os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4944.png"), | |
os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4941.png"), | |
], | |
allow_flagging="never", | |
theme="gradio/monochrome", | |
title="Brain MRI Segmentation Using U-Net", | |
description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n" | |
"Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."), | |
) | |
webapp.launch() | |