hasibzunair commited on
Commit
8a3583d
β€’
1 Parent(s): 3a4add5

add app files

Browse files
cmap.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5648506a4b5dbeb787e93f26b429cab659c3b66a4d579645edb2f24ba41a919
3
+ size 848
description.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Title</title>
6
+ </head>
7
+ <body>
8
+ This is a demo of our BMVC'2022 Oral paper <a href="https://arxiv.org/abs/2210.00923">Masked Supervised Learning for Semantic Segmentation</a>.</br>
9
+ </body>
10
+ </html>
nyu.ipynb ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import numpy as np\n",
11
+ "import cv2\n",
12
+ "import codecs\n",
13
+ "import torch\n",
14
+ "import torchvision.transforms as transforms\n",
15
+ "import gradio as gr\n",
16
+ "\n",
17
+ "from PIL import Image\n",
18
+ "\n",
19
+ "from unetplusplus import NestedUNet\n",
20
+ "\n",
21
+ "torch.manual_seed(0)\n",
22
+ "\n",
23
+ "if torch.cuda.is_available():\n",
24
+ " torch.backends.cudnn.deterministic = True\n",
25
+ "\n",
26
+ "# Device\n",
27
+ "DEVICE = \"cpu\"\n",
28
+ "print(DEVICE)\n",
29
+ "\n",
30
+ "# Load color map\n",
31
+ "cmap = np.load('cmap.npy')\n",
32
+ "\n",
33
+ "# Make directories\n",
34
+ "os.system(\"mkdir ./models\")\n",
35
+ "\n",
36
+ "# Get model weights\n",
37
+ "if not os.path.exists(\"./models/masksupnyu39.31d.pth\"):\n",
38
+ " os.system(\"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth\")\n",
39
+ "\n",
40
+ "# Load model\n",
41
+ "model = NestedUNet(num_classes=40)\n",
42
+ "checkpoint = torch.load(\"./models/masksupnyu39.31d.pth\")\n",
43
+ "model.load_state_dict(checkpoint)\n",
44
+ "model = model.to(DEVICE)\n",
45
+ "model.eval()\n",
46
+ "\n",
47
+ "\n",
48
+ "# Main inference function\n",
49
+ "def inference(img_path):\n",
50
+ " image = Image.open(img_path).convert(\"RGB\")\n",
51
+ " transforms_image = transforms.Compose(\n",
52
+ " [\n",
53
+ " transforms.Resize((224, 224)),\n",
54
+ " transforms.CenterCrop((224, 224)),\n",
55
+ " transforms.ToTensor(),\n",
56
+ " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
57
+ " ]\n",
58
+ " )\n",
59
+ "\n",
60
+ " image = transforms_image(image)\n",
61
+ " image = image[None, :]\n",
62
+ " # Predict\n",
63
+ " with torch.no_grad():\n",
64
+ " output = torch.sigmoid(model(image.to(DEVICE).float()))\n",
65
+ " output = torch.softmax(output, dim=1).argmax(dim=1)[0].float().cpu().numpy().astype(np.uint8)\n",
66
+ " pred = cmap[output]\n",
67
+ " return pred\n",
68
+ "\n",
69
+ "# App\n",
70
+ "title = \"Masked Supervised Learning for Semantic Segmentation\"\n",
71
+ "description = codecs.open(\"description.html\", \"r\", \"utf-8\").read()\n",
72
+ "article = \"<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>\"\n",
73
+ "\n",
74
+ "gr.Interface(\n",
75
+ " inference,\n",
76
+ " gr.inputs.Image(type='file', label=\"Input Image\"),\n",
77
+ " gr.outputs.Image(type=\"file\", label=\"Predicted Output\"),\n",
78
+ " examples=[\"./sample_images/a.png\", \"./sample_images/b.png\", \n",
79
+ " \"./sample_images/c.png\", \"./sample_images/d.png\"],\n",
80
+ " title=title,\n",
81
+ " description=description,\n",
82
+ " article=article,\n",
83
+ " allow_flagging=False,\n",
84
+ " analytics_enabled=False,\n",
85
+ " ).launch(debug=True, enable_queue=True)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": []
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": []
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": []
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": []
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": []
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": []
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": []
136
+ }
137
+ ],
138
+ "metadata": {
139
+ "kernelspec": {
140
+ "display_name": "Python 3.8.12 ('fifa')",
141
+ "language": "python",
142
+ "name": "python3"
143
+ },
144
+ "language_info": {
145
+ "codemirror_mode": {
146
+ "name": "ipython",
147
+ "version": 3
148
+ },
149
+ "file_extension": ".py",
150
+ "mimetype": "text/x-python",
151
+ "name": "python",
152
+ "nbconvert_exporter": "python",
153
+ "pygments_lexer": "ipython3",
154
+ "version": "3.8.12"
155
+ },
156
+ "orig_nbformat": 4,
157
+ "vscode": {
158
+ "interpreter": {
159
+ "hash": "5a4cff4f724f20f3784f32e905011239b516be3fadafd59414871df18d0dad63"
160
+ }
161
+ }
162
+ },
163
+ "nbformat": 4,
164
+ "nbformat_minor": 2
165
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ scipy==1.4.1
2
+ torch
3
+ h5py==2.10.0
4
+ numpy==1.18.1
5
+ opencv-python-headless==4.2.0.32
6
+ Pillow
sample_images/a.png ADDED
sample_images/b.png ADDED
sample_images/c.png ADDED
sample_images/d.png ADDED
unetplusplus.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['UNet', 'NestedUNet']
6
+
7
+ """Taken from https://github.com/4uiiurz1/pytorch-nested-unet"""
8
+
9
+ class VGGBlock(nn.Module):
10
+ def __init__(self, in_channels, middle_channels, out_channels):
11
+ super().__init__()
12
+ self.relu = nn.ReLU(inplace=True)
13
+ self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
14
+ self.bn1 = nn.BatchNorm2d(middle_channels)
15
+ self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
16
+ self.bn2 = nn.BatchNorm2d(out_channels)
17
+
18
+ def forward(self, x):
19
+ out = self.conv1(x)
20
+ out = self.bn1(out)
21
+ out = self.relu(out)
22
+
23
+ out = self.conv2(out)
24
+ out = self.bn2(out)
25
+ out = self.relu(out)
26
+
27
+ return out
28
+
29
+
30
+ class UNet(nn.Module):
31
+ def __init__(self, num_classes, input_channels=3, **kwargs):
32
+ super().__init__()
33
+
34
+ nb_filter = [32, 64, 128, 256, 512]
35
+
36
+ self.pool = nn.MaxPool2d(2, 2)
37
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
38
+
39
+ self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
40
+ self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
41
+ self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
42
+ self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
43
+ self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
44
+
45
+ self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
46
+ self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
47
+ self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
48
+ self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
49
+
50
+ self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
51
+
52
+
53
+ def forward(self, input):
54
+ x0_0 = self.conv0_0(input)
55
+ x1_0 = self.conv1_0(self.pool(x0_0))
56
+ x2_0 = self.conv2_0(self.pool(x1_0))
57
+ x3_0 = self.conv3_0(self.pool(x2_0))
58
+ x4_0 = self.conv4_0(self.pool(x3_0))
59
+
60
+ x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
61
+ x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
62
+ x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
63
+ x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
64
+
65
+ output = self.final(x0_4)
66
+ return output
67
+
68
+
69
+ class NestedUNet(nn.Module):
70
+ """
71
+ U-Net Plus plus architecture
72
+ Reference: https://arxiv.org/abs/1807.10165
73
+ """
74
+ def __init__(self, num_classes=1, input_channels=3, deep_supervision=False, **kwargs):
75
+ super().__init__()
76
+
77
+ nb_filter = [32, 64, 128, 256, 512]
78
+
79
+ self.deep_supervision = deep_supervision
80
+
81
+ self.pool = nn.MaxPool2d(2, 2)
82
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
83
+
84
+ self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
85
+ self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
86
+ self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
87
+ self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
88
+ self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
89
+
90
+ self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
91
+ self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
92
+ self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
93
+ self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
94
+
95
+ self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
96
+ self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
97
+ self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
98
+
99
+ self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
100
+ self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
101
+
102
+ self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
103
+
104
+ if self.deep_supervision:
105
+ self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
106
+ self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
107
+ self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
108
+ self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
109
+ else:
110
+ self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
111
+
112
+
113
+ def forward(self, input):
114
+ x0_0 = self.conv0_0(input)
115
+ x1_0 = self.conv1_0(self.pool(x0_0))
116
+ x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
117
+
118
+ x2_0 = self.conv2_0(self.pool(x1_0))
119
+ x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
120
+ x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
121
+
122
+ x3_0 = self.conv3_0(self.pool(x2_0))
123
+ x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
124
+ x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
125
+ x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
126
+
127
+ x4_0 = self.conv4_0(self.pool(x3_0))
128
+ x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
129
+ x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
130
+ x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
131
+ x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
132
+
133
+ if self.deep_supervision:
134
+ output1 = self.final1(x0_1)
135
+ output2 = self.final2(x0_2)
136
+ output3 = self.final3(x0_3)
137
+ output4 = self.final4(x0_4)
138
+ return [output1, output2, output3, output4]
139
+
140
+ else:
141
+ output = self.final(x0_4)
142
+ return output