Umong51
Initial Commit
732de68
import os
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from unet import UNet
# Dataset Mean and STD
mean = (0.09189, 0.0833, 0.08749)
std = (0.13539, 0.1238, 0.12927)
model = UNet(in_channels=3, out_channels=1)
model.eval()
# Load Checkpoint
state_dict = torch.load("unet.pt")
model.load_state_dict(state_dict)
def outline(image, mask, color):
image = image.copy()
mask = np.round(mask)
max_val = mask.max()
yy, xx = np.nonzero(mask)
for y, x in zip(yy, xx):
if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val:
image[max(0, y) : y + 1, max(0, x) : x + 1] = color
return image
def segment(input_image):
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
pred_mask = torch.round(output[0, 0]).numpy()
red = (255, 0, 0)
output_image = outline(input_image, pred_mask, red)
return output_image
if __name__ == "__main__":
inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339)
outputs = gr.Image(height=300, width=300)
webapp = gr.interface.Interface(
fn=segment,
inputs=inputs,
outputs=outputs,
examples=[
os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4944.png"),
os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4941.png"),
],
allow_flagging="never",
theme="gradio/monochrome",
title="Brain MRI Segmentation Using U-Net",
description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n"
"Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."),
)
webapp.launch()