File size: 2,024 Bytes
732de68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os

import gradio as gr
import numpy as np
import torch
from torchvision import transforms

from unet import UNet

# Dataset Mean and STD
mean = (0.09189, 0.0833, 0.08749)
std = (0.13539, 0.1238, 0.12927)

model = UNet(in_channels=3, out_channels=1)
model.eval()

# Load Checkpoint
state_dict = torch.load("unet.pt")
model.load_state_dict(state_dict)

def outline(image, mask, color):
    image = image.copy()
    mask = np.round(mask)
    max_val = mask.max()
    yy, xx = np.nonzero(mask)

    for y, x in zip(yy, xx):
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val:
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
    return image

def segment(input_image):
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        output = model(input_batch)

    pred_mask = torch.round(output[0, 0]).numpy()

    red = (255, 0, 0)
    output_image = outline(input_image, pred_mask, red)

    return output_image

if __name__ == "__main__":
    inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339)
    outputs = gr.Image(height=300, width=300)

    webapp = gr.interface.Interface(
        fn=segment,
        inputs=inputs,
        outputs=outputs,
        examples=[
            os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4944.png"),
            os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4941.png"),
        ],
        allow_flagging="never",
        theme="gradio/monochrome",
        title="Brain MRI Segmentation Using U-Net",
        description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n"
                     "Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."),
    )
    webapp.launch()