import os import gradio as gr import numpy as np import torch from torchvision import transforms from unet import UNet # Dataset Mean and STD mean = (0.09189, 0.0833, 0.08749) std = (0.13539, 0.1238, 0.12927) model = UNet(in_channels=3, out_channels=1) model.eval() # Load Checkpoint state_dict = torch.load("unet.pt") model.load_state_dict(state_dict) def outline(image, mask, color): image = image.copy() mask = np.round(mask) max_val = mask.max() yy, xx = np.nonzero(mask) for y, x in zip(yy, xx): if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val: image[max(0, y) : y + 1, max(0, x) : x + 1] = color return image def segment(input_image): preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) with torch.no_grad(): output = model(input_batch) pred_mask = torch.round(output[0, 0]).numpy() red = (255, 0, 0) output_image = outline(input_image, pred_mask, red) return output_image if __name__ == "__main__": inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339) outputs = gr.Image(height=300, width=300) webapp = gr.interface.Interface( fn=segment, inputs=inputs, outputs=outputs, examples=[ os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4944.png"), os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4941.png"), ], allow_flagging="never", theme="gradio/monochrome", title="Brain MRI Segmentation Using U-Net", description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n" "Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."), ) webapp.launch()