fix model
Browse files- app.py +16 -5
- environment.yaml +0 -161
app.py
CHANGED
@@ -22,6 +22,7 @@ import gradio as gr
|
|
22 |
import torchvision.transforms as standard_transforms
|
23 |
from torch.utils.data import DataLoader
|
24 |
from torch.utils.data import Dataset
|
|
|
25 |
|
26 |
warnings.filterwarnings('ignore')
|
27 |
|
@@ -96,14 +97,23 @@ with gr.Blocks() as demo:
|
|
96 |
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021).
|
97 |
|
98 |
## Abstract
|
99 |
-
In this paper, we address the large scale variation problem in crowd counting by taking full advantage of the multi-scale feature representations in a multi-level network. We
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
## References
|
104 |
-
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting.
|
|
|
105 |
""")
|
106 |
-
image_button = gr.Button("Count the Crowd!")
|
107 |
with gr.Row():
|
108 |
with gr.Column():
|
109 |
image_input = gr.Image(type="pil")
|
@@ -112,6 +122,7 @@ The code will be available at: https://github.com/TencentYoutuResearch/CrowdCoun
|
|
112 |
image_output = gr.Plot()
|
113 |
with gr.Column():
|
114 |
text_output = gr.Label()
|
|
|
115 |
|
116 |
|
117 |
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
|
|
|
22 |
import torchvision.transforms as standard_transforms
|
23 |
from torch.utils.data import DataLoader
|
24 |
from torch.utils.data import Dataset
|
25 |
+
from model import SASNet
|
26 |
|
27 |
warnings.filterwarnings('ignore')
|
28 |
|
|
|
97 |
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021).
|
98 |
|
99 |
## Abstract
|
100 |
+
In this paper, we address the large scale variation problem in crowd counting by taking full advantage of the multi-scale feature representations in a multi-level network. We
|
101 |
+
implement such an idea by keeping the counting error of a patch as small as possible with a proper feature level selection strategy, since a specific feature level tends to perform
|
102 |
+
better for a certain range of scales. However, without scale annotations, it is sub-optimal and error-prone to manually assign the predictions for heads of different scales to
|
103 |
+
specific feature levels. Therefore, we propose a Scale-Adaptive Selection Network (SASNet), which automatically learns the internal correspondence between the scales and the feature
|
104 |
+
levels. Instead of directly using the predictions from the most appropriate feature level as the final estimation, our SASNet also considers the predictions from other feature
|
105 |
+
levels via weighted average, which helps to mitigate the gap between discrete feature levels and continuous scale variation. Since the heads in a local patch share roughly a same
|
106 |
+
scale, we conduct the adaptive selection strategy in a patch-wise style. However, pixels within a patch contribute different counting errors due to the various difficulty degrees of
|
107 |
+
learning. Thus, we further propose a Pyramid Region Awareness Loss (PRA Loss) to recursively select the most hard sub-regions within a patch until reaching the pixel level. With
|
108 |
+
awareness of whether the parent patch is over-estimated or under-estimated, the fine-grained optimization with the PRA Loss for these region-aware hard pixels helps to alleviate the
|
109 |
+
inconsistency problem between training target and evaluation metric. The state-of-the-art results on four datasets demonstrate the superiority of our approach.
|
110 |
+
|
111 |
+
The code will be available at: https://github.com/TencentYoutuResearch/CrowdCounting-SASNet.
|
112 |
|
113 |
## References
|
114 |
+
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting.
|
115 |
+
The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21).
|
116 |
""")
|
|
|
117 |
with gr.Row():
|
118 |
with gr.Column():
|
119 |
image_input = gr.Image(type="pil")
|
|
|
122 |
image_output = gr.Plot()
|
123 |
with gr.Column():
|
124 |
text_output = gr.Label()
|
125 |
+
image_button = gr.Button("Count the Crowd!")
|
126 |
|
127 |
|
128 |
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
|
environment.yaml
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
name: SASNet
|
2 |
-
channels:
|
3 |
-
- pytorch
|
4 |
-
- nvidia
|
5 |
-
- anaconda
|
6 |
-
- defaults
|
7 |
-
dependencies:
|
8 |
-
- _libgcc_mutex=0.1=main
|
9 |
-
- _openmp_mutex=5.1=1_gnu
|
10 |
-
- _pytorch_select=0.1=cpu_0
|
11 |
-
- backcall=0.2.0=pyhd3eb1b0_0
|
12 |
-
- blas=1.0=mkl
|
13 |
-
- ca-certificates=2022.07.19=h06a4308_0
|
14 |
-
- certifi=2022.6.15=py37h06a4308_0
|
15 |
-
- cffi=1.15.0=py37h7f8727e_0
|
16 |
-
- cuda=12.0.0=0
|
17 |
-
- cuda-cccl=12.0.90=0
|
18 |
-
- cuda-command-line-tools=12.0.0=0
|
19 |
-
- cuda-compiler=12.0.0=0
|
20 |
-
- cuda-cudart=12.0.107=0
|
21 |
-
- cuda-cudart-dev=12.0.107=0
|
22 |
-
- cuda-cudart-static=12.0.107=0
|
23 |
-
- cuda-cuobjdump=12.0.76=0
|
24 |
-
- cuda-cupti=12.0.90=0
|
25 |
-
- cuda-cupti-static=12.0.90=0
|
26 |
-
- cuda-cuxxfilt=12.0.76=0
|
27 |
-
- cuda-demo-suite=12.0.76=0
|
28 |
-
- cuda-documentation=12.0.76=0
|
29 |
-
- cuda-driver-dev=12.0.107=0
|
30 |
-
- cuda-gdb=12.0.90=0
|
31 |
-
- cuda-libraries=12.0.0=0
|
32 |
-
- cuda-libraries-dev=12.0.0=0
|
33 |
-
- cuda-libraries-static=12.0.0=0
|
34 |
-
- cuda-nsight=12.0.78=0
|
35 |
-
- cuda-nsight-compute=12.0.0=0
|
36 |
-
- cuda-nvcc=12.0.76=0
|
37 |
-
- cuda-nvdisasm=12.0.76=0
|
38 |
-
- cuda-nvml-dev=12.0.76=0
|
39 |
-
- cuda-nvprof=12.0.90=0
|
40 |
-
- cuda-nvprune=12.0.76=0
|
41 |
-
- cuda-nvrtc=12.0.76=0
|
42 |
-
- cuda-nvrtc-dev=12.0.76=0
|
43 |
-
- cuda-nvrtc-static=12.0.76=0
|
44 |
-
- cuda-nvtx=12.0.76=0
|
45 |
-
- cuda-nvvp=12.0.90=0
|
46 |
-
- cuda-opencl=12.0.76=0
|
47 |
-
- cuda-opencl-dev=12.0.76=0
|
48 |
-
- cuda-profiler-api=12.0.76=0
|
49 |
-
- cuda-runtime=12.0.0=0
|
50 |
-
- cuda-sanitizer-api=12.0.90=0
|
51 |
-
- cuda-toolkit=12.0.0=0
|
52 |
-
- cuda-tools=12.0.0=0
|
53 |
-
- cuda-visual-tools=12.0.0=0
|
54 |
-
- cudatoolkit=10.2.89=hfd86e86_1
|
55 |
-
- debugpy=1.5.1=py37h295c915_0
|
56 |
-
- decorator=5.1.1=pyhd3eb1b0_0
|
57 |
-
- entrypoints=0.4=py37h06a4308_0
|
58 |
-
- freetype=2.12.1=h4a9f257_0
|
59 |
-
- gds-tools=1.5.0.59=0
|
60 |
-
- giflib=5.2.1=h7b6447c_0
|
61 |
-
- intel-openmp=2022.1.0=h9e868ea_3769
|
62 |
-
- ipykernel=6.9.1=py37h06a4308_0
|
63 |
-
- ipython=7.31.1=py37h06a4308_1
|
64 |
-
- jedi=0.18.1=py37h06a4308_1
|
65 |
-
- jpeg=9e=h7f8727e_0
|
66 |
-
- jupyter_client=7.2.2=py37h06a4308_0
|
67 |
-
- jupyter_core=4.10.0=py37h06a4308_0
|
68 |
-
- lcms2=2.12=h3be6417_0
|
69 |
-
- lerc=3.0=h295c915_0
|
70 |
-
- libcublas=12.0.1.189=0
|
71 |
-
- libcublas-dev=12.0.1.189=0
|
72 |
-
- libcublas-static=12.0.1.189=0
|
73 |
-
- libcufft=11.0.0.21=0
|
74 |
-
- libcufft-dev=11.0.0.21=0
|
75 |
-
- libcufft-static=11.0.0.21=0
|
76 |
-
- libcufile=1.5.0.59=0
|
77 |
-
- libcufile-dev=1.5.0.59=0
|
78 |
-
- libcufile-static=1.5.0.59=0
|
79 |
-
- libcurand=10.3.1.50=0
|
80 |
-
- libcurand-dev=10.3.1.50=0
|
81 |
-
- libcurand-static=10.3.1.50=0
|
82 |
-
- libcusolver=11.4.2.57=0
|
83 |
-
- libcusolver-dev=11.4.2.57=0
|
84 |
-
- libcusolver-static=11.4.2.57=0
|
85 |
-
- libcusparse=12.0.0.76=0
|
86 |
-
- libcusparse-dev=12.0.0.76=0
|
87 |
-
- libcusparse-static=12.0.0.76=0
|
88 |
-
- libdeflate=1.8=h7f8727e_5
|
89 |
-
- libedit=3.1.20221030=h5eee18b_0
|
90 |
-
- libffi=3.2.1=hf484d3e_1007
|
91 |
-
- libgcc-ng=11.2.0=h1234567_1
|
92 |
-
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
93 |
-
- libgfortran4=7.5.0=ha8ba4b0_17
|
94 |
-
- libgomp=11.2.0=h1234567_1
|
95 |
-
- libnpp=12.0.0.30=0
|
96 |
-
- libnpp-dev=12.0.0.30=0
|
97 |
-
- libnpp-static=12.0.0.30=0
|
98 |
-
- libnvjitlink=12.0.76=0
|
99 |
-
- libnvjitlink-dev=12.0.76=0
|
100 |
-
- libnvjpeg=12.0.0.28=0
|
101 |
-
- libnvjpeg-dev=12.0.0.28=0
|
102 |
-
- libnvjpeg-static=12.0.0.28=0
|
103 |
-
- libnvvm-samples=12.0.94=0
|
104 |
-
- libpng=1.6.37=hbc83047_0
|
105 |
-
- libsodium=1.0.18=h7b6447c_0
|
106 |
-
- libstdcxx-ng=11.2.0=h1234567_1
|
107 |
-
- libtiff=4.5.0=hecacb30_0
|
108 |
-
- libwebp=1.2.4=h11a3e52_0
|
109 |
-
- libwebp-base=1.2.4=h5eee18b_0
|
110 |
-
- lz4-c=1.9.4=h6a678d5_0
|
111 |
-
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
112 |
-
- mkl=2019.4=243
|
113 |
-
- mkl-service=2.3.0=py37he8ac12f_0
|
114 |
-
- mkl_fft=1.3.0=py37h54f3939_0
|
115 |
-
- mkl_random=1.1.0=py37hd6b4f25_0
|
116 |
-
- ncurses=6.3=h5eee18b_3
|
117 |
-
- nest-asyncio=1.5.5=py37h06a4308_0
|
118 |
-
- ninja=1.10.2=h06a4308_5
|
119 |
-
- ninja-base=1.10.2=hd09550d_5
|
120 |
-
- nsight-compute=2022.4.0.15=0
|
121 |
-
- numpy-base=1.17.0=py37hde5b4d6_0
|
122 |
-
- openssl=1.0.2u=h7b6447c_0
|
123 |
-
- parso=0.8.3=pyhd3eb1b0_0
|
124 |
-
- pexpect=4.8.0=pyhd3eb1b0_3
|
125 |
-
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
126 |
-
- pip=22.3.1=py37h06a4308_0
|
127 |
-
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
128 |
-
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
129 |
-
- pycparser=2.21=pyhd3eb1b0_0
|
130 |
-
- pygments=2.11.2=pyhd3eb1b0_0
|
131 |
-
- python=3.7.0=h6e4f718_3
|
132 |
-
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
133 |
-
- pytorch=1.5.0=py3.7_cuda10.2.89_cudnn7.6.5_0
|
134 |
-
- pyzmq=23.2.0=py37h6a678d5_0
|
135 |
-
- readline=7.0=h7b6447c_5
|
136 |
-
- setuptools=65.6.3=py37h06a4308_0
|
137 |
-
- six=1.16.0=pyhd3eb1b0_1
|
138 |
-
- sqlite=3.33.0=h62c20be_0
|
139 |
-
- tk=8.6.12=h1ccaba5_0
|
140 |
-
- torchvision=0.6.0=py37_cu102
|
141 |
-
- tornado=6.1=py37h27cfd23_0
|
142 |
-
- traitlets=5.1.1=pyhd3eb1b0_0
|
143 |
-
- wcwidth=0.2.5=pyhd3eb1b0_0
|
144 |
-
- wheel=0.37.1=pyhd3eb1b0_0
|
145 |
-
- xz=5.2.8=h5eee18b_0
|
146 |
-
- zeromq=4.3.4=h2531618_0
|
147 |
-
- zlib=1.2.13=h5eee18b_0
|
148 |
-
- zstd=1.5.2=ha4553b6_0
|
149 |
-
- pip:
|
150 |
-
- cached-property==1.5.2
|
151 |
-
- cycler==0.11.0
|
152 |
-
- h5py==3.1.0
|
153 |
-
- kiwisolver==1.4.4
|
154 |
-
- matplotlib==3.3.3
|
155 |
-
- numpy==1.19.0
|
156 |
-
- opencv-python==4.4.0.46
|
157 |
-
- pillow==8.0.1
|
158 |
-
- pyparsing==3.0.9
|
159 |
-
- scipy==1.5.4
|
160 |
-
- typing-extensions==4.4.0
|
161 |
-
prefix: /home/leuschnm/miniconda3/envs/SASNet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|