Umong51 commited on
Commit
732de68
1 Parent(s): 5815eea

Initial Commit

Browse files
Files changed (7) hide show
  1. README.md +1 -1
  2. app.py +69 -0
  3. examples/TCGA_CS_4941.png +0 -0
  4. examples/TCGA_CS_4944.png +0 -0
  5. requirements.txt +4 -0
  6. unet.pt +3 -0
  7. unet.py +98 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Brain Mri Segmentation
3
- emoji: 📉
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
 
1
  ---
2
  title: Brain Mri Segmentation
3
+ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+ from unet import UNet
9
+
10
+ # Dataset Mean and STD
11
+ mean = (0.09189, 0.0833, 0.08749)
12
+ std = (0.13539, 0.1238, 0.12927)
13
+
14
+ model = UNet(in_channels=3, out_channels=1)
15
+ model.eval()
16
+
17
+ # Load Checkpoint
18
+ state_dict = torch.load("unet.pt")
19
+ model.load_state_dict(state_dict)
20
+
21
+ def outline(image, mask, color):
22
+ image = image.copy()
23
+ mask = np.round(mask)
24
+ max_val = mask.max()
25
+ yy, xx = np.nonzero(mask)
26
+
27
+ for y, x in zip(yy, xx):
28
+ if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < max_val:
29
+ image[max(0, y) : y + 1, max(0, x) : x + 1] = color
30
+ return image
31
+
32
+ def segment(input_image):
33
+ preprocess = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=mean, std=std),
36
+ ])
37
+
38
+ input_tensor = preprocess(input_image)
39
+ input_batch = input_tensor.unsqueeze(0)
40
+
41
+ with torch.no_grad():
42
+ output = model(input_batch)
43
+
44
+ pred_mask = torch.round(output[0, 0]).numpy()
45
+
46
+ red = (255, 0, 0)
47
+ output_image = outline(input_image, pred_mask, red)
48
+
49
+ return output_image
50
+
51
+ if __name__ == "__main__":
52
+ inputs = gr.Image(sources=["upload", "clipboard"], height=339, width=339)
53
+ outputs = gr.Image(height=300, width=300)
54
+
55
+ webapp = gr.interface.Interface(
56
+ fn=segment,
57
+ inputs=inputs,
58
+ outputs=outputs,
59
+ examples=[
60
+ os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4944.png"),
61
+ os.path.join(os.path.dirname(__file__), "examples/TCGA_CS_4941.png"),
62
+ ],
63
+ allow_flagging="never",
64
+ theme="gradio/monochrome",
65
+ title="Brain MRI Segmentation Using U-Net",
66
+ description=("Explore **U-Net** with batch normalization for abnormality segmentation in brain MRI.\n\n"
67
+ "Input image must be a **3-channel brain MRI slice** from **pre-contrast**, **FLAIR**, and **post-contrast** sequences, respectively."),
68
+ )
69
+ webapp.launch()
examples/TCGA_CS_4941.png ADDED
examples/TCGA_CS_4944.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ gradio
unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:571a1e09cab5a895848ea63db76d0ab2d3e045cbb47709ef680584ee027ac2bd
3
+ size 31108781
unet.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class UNet(nn.Module):
8
+
9
+ def __init__(self, in_channels=3, out_channels=1, init_features=32):
10
+ super(UNet, self).__init__()
11
+
12
+ features = init_features
13
+ self.encoder1 = UNet._block(in_channels, features, name="enc1")
14
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
15
+ self.encoder2 = UNet._block(features, features * 2, name="enc2")
16
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
17
+ self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
18
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
19
+ self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
20
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
21
+
22
+ self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
23
+
24
+ self.upconv4 = nn.ConvTranspose2d(
25
+ features * 16, features * 8, kernel_size=2, stride=2
26
+ )
27
+ self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
28
+ self.upconv3 = nn.ConvTranspose2d(
29
+ features * 8, features * 4, kernel_size=2, stride=2
30
+ )
31
+ self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
32
+ self.upconv2 = nn.ConvTranspose2d(
33
+ features * 4, features * 2, kernel_size=2, stride=2
34
+ )
35
+ self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
36
+ self.upconv1 = nn.ConvTranspose2d(
37
+ features * 2, features, kernel_size=2, stride=2
38
+ )
39
+ self.decoder1 = UNet._block(features * 2, features, name="dec1")
40
+
41
+ self.conv = nn.Conv2d(
42
+ in_channels=features, out_channels=out_channels, kernel_size=1
43
+ )
44
+
45
+ def forward(self, x):
46
+ enc1 = self.encoder1(x)
47
+ enc2 = self.encoder2(self.pool1(enc1))
48
+ enc3 = self.encoder3(self.pool2(enc2))
49
+ enc4 = self.encoder4(self.pool3(enc3))
50
+
51
+ bottleneck = self.bottleneck(self.pool4(enc4))
52
+
53
+ dec4 = self.upconv4(bottleneck)
54
+ dec4 = torch.cat((dec4, enc4), dim=1)
55
+ dec4 = self.decoder4(dec4)
56
+ dec3 = self.upconv3(dec4)
57
+ dec3 = torch.cat((dec3, enc3), dim=1)
58
+ dec3 = self.decoder3(dec3)
59
+ dec2 = self.upconv2(dec3)
60
+ dec2 = torch.cat((dec2, enc2), dim=1)
61
+ dec2 = self.decoder2(dec2)
62
+ dec1 = self.upconv1(dec2)
63
+ dec1 = torch.cat((dec1, enc1), dim=1)
64
+ dec1 = self.decoder1(dec1)
65
+ return torch.sigmoid(self.conv(dec1))
66
+
67
+ @staticmethod
68
+ def _block(in_channels, features, name):
69
+ return nn.Sequential(
70
+ OrderedDict(
71
+ [
72
+ (
73
+ name + "conv1",
74
+ nn.Conv2d(
75
+ in_channels=in_channels,
76
+ out_channels=features,
77
+ kernel_size=3,
78
+ padding=1,
79
+ bias=False,
80
+ ),
81
+ ),
82
+ (name + "norm1", nn.BatchNorm2d(num_features=features)),
83
+ (name + "relu1", nn.ReLU(inplace=True)),
84
+ (
85
+ name + "conv2",
86
+ nn.Conv2d(
87
+ in_channels=features,
88
+ out_channels=features,
89
+ kernel_size=3,
90
+ padding=1,
91
+ bias=False,
92
+ ),
93
+ ),
94
+ (name + "norm2", nn.BatchNorm2d(num_features=features)),
95
+ (name + "relu2", nn.ReLU(inplace=True)),
96
+ ]
97
+ )
98
+ )