leuschnm commited on
Commit
994ce72
·
1 Parent(s): 93e74b2
Files changed (2) hide show
  1. app.py +61 -4
  2. model.py +257 -0
app.py CHANGED
@@ -1,8 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
 
 
 
 
 
2
  import gradio as gr
3
 
4
- def flip_image(x):
5
- return (1000, np.zeros((100, 100)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  with gr.Blocks() as demo:
8
  gr.Markdown("""
@@ -19,9 +75,10 @@ with gr.Blocks() as demo:
19
  image_input = gr.Image(type="pil")
20
  gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
21
  with gr.Column():
22
- image_output = gr.Image()
23
  text_output = gr.Label()
 
 
24
 
25
- image_button.click(flip_image, inputs=image_input, outputs=[text_output, image_output])
26
 
27
  demo.launch()
 
1
+ # Copyright 2021 Tencent
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =============================================================================
15
+ import os
16
  import numpy as np
17
+ import torch
18
+ from model import SASNet
19
+ import warnings
20
+ import random
21
+ import matplotlib.pyplot as plt
22
  import gradio as gr
23
 
24
+ warnings.filterwarnings('ignore')
25
+
26
+ # define the GPU id to be used
27
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
28
+
29
+ def predict(img):
30
+ """the main process of inference"""
31
+ test_loader = loading_data(img)
32
+
33
+ model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda()
34
+ model_path = "SHHA.pth"
35
+ # load the trained model
36
+ model.load_state_dict(torch.load(model_path))
37
+ print('successfully load model from', model_path)
38
+
39
+ with torch.no_grad():
40
+ model.eval()
41
+
42
+ img = img.convert('RGB')
43
+ transform = standard_transforms.Compose([
44
+ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225]),])
46
+ img = transform(img)
47
+ img = torch.Tensor(img)
48
+
49
+ img = img.cuda()
50
+ pred_map = model(img)
51
+
52
+ pred_map = pred_map.data.cpu().numpy()
53
+ pred_cnt = np.sum(pred_map[i_img]) / 1000
54
+
55
+ den_map = np.squeeze(pred_map[i_img])
56
+ fig = plt.figure(frameon=False)
57
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
58
+ ax.set_axis_off()
59
+ fig.add_axes(ax)
60
+ ax.imshow(den_map, aspect='auto')
61
+ return (pred_cnt, fig)
62
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown("""
 
75
  image_input = gr.Image(type="pil")
76
  gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
77
  with gr.Column():
 
78
  text_output = gr.Label()
79
+ image_output = gr.Plot()
80
+
81
 
82
+ image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
83
 
84
  demo.launch()
model.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Tencent
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torchvision import models
19
+
20
+ class Conv2d(nn.Module):
21
+ def __init__(self, in_channels, out_channels, kernel_size, \
22
+ stride=1, NL='relu', same_padding=False, bn=False, dilation=1):
23
+ super(Conv2d, self).__init__()
24
+ padding = int((kernel_size - 1) // 2) if same_padding else 0
25
+ self.conv = []
26
+ if dilation==1:
27
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation)
28
+ else:
29
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=dilation, dilation=dilation)
30
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else nn.Identity()
31
+ if NL == 'relu' :
32
+ self.relu = nn.ReLU(inplace=True)
33
+ elif NL == 'prelu':
34
+ self.relu = nn.PReLU()
35
+ else:
36
+ self.relu = None
37
+
38
+ def forward(self, x):
39
+ x = self.conv(x)
40
+ if self.bn is not None:
41
+ x = self.bn(x)
42
+ if self.relu is not None:
43
+ x = self.relu(x)
44
+ return x
45
+
46
+ # the main implementation of the SASNet
47
+ class SASNet(nn.Module):
48
+ def __init__(self, pretrained=False, args=None):
49
+ super(SASNet, self).__init__()
50
+ # define the backbone network
51
+ vgg = models.vgg16_bn(pretrained=pretrained)
52
+
53
+ features = list(vgg.features.children())
54
+ # get each stage of the backbone
55
+ self.features1 = nn.Sequential(*features[0:6])
56
+ self.features2 = nn.Sequential(*features[6:13])
57
+ self.features3 = nn.Sequential(*features[13:23])
58
+ self.features4 = nn.Sequential(*features[23:33])
59
+ self.features5 = nn.Sequential(*features[33:43])
60
+ # docoder definition
61
+ self.de_pred5 = nn.Sequential(
62
+ Conv2d(512, 1024, 3, same_padding=True, NL='relu'),
63
+ Conv2d(1024, 512, 3, same_padding=True, NL='relu'),
64
+ )
65
+
66
+ self.de_pred4 = nn.Sequential(
67
+ Conv2d(512 + 512, 512, 3, same_padding=True, NL='relu'),
68
+ Conv2d(512, 256, 3, same_padding=True, NL='relu'),
69
+ )
70
+
71
+ self.de_pred3 = nn.Sequential(
72
+ Conv2d(256 + 256, 256, 3, same_padding=True, NL='relu'),
73
+ Conv2d(256, 128, 3, same_padding=True, NL='relu'),
74
+ )
75
+
76
+ self.de_pred2 = nn.Sequential(
77
+ Conv2d(128 + 128, 128, 3, same_padding=True, NL='relu'),
78
+ Conv2d(128, 64, 3, same_padding=True, NL='relu'),
79
+ )
80
+
81
+ self.de_pred1 = nn.Sequential(
82
+ Conv2d(64 + 64, 64, 3, same_padding=True, NL='relu'),
83
+ Conv2d(64, 64, 3, same_padding=True, NL='relu'),
84
+ )
85
+ # density head definition
86
+ self.density_head5 = nn.Sequential(
87
+ MultiBranchModule(512),
88
+ Conv2d(2048, 1, 1, same_padding=True)
89
+ )
90
+
91
+ self.density_head4 = nn.Sequential(
92
+ MultiBranchModule(256),
93
+ Conv2d(1024, 1, 1, same_padding=True)
94
+ )
95
+
96
+ self.density_head3 = nn.Sequential(
97
+ MultiBranchModule(128),
98
+ Conv2d(512, 1, 1, same_padding=True)
99
+ )
100
+
101
+ self.density_head2 = nn.Sequential(
102
+ MultiBranchModule(64),
103
+ Conv2d(256, 1, 1, same_padding=True)
104
+ )
105
+
106
+ self.density_head1 = nn.Sequential(
107
+ MultiBranchModule(64),
108
+ Conv2d(256, 1, 1, same_padding=True)
109
+ )
110
+ # confidence head definition
111
+ self.confidence_head5 = nn.Sequential(
112
+ Conv2d(512, 256, 1, same_padding=True, NL='relu'),
113
+ Conv2d(256, 1, 1, same_padding=True, NL=None)
114
+ )
115
+
116
+ self.confidence_head4 = nn.Sequential(
117
+ Conv2d(256, 128, 1, same_padding=True, NL='relu'),
118
+ Conv2d(128, 1, 1, same_padding=True, NL=None)
119
+ )
120
+
121
+ self.confidence_head3 = nn.Sequential(
122
+ Conv2d(128, 64, 1, same_padding=True, NL='relu'),
123
+ Conv2d(64, 1, 1, same_padding=True, NL=None)
124
+ )
125
+
126
+ self.confidence_head2 = nn.Sequential(
127
+ Conv2d(64, 32, 1, same_padding=True, NL='relu'),
128
+ Conv2d(32, 1, 1, same_padding=True, NL=None)
129
+ )
130
+
131
+ self.confidence_head1 = nn.Sequential(
132
+ Conv2d(64, 32, 1, same_padding=True, NL='relu'),
133
+ Conv2d(32, 1, 1, same_padding=True, NL=None)
134
+ )
135
+
136
+ self.block_size = args.block_size
137
+ # the forward process
138
+ def forward(self, x):
139
+ size = x.size()
140
+ x1 = self.features1(x)
141
+ x2 = self.features2(x1)
142
+ x3 = self.features3(x2)
143
+ x4 = self.features4(x3)
144
+ x5 = self.features5(x4)
145
+ # begining of decoding
146
+ x = self.de_pred5(x5)
147
+ x5_out = x
148
+ x = F.upsample_bilinear(x, size=x4.size()[2:])
149
+
150
+ x = torch.cat([x4, x], 1)
151
+ x = self.de_pred4(x)
152
+ x4_out = x
153
+ x = F.upsample_bilinear(x, size=x3.size()[2:])
154
+
155
+ x = torch.cat([x3, x], 1)
156
+ x = self.de_pred3(x)
157
+ x3_out = x
158
+ x = F.upsample_bilinear(x, size=x2.size()[2:])
159
+
160
+ x = torch.cat([x2, x], 1)
161
+ x = self.de_pred2(x)
162
+ x2_out = x
163
+ x = F.upsample_bilinear(x, size=x1.size()[2:])
164
+
165
+ x = torch.cat([x1, x], 1)
166
+ x = self.de_pred1(x)
167
+ x1_out = x
168
+ # density prediction
169
+ x5_density = self.density_head5(x5_out)
170
+ x4_density = self.density_head4(x4_out)
171
+ x3_density = self.density_head3(x3_out)
172
+ x2_density = self.density_head2(x2_out)
173
+ x1_density = self.density_head1(x1_out)
174
+ # get patch features for confidence prediction
175
+ x5_confi = F.adaptive_avg_pool2d(x5_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
176
+ x4_confi = F.adaptive_avg_pool2d(x4_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
177
+ x3_confi = F.adaptive_avg_pool2d(x3_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
178
+ x2_confi = F.adaptive_avg_pool2d(x2_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
179
+ x1_confi = F.adaptive_avg_pool2d(x1_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
180
+ # confidence prediction
181
+ x5_confi = self.confidence_head5(x5_confi)
182
+ x4_confi = self.confidence_head4(x4_confi)
183
+ x3_confi = self.confidence_head3(x3_confi)
184
+ x2_confi = self.confidence_head2(x2_confi)
185
+ x1_confi = self.confidence_head1(x1_confi)
186
+ # upsample the density prediction to be the same with the input size
187
+ x5_density = F.upsample_nearest(x5_density, size=x1.size()[2:])
188
+ x4_density = F.upsample_nearest(x4_density, size=x1.size()[2:])
189
+ x3_density = F.upsample_nearest(x3_density, size=x1.size()[2:])
190
+ x2_density = F.upsample_nearest(x2_density, size=x1.size()[2:])
191
+ x1_density = F.upsample_nearest(x1_density, size=x1.size()[2:])
192
+ # upsample the confidence prediction to be the same with the input size
193
+ x5_confi_upsample = F.upsample_nearest(x5_confi, size=x1.size()[2:])
194
+ x4_confi_upsample = F.upsample_nearest(x4_confi, size=x1.size()[2:])
195
+ x3_confi_upsample = F.upsample_nearest(x3_confi, size=x1.size()[2:])
196
+ x2_confi_upsample = F.upsample_nearest(x2_confi, size=x1.size()[2:])
197
+ x1_confi_upsample = F.upsample_nearest(x1_confi, size=x1.size()[2:])
198
+
199
+ # =============================================================================================================
200
+ # soft √
201
+ confidence_map = torch.cat([x5_confi_upsample, x4_confi_upsample,
202
+ x3_confi_upsample, x2_confi_upsample, x1_confi_upsample], 1)
203
+ confidence_map = torch.nn.functional.sigmoid(confidence_map)
204
+
205
+ # use softmax to normalize
206
+ confidence_map = torch.nn.functional.softmax(confidence_map, 1)
207
+
208
+ density_map = torch.cat([x5_density, x4_density, x3_density, x2_density, x1_density], 1)
209
+ # soft selection
210
+ density_map *= confidence_map
211
+ density = torch.sum(density_map, 1, keepdim=True)
212
+
213
+ return density
214
+
215
+ # the module definition for the multi-branch in the density head
216
+ class MultiBranchModule(nn.Module):
217
+ def __init__(self, in_channels, sync=False):
218
+ super(MultiBranchModule, self).__init__()
219
+ self.branch1x1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
220
+ self.branch1x1_1 = BasicConv2d(in_channels//2, in_channels, kernel_size=1, sync=sync)
221
+
222
+ self.branch3x3_1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
223
+ self.branch3x3_2 = BasicConv2d(in_channels // 2, in_channels, kernel_size=(3, 3), padding=(1, 1), sync=sync)
224
+
225
+ self.branch3x3dbl_1 = BasicConv2d(in_channels, in_channels//2, kernel_size=1, sync=sync)
226
+ self.branch3x3dbl_2 = BasicConv2d(in_channels // 2, in_channels, kernel_size=5, padding=2, sync=sync)
227
+
228
+ def forward(self, x):
229
+ branch1x1 = self.branch1x1(x)
230
+ branch1x1 = self.branch1x1_1(branch1x1)
231
+
232
+ branch3x3 = self.branch3x3_1(x)
233
+ branch3x3 = self.branch3x3_2(branch3x3)
234
+
235
+ branch3x3dbl = self.branch3x3dbl_1(x)
236
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
237
+
238
+ outputs = [branch1x1, branch3x3, branch3x3dbl, x]
239
+ return torch.cat(outputs, 1)
240
+
241
+ # the module definition for the basic conv module
242
+ class BasicConv2d(nn.Module):
243
+
244
+ def __init__(self, in_channels, out_channels, sync=False, **kwargs):
245
+ super(BasicConv2d, self).__init__()
246
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
247
+ if sync:
248
+ # for sync bn
249
+ print('use sync inception')
250
+ self.bn = nn.SyncBatchNorm(out_channels, eps=0.001)
251
+ else:
252
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
253
+
254
+ def forward(self, x):
255
+ x = self.conv(x)
256
+ x = self.bn(x)
257
+ return F.relu(x, inplace=True)