add model
Browse files
app.py
CHANGED
@@ -1,8 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
with gr.Blocks() as demo:
|
8 |
gr.Markdown("""
|
@@ -19,9 +75,10 @@ with gr.Blocks() as demo:
|
|
19 |
image_input = gr.Image(type="pil")
|
20 |
gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
|
21 |
with gr.Column():
|
22 |
-
image_output = gr.Image()
|
23 |
text_output = gr.Label()
|
|
|
|
|
24 |
|
25 |
-
image_button.click(
|
26 |
|
27 |
demo.launch()
|
|
|
1 |
+
# Copyright 2021 Tencent
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# =============================================================================
|
15 |
+
import os
|
16 |
import numpy as np
|
17 |
+
import torch
|
18 |
+
from model import SASNet
|
19 |
+
import warnings
|
20 |
+
import random
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
import gradio as gr
|
23 |
|
24 |
+
warnings.filterwarnings('ignore')
|
25 |
+
|
26 |
+
# define the GPU id to be used
|
27 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
28 |
+
|
29 |
+
def predict(img):
|
30 |
+
"""the main process of inference"""
|
31 |
+
test_loader = loading_data(img)
|
32 |
+
|
33 |
+
model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda()
|
34 |
+
model_path = "SHHA.pth"
|
35 |
+
# load the trained model
|
36 |
+
model.load_state_dict(torch.load(model_path))
|
37 |
+
print('successfully load model from', model_path)
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
img = img.convert('RGB')
|
43 |
+
transform = standard_transforms.Compose([
|
44 |
+
standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
45 |
+
std=[0.229, 0.224, 0.225]),])
|
46 |
+
img = transform(img)
|
47 |
+
img = torch.Tensor(img)
|
48 |
+
|
49 |
+
img = img.cuda()
|
50 |
+
pred_map = model(img)
|
51 |
+
|
52 |
+
pred_map = pred_map.data.cpu().numpy()
|
53 |
+
pred_cnt = np.sum(pred_map[i_img]) / 1000
|
54 |
+
|
55 |
+
den_map = np.squeeze(pred_map[i_img])
|
56 |
+
fig = plt.figure(frameon=False)
|
57 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
58 |
+
ax.set_axis_off()
|
59 |
+
fig.add_axes(ax)
|
60 |
+
ax.imshow(den_map, aspect='auto')
|
61 |
+
return (pred_cnt, fig)
|
62 |
|
63 |
with gr.Blocks() as demo:
|
64 |
gr.Markdown("""
|
|
|
75 |
image_input = gr.Image(type="pil")
|
76 |
gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
|
77 |
with gr.Column():
|
|
|
78 |
text_output = gr.Label()
|
79 |
+
image_output = gr.Plot()
|
80 |
+
|
81 |
|
82 |
+
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
|
83 |
|
84 |
demo.launch()
|
model.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 Tencent
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# =============================================================================
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torchvision import models
|
19 |
+
|
20 |
+
class Conv2d(nn.Module):
|
21 |
+
def __init__(self, in_channels, out_channels, kernel_size, \
|
22 |
+
stride=1, NL='relu', same_padding=False, bn=False, dilation=1):
|
23 |
+
super(Conv2d, self).__init__()
|
24 |
+
padding = int((kernel_size - 1) // 2) if same_padding else 0
|
25 |
+
self.conv = []
|
26 |
+
if dilation==1:
|
27 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation)
|
28 |
+
else:
|
29 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=dilation, dilation=dilation)
|
30 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else nn.Identity()
|
31 |
+
if NL == 'relu' :
|
32 |
+
self.relu = nn.ReLU(inplace=True)
|
33 |
+
elif NL == 'prelu':
|
34 |
+
self.relu = nn.PReLU()
|
35 |
+
else:
|
36 |
+
self.relu = None
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.conv(x)
|
40 |
+
if self.bn is not None:
|
41 |
+
x = self.bn(x)
|
42 |
+
if self.relu is not None:
|
43 |
+
x = self.relu(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
# the main implementation of the SASNet
|
47 |
+
class SASNet(nn.Module):
|
48 |
+
def __init__(self, pretrained=False, args=None):
|
49 |
+
super(SASNet, self).__init__()
|
50 |
+
# define the backbone network
|
51 |
+
vgg = models.vgg16_bn(pretrained=pretrained)
|
52 |
+
|
53 |
+
features = list(vgg.features.children())
|
54 |
+
# get each stage of the backbone
|
55 |
+
self.features1 = nn.Sequential(*features[0:6])
|
56 |
+
self.features2 = nn.Sequential(*features[6:13])
|
57 |
+
self.features3 = nn.Sequential(*features[13:23])
|
58 |
+
self.features4 = nn.Sequential(*features[23:33])
|
59 |
+
self.features5 = nn.Sequential(*features[33:43])
|
60 |
+
# docoder definition
|
61 |
+
self.de_pred5 = nn.Sequential(
|
62 |
+
Conv2d(512, 1024, 3, same_padding=True, NL='relu'),
|
63 |
+
Conv2d(1024, 512, 3, same_padding=True, NL='relu'),
|
64 |
+
)
|
65 |
+
|
66 |
+
self.de_pred4 = nn.Sequential(
|
67 |
+
Conv2d(512 + 512, 512, 3, same_padding=True, NL='relu'),
|
68 |
+
Conv2d(512, 256, 3, same_padding=True, NL='relu'),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.de_pred3 = nn.Sequential(
|
72 |
+
Conv2d(256 + 256, 256, 3, same_padding=True, NL='relu'),
|
73 |
+
Conv2d(256, 128, 3, same_padding=True, NL='relu'),
|
74 |
+
)
|
75 |
+
|
76 |
+
self.de_pred2 = nn.Sequential(
|
77 |
+
Conv2d(128 + 128, 128, 3, same_padding=True, NL='relu'),
|
78 |
+
Conv2d(128, 64, 3, same_padding=True, NL='relu'),
|
79 |
+
)
|
80 |
+
|
81 |
+
self.de_pred1 = nn.Sequential(
|
82 |
+
Conv2d(64 + 64, 64, 3, same_padding=True, NL='relu'),
|
83 |
+
Conv2d(64, 64, 3, same_padding=True, NL='relu'),
|
84 |
+
)
|
85 |
+
# density head definition
|
86 |
+
self.density_head5 = nn.Sequential(
|
87 |
+
MultiBranchModule(512),
|
88 |
+
Conv2d(2048, 1, 1, same_padding=True)
|
89 |
+
)
|
90 |
+
|
91 |
+
self.density_head4 = nn.Sequential(
|
92 |
+
MultiBranchModule(256),
|
93 |
+
Conv2d(1024, 1, 1, same_padding=True)
|
94 |
+
)
|
95 |
+
|
96 |
+
self.density_head3 = nn.Sequential(
|
97 |
+
MultiBranchModule(128),
|
98 |
+
Conv2d(512, 1, 1, same_padding=True)
|
99 |
+
)
|
100 |
+
|
101 |
+
self.density_head2 = nn.Sequential(
|
102 |
+
MultiBranchModule(64),
|
103 |
+
Conv2d(256, 1, 1, same_padding=True)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.density_head1 = nn.Sequential(
|
107 |
+
MultiBranchModule(64),
|
108 |
+
Conv2d(256, 1, 1, same_padding=True)
|
109 |
+
)
|
110 |
+
# confidence head definition
|
111 |
+
self.confidence_head5 = nn.Sequential(
|
112 |
+
Conv2d(512, 256, 1, same_padding=True, NL='relu'),
|
113 |
+
Conv2d(256, 1, 1, same_padding=True, NL=None)
|
114 |
+
)
|
115 |
+
|
116 |
+
self.confidence_head4 = nn.Sequential(
|
117 |
+
Conv2d(256, 128, 1, same_padding=True, NL='relu'),
|
118 |
+
Conv2d(128, 1, 1, same_padding=True, NL=None)
|
119 |
+
)
|
120 |
+
|
121 |
+
self.confidence_head3 = nn.Sequential(
|
122 |
+
Conv2d(128, 64, 1, same_padding=True, NL='relu'),
|
123 |
+
Conv2d(64, 1, 1, same_padding=True, NL=None)
|
124 |
+
)
|
125 |
+
|
126 |
+
self.confidence_head2 = nn.Sequential(
|
127 |
+
Conv2d(64, 32, 1, same_padding=True, NL='relu'),
|
128 |
+
Conv2d(32, 1, 1, same_padding=True, NL=None)
|
129 |
+
)
|
130 |
+
|
131 |
+
self.confidence_head1 = nn.Sequential(
|
132 |
+
Conv2d(64, 32, 1, same_padding=True, NL='relu'),
|
133 |
+
Conv2d(32, 1, 1, same_padding=True, NL=None)
|
134 |
+
)
|
135 |
+
|
136 |
+
self.block_size = args.block_size
|
137 |
+
# the forward process
|
138 |
+
def forward(self, x):
|
139 |
+
size = x.size()
|
140 |
+
x1 = self.features1(x)
|
141 |
+
x2 = self.features2(x1)
|
142 |
+
x3 = self.features3(x2)
|
143 |
+
x4 = self.features4(x3)
|
144 |
+
x5 = self.features5(x4)
|
145 |
+
# begining of decoding
|
146 |
+
x = self.de_pred5(x5)
|
147 |
+
x5_out = x
|
148 |
+
x = F.upsample_bilinear(x, size=x4.size()[2:])
|
149 |
+
|
150 |
+
x = torch.cat([x4, x], 1)
|
151 |
+
x = self.de_pred4(x)
|
152 |
+
x4_out = x
|
153 |
+
x = F.upsample_bilinear(x, size=x3.size()[2:])
|
154 |
+
|
155 |
+
x = torch.cat([x3, x], 1)
|
156 |
+
x = self.de_pred3(x)
|
157 |
+
x3_out = x
|
158 |
+
x = F.upsample_bilinear(x, size=x2.size()[2:])
|
159 |
+
|
160 |
+
x = torch.cat([x2, x], 1)
|
161 |
+
x = self.de_pred2(x)
|
162 |
+
x2_out = x
|
163 |
+
x = F.upsample_bilinear(x, size=x1.size()[2:])
|
164 |
+
|
165 |
+
x = torch.cat([x1, x], 1)
|
166 |
+
x = self.de_pred1(x)
|
167 |
+
x1_out = x
|
168 |
+
# density prediction
|
169 |
+
x5_density = self.density_head5(x5_out)
|
170 |
+
x4_density = self.density_head4(x4_out)
|
171 |
+
x3_density = self.density_head3(x3_out)
|
172 |
+
x2_density = self.density_head2(x2_out)
|
173 |
+
x1_density = self.density_head1(x1_out)
|
174 |
+
# get patch features for confidence prediction
|
175 |
+
x5_confi = F.adaptive_avg_pool2d(x5_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
|
176 |
+
x4_confi = F.adaptive_avg_pool2d(x4_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
|
177 |
+
x3_confi = F.adaptive_avg_pool2d(x3_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
|
178 |
+
x2_confi = F.adaptive_avg_pool2d(x2_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
|
179 |
+
x1_confi = F.adaptive_avg_pool2d(x1_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
|
180 |
+
# confidence prediction
|
181 |
+
x5_confi = self.confidence_head5(x5_confi)
|
182 |
+
x4_confi = self.confidence_head4(x4_confi)
|
183 |
+
x3_confi = self.confidence_head3(x3_confi)
|
184 |
+
x2_confi = self.confidence_head2(x2_confi)
|
185 |
+
x1_confi = self.confidence_head1(x1_confi)
|
186 |
+
# upsample the density prediction to be the same with the input size
|
187 |
+
x5_density = F.upsample_nearest(x5_density, size=x1.size()[2:])
|
188 |
+
x4_density = F.upsample_nearest(x4_density, size=x1.size()[2:])
|
189 |
+
x3_density = F.upsample_nearest(x3_density, size=x1.size()[2:])
|
190 |
+
x2_density = F.upsample_nearest(x2_density, size=x1.size()[2:])
|
191 |
+
x1_density = F.upsample_nearest(x1_density, size=x1.size()[2:])
|
192 |
+
# upsample the confidence prediction to be the same with the input size
|
193 |
+
x5_confi_upsample = F.upsample_nearest(x5_confi, size=x1.size()[2:])
|
194 |
+
x4_confi_upsample = F.upsample_nearest(x4_confi, size=x1.size()[2:])
|
195 |
+
x3_confi_upsample = F.upsample_nearest(x3_confi, size=x1.size()[2:])
|
196 |
+
x2_confi_upsample = F.upsample_nearest(x2_confi, size=x1.size()[2:])
|
197 |
+
x1_confi_upsample = F.upsample_nearest(x1_confi, size=x1.size()[2:])
|
198 |
+
|
199 |
+
# =============================================================================================================
|
200 |
+
# soft √
|
201 |
+
confidence_map = torch.cat([x5_confi_upsample, x4_confi_upsample,
|
202 |
+
x3_confi_upsample, x2_confi_upsample, x1_confi_upsample], 1)
|
203 |
+
confidence_map = torch.nn.functional.sigmoid(confidence_map)
|
204 |
+
|
205 |
+
# use softmax to normalize
|
206 |
+
confidence_map = torch.nn.functional.softmax(confidence_map, 1)
|
207 |
+
|
208 |
+
density_map = torch.cat([x5_density, x4_density, x3_density, x2_density, x1_density], 1)
|
209 |
+
# soft selection
|
210 |
+
density_map *= confidence_map
|
211 |
+
density = torch.sum(density_map, 1, keepdim=True)
|
212 |
+
|
213 |
+
return density
|
214 |
+
|
215 |
+
# the module definition for the multi-branch in the density head
|
216 |
+
class MultiBranchModule(nn.Module):
|
217 |
+
def __init__(self, in_channels, sync=False):
|
218 |
+
super(MultiBranchModule, self).__init__()
|
219 |
+
self.branch1x1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
|
220 |
+
self.branch1x1_1 = BasicConv2d(in_channels//2, in_channels, kernel_size=1, sync=sync)
|
221 |
+
|
222 |
+
self.branch3x3_1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
|
223 |
+
self.branch3x3_2 = BasicConv2d(in_channels // 2, in_channels, kernel_size=(3, 3), padding=(1, 1), sync=sync)
|
224 |
+
|
225 |
+
self.branch3x3dbl_1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
|
226 |
+
self.branch3x3dbl_2 = BasicConv2d(in_channels // 2, in_channels, kernel_size=5, padding=2, sync=sync)
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
branch1x1 = self.branch1x1(x)
|
230 |
+
branch1x1 = self.branch1x1_1(branch1x1)
|
231 |
+
|
232 |
+
branch3x3 = self.branch3x3_1(x)
|
233 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
234 |
+
|
235 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
236 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
237 |
+
|
238 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, x]
|
239 |
+
return torch.cat(outputs, 1)
|
240 |
+
|
241 |
+
# the module definition for the basic conv module
|
242 |
+
class BasicConv2d(nn.Module):
|
243 |
+
|
244 |
+
def __init__(self, in_channels, out_channels, sync=False, **kwargs):
|
245 |
+
super(BasicConv2d, self).__init__()
|
246 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
247 |
+
if sync:
|
248 |
+
# for sync bn
|
249 |
+
print('use sync inception')
|
250 |
+
self.bn = nn.SyncBatchNorm(out_channels, eps=0.001)
|
251 |
+
else:
|
252 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
x = self.conv(x)
|
256 |
+
x = self.bn(x)
|
257 |
+
return F.relu(x, inplace=True)
|