Maksym-Lysyi commited on
Commit
e3641b1
·
1 Parent(s): 3964794

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +8 -0
  2. .gitignore +5 -0
  3. Dockerfile +19 -0
  4. app.py +117 -0
  5. config.py +373 -0
  6. easy_ViTPose/__init__.py +5 -0
  7. easy_ViTPose/config.yaml +14 -0
  8. easy_ViTPose/configs/ViTPose_aic.py +20 -0
  9. easy_ViTPose/configs/ViTPose_ap10k.py +22 -0
  10. easy_ViTPose/configs/ViTPose_apt36k.py +22 -0
  11. easy_ViTPose/configs/ViTPose_coco.py +18 -0
  12. easy_ViTPose/configs/ViTPose_coco_25.py +20 -0
  13. easy_ViTPose/configs/ViTPose_common.py +195 -0
  14. easy_ViTPose/configs/ViTPose_mpii.py +18 -0
  15. easy_ViTPose/configs/ViTPose_wholebody.py +20 -0
  16. easy_ViTPose/configs/__init__.py +0 -0
  17. easy_ViTPose/datasets/COCO.py +556 -0
  18. easy_ViTPose/datasets/HumanPoseEstimation.py +17 -0
  19. easy_ViTPose/datasets/__init__.py +0 -0
  20. easy_ViTPose/easy_ViTPose.egg-info/PKG-INFO +4 -0
  21. easy_ViTPose/easy_ViTPose.egg-info/SOURCES.txt +35 -0
  22. easy_ViTPose/easy_ViTPose.egg-info/dependency_links.txt +1 -0
  23. easy_ViTPose/easy_ViTPose.egg-info/top_level.txt +2 -0
  24. easy_ViTPose/inference.py +334 -0
  25. easy_ViTPose/sort.py +266 -0
  26. easy_ViTPose/to_onnx.ipynb +0 -0
  27. easy_ViTPose/to_trt.ipynb +0 -0
  28. easy_ViTPose/train.py +174 -0
  29. easy_ViTPose/vit_models/__init__.py +8 -0
  30. easy_ViTPose/vit_models/backbone/__init__.py +0 -0
  31. easy_ViTPose/vit_models/backbone/vit.py +394 -0
  32. easy_ViTPose/vit_models/head/__init__.py +0 -0
  33. easy_ViTPose/vit_models/head/topdown_heatmap_base_head.py +120 -0
  34. easy_ViTPose/vit_models/head/topdown_heatmap_simple_head.py +334 -0
  35. easy_ViTPose/vit_models/losses/__init__.py +16 -0
  36. easy_ViTPose/vit_models/losses/classfication_loss.py +41 -0
  37. easy_ViTPose/vit_models/losses/heatmap_loss.py +83 -0
  38. easy_ViTPose/vit_models/losses/mesh_loss.py +402 -0
  39. easy_ViTPose/vit_models/losses/mse_loss.py +151 -0
  40. easy_ViTPose/vit_models/losses/multi_loss_factory.py +279 -0
  41. easy_ViTPose/vit_models/losses/regression_loss.py +444 -0
  42. easy_ViTPose/vit_models/model.py +24 -0
  43. easy_ViTPose/vit_models/optimizer.py +15 -0
  44. easy_ViTPose/vit_utils/__init__.py +6 -0
  45. easy_ViTPose/vit_utils/dist_util.py +212 -0
  46. easy_ViTPose/vit_utils/inference.py +93 -0
  47. easy_ViTPose/vit_utils/logging.py +133 -0
  48. easy_ViTPose/vit_utils/nms/__init__.py +0 -0
  49. easy_ViTPose/vit_utils/nms/cpu_nms.c +0 -0
  50. easy_ViTPose/vit_utils/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so +0 -0
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ pose_env_1
3
+ testing
4
+ vit_env
5
+ vit_test
6
+ test_vit_model.ipynb
7
+ models
8
+ models_2
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ pose_env_1
3
+ testing
4
+ vit_env
5
+ test_vit_model.ipynb
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+
7
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
8
+ RUN pip install --upgrade pip
9
+
10
+ # --no-cache-dir
11
+ RUN pip install -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ EXPOSE 7860
16
+
17
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
18
+ ENV USE_NNPACK=0
19
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from main_func import video_identity
3
+
4
+ with gr.Blocks() as demo:
5
+
6
+ with gr.Row(variant='compact'):
7
+
8
+ with gr.Column():
9
+ gr.Markdown("#### Dynamic Time Warping:")
10
+
11
+ with gr.Row(variant='compact'):
12
+ dtw_mean = gr.Slider(
13
+ value=0.5,
14
+ minimum=0,
15
+ maximum=1.0,
16
+ step=0.05,
17
+ label="Winsorize Mean"
18
+ )
19
+
20
+ dtw_filter = gr.Slider(
21
+ value=3,
22
+ minimum=1,
23
+ maximum=20,
24
+ step=1,
25
+ label="Savitzky-Golay Filter"
26
+ )
27
+
28
+ gr.Markdown("#### Thresholds:")
29
+
30
+ with gr.Row(variant='compact'):
31
+ angles_sensitive = gr.Number(
32
+ value=15,
33
+ minimum=0,
34
+ maximum=75,
35
+ step=1,
36
+ min_width=100,
37
+ label="Sensitive"
38
+ )
39
+
40
+ angles_common = gr.Number(
41
+ value=25,
42
+ minimum=0,
43
+ maximum=75,
44
+ step=1,
45
+ min_width=100,
46
+ label="Standart"
47
+ )
48
+
49
+ angles_insensitive = gr.Number(
50
+ value=45,
51
+ minimum=0,
52
+ maximum=75,
53
+ step=1,
54
+ min_width=100,
55
+ label="Insensitive"
56
+ )
57
+
58
+ gr.Markdown("#### Patience:")
59
+
60
+ trigger_state = gr.Radio(value="three", choices=["three", "two"], label="Trigger Count")
61
+
62
+ input_teacher = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Teacher's Video")
63
+ input_student = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Student's Video")
64
+
65
+
66
+ with gr.Accordion("Clarifications:", open=True):
67
+ with gr.Accordion("Dynamic Time Warping:", open=False):
68
+ gr.Markdown("""
69
+ Dynamic Time Warping is an algorithm that performs frame-by-frame alignment for videos with different speeds.
70
+
71
+ - **Winsorized mean**: Determines the portion of DTW paths, sorted from best to worst, to use for generating the mean DTW alignment. Reasonable values range from 0.25 to 0.6.
72
+ - **Savitzky-Golay Filter**: Enhances the capabilities of the Winsorized mean, making DTW alignment more similar to a strict line. Reasonable values range from 2 to 10.
73
+ """)
74
+
75
+ with gr.Accordion("Thresholds:", open=False):
76
+ gr.Markdown("""
77
+ Thresholds are used to identify student errors in dance. If the difference in angle between the teacher's and student's videos exceeds this threshold, it is counted as an error.
78
+
79
+ - **Sensitive**: A threshold that is currently not used.
80
+ - **Standard**: A threshold for most angles. Reasonable values range from 20 to 40.
81
+ - **Insensitive**: A threshold for difficult areas, such as hands and toes. Reasonable values range from 35 to 55.
82
+ """)
83
+
84
+ with gr.Accordion("Patience:", open=False):
85
+ gr.Markdown("""
86
+ Patience helps prevent model errors by highlighting only errors detected in consecutive frames.
87
+
88
+ - **Three**: Utilizes 3 consecutive frames for error detection.
89
+ - **Two**: Utilizes 2 consecutive frames for error detection.
90
+
91
+ Both options can be used interchangeably.
92
+ """)
93
+
94
+
95
+
96
+ with gr.Row():
97
+ gr_button = gr.Button("Run Pose Comparison")
98
+
99
+ with gr.Row():
100
+ gr.HTML("<div style='height: 100px;'></div>")
101
+
102
+
103
+ with gr.Row():
104
+ output_merged = gr.Video(show_download_button=True)
105
+
106
+ with gr.Row():
107
+ general_log = gr.TextArea(lines=10, max_lines=9999, label="Error log")
108
+
109
+ gr_button.click(
110
+ fn=video_identity,
111
+ inputs=[dtw_mean, dtw_filter, angles_sensitive, angles_common, angles_insensitive, trigger_state, input_teacher, input_student],
112
+ outputs=[output_merged, general_log]
113
+ )
114
+
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch()
config.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONNECTIONS_VIT_FULL = [
2
+ # head
3
+ (0, 2),
4
+ (0, 1),
5
+ (2, 4),
6
+ (1, 3),
7
+ (0, 6),
8
+ (0, 5),
9
+
10
+ # right arm
11
+ (6, 8),
12
+ (8, 10),
13
+
14
+ # right hand
15
+ (10, 112),
16
+
17
+ # Big toe 1
18
+ (112, 113),
19
+ (113, 114),
20
+ (114, 115),
21
+ (115, 116),
22
+
23
+ # toe 2
24
+ (112, 117),
25
+ (117, 118),
26
+ (118, 119),
27
+ (119, 120),
28
+
29
+ # toe 3
30
+ (112, 121),
31
+ (121, 122),
32
+ (122, 123),
33
+ (123, 124),
34
+
35
+ # toe 4
36
+ (112, 125),
37
+ (125, 126),
38
+ (126, 127),
39
+ (127, 128),
40
+
41
+ # toe 5
42
+ (112, 129),
43
+ (129, 130),
44
+ (130, 131),
45
+ (131, 132),
46
+
47
+
48
+
49
+ # left arm
50
+ (5, 7),
51
+ (7, 9),
52
+
53
+ # left hand
54
+ (9, 91),
55
+
56
+
57
+ # Big toe 1
58
+ (91, 92),
59
+ (92, 93),
60
+ (93, 94),
61
+ (94, 95),
62
+
63
+ # toe 2
64
+ (91, 96),
65
+ (96, 97),
66
+ (97, 98),
67
+ (98, 99),
68
+
69
+ # toe 3
70
+ (91, 100),
71
+ (100, 101),
72
+ (101, 102),
73
+ (102, 103),
74
+
75
+ # toe 4
76
+ (91, 104),
77
+ (104, 105),
78
+ (105, 106),
79
+ (106, 107),
80
+
81
+ # toe 5
82
+ (91, 108),
83
+ (108, 109),
84
+ (109, 110),
85
+ (110, 111),
86
+
87
+
88
+
89
+ # torso
90
+ (6, 5),
91
+ (12, 11),
92
+ (6, 12),
93
+ (5, 11),
94
+
95
+ # right leg
96
+ (12, 14),
97
+ (14, 16),
98
+
99
+ # right foot
100
+ (16, 22),
101
+ (22, 21),
102
+ (22, 20),
103
+
104
+
105
+ # left leg
106
+ (11, 13),
107
+ (13, 15),
108
+
109
+ # left foot
110
+ (15, 19),
111
+ (19, 18),
112
+ (19, 17),
113
+ ]
114
+
115
+ EDGE_GROUPS_FOR_ERRORS = [
116
+ [0, 2, 4],
117
+ [0, 1, 3],
118
+
119
+ # neck
120
+ [6, 0, 2],
121
+ [5, 0, 1],
122
+
123
+ # right arm
124
+
125
+ # right shoulder
126
+ [5, 6, 8],
127
+
128
+ # right elbow
129
+ [6, 8, 10],
130
+
131
+ # right hand
132
+ [8, 10, 121],
133
+
134
+ [112, 114, 116],
135
+ [112, 117, 120],
136
+ [112, 121, 124],
137
+ [112, 125, 128],
138
+ [112, 129, 132],
139
+
140
+ # left arm
141
+
142
+ # left shoulder
143
+ [6, 5, 7],
144
+
145
+ # left elbow
146
+ [5, 7, 9],
147
+
148
+ # left hand
149
+ [7, 9, 100],
150
+
151
+ [91, 93, 95],
152
+ [91, 96, 99],
153
+ [91, 100, 103],
154
+ [91, 104, 107],
155
+ [91, 108, 111],
156
+
157
+
158
+ # right leg
159
+
160
+ # right upper-leg
161
+ [6, 12, 14],
162
+
163
+ # right middle-leg
164
+ [12, 14, 16],
165
+
166
+ # right lower-leg
167
+ [14, 16, 22],
168
+ [16, 22, 21],
169
+ [16, 22, 20],
170
+
171
+ # left leg
172
+
173
+ # left upper-leg
174
+ [5, 11, 13],
175
+
176
+ # left middle-leg
177
+ [11, 13, 15],
178
+
179
+ # left lower-leg
180
+ [13, 15, 19],
181
+ [15, 19, 17],
182
+ [15, 19, 18],
183
+
184
+ ]
185
+
186
+
187
+
188
+ CONNECTIONS_FOR_ERROR = [
189
+ # head
190
+ (0, 2),
191
+ (2, 4),
192
+ (0, 1),
193
+ (1, 3),
194
+
195
+ # right arm
196
+ (6, 0),
197
+ (8, 6),
198
+ (10, 8),
199
+
200
+ # right hand
201
+ # (121, 10),
202
+
203
+ (112, 114),
204
+ (114, 116),
205
+
206
+ (112, 117),
207
+ (117, 120),
208
+
209
+ (112, 121),
210
+ (121, 124),
211
+
212
+ (112, 125),
213
+ (125, 128),
214
+
215
+ (112, 129),
216
+ (129, 132),
217
+
218
+ # left arm
219
+ (5, 0),
220
+ (7, 5),
221
+ (9, 7),
222
+
223
+ # left hand
224
+ # (100, 9),
225
+
226
+ (91, 93),
227
+ (93, 95),
228
+
229
+ (91, 96),
230
+ (96, 99),
231
+
232
+ (91, 100),
233
+ (100, 103),
234
+
235
+ (91, 104),
236
+ (104, 107),
237
+
238
+ (91, 108),
239
+ (108, 111),
240
+
241
+ # torso
242
+ (6, 12),
243
+ (5, 11),
244
+
245
+ # right leg
246
+ (12, 14),
247
+ (14, 16),
248
+
249
+ (16, 22),
250
+ (22, 21),
251
+ (22, 20),
252
+
253
+ # left leg
254
+ (11, 13),
255
+ (13, 15),
256
+
257
+ (15, 19),
258
+ (19, 17),
259
+ (19, 18),
260
+
261
+ ]
262
+
263
+ def get_thresholds(sensetive_error, general_error, unsensetive_error):
264
+ thresholds = [
265
+ general_error,
266
+ general_error,
267
+ general_error,
268
+ general_error,
269
+
270
+ general_error,
271
+ general_error,
272
+
273
+ unsensetive_error,
274
+ unsensetive_error,
275
+ unsensetive_error,
276
+ unsensetive_error,
277
+ unsensetive_error,
278
+ unsensetive_error,
279
+
280
+ general_error,
281
+ general_error,
282
+ unsensetive_error,
283
+ unsensetive_error,
284
+ unsensetive_error,
285
+ unsensetive_error,
286
+ unsensetive_error,
287
+ unsensetive_error,
288
+
289
+ general_error,
290
+ general_error,
291
+ unsensetive_error,
292
+ unsensetive_error,
293
+ unsensetive_error,
294
+
295
+ general_error,
296
+ general_error,
297
+ unsensetive_error,
298
+ unsensetive_error,
299
+ unsensetive_error,
300
+ ]
301
+
302
+ return thresholds
303
+
304
+
305
+ EDGE_GROUPS_FOR_SUMMARY = {
306
+ (2, 4): "Head position is incorrect",
307
+ (1, 3): "Head position is incorrect",
308
+
309
+ # neck
310
+
311
+ (0, 2): "Head position is incorrect",
312
+ (0, 1): "Head position is incorrect",
313
+
314
+ # right arm
315
+
316
+ # right shoulder
317
+ (6, 8): "Right shoulder position is incorrect",
318
+
319
+ # right elbow
320
+ (8, 10): "Right elbow position is incorrect",
321
+
322
+ # right hand
323
+ (10, 121): "Right hand's palm position is incorrect",
324
+
325
+ (114, 116): "Right thumb finger position is incorrect",
326
+ (117, 120): "Right index finger position is incorrect",
327
+ (121, 124): "Right middle finger position is incorrect",
328
+ (125, 128): "Right ring finger position is incorrect",
329
+ (129, 132): "Right pinky finger position is incorrect",
330
+
331
+ # left arm
332
+
333
+ # left shoulder
334
+ (5, 7): "Left shoulder position is incorrect",
335
+
336
+ # left elbow
337
+ (7, 9): "Left elbow position is incorrect",
338
+
339
+ # left hand
340
+ (9, 100): "Left hand palm position is incorrect",
341
+
342
+ (93, 95): "Left thumb finger position is incorrect",
343
+ (96, 99): "Left index finger position is incorrect",
344
+ (100, 103): "Left middle finger position is incorrect",
345
+ (104, 107): "Left ring finger position is incorrect",
346
+ (108, 111): "Left pinky finger position is incorrect",
347
+
348
+ # right leg
349
+
350
+ # right upper-leg
351
+ (12, 14): "Right thigh position is incorrect",
352
+
353
+ # right middle-leg
354
+ (14, 16): "Right shin position is incorrect",
355
+
356
+ # right lower-leg
357
+ (16, 22): "Right foot position is incorrect",
358
+ (22, 21): "Right shin position is incorrect",
359
+ (22, 20): "Right shin position is incorrect",
360
+
361
+ # left leg
362
+
363
+ # left upper-leg
364
+ (11, 13): "Left thigh position is incorrect",
365
+
366
+ # left middle-leg
367
+ (13, 15): "Left shin position is incorrect",
368
+
369
+ # left lower-leg
370
+ (15, 19): "Left foot position is incorrect",
371
+ (19, 17): "Left shin position is incorrect",
372
+ (19, 18): "Left shin position is incorrect"
373
+ }
easy_ViTPose/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .inference import VitInference
2
+
3
+ __all__ = [
4
+ 'VitInference'
5
+ ]
easy_ViTPose/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train config ---------------------------------------
2
+ log_level: logging.INFO
3
+ seed: 0
4
+ deterministic: True
5
+ cudnn_benchmark: True # Use cudnn
6
+ resume_from: "ckpts/og-vitpose-s.pth" # CKPT path
7
+ # resume_from: False
8
+ gpu_ids: [0]
9
+ launcher: 'none' # When distributed training ['none', 'pytorch', 'slurm', 'mpi']
10
+ use_amp: True
11
+ validate: True
12
+ autoscale_lr: False
13
+ dist_params:
14
+ ...
easy_ViTPose/configs/ViTPose_aic.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=14,
6
+ dataset_joints=14,
7
+ dataset_channel=[
8
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
9
+ ],
10
+ inference_channel=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
11
+
12
+ # Set models channels
13
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
14
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
15
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
16
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
17
+
18
+ names = ['small', 'base', 'large', 'huge']
19
+ for name in names:
20
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_ap10k.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=17,
6
+ dataset_joints=17,
7
+ dataset_channel=[
8
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
9
+ ],
10
+ inference_channel=[
11
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
12
+ ])
13
+
14
+ # Set models channels
15
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
16
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
17
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
18
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
19
+
20
+ names = ['small', 'base', 'large', 'huge']
21
+ for name in names:
22
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_apt36k.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=17,
6
+ dataset_joints=17,
7
+ dataset_channel=[
8
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
9
+ ],
10
+ inference_channel=[
11
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
12
+ ])
13
+
14
+ # Set models channels
15
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
16
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
17
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
18
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
19
+
20
+ names = ['small', 'base', 'large', 'huge']
21
+ for name in names:
22
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_coco.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=17,
6
+ dataset_joints=17,
7
+ dataset_channel=list(range(17)),
8
+ inference_channel=list(range(17)))
9
+
10
+ # Set models channels
11
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
12
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
13
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
14
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
15
+
16
+ names = ['small', 'base', 'large', 'huge']
17
+ for name in names:
18
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_coco_25.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=25,
6
+ dataset_joints=25,
7
+ dataset_channel=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
8
+ 16, 17, 18, 19, 20, 21, 22, 23, 24], ],
9
+ inference_channel=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
10
+ 16, 17, 18, 19, 20, 21, 22, 23, 24])
11
+
12
+ # Set models channels
13
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
14
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
15
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
16
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
17
+
18
+ names = ['small', 'base', 'large', 'huge']
19
+ for name in names:
20
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_common.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common configuration
2
+ optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.1,
3
+ constructor='LayerDecayOptimizerConstructor',
4
+ paramwise_cfg=dict(
5
+ num_layers=12,
6
+ layer_decay_rate=1 - 2e-4,
7
+ custom_keys={
8
+ 'bias': dict(decay_multi=0.),
9
+ 'pos_embed': dict(decay_mult=0.),
10
+ 'relative_position_bias_table': dict(decay_mult=0.),
11
+ 'norm': dict(decay_mult=0.)
12
+ }
13
+ )
14
+ )
15
+
16
+ optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
17
+
18
+ # learning policy
19
+ lr_config = dict(
20
+ policy='step',
21
+ warmup='linear',
22
+ warmup_iters=300,
23
+ warmup_ratio=0.001,
24
+ step=[3])
25
+
26
+ total_epochs = 4
27
+ target_type = 'GaussianHeatmap'
28
+
29
+ data_cfg = dict(
30
+ image_size=[192, 256],
31
+ heatmap_size=[48, 64],
32
+ soft_nms=False,
33
+ nms_thr=1.0,
34
+ oks_thr=0.9,
35
+ vis_thr=0.2,
36
+ use_gt_bbox=False,
37
+ det_bbox_thr=0.0,
38
+ bbox_file='data/coco/person_detection_results/'
39
+ 'COCO_val2017_detections_AP_H_56_person.json',
40
+ )
41
+
42
+ data_root = '/home/adryw/dataset/COCO17'
43
+ data = dict(
44
+ samples_per_gpu=64,
45
+ workers_per_gpu=6,
46
+ val_dataloader=dict(samples_per_gpu=128),
47
+ test_dataloader=dict(samples_per_gpu=128),
48
+ train=dict(
49
+ type='TopDownCocoDataset',
50
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
51
+ img_prefix=f'{data_root}/train2017/',
52
+ data_cfg=data_cfg),
53
+ val=dict(
54
+ type='TopDownCocoDataset',
55
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
56
+ img_prefix=f'{data_root}/val2017/',
57
+ data_cfg=data_cfg),
58
+ test=dict(
59
+ type='TopDownCocoDataset',
60
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
61
+ img_prefix=f'{data_root}/val2017/',
62
+ data_cfg=data_cfg)
63
+ )
64
+
65
+ model_small = dict(
66
+ type='TopDown',
67
+ pretrained=None,
68
+ backbone=dict(
69
+ type='ViT',
70
+ img_size=(256, 192),
71
+ patch_size=16,
72
+ embed_dim=384,
73
+ depth=12,
74
+ num_heads=12,
75
+ ratio=1,
76
+ use_checkpoint=False,
77
+ mlp_ratio=4,
78
+ qkv_bias=True,
79
+ drop_path_rate=0.1,
80
+ ),
81
+ keypoint_head=dict(
82
+ type='TopdownHeatmapSimpleHead',
83
+ in_channels=384,
84
+ num_deconv_layers=2,
85
+ num_deconv_filters=(256, 256),
86
+ num_deconv_kernels=(4, 4),
87
+ extra=dict(final_conv_kernel=1, ),
88
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
89
+ train_cfg=dict(),
90
+ test_cfg=dict(
91
+ flip_test=True,
92
+ post_process='default',
93
+ shift_heatmap=False,
94
+ target_type=target_type,
95
+ modulate_kernel=11,
96
+ use_udp=True))
97
+
98
+ model_base = dict(
99
+ type='TopDown',
100
+ pretrained=None,
101
+ backbone=dict(
102
+ type='ViT',
103
+ img_size=(256, 192),
104
+ patch_size=16,
105
+ embed_dim=768,
106
+ depth=12,
107
+ num_heads=12,
108
+ ratio=1,
109
+ use_checkpoint=False,
110
+ mlp_ratio=4,
111
+ qkv_bias=True,
112
+ drop_path_rate=0.3,
113
+ ),
114
+ keypoint_head=dict(
115
+ type='TopdownHeatmapSimpleHead',
116
+ in_channels=768,
117
+ num_deconv_layers=2,
118
+ num_deconv_filters=(256, 256),
119
+ num_deconv_kernels=(4, 4),
120
+ extra=dict(final_conv_kernel=1, ),
121
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
122
+ train_cfg=dict(),
123
+ test_cfg=dict(
124
+ flip_test=True,
125
+ post_process='default',
126
+ shift_heatmap=False,
127
+ target_type=target_type,
128
+ modulate_kernel=11,
129
+ use_udp=True))
130
+
131
+ model_large = dict(
132
+ type='TopDown',
133
+ pretrained=None,
134
+ backbone=dict(
135
+ type='ViT',
136
+ img_size=(256, 192),
137
+ patch_size=16,
138
+ embed_dim=1024,
139
+ depth=24,
140
+ num_heads=16,
141
+ ratio=1,
142
+ use_checkpoint=False,
143
+ mlp_ratio=4,
144
+ qkv_bias=True,
145
+ drop_path_rate=0.5,
146
+ ),
147
+ keypoint_head=dict(
148
+ type='TopdownHeatmapSimpleHead',
149
+ in_channels=1024,
150
+ num_deconv_layers=2,
151
+ num_deconv_filters=(256, 256),
152
+ num_deconv_kernels=(4, 4),
153
+ extra=dict(final_conv_kernel=1, ),
154
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
155
+ train_cfg=dict(),
156
+ test_cfg=dict(
157
+ flip_test=True,
158
+ post_process='default',
159
+ shift_heatmap=False,
160
+ target_type=target_type,
161
+ modulate_kernel=11,
162
+ use_udp=True))
163
+
164
+ model_huge = dict(
165
+ type='TopDown',
166
+ pretrained=None,
167
+ backbone=dict(
168
+ type='ViT',
169
+ img_size=(256, 192),
170
+ patch_size=16,
171
+ embed_dim=1280,
172
+ depth=32,
173
+ num_heads=16,
174
+ ratio=1,
175
+ use_checkpoint=False,
176
+ mlp_ratio=4,
177
+ qkv_bias=True,
178
+ drop_path_rate=0.55,
179
+ ),
180
+ keypoint_head=dict(
181
+ type='TopdownHeatmapSimpleHead',
182
+ in_channels=1280,
183
+ num_deconv_layers=2,
184
+ num_deconv_filters=(256, 256),
185
+ num_deconv_kernels=(4, 4),
186
+ extra=dict(final_conv_kernel=1, ),
187
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
188
+ train_cfg=dict(),
189
+ test_cfg=dict(
190
+ flip_test=True,
191
+ post_process='default',
192
+ shift_heatmap=False,
193
+ target_type=target_type,
194
+ modulate_kernel=11,
195
+ use_udp=True))
easy_ViTPose/configs/ViTPose_mpii.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=16,
6
+ dataset_joints=16,
7
+ dataset_channel=list(range(16)),
8
+ inference_channel=list(range(16)))
9
+
10
+ # Set models channels
11
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
12
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
13
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
14
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
15
+
16
+ names = ['small', 'base', 'large', 'huge']
17
+ for name in names:
18
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/ViTPose_wholebody.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ViTPose_common import *
2
+
3
+ # Channel configuration
4
+ channel_cfg = dict(
5
+ num_output_channels=133,
6
+ dataset_joints=133,
7
+ dataset_channel=[
8
+ list(range(133)),
9
+ ],
10
+ inference_channel=list(range(133)))
11
+
12
+ # Set models channels
13
+ data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
14
+ data_cfg['num_joints']= channel_cfg['dataset_joints']
15
+ data_cfg['dataset_channel']= channel_cfg['dataset_channel']
16
+ data_cfg['inference_channel']= channel_cfg['inference_channel']
17
+
18
+ names = ['small', 'base', 'large', 'huge']
19
+ for name in names:
20
+ globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
easy_ViTPose/configs/__init__.py ADDED
File without changes
easy_ViTPose/datasets/COCO.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Part of this code is derived/taken from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
2
+ import os
3
+ import sys
4
+ import pickle
5
+ import random
6
+
7
+ import cv2
8
+ import json_tricks as json
9
+ import numpy as np
10
+ from pycocotools.coco import COCO
11
+ from torchvision import transforms
12
+ import torchvision.transforms.functional as F
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+
16
+ from .HumanPoseEstimation import HumanPoseEstimationDataset as Dataset
17
+
18
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
19
+ from vit_utils.transform import fliplr_joints, affine_transform, get_affine_transform
20
+
21
+ import numpy as np
22
+
23
+
24
+ class COCODataset(Dataset):
25
+ """
26
+ COCODataset class.
27
+ """
28
+
29
+ def __init__(self, root_path="./datasets/COCO", data_version="train2017",
30
+ is_train=True, use_gt_bboxes=True, bbox_path="",
31
+ image_width=288, image_height=384,
32
+ scale=True, scale_factor=0.35, flip_prob=0.5, rotate_prob=0.5, rotation_factor=45., half_body_prob=0.3,
33
+ use_different_joints_weight=False, heatmap_sigma=3, soft_nms=False):
34
+ """
35
+ Initializes a new COCODataset object.
36
+
37
+ Image and annotation indexes are loaded and stored in memory.
38
+ Annotations are preprocessed to have a simple list of annotations to iterate over.
39
+
40
+ Bounding boxes can be loaded from the ground truth or from a pickle file (in this case, no annotations are
41
+ provided).
42
+
43
+ Args:
44
+ root_path (str): dataset root path.
45
+ Default: "./datasets/COCO"
46
+ data_version (str): desired version/folder of COCO. Possible options are "train2017", "val2017".
47
+ Default: "train2017"
48
+ is_train (bool): train or eval mode. If true, train mode is used.
49
+ Default: True
50
+ use_gt_bboxes (bool): use ground truth bounding boxes. If False, bbox_path is required.
51
+ Default: True
52
+ bbox_path (str): bounding boxes pickle file path.
53
+ Default: ""
54
+ image_width (int): image width.
55
+ Default: 288
56
+ image_height (int): image height.
57
+ Default: ``384``
58
+ color_rgb (bool): rgb or bgr color mode. If True, rgb color mode is used.
59
+ Default: True
60
+ scale (bool): scale mode.
61
+ Default: True
62
+ scale_factor (float): scale factor.
63
+ Default: 0.35
64
+ flip_prob (float): flip probability.
65
+ Default: 0.5
66
+ rotate_prob (float): rotate probability.
67
+ Default: 0.5
68
+ rotation_factor (float): rotation factor.
69
+ Default: 45.
70
+ half_body_prob (float): half body probability.
71
+ Default: 0.3
72
+ use_different_joints_weight (bool): use different joints weights.
73
+ If true, the following joints weights will be used:
74
+ [1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, 1.5]
75
+ Default: False
76
+ heatmap_sigma (float): sigma of the gaussian used to create the heatmap.
77
+ Default: 3
78
+ soft_nms (bool): enable soft non-maximum suppression.
79
+ Default: False
80
+ """
81
+ super(COCODataset, self).__init__()
82
+
83
+ self.root_path = root_path
84
+ self.data_version = data_version
85
+ self.is_train = is_train
86
+ self.use_gt_bboxes = use_gt_bboxes
87
+ self.bbox_path = bbox_path
88
+ self.scale = scale # ToDo Check
89
+ self.scale_factor = scale_factor
90
+ self.flip_prob = flip_prob
91
+ self.rotate_prob = rotate_prob
92
+ self.rotation_factor = rotation_factor
93
+ self.half_body_prob = half_body_prob
94
+ self.use_different_joints_weight = use_different_joints_weight # ToDo Check
95
+ self.heatmap_sigma = heatmap_sigma
96
+ self.soft_nms = soft_nms
97
+
98
+ # Image & annotation path
99
+ self.data_path = f"{root_path}/{data_version}"
100
+ self.annotation_path = f"{root_path}/annotations/person_keypoints_{data_version}.json"
101
+
102
+ self.image_size = (image_width, image_height)
103
+ self.aspect_ratio = image_width * 1.0 / image_height
104
+
105
+ self.heatmap_size = (int(image_width / 4), int(image_height / 4))
106
+ self.heatmap_type = 'gaussian'
107
+ self.pixel_std = 200 # I don't understand the meaning of pixel_std (=200) in the original implementation
108
+
109
+ self.num_joints = 25
110
+ self.num_joints_half_body = 15
111
+
112
+ # eye, ear, shoulder, elbow, wrist, hip, knee, ankle
113
+ self.flip_pairs = [[1, 2], [3, 4], [6, 7], [8, 9], [10, 11], [12, 13],
114
+ [15, 16], [17, 18], [19, 22], [20, 23], [21, 24]]
115
+ self.upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
116
+ self.lower_body_ids = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
117
+ self.joints_weight = np.array([1., 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2,
118
+ 1.5, 1.5, 1., 1., 1., 1.2, 1.2, 1.5, 1.5,
119
+ 1.5, 1.5, 1.5, 1.5, 1.5,
120
+ 1.5]).reshape((self.num_joints, 1))
121
+
122
+ self.transform = transforms.Compose([
123
+ transforms.ToTensor(),
124
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
125
+ ])
126
+
127
+ # Load COCO dataset - Create COCO object then load images and annotations
128
+ self.coco = COCO(self.annotation_path)
129
+
130
+ # Create a list of annotations and the corresponding image (each image can contain more than one detection)
131
+
132
+ """ Load bboxes and joints
133
+ - if self.use_gt_bboxes -> Load GT bboxes and joints
134
+ - else -> Load pre-predicted bboxes by a detector (as YOLOv3) and null joints
135
+ """
136
+
137
+ if not self.use_gt_bboxes:
138
+ """
139
+ bboxes must be saved as the original COCO annotations
140
+ i.e. the format must be:
141
+ bboxes = {
142
+ '<imgId>': [
143
+ {
144
+ 'id': <annId>, # progressive id for debugging
145
+ 'clean_bbox': np.array([<x>, <y>, <w>, <h>])}
146
+ },
147
+ ...
148
+ ],
149
+ ...
150
+ }
151
+ """
152
+ with open(self.bbox_path, 'rb') as fd:
153
+ bboxes = pickle.load(fd)
154
+
155
+ self.data = []
156
+ # load annotations for each image of COCO
157
+ for imgId in tqdm(self.coco.getImgIds(), desc="Prepare images, annotations ... "):
158
+ ann_ids = self.coco.getAnnIds(imgIds=imgId, iscrowd=False) # annotation ids
159
+ img = self.coco.loadImgs(imgId)[0] # load img
160
+
161
+ if self.use_gt_bboxes:
162
+ objs = self.coco.loadAnns(ann_ids)
163
+
164
+ # sanitize bboxes
165
+ valid_objs = []
166
+ for obj in objs:
167
+ # Skip non-person objects (it should never happen)
168
+ if obj['category_id'] != 1:
169
+ continue
170
+
171
+ # ignore objs without keypoints annotation
172
+ if max(obj['keypoints']) == 0 and max(obj['foot_kpts']) == 0:
173
+ continue
174
+
175
+ x, y, w, h = obj['bbox']
176
+ x1 = np.max((0, x))
177
+ y1 = np.max((0, y))
178
+ x2 = np.min((img['width'] - 1, x1 + np.max((0, w - 1))))
179
+ y2 = np.min((img['height'] - 1, y1 + np.max((0, h - 1))))
180
+
181
+ # Use only valid bounding boxes
182
+ if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
183
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
184
+ valid_objs.append(obj)
185
+
186
+ objs = valid_objs
187
+
188
+ else:
189
+ objs = bboxes[imgId]
190
+
191
+ # for each annotation of this image, add the formatted annotation to self.data
192
+ for obj in objs:
193
+ joints = np.zeros((self.num_joints, 2), dtype=np.float)
194
+ joints_visibility = np.ones((self.num_joints, 2), dtype=np.float)
195
+
196
+ # Add foot data to keypoints
197
+ obj['keypoints'].extend(obj['foot_kpts'])
198
+
199
+ if self.use_gt_bboxes:
200
+ """ COCO pre-processing
201
+
202
+ - Moved above
203
+ - Skip non-person objects (it should never happen)
204
+ if obj['category_id'] != 1:
205
+ continue
206
+
207
+ # ignore objs without keypoints annotation
208
+ if max(obj['keypoints']) == 0:
209
+ continue
210
+ """
211
+
212
+ # Not all joints are already present, skip them
213
+ vjoints = list(range(self.num_joints))
214
+ vjoints.remove(5)
215
+ vjoints.remove(14)
216
+
217
+ for idx, pt in enumerate(vjoints):
218
+ if pt == 5 or pt == 14:
219
+ continue # Neck and hip are manually filled
220
+ joints[pt, 0] = obj['keypoints'][idx * 3 + 0]
221
+ joints[pt, 1] = obj['keypoints'][idx * 3 + 1]
222
+ t_vis = int(np.clip(obj['keypoints'][idx * 3 + 2], 0, 1))
223
+ """
224
+ - COCO:
225
+ if visibility == 0 -> keypoint is not in the image.
226
+ if visibility == 1 -> keypoint is in the image BUT not visible
227
+ (e.g. behind an object).
228
+ if visibility == 2 -> keypoint looks clearly
229
+ (i.e. it is not hidden).
230
+ """
231
+ joints_visibility[pt, 0] = t_vis
232
+ joints_visibility[pt, 1] = t_vis
233
+
234
+ center, scale = self._box2cs(obj['clean_bbox'][:4])
235
+
236
+ # Add neck and c-hip (check utils/visualization.py for keypoints)
237
+ joints[5, 0] = (joints[6, 0] + joints[7, 0]) / 2
238
+ joints[5, 1] = (joints[6, 1] + joints[7, 1]) / 2
239
+ joints_visibility[5, :] = min(joints_visibility[6, 0],
240
+ joints_visibility[7, 0])
241
+ joints[14, 0] = (joints[12, 0] + joints[13, 0]) / 2
242
+ joints[14, 1] = (joints[12, 1] + joints[13, 1]) / 2
243
+ joints_visibility[14, :] = min(joints_visibility[12, 0],
244
+ joints_visibility[13, 0])
245
+
246
+ self.data.append({
247
+ 'imgId': imgId,
248
+ 'annId': obj['id'],
249
+ 'imgPath': f"{self.root_path}/{self.data_version}/{imgId:012d}.jpg",
250
+ 'center': center,
251
+ 'scale': scale,
252
+ 'joints': joints,
253
+ 'joints_visibility': joints_visibility,
254
+ })
255
+
256
+ # Done check if we need prepare_data -> We should not
257
+ print('\nCOCO dataset loaded!')
258
+
259
+ # Default values
260
+ self.bbox_thre = 1.0
261
+ self.image_thre = 0.0
262
+ self.in_vis_thre = 0.2
263
+ self.nms_thre = 1.0
264
+ self.oks_thre = 0.9
265
+
266
+ def __len__(self):
267
+ return len(self.data)
268
+
269
+ def __getitem__(self, index):
270
+ # index = 0
271
+ joints_data = self.data[index].copy()
272
+
273
+ # Load image
274
+ try:
275
+ image = np.array(Image.open(joints_data['imgPath']))
276
+ if image.ndim == 2:
277
+ # Some images are grayscale and will fail the trasform, convert to RGB
278
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
279
+ except:
280
+ raise ValueError(f"Fail to read {joints_data['imgPath']}")
281
+
282
+ joints = joints_data['joints']
283
+ joints_vis = joints_data['joints_visibility']
284
+
285
+ c = joints_data['center']
286
+ s = joints_data['scale']
287
+ score = joints_data['score'] if 'score' in joints_data else 1
288
+ r = 0
289
+
290
+ # Apply data augmentation
291
+ if self.is_train:
292
+ if (self.half_body_prob and random.random() < self.half_body_prob and
293
+ np.sum(joints_vis[:, 0]) > self.num_joints_half_body):
294
+ c_half_body, s_half_body = self._half_body_transform(joints, joints_vis)
295
+
296
+ if c_half_body is not None and s_half_body is not None:
297
+ c, s = c_half_body, s_half_body
298
+
299
+ sf = self.scale_factor
300
+ rf = self.rotation_factor
301
+
302
+ if self.scale:
303
+ # A random scale factor in [1 - sf, 1 + sf]
304
+ s = s * np.clip(random.random() * sf + 1, 1 - sf, 1 + sf)
305
+
306
+ if self.rotate_prob and random.random() < self.rotate_prob:
307
+ # A random rotation factor in [-2 * rf, 2 * rf]
308
+ r = np.clip(random.random() * rf, -rf * 2, rf * 2)
309
+ else:
310
+ r = 0
311
+
312
+ if self.flip_prob and random.random() < self.flip_prob:
313
+ image = image[:, ::-1, :]
314
+ joints, joints_vis = fliplr_joints(joints, joints_vis,
315
+ image.shape[1],
316
+ self.flip_pairs)
317
+ c[0] = image.shape[1] - c[0] - 1
318
+
319
+ # Apply affine transform on joints and image
320
+ trans = get_affine_transform(c, s, self.pixel_std, r, self.image_size)
321
+ image = cv2.warpAffine(
322
+ image,
323
+ trans,
324
+ (int(self.image_size[0]), int(self.image_size[1])),
325
+ flags=cv2.INTER_LINEAR
326
+ )
327
+
328
+ for i in range(self.num_joints):
329
+ if joints_vis[i, 0] > 0.:
330
+ joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
331
+
332
+ # Convert image to tensor and normalize
333
+ if self.transform is not None: # I could remove this check
334
+ image = self.transform(image)
335
+
336
+ target, target_weight = self._generate_target(joints, joints_vis)
337
+
338
+ # Update metadata
339
+ joints_data['joints'] = joints
340
+ joints_data['joints_visibility'] = joints_vis
341
+ joints_data['center'] = c
342
+ joints_data['scale'] = s
343
+ joints_data['rotation'] = r
344
+ joints_data['score'] = score
345
+
346
+ # from utils.visualization import draw_points_and_skeleton, joints_dict
347
+ # image = np.rollaxis(image.detach().cpu().numpy(), 0, 3)
348
+ # joints = np.hstack((joints[:, ::-1], joints_vis[:, 0][..., None]))
349
+ # image = draw_points_and_skeleton(image.copy(), joints,
350
+ # joints_dict()['coco']['skeleton'],
351
+ # person_index=0,
352
+ # points_color_palette='gist_rainbow',
353
+ # skeleton_color_palette='jet',
354
+ # points_palette_samples=10,
355
+ # confidence_threshold=0.4)
356
+ # cv2.imshow('', image)
357
+ # cv2.waitKey(0)
358
+
359
+ return image, target.astype(np.float32), target_weight.astype(np.float32), joints_data
360
+
361
+
362
+ # Private methods
363
+ def _box2cs(self, box):
364
+ x, y, w, h = box[:4]
365
+ return self._xywh2cs(x, y, w, h)
366
+
367
+ def _xywh2cs(self, x, y, w, h):
368
+ center = np.zeros((2,), dtype=np.float32)
369
+ center[0] = x + w * 0.5
370
+ center[1] = y + h * 0.5
371
+
372
+ if w > self.aspect_ratio * h:
373
+ h = w * 1.0 / self.aspect_ratio
374
+ elif w < self.aspect_ratio * h:
375
+ w = h * self.aspect_ratio
376
+ scale = np.array(
377
+ [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
378
+ dtype=np.float32)
379
+ if center[0] != -1:
380
+ scale = scale * 1.25
381
+
382
+ return center, scale
383
+
384
+ def _half_body_transform(self, joints, joints_vis):
385
+ upper_joints = []
386
+ lower_joints = []
387
+ for joint_id in range(self.num_joints):
388
+ if joints_vis[joint_id][0] > 0:
389
+ if joint_id in self.upper_body_ids:
390
+ upper_joints.append(joints[joint_id])
391
+ else:
392
+ lower_joints.append(joints[joint_id])
393
+
394
+ if random.random() < 0.5 and len(upper_joints) > 2:
395
+ selected_joints = upper_joints
396
+ else:
397
+ selected_joints = lower_joints \
398
+ if len(lower_joints) > 2 else upper_joints
399
+
400
+ if len(selected_joints) < 2:
401
+ return None, None
402
+
403
+ selected_joints = np.array(selected_joints, dtype=np.float32)
404
+ center = selected_joints.mean(axis=0)[:2]
405
+
406
+ left_top = np.amin(selected_joints, axis=0)
407
+ right_bottom = np.amax(selected_joints, axis=0)
408
+
409
+ w = right_bottom[0] - left_top[0]
410
+ h = right_bottom[1] - left_top[1]
411
+
412
+ if w > self.aspect_ratio * h:
413
+ h = w * 1.0 / self.aspect_ratio
414
+ elif w < self.aspect_ratio * h:
415
+ w = h * self.aspect_ratio
416
+
417
+ scale = np.array(
418
+ [
419
+ w * 1.0 / self.pixel_std,
420
+ h * 1.0 / self.pixel_std
421
+ ],
422
+ dtype=np.float32
423
+ )
424
+
425
+ scale = scale * 1.5
426
+
427
+ return center, scale
428
+
429
+ def _generate_target(self, joints, joints_vis):
430
+ """
431
+ :param joints: [num_joints, 2]
432
+ :param joints_vis: [num_joints, 2]
433
+ :return: target, target_weight(1: visible, 0: invisible)
434
+ """
435
+ target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
436
+ target_weight[:, 0] = joints_vis[:, 0]
437
+
438
+ if self.heatmap_type == 'gaussian':
439
+ target = np.zeros((self.num_joints,
440
+ self.heatmap_size[1],
441
+ self.heatmap_size[0]),
442
+ dtype=np.float32)
443
+
444
+ tmp_size = self.heatmap_sigma * 3
445
+
446
+ for joint_id in range(self.num_joints):
447
+ feat_stride = np.asarray(self.image_size) / np.asarray(self.heatmap_size)
448
+ mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
449
+ mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
450
+ # Check that any part of the gaussian is in-bounds
451
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
452
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
453
+ if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
454
+ or br[0] < 0 or br[1] < 0:
455
+ # If not, just return the image as is
456
+ target_weight[joint_id] = 0
457
+ continue
458
+
459
+ # # Generate gaussian
460
+ size = 2 * tmp_size + 1
461
+ x = np.arange(0, size, 1, np.float32)
462
+ y = x[:, np.newaxis]
463
+ x0 = y0 = size // 2
464
+ # The gaussian is not normalized, we want the center value to equal 1
465
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.heatmap_sigma ** 2))
466
+
467
+ # Usable gaussian range
468
+ g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
469
+ g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
470
+ # Image range
471
+ img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
472
+ img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
473
+
474
+ v = target_weight[joint_id]
475
+ if v > 0.5:
476
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
477
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
478
+ else:
479
+ raise NotImplementedError
480
+
481
+ if self.use_different_joints_weight:
482
+ target_weight = np.multiply(target_weight, self.joints_weight)
483
+
484
+ return target, target_weight
485
+
486
+ def _write_coco_keypoint_results(self, keypoints, res_file):
487
+ data_pack = [
488
+ {
489
+ 'cat_id': 1, # 1 == 'person'
490
+ 'cls': 'person',
491
+ 'ann_type': 'keypoints',
492
+ 'keypoints': keypoints
493
+ }
494
+ ]
495
+
496
+ results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
497
+ with open(res_file, 'w') as f:
498
+ json.dump(results, f, sort_keys=True, indent=4)
499
+ try:
500
+ json.load(open(res_file))
501
+ except Exception:
502
+ content = []
503
+ with open(res_file, 'r') as f:
504
+ for line in f:
505
+ content.append(line)
506
+ content[-1] = ']'
507
+ with open(res_file, 'w') as f:
508
+ for c in content:
509
+ f.write(c)
510
+
511
+ def _coco_keypoint_results_one_category_kernel(self, data_pack):
512
+ cat_id = data_pack['cat_id']
513
+ keypoints = data_pack['keypoints']
514
+ cat_results = []
515
+
516
+ for img_kpts in keypoints:
517
+ if len(img_kpts) == 0:
518
+ continue
519
+
520
+ _key_points = np.array([img_kpts[k]['keypoints'] for k in range(len(img_kpts))], dtype=np.float32)
521
+ key_points = np.zeros((_key_points.shape[0], self.num_joints * 3), dtype=np.float32)
522
+
523
+ for ipt in range(self.num_joints):
524
+ key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
525
+ key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
526
+ key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score.
527
+
528
+ result = [
529
+ {
530
+ 'image_id': img_kpts[k]['image'],
531
+ 'category_id': cat_id,
532
+ 'keypoints': list(key_points[k]),
533
+ 'score': img_kpts[k]['score'].astype(np.float32),
534
+ 'center': list(img_kpts[k]['center']),
535
+ 'scale': list(img_kpts[k]['scale'])
536
+ }
537
+ for k in range(len(img_kpts))
538
+ ]
539
+ cat_results.extend(result)
540
+
541
+ return cat_results
542
+
543
+
544
+ if __name__ == '__main__':
545
+ # from skimage import io
546
+ coco = COCODataset(root_path=f"{os.path.dirname(__file__)}/COCO", data_version="traincoex", rotate_prob=0., half_body_prob=0.)
547
+ item = coco[1]
548
+ # io.imsave("tmp.jpg", item[0].permute(1,2,0).numpy())
549
+ print()
550
+ print(item[1].shape)
551
+ print('ok!!')
552
+ # img = np.clip(np.transpose(item[0].numpy(), (1, 2, 0))[:, :, ::-1] * np.asarray([0.229, 0.224, 0.225]) +
553
+ # np.asarray([0.485, 0.456, 0.406]), 0, 1) * 255
554
+ # cv2.imwrite('./tmp.png', img.astype(np.uint8))
555
+ # print(item[-1])
556
+ pass
easy_ViTPose/datasets/HumanPoseEstimation.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+
3
+
4
+ class HumanPoseEstimationDataset(Dataset):
5
+ """
6
+ HumanPoseEstimationDataset class.
7
+
8
+ Generic class for HPE datasets.
9
+ """
10
+ def __init__(self):
11
+ pass
12
+
13
+ def __len__(self):
14
+ pass
15
+
16
+ def __getitem__(self, item):
17
+ pass
easy_ViTPose/datasets/__init__.py ADDED
File without changes
easy_ViTPose/easy_ViTPose.egg-info/PKG-INFO ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: easy-ViTPose
3
+ Version: 0.1
4
+ License-File: LICENSE
easy_ViTPose/easy_ViTPose.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ src/easy_ViTPose.egg-info/PKG-INFO
5
+ src/easy_ViTPose.egg-info/SOURCES.txt
6
+ src/easy_ViTPose.egg-info/dependency_links.txt
7
+ src/easy_ViTPose.egg-info/top_level.txt
8
+ src/vit_models/__init__.py
9
+ src/vit_models/model.py
10
+ src/vit_models/optimizer.py
11
+ src/vit_models/losses/__init__.py
12
+ src/vit_models/losses/classfication_loss.py
13
+ src/vit_models/losses/heatmap_loss.py
14
+ src/vit_models/losses/mesh_loss.py
15
+ src/vit_models/losses/mse_loss.py
16
+ src/vit_models/losses/multi_loss_factory.py
17
+ src/vit_models/losses/regression_loss.py
18
+ src/vit_utils/__init__.py
19
+ src/vit_utils/dist_util.py
20
+ src/vit_utils/inference.py
21
+ src/vit_utils/logging.py
22
+ src/vit_utils/top_down_eval.py
23
+ src/vit_utils/train_valid_fn.py
24
+ src/vit_utils/transform.py
25
+ src/vit_utils/util.py
26
+ src/vit_utils/visualization.py
27
+ src/vit_utils/nms/__init__.py
28
+ src/vit_utils/nms/nms.py
29
+ src/vit_utils/nms/nms_ori.py
30
+ src/vit_utils/nms/setup_linux.py
31
+ src/vit_utils/post_processing/__init__.py
32
+ src/vit_utils/post_processing/group.py
33
+ src/vit_utils/post_processing/nms.py
34
+ src/vit_utils/post_processing/one_euro_filter.py
35
+ src/vit_utils/post_processing/post_transforms.py
easy_ViTPose/easy_ViTPose.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
easy_ViTPose/easy_ViTPose.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ vit_models
2
+ vit_utils
easy_ViTPose/inference.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+ from typing import Optional
4
+ import typing
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ultralytics import YOLO
11
+
12
+ from .configs.ViTPose_common import data_cfg
13
+ from .sort import Sort
14
+ from .vit_models.model import ViTPose
15
+ from .vit_utils.inference import draw_bboxes, pad_image
16
+ from .vit_utils.top_down_eval import keypoints_from_heatmaps
17
+ from .vit_utils.util import dyn_model_import, infer_dataset_by_path
18
+ from .vit_utils.visualization import draw_points_and_skeleton, joints_dict
19
+
20
+ try:
21
+ import torch_tensorrt
22
+ except ModuleNotFoundError:
23
+ pass
24
+
25
+ try:
26
+ import onnxruntime
27
+ except ModuleNotFoundError:
28
+ pass
29
+
30
+ __all__ = ['VitInference']
31
+ np.bool = np.bool_
32
+ MEAN = [0.485, 0.456, 0.406]
33
+ STD = [0.229, 0.224, 0.225]
34
+
35
+
36
+ DETC_TO_YOLO_YOLOC = {
37
+ 'human': [0],
38
+ 'cat': [15],
39
+ 'dog': [16],
40
+ 'horse': [17],
41
+ 'sheep': [18],
42
+ 'cow': [19],
43
+ 'elephant': [20],
44
+ 'bear': [21],
45
+ 'zebra': [22],
46
+ 'giraffe': [23],
47
+ 'animals': [15, 16, 17, 18, 19, 20, 21, 22, 23]
48
+ }
49
+
50
+
51
+ class VitInference:
52
+ """
53
+ Class for performing inference using ViTPose models with YOLOv8 human detection and SORT tracking.
54
+
55
+ Args:
56
+ model (str): Path to the ViT model file (.pth, .onnx, .engine).
57
+ yolo (str): Path of the YOLOv8 model to load.
58
+ model_name (str, optional): Name of the ViT model architecture to use.
59
+ Valid values are 's', 'b', 'l', 'h'.
60
+ Defaults to None, is necessary when using .pth checkpoints.
61
+ det_class (str, optional): the detection class. if None it is inferred by the dataset.
62
+ valid values are 'human', 'cat', 'dog', 'horse', 'sheep',
63
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
64
+ 'animals' (which is all previous but human)
65
+ dataset (str, optional): Name of the dataset. If None it's extracted from the file name.
66
+ Valid values are 'coco', 'coco_25', 'wholebody', 'mpii',
67
+ 'ap10k', 'apt36k', 'aic'
68
+ yolo_size (int, optional): Size of the input image for YOLOv8 model. Defaults to 320.
69
+ device (str, optional): Device to use for inference. Defaults to 'cuda' if available, else 'cpu'.
70
+ is_video (bool, optional): Flag indicating if the input is video. Defaults to False.
71
+ single_pose (bool, optional): Flag indicating if the video (on images this flag has no effect)
72
+ will contain a single pose.
73
+ In this case the SORT tracker is not used (increasing performance)
74
+ but people id tracking
75
+ won't be consistent among frames.
76
+ yolo_step (int, optional): The tracker can be used to predict the bboxes instead of yolo for performance,
77
+ this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame).
78
+ This does not have any effect when is_video is False.
79
+ """
80
+
81
+ def __init__(self, model: str,
82
+ yolo: str,
83
+ model_name: Optional[str] = None,
84
+ det_class: Optional[str] = None,
85
+ dataset: Optional[str] = None,
86
+ yolo_size: Optional[int] = 320,
87
+ device: Optional[str] = None,
88
+ is_video: Optional[bool] = False,
89
+ single_pose: Optional[bool] = False,
90
+ yolo_step: Optional[int] = 1):
91
+ assert os.path.isfile(model), f'The model file {model} does not exist'
92
+ assert os.path.isfile(yolo), f'The YOLOv8 model {yolo} does not exist'
93
+
94
+ # Device priority is cuda / mps / cpu
95
+ if device is None:
96
+ if torch.cuda.is_available():
97
+ device = 'cuda'
98
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
99
+ device = 'mps'
100
+ else:
101
+ device = 'cpu'
102
+
103
+ self.device = device
104
+ self.yolo = YOLO(yolo, task='detect')
105
+ self.yolo_size = yolo_size
106
+ self.yolo_step = yolo_step
107
+ self.is_video = is_video
108
+ self.single_pose = single_pose
109
+ self.reset()
110
+
111
+ # State saving during inference
112
+ self.save_state = True # Can be disabled manually
113
+ self._img = None
114
+ self._yolo_res = None
115
+ self._tracker_res = None
116
+ self._keypoints = None
117
+
118
+ # Use extension to decide which kind of model has been loaded
119
+ use_onnx = model.endswith('.onnx')
120
+ use_trt = model.endswith('.engine')
121
+
122
+
123
+ # Extract dataset name
124
+ if dataset is None:
125
+ dataset = infer_dataset_by_path(model)
126
+
127
+ assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
128
+ 'The specified dataset is not valid'
129
+
130
+ # Dataset can now be set for visualization
131
+ self.dataset = dataset
132
+
133
+ # if we picked the dataset switch to correct yolo classes if not set
134
+ if det_class is None:
135
+ det_class = 'animals' if dataset in ['ap10k', 'apt36k'] else 'human'
136
+ self.yolo_classes = DETC_TO_YOLO_YOLOC[det_class]
137
+
138
+ assert model_name in [None, 's', 'b', 'l', 'h'], \
139
+ f'The model name {model_name} is not valid'
140
+
141
+ # onnx / trt models do not require model_cfg specification
142
+ if model_name is None:
143
+ assert use_onnx or use_trt, \
144
+ 'Specify the model_name if not using onnx / trt'
145
+ else:
146
+ # Dynamically import the model class
147
+ model_cfg = dyn_model_import(self.dataset, model_name)
148
+
149
+ self.target_size = data_cfg['image_size']
150
+ if use_onnx:
151
+ self._ort_session = onnxruntime.InferenceSession(model,
152
+ providers=['CUDAExecutionProvider',
153
+ 'CPUExecutionProvider'])
154
+ inf_fn = self._inference_onnx
155
+ else:
156
+ self._vit_pose = ViTPose(model_cfg)
157
+ self._vit_pose.eval()
158
+
159
+ if use_trt:
160
+ self._vit_pose = torch.jit.load(model)
161
+ else:
162
+ ckpt = torch.load(model, map_location='cpu')
163
+ if 'state_dict' in ckpt:
164
+ self._vit_pose.load_state_dict(ckpt['state_dict'])
165
+ else:
166
+ self._vit_pose.load_state_dict(ckpt)
167
+ self._vit_pose.to(torch.device(device))
168
+
169
+ inf_fn = self._inference_torch
170
+
171
+ # Override _inference abstract with selected engine
172
+ self._inference = inf_fn # type: ignore
173
+
174
+ def reset(self):
175
+ """
176
+ Reset the inference class to be ready for a new video.
177
+ This will reset the internal counter of frames, on videos
178
+ this is necessary to reset the tracker.
179
+ """
180
+ min_hits = 3 if self.yolo_step == 1 else 1
181
+ use_tracker = self.is_video and not self.single_pose
182
+ self.tracker = Sort(max_age=self.yolo_step,
183
+ min_hits=min_hits,
184
+ iou_threshold=0.3) if use_tracker else None # TODO: Params
185
+ self.frame_counter = 0
186
+
187
+ @classmethod
188
+ def postprocess(cls, heatmaps, org_w, org_h):
189
+ """
190
+ Postprocess the heatmaps to obtain keypoints and their probabilities.
191
+
192
+ Args:
193
+ heatmaps (ndarray): Heatmap predictions from the model.
194
+ org_w (int): Original width of the image.
195
+ org_h (int): Original height of the image.
196
+
197
+ Returns:
198
+ ndarray: Processed keypoints with probabilities.
199
+ """
200
+ points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
201
+ center=np.array([[org_w // 2,
202
+ org_h // 2]]),
203
+ scale=np.array([[org_w, org_h]]),
204
+ unbiased=True, use_udp=True)
205
+ return np.concatenate([points[:, :, ::-1], prob], axis=2)
206
+
207
+ @abc.abstractmethod
208
+ def _inference(self, img: np.ndarray) -> np.ndarray:
209
+ """
210
+ Abstract method for performing inference on an image.
211
+ It is overloaded by each inference engine.
212
+
213
+ Args:
214
+ img (ndarray): Input image for inference.
215
+
216
+ Returns:
217
+ ndarray: Inference results.
218
+ """
219
+ raise NotImplementedError
220
+
221
+ def inference(self, img: np.ndarray) -> dict[typing.Any, typing.Any]:
222
+ """
223
+ Perform inference on the input image.
224
+
225
+ Args:
226
+ img (ndarray): Input image for inference in RGB format.
227
+
228
+ Returns:
229
+ dict[typing.Any, typing.Any]: Inference results.
230
+ """
231
+
232
+ # First use YOLOv8 for detection
233
+ res_pd = np.empty((0, 5))
234
+ results = None
235
+ if (self.tracker is None or
236
+ (self.frame_counter % self.yolo_step == 0 or self.frame_counter < 3)):
237
+ results = self.yolo(img, verbose=False, imgsz=self.yolo_size,
238
+ device=self.device if self.device != 'cuda' else 0,
239
+ classes=self.yolo_classes)[0]
240
+ res_pd = np.array([r[:5].tolist() for r in # TODO: Confidence threshold
241
+ results.boxes.data.cpu().numpy() if r[4] > 0.35]).reshape((-1, 5))
242
+ self.frame_counter += 1
243
+
244
+ frame_keypoints = {}
245
+ ids = None
246
+ if self.tracker is not None:
247
+ res_pd = self.tracker.update(res_pd)
248
+ ids = res_pd[:, 5].astype(int).tolist()
249
+
250
+ # Prepare boxes for inference
251
+ bboxes = res_pd[:, :4].round().astype(int)
252
+ scores = res_pd[:, 4].tolist()
253
+ pad_bbox = 10
254
+
255
+ if ids is None:
256
+ ids = range(len(bboxes))
257
+
258
+ for bbox, id in zip(bboxes, ids):
259
+ # TODO: Slightly bigger bbox
260
+ bbox[[0, 2]] = np.clip(bbox[[0, 2]] + [-pad_bbox, pad_bbox], 0, img.shape[1])
261
+ bbox[[1, 3]] = np.clip(bbox[[1, 3]] + [-pad_bbox, pad_bbox], 0, img.shape[0])
262
+
263
+ # Crop image and pad to 3/4 aspect ratio
264
+ img_inf = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
265
+ img_inf, (left_pad, top_pad) = pad_image(img_inf, 3 / 4)
266
+
267
+ keypoints = self._inference(img_inf)[0]
268
+ # Transform keypoints to original image
269
+ keypoints[:, :2] += bbox[:2][::-1] - [top_pad, left_pad]
270
+ frame_keypoints[id] = keypoints
271
+
272
+ if self.save_state:
273
+ self._img = img
274
+ self._yolo_res = results
275
+ self._tracker_res = (bboxes, ids, scores)
276
+ self._keypoints = frame_keypoints
277
+
278
+ return frame_keypoints
279
+
280
+ def draw(self, show_yolo=True, show_raw_yolo=False, confidence_threshold=0.5):
281
+ """
282
+ Draw keypoints and bounding boxes on the image.
283
+
284
+ Args:
285
+ show_yolo (bool, optional): Whether to show YOLOv8 bounding boxes. Default is True.
286
+ show_raw_yolo (bool, optional): Whether to show raw YOLOv8 bounding boxes. Default is False.
287
+
288
+ Returns:
289
+ ndarray: Image with keypoints and bounding boxes drawn.
290
+ """
291
+ img = self._img.copy()
292
+ bboxes, ids, scores = self._tracker_res
293
+
294
+ if self._yolo_res is not None and (show_raw_yolo or (self.tracker is None and show_yolo)):
295
+ img = np.array(self._yolo_res.plot())
296
+
297
+ if show_yolo and self.tracker is not None:
298
+ img = draw_bboxes(img, bboxes, ids, scores)
299
+
300
+ img = np.array(img)[..., ::-1] # RGB to BGR for cv2 modules
301
+ for idx, k in self._keypoints.items():
302
+ img = draw_points_and_skeleton(img.copy(), k,
303
+ joints_dict()[self.dataset]['skeleton'],
304
+ person_index=idx,
305
+ points_color_palette='gist_rainbow',
306
+ skeleton_color_palette='jet',
307
+ points_palette_samples=10,
308
+ confidence_threshold=confidence_threshold)
309
+ return img[..., ::-1] # Return RGB as original
310
+
311
+ def pre_img(self, img):
312
+ org_h, org_w = img.shape[:2]
313
+ img_input = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR) / 255
314
+ img_input = ((img_input - MEAN) / STD).transpose(2, 0, 1)[None].astype(np.float32)
315
+ return img_input, org_h, org_w
316
+
317
+ @torch.no_grad()
318
+ def _inference_torch(self, img: np.ndarray) -> np.ndarray:
319
+ # Prepare input data
320
+ img_input, org_h, org_w = self.pre_img(img)
321
+ img_input = torch.from_numpy(img_input).to(torch.device(self.device))
322
+
323
+ # Feed to model
324
+ heatmaps = self._vit_pose(img_input).detach().cpu().numpy()
325
+ return self.postprocess(heatmaps, org_w, org_h)
326
+
327
+ def _inference_onnx(self, img: np.ndarray) -> np.ndarray:
328
+ # Prepare input data
329
+ img_input, org_h, org_w = self.pre_img(img)
330
+
331
+ # Feed to model
332
+ ort_inputs = {self._ort_session.get_inputs()[0].name: img_input}
333
+ heatmaps = self._ort_session.run(None, ort_inputs)[0]
334
+ return self.postprocess(heatmaps, org_w, org_h)
easy_ViTPose/sort.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SORT: A Simple, Online and Realtime Tracker
3
+ Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
17
+ """
18
+ from __future__ import print_function
19
+
20
+ import os
21
+ import numpy as np
22
+ import matplotlib
23
+
24
+ import matplotlib.pyplot as plt
25
+ import matplotlib.patches as patches
26
+ from skimage import io
27
+
28
+ import glob
29
+ import time
30
+ import argparse
31
+ from filterpy.kalman import KalmanFilter
32
+
33
+ np.random.seed(0)
34
+
35
+
36
+ def linear_assignment(cost_matrix):
37
+ try:
38
+ import lap
39
+ _, x, y = lap.lapjv(cost_matrix, extend_cost=True)
40
+ return np.array([[y[i], i] for i in x if i >= 0])
41
+ except ImportError:
42
+ from scipy.optimize import linear_sum_assignment
43
+ x, y = linear_sum_assignment(cost_matrix)
44
+ return np.array(list(zip(x, y)))
45
+
46
+
47
+ def iou_batch(bb_test, bb_gt):
48
+ """
49
+ From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
50
+ """
51
+ bb_gt = np.expand_dims(bb_gt, 0)
52
+ bb_test = np.expand_dims(bb_test, 1)
53
+
54
+ xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
55
+ yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
56
+ xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
57
+ yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
58
+ w = np.maximum(0., xx2 - xx1)
59
+ h = np.maximum(0., yy2 - yy1)
60
+ wh = w * h
61
+ o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
62
+ + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
63
+ return(o)
64
+
65
+
66
+ def convert_bbox_to_z(bbox):
67
+ """
68
+ Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
69
+ [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
70
+ the aspect ratio
71
+ """
72
+ w = bbox[2] - bbox[0]
73
+ h = bbox[3] - bbox[1]
74
+ x = bbox[0] + w/2.
75
+ y = bbox[1] + h/2.
76
+ s = w * h # scale is just area
77
+ r = w / float(h)
78
+ return np.array([x, y, s, r]).reshape((4, 1))
79
+
80
+
81
+ def convert_x_to_bbox(x, score=None):
82
+ """
83
+ Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
84
+ [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
85
+ """
86
+ w = np.sqrt(x[2] * x[3])
87
+ h = x[2] / w
88
+ if(score == None):
89
+ return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
90
+ else:
91
+ return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
92
+
93
+
94
+ class KalmanBoxTracker(object):
95
+ """
96
+ This class represents the internal state of individual tracked objects observed as bbox.
97
+ """
98
+ count = 0
99
+
100
+ def __init__(self, bbox, score):
101
+ """
102
+ Initialises a tracker using initial bounding box.
103
+ """
104
+ # define constant velocity model
105
+ self.kf = KalmanFilter(dim_x=7, dim_z=4)
106
+ self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [
107
+ 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
108
+ self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
109
+
110
+ self.kf.R[2:, 2:] *= 10.
111
+ self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
112
+ self.kf.P *= 10.
113
+ self.kf.Q[-1, -1] *= 0.01
114
+ self.kf.Q[4:, 4:] *= 0.01
115
+
116
+ self.kf.x[:4] = convert_bbox_to_z(bbox)
117
+ self.time_since_update = 0
118
+ self.id = KalmanBoxTracker.count
119
+ KalmanBoxTracker.count += 1
120
+ self.history = []
121
+ self.hits = 0
122
+ self.hit_streak = 0
123
+ self.age = 0
124
+ self.score = score
125
+
126
+ def update(self, bbox, score):
127
+ """
128
+ Updates the state vector with observed bbox.
129
+ """
130
+ self.time_since_update = 0
131
+ self.history = []
132
+ self.hits += 1
133
+ self.hit_streak += 1
134
+ self.kf.update(convert_bbox_to_z(bbox))
135
+ self.score = score
136
+
137
+ def predict(self):
138
+ """
139
+ Advances the state vector and returns the predicted bounding box estimate.
140
+ """
141
+ if((self.kf.x[6]+self.kf.x[2]) <= 0):
142
+ self.kf.x[6] *= 0.0
143
+ self.kf.predict()
144
+ self.age += 1
145
+ if(self.time_since_update > 0):
146
+ self.hit_streak = 0
147
+ self.time_since_update += 1
148
+ self.history.append(convert_x_to_bbox(self.kf.x))
149
+ return self.history[-1]
150
+
151
+ def get_state(self):
152
+ """
153
+ Returns the current bounding box estimate.
154
+ """
155
+ return convert_x_to_bbox(self.kf.x)
156
+
157
+
158
+ def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
159
+ """
160
+ Assigns detections to tracked object (both represented as bounding boxes)
161
+
162
+ Returns 3 lists of matches, unmatched_detections and unmatched_trackers
163
+ """
164
+ if(len(trackers) == 0):
165
+ return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
166
+
167
+ iou_matrix = iou_batch(detections, trackers)
168
+
169
+ if min(iou_matrix.shape) > 0:
170
+ a = (iou_matrix > iou_threshold).astype(np.int32)
171
+ if a.sum(1).max() == 1 and a.sum(0).max() == 1:
172
+ matched_indices = np.stack(np.where(a), axis=1)
173
+ else:
174
+ matched_indices = linear_assignment(-iou_matrix)
175
+ else:
176
+ matched_indices = np.empty(shape=(0, 2))
177
+
178
+ unmatched_detections = []
179
+ for d, det in enumerate(detections):
180
+ if(d not in matched_indices[:, 0]):
181
+ unmatched_detections.append(d)
182
+ unmatched_trackers = []
183
+ for t, trk in enumerate(trackers):
184
+ if(t not in matched_indices[:, 1]):
185
+ unmatched_trackers.append(t)
186
+
187
+ # filter out matched with low IOU
188
+ matches = []
189
+ for m in matched_indices:
190
+ if(iou_matrix[m[0], m[1]] < iou_threshold):
191
+ unmatched_detections.append(m[0])
192
+ unmatched_trackers.append(m[1])
193
+ else:
194
+ matches.append(m.reshape(1, 2))
195
+ if(len(matches) == 0):
196
+ matches = np.empty((0, 2), dtype=int)
197
+ else:
198
+ matches = np.concatenate(matches, axis=0)
199
+
200
+ return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
201
+
202
+
203
+ class Sort(object):
204
+ def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
205
+ """
206
+ Sets key parameters for SORT
207
+ """
208
+ self.max_age = max_age
209
+ self.min_hits = min_hits
210
+ self.iou_threshold = iou_threshold
211
+ self.trackers = []
212
+ self.frame_count = 0
213
+
214
+ def update(self, dets=np.empty((0, 5))):
215
+ """
216
+ Params:
217
+ dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
218
+ Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
219
+ Returns the a similar array, where the last column is the object ID.
220
+
221
+ NOTE: The number of objects returned may differ from the number of detections provided.
222
+ """
223
+ self.frame_count += 1
224
+ empty_dets = dets.shape[0] == 0
225
+
226
+ # get predicted locations from existing trackers.
227
+ trks = np.zeros((len(self.trackers), 5))
228
+ to_del = []
229
+ ret = []
230
+ for t, trk in enumerate(trks):
231
+ pos = self.trackers[t].predict()[0]
232
+ trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
233
+ if np.any(np.isnan(pos)):
234
+ to_del.append(t)
235
+ trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
236
+ for t in reversed(to_del):
237
+ self.trackers.pop(t)
238
+ matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
239
+
240
+ # update matched trackers with assigned detections
241
+ for m in matched:
242
+ self.trackers[m[1]].update(dets[m[0], :], dets[m[0], -1])
243
+
244
+ # create and initialise new trackers for unmatched detections
245
+ for i in unmatched_dets:
246
+ trk = KalmanBoxTracker(dets[i, :], dets[i, -1])
247
+ self.trackers.append(trk)
248
+
249
+ i = len(self.trackers)
250
+ unmatched = []
251
+ for trk in reversed(self.trackers):
252
+ d = trk.get_state()[0]
253
+ if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
254
+ ret.append(np.concatenate((d, [trk.score, trk.id+1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
255
+ i -= 1
256
+ # remove dead tracklet
257
+ if(trk.time_since_update > self.max_age):
258
+ self.trackers.pop(i)
259
+ if empty_dets:
260
+ unmatched.append(np.concatenate((d, [trk.score, trk.id + 1])).reshape(1, -1))
261
+
262
+ if len(ret):
263
+ return np.concatenate(ret)
264
+ elif empty_dets:
265
+ return np.concatenate(unmatched) if len(unmatched) else np.empty((0, 6))
266
+ return np.empty((0, 6))
easy_ViTPose/to_onnx.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
easy_ViTPose/to_trt.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
easy_ViTPose/train.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import copy
4
+ import os
5
+ import os.path as osp
6
+ import time
7
+ import warnings
8
+ import click
9
+ import yaml
10
+
11
+ from glob import glob
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from vit_utils.util import init_random_seed, set_random_seed
17
+ from vit_utils.dist_util import get_dist_info, init_dist
18
+ from vit_utils.logging import get_root_logger
19
+
20
+ import configs.ViTPose_small_coco_256x192 as s_cfg
21
+ import configs.ViTPose_base_coco_256x192 as b_cfg
22
+ import configs.ViTPose_large_coco_256x192 as l_cfg
23
+ import configs.ViTPose_huge_coco_256x192 as h_cfg
24
+
25
+ from vit_models.model import ViTPose
26
+ from datasets.COCO import COCODataset
27
+ from vit_utils.train_valid_fn import train_model
28
+
29
+ CUR_PATH = osp.dirname(__file__)
30
+
31
+ @click.command()
32
+ @click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
33
+ @click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
34
+ def main(config_path, model_name):
35
+
36
+ cfg = {'b':b_cfg,
37
+ 's':s_cfg,
38
+ 'l':l_cfg,
39
+ 'h':h_cfg}.get(model_name.lower())
40
+ # Load config.yaml
41
+ with open(config_path, 'r') as f:
42
+ cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
43
+
44
+ for k, v in cfg_yaml.items():
45
+ if hasattr(cfg, k):
46
+ raise ValueError(f"Already exists {k} in config")
47
+ else:
48
+ cfg.__setattr__(k, v)
49
+
50
+ # set cudnn_benchmark
51
+ if cfg.cudnn_benchmark:
52
+ torch.backends.cudnn.benchmark = True
53
+
54
+ # Set work directory (session-level)
55
+ if not hasattr(cfg, 'work_dir'):
56
+ cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
57
+
58
+ if not osp.exists(cfg.work_dir):
59
+ os.makedirs(cfg.work_dir)
60
+ session_list = sorted(glob(f"{cfg.work_dir}/*"))
61
+ if len(session_list) == 0:
62
+ session = 1
63
+ else:
64
+ session = int(os.path.basename(session_list[-1])) + 1
65
+ session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
66
+ os.makedirs(session_dir)
67
+ cfg.__setattr__('work_dir', session_dir)
68
+
69
+
70
+ if cfg.autoscale_lr:
71
+ # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
72
+ cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
73
+
74
+ # init distributed env first, since logger depends on the dist info.
75
+ if cfg.launcher == 'none':
76
+ distributed = False
77
+ if len(cfg.gpu_ids) > 1:
78
+ warnings.warn(
79
+ f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
80
+ f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
81
+ "non-distribute training time.")
82
+ cfg.gpu_ids = cfg.gpu_ids[0:1]
83
+ else:
84
+ distributed = True
85
+ init_dist(cfg.launcher, **cfg.dist_params)
86
+ # re-set gpu_ids with distributed training mode
87
+ _, world_size = get_dist_info()
88
+ cfg.gpu_ids = range(world_size)
89
+
90
+ # init the logger before other steps
91
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
92
+ log_file = osp.join(session_dir, f'{timestamp}.log')
93
+ logger = get_root_logger(log_file=log_file)
94
+
95
+ # init the meta dict to record some important information such as
96
+ # environment info and seed, which will be logged
97
+ meta = dict()
98
+
99
+ # log some basic info
100
+ logger.info(f'Distributed training: {distributed}')
101
+
102
+ # set random seeds
103
+ seed = init_random_seed(cfg.seed)
104
+ logger.info(f"Set random seed to {seed}, "
105
+ f"deterministic: {cfg.deterministic}")
106
+ set_random_seed(seed, deterministic=cfg.deterministic)
107
+ meta['seed'] = seed
108
+
109
+ # Set model
110
+ model = ViTPose(cfg.model)
111
+ if cfg.resume_from:
112
+ # Load ckpt partially
113
+ ckpt_state = torch.load(cfg.resume_from)['state_dict']
114
+ ckpt_state.pop('keypoint_head.final_layer.bias')
115
+ ckpt_state.pop('keypoint_head.final_layer.weight')
116
+ model.load_state_dict(ckpt_state, strict=False)
117
+
118
+ # freeze the backbone, leave the head to be finetuned
119
+ model.backbone.frozen_stages = model.backbone.depth - 1
120
+ model.backbone.freeze_ffn = True
121
+ model.backbone.freeze_attn = True
122
+ model.backbone._freeze_stages()
123
+
124
+ # Set dataset
125
+ datasets_train = COCODataset(
126
+ root_path=cfg.data_root,
127
+ data_version="feet_train",
128
+ is_train=True,
129
+ use_gt_bboxes=True,
130
+ image_width=192,
131
+ image_height=256,
132
+ scale=True,
133
+ scale_factor=0.35,
134
+ flip_prob=0.5,
135
+ rotate_prob=0.5,
136
+ rotation_factor=45.,
137
+ half_body_prob=0.3,
138
+ use_different_joints_weight=True,
139
+ heatmap_sigma=3,
140
+ soft_nms=False
141
+ )
142
+
143
+ datasets_valid = COCODataset(
144
+ root_path=cfg.data_root,
145
+ data_version="feet_val",
146
+ is_train=False,
147
+ use_gt_bboxes=True,
148
+ image_width=192,
149
+ image_height=256,
150
+ scale=False,
151
+ scale_factor=0.35,
152
+ flip_prob=0.5,
153
+ rotate_prob=0.5,
154
+ rotation_factor=45.,
155
+ half_body_prob=0.3,
156
+ use_different_joints_weight=True,
157
+ heatmap_sigma=3,
158
+ soft_nms=False
159
+ )
160
+
161
+ train_model(
162
+ model=model,
163
+ datasets_train=datasets_train,
164
+ datasets_valid=datasets_valid,
165
+ cfg=cfg,
166
+ distributed=distributed,
167
+ validate=cfg.validate,
168
+ timestamp=timestamp,
169
+ meta=meta
170
+ )
171
+
172
+
173
+ if __name__ == '__main__':
174
+ main()
easy_ViTPose/vit_models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path as osp
3
+
4
+ sys.path.append(osp.dirname(osp.dirname(__file__)))
5
+
6
+ from vit_utils.util import load_checkpoint, resize, constant_init, normal_init
7
+ from vit_utils.top_down_eval import keypoints_from_heatmaps, pose_pck_accuracy
8
+ from vit_utils.post_processing import *
easy_ViTPose/vit_models/backbone/__init__.py ADDED
File without changes
easy_ViTPose/vit_models/backbone/vit.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import warnings
4
+
5
+ from itertools import repeat
6
+ import collections.abc
7
+
8
+ import torch
9
+ from functools import partial
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+ from torch import Tensor
14
+
15
+ # from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ # from .base_backbone import BaseBackbone
18
+
19
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
20
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
21
+
22
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
23
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
24
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
25
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
26
+ 'survival rate' as the argument.
27
+
28
+ """
29
+ if drop_prob == 0. or not training:
30
+ return x
31
+ keep_prob = 1 - drop_prob
32
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
33
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
34
+ if keep_prob > 0.0 and scale_by_keep:
35
+ random_tensor.div_(keep_prob)
36
+ return x * random_tensor
37
+
38
+ def _ntuple(n):
39
+ def parse(x):
40
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
41
+ return x
42
+ return tuple(repeat(x, n))
43
+ return parse
44
+
45
+
46
+ to_1tuple = _ntuple(1)
47
+ to_2tuple = _ntuple(2)
48
+ to_3tuple = _ntuple(3)
49
+ to_4tuple = _ntuple(4)
50
+ to_ntuple = _ntuple
51
+
52
+ def _trunc_normal_(tensor, mean, std, a, b):
53
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
54
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
55
+ def norm_cdf(x):
56
+ # Computes standard normal cumulative distribution function
57
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
58
+
59
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
60
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
61
+ "The distribution of values may be incorrect.",
62
+ stacklevel=2)
63
+
64
+ # Values are generated by using a truncated uniform distribution and
65
+ # then using the inverse CDF for the normal distribution.
66
+ # Get upper and lower cdf values
67
+ l = norm_cdf((a - mean) / std)
68
+ u = norm_cdf((b - mean) / std)
69
+
70
+ # Uniformly fill tensor with values from [l, u], then translate to
71
+ # [2l-1, 2u-1].
72
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
73
+
74
+ # Use inverse cdf transform for normal distribution to get truncated
75
+ # standard normal
76
+ tensor.erfinv_()
77
+
78
+ # Transform to proper mean, std
79
+ tensor.mul_(std * math.sqrt(2.))
80
+ tensor.add_(mean)
81
+
82
+ # Clamp to ensure it's in the proper range
83
+ tensor.clamp_(min=a, max=b)
84
+ return tensor
85
+
86
+
87
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
88
+ # type: (Tensor, float, float, float, float) -> Tensor
89
+ r"""Fills the input Tensor with values drawn from a truncated
90
+ normal distribution. The values are effectively drawn from the
91
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
92
+ with values outside :math:`[a, b]` redrawn until they are within
93
+ the bounds. The method used for generating the random values works
94
+ best when :math:`a \leq \text{mean} \leq b`.
95
+
96
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
97
+ applied while sampling the normal with mean/std applied, therefore a, b args
98
+ should be adjusted to match the range of mean, std args.
99
+
100
+ Args:
101
+ tensor: an n-dimensional `torch.Tensor`
102
+ mean: the mean of the normal distribution
103
+ std: the standard deviation of the normal distribution
104
+ a: the minimum cutoff value
105
+ b: the maximum cutoff value
106
+ Examples:
107
+ >>> w = torch.empty(3, 5)
108
+ >>> nn.init.trunc_normal_(w)
109
+ """
110
+ with torch.no_grad():
111
+ return _trunc_normal_(tensor, mean, std, a, b)
112
+
113
+ class DropPath(nn.Module):
114
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
115
+ """
116
+ def __init__(self, drop_prob=None):
117
+ super(DropPath, self).__init__()
118
+ self.drop_prob = drop_prob
119
+
120
+ def forward(self, x):
121
+ return drop_path(x, self.drop_prob, self.training)
122
+
123
+ def extra_repr(self):
124
+ return 'p={}'.format(self.drop_prob)
125
+
126
+ class Mlp(nn.Module):
127
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
128
+ super().__init__()
129
+ out_features = out_features or in_features
130
+ hidden_features = hidden_features or in_features
131
+ self.fc1 = nn.Linear(in_features, hidden_features)
132
+ self.act = act_layer()
133
+ self.fc2 = nn.Linear(hidden_features, out_features)
134
+ self.drop = nn.Dropout(drop)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x)
138
+ x = self.act(x)
139
+ x = self.fc2(x)
140
+ x = self.drop(x)
141
+ return x
142
+
143
+ class Attention(nn.Module):
144
+ def __init__(
145
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
146
+ proj_drop=0., attn_head_dim=None,):
147
+ super().__init__()
148
+ self.num_heads = num_heads
149
+ head_dim = dim // num_heads
150
+ self.dim = dim
151
+
152
+ if attn_head_dim is not None:
153
+ head_dim = attn_head_dim
154
+ all_head_dim = head_dim * self.num_heads
155
+
156
+ self.scale = qk_scale or head_dim ** -0.5
157
+
158
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
159
+
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(all_head_dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop)
163
+
164
+ def forward(self, x):
165
+ B, N, C = x.shape
166
+ qkv = self.qkv(x)
167
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
168
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
169
+
170
+ q = q * self.scale
171
+ attn = (q @ k.transpose(-2, -1))
172
+
173
+ attn = attn.softmax(dim=-1)
174
+ attn = self.attn_drop(attn)
175
+
176
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
177
+ x = self.proj(x)
178
+ x = self.proj_drop(x)
179
+
180
+ return x
181
+
182
+ class Block(nn.Module):
183
+
184
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
185
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
186
+ norm_layer=nn.LayerNorm, attn_head_dim=None
187
+ ):
188
+ super().__init__()
189
+
190
+ self.norm1 = norm_layer(dim)
191
+ self.attn = Attention(
192
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
193
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
194
+ )
195
+
196
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
197
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
198
+ self.norm2 = norm_layer(dim)
199
+ mlp_hidden_dim = int(dim * mlp_ratio)
200
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
201
+
202
+ def forward(self, x):
203
+ x = x + self.drop_path(self.attn(self.norm1(x)))
204
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
205
+ return x
206
+
207
+
208
+ class PatchEmbed(nn.Module):
209
+ """ Image to Patch Embedding
210
+ """
211
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
212
+ super().__init__()
213
+ img_size = to_2tuple(img_size)
214
+ patch_size = to_2tuple(patch_size)
215
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
216
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
217
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
218
+ self.img_size = img_size
219
+ self.patch_size = patch_size
220
+ self.num_patches = num_patches
221
+
222
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
223
+
224
+ def forward(self, x):
225
+ x = self.proj(x)
226
+ B, C, Hp, Wp = x.shape
227
+ x = x.view(B, C, Hp * Wp).transpose(1, 2)
228
+ return x, (Hp, Wp)
229
+
230
+
231
+ class HybridEmbed(nn.Module):
232
+ """ CNN Feature Map Embedding
233
+ Extract feature map from CNN, flatten, project to embedding dim.
234
+ """
235
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
236
+ super().__init__()
237
+ assert isinstance(backbone, nn.Module)
238
+ img_size = to_2tuple(img_size)
239
+ self.img_size = img_size
240
+ self.backbone = backbone
241
+ if feature_size is None:
242
+ with torch.no_grad():
243
+ training = backbone.training
244
+ if training:
245
+ backbone.eval()
246
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
247
+ feature_size = o.shape[-2:]
248
+ feature_dim = o.shape[1]
249
+ backbone.train(training)
250
+ else:
251
+ feature_size = to_2tuple(feature_size)
252
+ feature_dim = self.backbone.feature_info.channels()[-1]
253
+ self.num_patches = feature_size[0] * feature_size[1]
254
+ self.proj = nn.Linear(feature_dim, embed_dim)
255
+
256
+ def forward(self, x):
257
+ x = self.backbone(x)[-1]
258
+ x = x.flatten(2).transpose(1, 2)
259
+ x = self.proj(x)
260
+ return x
261
+
262
+
263
+ class ViT(nn.Module):
264
+ def __init__(self,
265
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
266
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
267
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
268
+ frozen_stages=-1, ratio=1, last_norm=True,
269
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
270
+ ):
271
+ super(ViT, self).__init__()
272
+ # Protect mutable default arguments
273
+
274
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
275
+ self.num_classes = num_classes
276
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
277
+ self.frozen_stages = frozen_stages
278
+ self.use_checkpoint = use_checkpoint
279
+ self.patch_padding = patch_padding
280
+ self.freeze_attn = freeze_attn
281
+ self.freeze_ffn = freeze_ffn
282
+ self.depth = depth
283
+
284
+ if hybrid_backbone is not None:
285
+ self.patch_embed = HybridEmbed(
286
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
287
+ else:
288
+ self.patch_embed = PatchEmbed(
289
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
290
+ num_patches = self.patch_embed.num_patches
291
+
292
+ # since the pretraining model has class token
293
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
294
+
295
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
296
+
297
+ self.blocks = nn.ModuleList([
298
+ Block(
299
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
300
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
301
+ )
302
+ for i in range(depth)])
303
+
304
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
305
+
306
+ if self.pos_embed is not None:
307
+ trunc_normal_(self.pos_embed, std=.02)
308
+
309
+ self._freeze_stages()
310
+
311
+ def _freeze_stages(self):
312
+ """Freeze parameters."""
313
+ if self.frozen_stages >= 0:
314
+ self.patch_embed.eval()
315
+ for param in self.patch_embed.parameters():
316
+ param.requires_grad = False
317
+
318
+ for i in range(1, self.frozen_stages + 1):
319
+ m = self.blocks[i]
320
+ m.eval()
321
+ for param in m.parameters():
322
+ param.requires_grad = False
323
+
324
+ if self.freeze_attn:
325
+ for i in range(0, self.depth):
326
+ m = self.blocks[i]
327
+ m.attn.eval()
328
+ m.norm1.eval()
329
+ for param in m.attn.parameters():
330
+ param.requires_grad = False
331
+ for param in m.norm1.parameters():
332
+ param.requires_grad = False
333
+
334
+ if self.freeze_ffn:
335
+ self.pos_embed.requires_grad = False
336
+ self.patch_embed.eval()
337
+ for param in self.patch_embed.parameters():
338
+ param.requires_grad = False
339
+ for i in range(0, self.depth):
340
+ m = self.blocks[i]
341
+ m.mlp.eval()
342
+ m.norm2.eval()
343
+ for param in m.mlp.parameters():
344
+ param.requires_grad = False
345
+ for param in m.norm2.parameters():
346
+ param.requires_grad = False
347
+
348
+ def init_weights(self, pretrained=None):
349
+ """Initialize the weights in backbone.
350
+ Args:
351
+ pretrained (str, optional): Path to pre-trained weights.
352
+ Defaults to None.
353
+ """
354
+ super().init_weights(pretrained, patch_padding=self.patch_padding)
355
+
356
+ if pretrained is None:
357
+ def _init_weights(m):
358
+ if isinstance(m, nn.Linear):
359
+ trunc_normal_(m.weight, std=.02)
360
+ if isinstance(m, nn.Linear) and m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+ elif isinstance(m, nn.LayerNorm):
363
+ nn.init.constant_(m.bias, 0)
364
+ nn.init.constant_(m.weight, 1.0)
365
+
366
+ self.apply(_init_weights)
367
+
368
+ def get_num_layers(self):
369
+ return len(self.blocks)
370
+
371
+ @torch.jit.ignore
372
+ def no_weight_decay(self):
373
+ return {'pos_embed', 'cls_token'}
374
+
375
+ def forward(self, x):
376
+ B, C, H, W = x.shape
377
+ x, (Hp, Wp) = self.patch_embed(x)
378
+
379
+ if self.pos_embed is not None:
380
+ # fit for multiple GPU training
381
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
382
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
383
+
384
+ for blk in self.blocks:
385
+ x = blk(x)
386
+
387
+ x = self.last_norm(x)
388
+ x = x.permute(0, 2, 1).view(B, -1, Hp, Wp).contiguous()
389
+ return x
390
+
391
+ def train(self, mode=True):
392
+ """Convert the model into training mode."""
393
+ super().train(mode)
394
+ self._freeze_stages()
easy_ViTPose/vit_models/head/__init__.py ADDED
File without changes
easy_ViTPose/vit_models/head/topdown_heatmap_base_head.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+ import numpy as np
5
+ import torch.nn as nn
6
+
7
+ from .. import keypoints_from_heatmaps
8
+
9
+
10
+ class TopdownHeatmapBaseHead(nn.Module):
11
+ """Base class for top-down heatmap heads.
12
+
13
+ All top-down heatmap heads should subclass it.
14
+ All subclass should overwrite:
15
+
16
+ Methods:`get_loss`, supporting to calculate loss.
17
+ Methods:`get_accuracy`, supporting to calculate accuracy.
18
+ Methods:`forward`, supporting to forward model.
19
+ Methods:`inference_model`, supporting to inference model.
20
+ """
21
+
22
+ __metaclass__ = ABCMeta
23
+
24
+ @abstractmethod
25
+ def get_loss(self, **kwargs):
26
+ """Gets the loss."""
27
+
28
+ @abstractmethod
29
+ def get_accuracy(self, **kwargs):
30
+ """Gets the accuracy."""
31
+
32
+ @abstractmethod
33
+ def forward(self, **kwargs):
34
+ """Forward function."""
35
+
36
+ @abstractmethod
37
+ def inference_model(self, **kwargs):
38
+ """Inference function."""
39
+
40
+ def decode(self, img_metas, output, **kwargs):
41
+ """Decode keypoints from heatmaps.
42
+
43
+ Args:
44
+ img_metas (list(dict)): Information about data augmentation
45
+ By default this includes:
46
+
47
+ - "image_file: path to the image file
48
+ - "center": center of the bbox
49
+ - "scale": scale of the bbox
50
+ - "rotation": rotation of the bbox
51
+ - "bbox_score": score of bbox
52
+ output (np.ndarray[N, K, H, W]): model predicted heatmaps.
53
+ """
54
+ batch_size = len(img_metas)
55
+
56
+ if 'bbox_id' in img_metas[0]:
57
+ bbox_ids = []
58
+ else:
59
+ bbox_ids = None
60
+
61
+ c = np.zeros((batch_size, 2), dtype=np.float32)
62
+ s = np.zeros((batch_size, 2), dtype=np.float32)
63
+ image_paths = []
64
+ score = np.ones(batch_size)
65
+ for i in range(batch_size):
66
+ c[i, :] = img_metas[i]['center']
67
+ s[i, :] = img_metas[i]['scale']
68
+ image_paths.append(img_metas[i]['image_file'])
69
+
70
+ if 'bbox_score' in img_metas[i]:
71
+ score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
72
+ if bbox_ids is not None:
73
+ bbox_ids.append(img_metas[i]['bbox_id'])
74
+
75
+ preds, maxvals = keypoints_from_heatmaps(
76
+ output,
77
+ c,
78
+ s,
79
+ unbiased=self.test_cfg.get('unbiased_decoding', False),
80
+ post_process=self.test_cfg.get('post_process', 'default'),
81
+ kernel=self.test_cfg.get('modulate_kernel', 11),
82
+ valid_radius_factor=self.test_cfg.get('valid_radius_factor',
83
+ 0.0546875),
84
+ use_udp=self.test_cfg.get('use_udp', False),
85
+ target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
86
+
87
+ all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
88
+ all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
89
+ all_preds[:, :, 0:2] = preds[:, :, 0:2]
90
+ all_preds[:, :, 2:3] = maxvals
91
+ all_boxes[:, 0:2] = c[:, 0:2]
92
+ all_boxes[:, 2:4] = s[:, 0:2]
93
+ all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
94
+ all_boxes[:, 5] = score
95
+
96
+ result = {}
97
+
98
+ result['preds'] = all_preds
99
+ result['boxes'] = all_boxes
100
+ result['image_paths'] = image_paths
101
+ result['bbox_ids'] = bbox_ids
102
+
103
+ return result
104
+
105
+ @staticmethod
106
+ def _get_deconv_cfg(deconv_kernel):
107
+ """Get configurations for deconv layers."""
108
+ if deconv_kernel == 4:
109
+ padding = 1
110
+ output_padding = 0
111
+ elif deconv_kernel == 3:
112
+ padding = 1
113
+ output_padding = 1
114
+ elif deconv_kernel == 2:
115
+ padding = 0
116
+ output_padding = 0
117
+ else:
118
+ raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
119
+
120
+ return deconv_kernel, padding, output_padding
easy_ViTPose/vit_models/head/topdown_heatmap_simple_head.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import constant_init, normal_init
5
+
6
+ from .. import pose_pck_accuracy, flip_back, resize
7
+ import torch.nn.functional as F
8
+ from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
9
+
10
+
11
+ class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead):
12
+ """Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple
13
+ Baselines for Human Pose Estimation and Tracking``.
14
+
15
+ TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
16
+ and a simple conv2d layer.
17
+
18
+ Args:
19
+ in_channels (int): Number of input channels
20
+ out_channels (int): Number of output channels
21
+ num_deconv_layers (int): Number of deconv layers.
22
+ num_deconv_layers should >= 0. Note that 0 means
23
+ no deconv layers.
24
+ num_deconv_filters (list|tuple): Number of filters.
25
+ If num_deconv_layers > 0, the length of
26
+ num_deconv_kernels (list|tuple): Kernel sizes.
27
+ in_index (int|Sequence[int]): Input feature index. Default: 0
28
+ input_transform (str|None): Transformation type of input features.
29
+ Options: 'resize_concat', 'multiple_select', None.
30
+ Default: None.
31
+
32
+ - 'resize_concat': Multiple feature maps will be resized to the
33
+ same size as the first one and then concat together.
34
+ Usually used in FCN head of HRNet.
35
+ - 'multiple_select': Multiple feature maps will be bundle into
36
+ a list and passed into decode head.
37
+ - None: Only one select feature map is allowed.
38
+ align_corners (bool): align_corners argument of F.interpolate.
39
+ Default: False.
40
+ loss_keypoint (dict): Config for keypoint loss. Default: None.
41
+ """
42
+
43
+ def __init__(self,
44
+ in_channels,
45
+ out_channels,
46
+ num_deconv_layers=3,
47
+ num_deconv_filters=(256, 256, 256),
48
+ num_deconv_kernels=(4, 4, 4),
49
+ extra=None,
50
+ in_index=0,
51
+ input_transform=None,
52
+ align_corners=False,
53
+ loss_keypoint=None,
54
+ train_cfg=None,
55
+ test_cfg=None,
56
+ upsample=0,):
57
+ super().__init__()
58
+
59
+ self.in_channels = in_channels
60
+ self.loss = loss_keypoint
61
+ self.upsample = upsample
62
+
63
+ self.train_cfg = {} if train_cfg is None else train_cfg
64
+ self.test_cfg = {} if test_cfg is None else test_cfg
65
+ self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
66
+
67
+ self._init_inputs(in_channels, in_index, input_transform)
68
+ self.in_index = in_index
69
+ self.align_corners = align_corners
70
+
71
+ if extra is not None and not isinstance(extra, dict):
72
+ raise TypeError('extra should be dict or None.')
73
+
74
+ if num_deconv_layers > 0:
75
+ self.deconv_layers = self._make_deconv_layer(
76
+ num_deconv_layers,
77
+ num_deconv_filters,
78
+ num_deconv_kernels,
79
+ )
80
+ elif num_deconv_layers == 0:
81
+ self.deconv_layers = nn.Identity()
82
+ else:
83
+ raise ValueError(
84
+ f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
85
+
86
+ identity_final_layer = False
87
+ if extra is not None and 'final_conv_kernel' in extra:
88
+ assert extra['final_conv_kernel'] in [0, 1, 3]
89
+ if extra['final_conv_kernel'] == 3:
90
+ padding = 1
91
+ elif extra['final_conv_kernel'] == 1:
92
+ padding = 0
93
+ else:
94
+ # 0 for Identity mapping.
95
+ identity_final_layer = True
96
+ kernel_size = extra['final_conv_kernel']
97
+ else:
98
+ kernel_size = 1
99
+ padding = 0
100
+
101
+ if identity_final_layer:
102
+ self.final_layer = nn.Identity()
103
+ else:
104
+ conv_channels = num_deconv_filters[
105
+ -1] if num_deconv_layers > 0 else self.in_channels
106
+
107
+ layers = []
108
+ if extra is not None:
109
+ num_conv_layers = extra.get('num_conv_layers', 0)
110
+ num_conv_kernels = extra.get('num_conv_kernels',
111
+ [1] * num_conv_layers)
112
+
113
+ for i in range(num_conv_layers):
114
+ layers.append(
115
+ nn.Conv2d(in_channels=conv_channels,
116
+ out_channels=conv_channels,
117
+ kernel_size=num_conv_kernels[i],
118
+ stride=1,
119
+ padding=(num_conv_kernels[i] - 1) // 2)
120
+ )
121
+ layers.append(nn.BatchNorm2d(conv_channels))
122
+ layers.append(nn.ReLU(inplace=True))
123
+
124
+ layers.append(
125
+ nn.Conv2d(in_channels=conv_channels,
126
+ out_channels=out_channels,
127
+ kernel_size=kernel_size,
128
+ stride=1,
129
+ padding=padding)
130
+ )
131
+
132
+ if len(layers) > 1:
133
+ self.final_layer = nn.Sequential(*layers)
134
+ else:
135
+ self.final_layer = layers[0]
136
+
137
+ def get_loss(self, output, target, target_weight):
138
+ """Calculate top-down keypoint loss.
139
+
140
+ Note:
141
+ - batch_size: N
142
+ - num_keypoints: K
143
+ - heatmaps height: H
144
+ - heatmaps weight: W
145
+
146
+ Args:
147
+ output (torch.Tensor[N,K,H,W]): Output heatmaps.
148
+ target (torch.Tensor[N,K,H,W]): Target heatmaps.
149
+ target_weight (torch.Tensor[N,K,1]):
150
+ Weights across different joint types.
151
+ """
152
+
153
+ losses = dict()
154
+
155
+ assert not isinstance(self.loss, nn.Sequential)
156
+ assert target.dim() == 4 and target_weight.dim() == 3
157
+ losses['heatmap_loss'] = self.loss(output, target, target_weight)
158
+
159
+ return losses
160
+
161
+ def get_accuracy(self, output, target, target_weight):
162
+ """Calculate accuracy for top-down keypoint loss.
163
+
164
+ Note:
165
+ - batch_size: N
166
+ - num_keypoints: K
167
+ - heatmaps height: H
168
+ - heatmaps weight: W
169
+
170
+ Args:
171
+ output (torch.Tensor[N,K,H,W]): Output heatmaps.
172
+ target (torch.Tensor[N,K,H,W]): Target heatmaps.
173
+ target_weight (torch.Tensor[N,K,1]):
174
+ Weights across different joint types.
175
+ """
176
+
177
+ accuracy = dict()
178
+
179
+ if self.target_type == 'GaussianHeatmap':
180
+ _, avg_acc, _ = pose_pck_accuracy(
181
+ output.detach().cpu().numpy(),
182
+ target.detach().cpu().numpy(),
183
+ target_weight.detach().cpu().numpy().squeeze(-1) > 0)
184
+ accuracy['acc_pose'] = float(avg_acc)
185
+
186
+ return accuracy
187
+
188
+ def forward(self, x):
189
+ """Forward function."""
190
+ x = self._transform_inputs(x)
191
+ x = self.deconv_layers(x)
192
+ x = self.final_layer(x)
193
+ return x
194
+
195
+ def inference_model(self, x, flip_pairs=None):
196
+ """Inference function.
197
+
198
+ Returns:
199
+ output_heatmap (np.ndarray): Output heatmaps.
200
+
201
+ Args:
202
+ x (torch.Tensor[N,K,H,W]): Input features.
203
+ flip_pairs (None | list[tuple]):
204
+ Pairs of keypoints which are mirrored.
205
+ """
206
+ output = self.forward(x)
207
+
208
+ if flip_pairs is not None:
209
+ output_heatmap = flip_back(
210
+ output.detach().cpu().numpy(),
211
+ flip_pairs,
212
+ target_type=self.target_type)
213
+ # feature is not aligned, shift flipped heatmap for higher accuracy
214
+ if self.test_cfg.get('shift_heatmap', False):
215
+ output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
216
+ else:
217
+ output_heatmap = output.detach().cpu().numpy()
218
+ return output_heatmap
219
+
220
+ def _init_inputs(self, in_channels, in_index, input_transform):
221
+ """Check and initialize input transforms.
222
+
223
+ The in_channels, in_index and input_transform must match.
224
+ Specifically, when input_transform is None, only single feature map
225
+ will be selected. So in_channels and in_index must be of type int.
226
+ When input_transform is not None, in_channels and in_index must be
227
+ list or tuple, with the same length.
228
+
229
+ Args:
230
+ in_channels (int|Sequence[int]): Input channels.
231
+ in_index (int|Sequence[int]): Input feature index.
232
+ input_transform (str|None): Transformation type of input features.
233
+ Options: 'resize_concat', 'multiple_select', None.
234
+
235
+ - 'resize_concat': Multiple feature maps will be resize to the
236
+ same size as first one and than concat together.
237
+ Usually used in FCN head of HRNet.
238
+ - 'multiple_select': Multiple feature maps will be bundle into
239
+ a list and passed into decode head.
240
+ - None: Only one select feature map is allowed.
241
+ """
242
+
243
+ if input_transform is not None:
244
+ assert input_transform in ['resize_concat', 'multiple_select']
245
+ self.input_transform = input_transform
246
+ self.in_index = in_index
247
+ if input_transform is not None:
248
+ assert isinstance(in_channels, (list, tuple))
249
+ assert isinstance(in_index, (list, tuple))
250
+ assert len(in_channels) == len(in_index)
251
+ if input_transform == 'resize_concat':
252
+ self.in_channels = sum(in_channels)
253
+ else:
254
+ self.in_channels = in_channels
255
+ else:
256
+ assert isinstance(in_channels, int)
257
+ assert isinstance(in_index, int)
258
+ self.in_channels = in_channels
259
+
260
+ def _transform_inputs(self, inputs):
261
+ """Transform inputs for decoder.
262
+
263
+ Args:
264
+ inputs (list[Tensor] | Tensor): multi-level img features.
265
+
266
+ Returns:
267
+ Tensor: The transformed inputs
268
+ """
269
+ if not isinstance(inputs, list):
270
+ if self.upsample > 0:
271
+ raise NotImplementedError
272
+ return inputs
273
+
274
+ if self.input_transform == 'resize_concat':
275
+ inputs = [inputs[i] for i in self.in_index]
276
+ upsampled_inputs = [
277
+ resize(
278
+ input=x,
279
+ size=inputs[0].shape[2:],
280
+ mode='bilinear',
281
+ align_corners=self.align_corners) for x in inputs
282
+ ]
283
+ inputs = torch.cat(upsampled_inputs, dim=1)
284
+ elif self.input_transform == 'multiple_select':
285
+ inputs = [inputs[i] for i in self.in_index]
286
+ else:
287
+ inputs = inputs[self.in_index]
288
+
289
+ return inputs
290
+
291
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
292
+ """Make deconv layers."""
293
+ if num_layers != len(num_filters):
294
+ error_msg = f'num_layers({num_layers}) ' \
295
+ f'!= length of num_filters({len(num_filters)})'
296
+ raise ValueError(error_msg)
297
+ if num_layers != len(num_kernels):
298
+ error_msg = f'num_layers({num_layers}) ' \
299
+ f'!= length of num_kernels({len(num_kernels)})'
300
+ raise ValueError(error_msg)
301
+
302
+ layers = []
303
+ for i in range(num_layers):
304
+ kernel, padding, output_padding = \
305
+ self._get_deconv_cfg(num_kernels[i])
306
+
307
+ planes = num_filters[i]
308
+ layers.append(
309
+ nn.ConvTranspose2d(in_channels=self.in_channels,
310
+ out_channels=planes,
311
+ kernel_size=kernel,
312
+ stride=2,
313
+ padding=padding,
314
+ output_padding=output_padding,
315
+ bias=False)
316
+ )
317
+ layers.append(nn.BatchNorm2d(planes))
318
+ layers.append(nn.ReLU(inplace=True))
319
+ self.in_channels = planes
320
+
321
+ return nn.Sequential(*layers)
322
+
323
+ def init_weights(self):
324
+ """Initialize model weights."""
325
+ for _, m in self.deconv_layers.named_modules():
326
+ if isinstance(m, nn.ConvTranspose2d):
327
+ normal_init(m, std=0.001)
328
+ elif isinstance(m, nn.BatchNorm2d):
329
+ constant_init(m, 1)
330
+ for m in self.final_layer.modules():
331
+ if isinstance(m, nn.Conv2d):
332
+ normal_init(m, std=0.001, bias=0)
333
+ elif isinstance(m, nn.BatchNorm2d):
334
+ constant_init(m, 1)
easy_ViTPose/vit_models/losses/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .classfication_loss import BCELoss
3
+ from .heatmap_loss import AdaptiveWingLoss
4
+ from .mesh_loss import GANLoss, MeshLoss
5
+ from .mse_loss import JointsMSELoss, JointsOHKMMSELoss
6
+ from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory
7
+ from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss,
8
+ SemiSupervisionLoss, SmoothL1Loss, SoftWingLoss,
9
+ WingLoss)
10
+
11
+ __all__ = [
12
+ 'JointsMSELoss', 'JointsOHKMMSELoss', 'HeatmapLoss', 'AELoss',
13
+ 'MultiLossFactory', 'MeshLoss', 'GANLoss', 'SmoothL1Loss', 'WingLoss',
14
+ 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss',
15
+ 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss'
16
+ ]
easy_ViTPose/vit_models/losses/classfication_loss.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ __all__ = ['BCELoss']
7
+
8
+
9
+ class BCELoss(nn.Module):
10
+ """Binary Cross Entropy loss."""
11
+
12
+ def __init__(self, use_target_weight=False, loss_weight=1.):
13
+ super().__init__()
14
+ self.criterion = F.binary_cross_entropy
15
+ self.use_target_weight = use_target_weight
16
+ self.loss_weight = loss_weight
17
+
18
+ def forward(self, output, target, target_weight=None):
19
+ """Forward function.
20
+
21
+ Note:
22
+ - batch_size: N
23
+ - num_labels: K
24
+
25
+ Args:
26
+ output (torch.Tensor[N, K]): Output classification.
27
+ target (torch.Tensor[N, K]): Target classification.
28
+ target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
29
+ Weights across different labels.
30
+ """
31
+
32
+ if self.use_target_weight:
33
+ assert target_weight is not None
34
+ loss = self.criterion(output, target, reduction='none')
35
+ if target_weight.dim() == 1:
36
+ target_weight = target_weight[:, None]
37
+ loss = (loss * target_weight).mean()
38
+ else:
39
+ loss = self.criterion(output, target)
40
+
41
+ return loss * self.loss_weight
easy_ViTPose/vit_models/losses/heatmap_loss.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class AdaptiveWingLoss(nn.Module):
7
+ """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
8
+ Alignment via Heatmap Regression' Wang et al. ICCV'2019.
9
+
10
+ Args:
11
+ alpha (float), omega (float), epsilon (float), theta (float)
12
+ are hyper-parameters.
13
+ use_target_weight (bool): Option to use weighted MSE loss.
14
+ Different joint types may have different target weights.
15
+ loss_weight (float): Weight of the loss. Default: 1.0.
16
+ """
17
+
18
+ def __init__(self,
19
+ alpha=2.1,
20
+ omega=14,
21
+ epsilon=1,
22
+ theta=0.5,
23
+ use_target_weight=False,
24
+ loss_weight=1.):
25
+ super().__init__()
26
+ self.alpha = float(alpha)
27
+ self.omega = float(omega)
28
+ self.epsilon = float(epsilon)
29
+ self.theta = float(theta)
30
+ self.use_target_weight = use_target_weight
31
+ self.loss_weight = loss_weight
32
+
33
+ def criterion(self, pred, target):
34
+ """Criterion of wingloss.
35
+
36
+ Note:
37
+ batch_size: N
38
+ num_keypoints: K
39
+
40
+ Args:
41
+ pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
42
+ target (torch.Tensor[NxKxHxW]): Target heatmaps.
43
+ """
44
+ H, W = pred.shape[2:4]
45
+ delta = (target - pred).abs()
46
+
47
+ A = self.omega * (
48
+ 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
49
+ ) * (self.alpha - target) * (torch.pow(
50
+ self.theta / self.epsilon,
51
+ self.alpha - target - 1)) * (1 / self.epsilon)
52
+ C = self.theta * A - self.omega * torch.log(
53
+ 1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
54
+
55
+ losses = torch.where(
56
+ delta < self.theta,
57
+ self.omega *
58
+ torch.log(1 +
59
+ torch.pow(delta / self.epsilon, self.alpha - target)),
60
+ A * delta - C)
61
+
62
+ return torch.mean(losses)
63
+
64
+ def forward(self, output, target, target_weight):
65
+ """Forward function.
66
+
67
+ Note:
68
+ batch_size: N
69
+ num_keypoints: K
70
+
71
+ Args:
72
+ output (torch.Tensor[NxKxHxW]): Output heatmaps.
73
+ target (torch.Tensor[NxKxHxW]): Target heatmaps.
74
+ target_weight (torch.Tensor[NxKx1]):
75
+ Weights across different joint types.
76
+ """
77
+ if self.use_target_weight:
78
+ loss = self.criterion(output * target_weight.unsqueeze(-1),
79
+ target * target_weight.unsqueeze(-1))
80
+ else:
81
+ loss = self.criterion(output, target)
82
+
83
+ return loss * self.loss_weight
easy_ViTPose/vit_models/losses/mesh_loss.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ __all__ = ['MeshLoss', 'GANLoss']
6
+
7
+ def rot6d_to_rotmat(x):
8
+ """Convert 6D rotation representation to 3x3 rotation matrix.
9
+
10
+ Based on Zhou et al., "On the Continuity of Rotation
11
+ Representations in Neural Networks", CVPR 2019
12
+ Input:
13
+ (B,6) Batch of 6-D rotation representations
14
+ Output:
15
+ (B,3,3) Batch of corresponding rotation matrices
16
+ """
17
+ x = x.view(-1, 3, 2)
18
+ a1 = x[:, :, 0]
19
+ a2 = x[:, :, 1]
20
+ b1 = F.normalize(a1)
21
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
22
+ b3 = torch.cross(b1, b2)
23
+ return torch.stack((b1, b2, b3), dim=-1)
24
+
25
+
26
+ def batch_rodrigues(theta):
27
+ """Convert axis-angle representation to rotation matrix.
28
+ Args:
29
+ theta: size = [B, 3]
30
+ Returns:
31
+ Rotation matrix corresponding to the quaternion
32
+ -- size = [B, 3, 3]
33
+ """
34
+ l2norm = torch.norm(theta + 1e-8, p=2, dim=1)
35
+ angle = torch.unsqueeze(l2norm, -1)
36
+ normalized = torch.div(theta, angle)
37
+ angle = angle * 0.5
38
+ v_cos = torch.cos(angle)
39
+ v_sin = torch.sin(angle)
40
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
41
+ return quat_to_rotmat(quat)
42
+
43
+
44
+ def quat_to_rotmat(quat):
45
+ """Convert quaternion coefficients to rotation matrix.
46
+ Args:
47
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
48
+ Returns:
49
+ Rotation matrix corresponding to the quaternion
50
+ -- size = [B, 3, 3]
51
+ """
52
+ norm_quat = quat
53
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
54
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\
55
+ norm_quat[:, 2], norm_quat[:, 3]
56
+
57
+ B = quat.size(0)
58
+
59
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
60
+ wx, wy, wz = w * x, w * y, w * z
61
+ xy, xz, yz = x * y, x * z, y * z
62
+
63
+ rotMat = torch.stack([
64
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
65
+ w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
66
+ w2 - x2 - y2 + z2
67
+ ],
68
+ dim=1).view(B, 3, 3)
69
+ return rotMat
70
+
71
+
72
+
73
+ def perspective_projection(points, rotation, translation, focal_length,
74
+ camera_center):
75
+ """This function computes the perspective projection of a set of 3D points.
76
+
77
+ Note:
78
+ - batch size: B
79
+ - point number: N
80
+
81
+ Args:
82
+ points (Tensor([B, N, 3])): A set of 3D points
83
+ rotation (Tensor([B, 3, 3])): Camera rotation matrix
84
+ translation (Tensor([B, 3])): Camera translation
85
+ focal_length (Tensor([B,])): Focal length
86
+ camera_center (Tensor([B, 2])): Camera center
87
+
88
+ Returns:
89
+ projected_points (Tensor([B, N, 2])): Projected 2D
90
+ points in image space.
91
+ """
92
+
93
+ batch_size = points.shape[0]
94
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
95
+ K[:, 0, 0] = focal_length
96
+ K[:, 1, 1] = focal_length
97
+ K[:, 2, 2] = 1.
98
+ K[:, :-1, -1] = camera_center
99
+
100
+ # Transform points
101
+ points = torch.einsum('bij,bkj->bki', rotation, points)
102
+ points = points + translation.unsqueeze(1)
103
+
104
+ # Apply perspective distortion
105
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
106
+
107
+ # Apply camera intrinsics
108
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
109
+ projected_points = projected_points[:, :, :-1]
110
+ return projected_points
111
+
112
+
113
+ class MeshLoss(nn.Module):
114
+ """Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D
115
+ joints, mesh vertices and smpl parameters (if any).
116
+
117
+ Args:
118
+ joints_2d_loss_weight (float): Weight for loss on 2D joints.
119
+ joints_3d_loss_weight (float): Weight for loss on 3D joints.
120
+ vertex_loss_weight (float): Weight for loss on 3D verteices.
121
+ smpl_pose_loss_weight (float): Weight for loss on SMPL
122
+ pose parameters.
123
+ smpl_beta_loss_weight (float): Weight for loss on SMPL
124
+ shape parameters.
125
+ img_res (int): Input image resolution.
126
+ focal_length (float): Focal length of camera model. Default=5000.
127
+ """
128
+
129
+ def __init__(self,
130
+ joints_2d_loss_weight,
131
+ joints_3d_loss_weight,
132
+ vertex_loss_weight,
133
+ smpl_pose_loss_weight,
134
+ smpl_beta_loss_weight,
135
+ img_res,
136
+ focal_length=5000):
137
+
138
+ super().__init__()
139
+ # Per-vertex loss on the mesh
140
+ self.criterion_vertex = nn.L1Loss(reduction='none')
141
+
142
+ # Joints (2D and 3D) loss
143
+ self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none')
144
+ self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none')
145
+
146
+ # Loss for SMPL parameter regression
147
+ self.criterion_regr = nn.MSELoss(reduction='none')
148
+
149
+ self.joints_2d_loss_weight = joints_2d_loss_weight
150
+ self.joints_3d_loss_weight = joints_3d_loss_weight
151
+ self.vertex_loss_weight = vertex_loss_weight
152
+ self.smpl_pose_loss_weight = smpl_pose_loss_weight
153
+ self.smpl_beta_loss_weight = smpl_beta_loss_weight
154
+ self.focal_length = focal_length
155
+ self.img_res = img_res
156
+
157
+ def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible):
158
+ """Compute 2D reprojection loss on the joints.
159
+
160
+ The loss is weighted by joints_2d_visible.
161
+ """
162
+ conf = joints_2d_visible.float()
163
+ loss = (conf *
164
+ self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean()
165
+ return loss
166
+
167
+ def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible):
168
+ """Compute 3D joints loss for the examples that 3D joint annotations
169
+ are available.
170
+
171
+ The loss is weighted by joints_3d_visible.
172
+ """
173
+ conf = joints_3d_visible.float()
174
+ if len(gt_joints_3d) > 0:
175
+ gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2
176
+ gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :]
177
+ pred_pelvis = (pred_joints_3d[:, 2, :] +
178
+ pred_joints_3d[:, 3, :]) / 2
179
+ pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :]
180
+ return (
181
+ conf *
182
+ self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean()
183
+ return pred_joints_3d.sum() * 0
184
+
185
+ def vertex_loss(self, pred_vertices, gt_vertices, has_smpl):
186
+ """Compute 3D vertex loss for the examples that 3D human mesh
187
+ annotations are available.
188
+
189
+ The loss is weighted by the has_smpl.
190
+ """
191
+ conf = has_smpl.float()
192
+ loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices)
193
+ loss_vertex = (conf[:, None, None] * loss_vertex).mean()
194
+ return loss_vertex
195
+
196
+ def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
197
+ has_smpl):
198
+ """Compute SMPL parameters loss for the examples that SMPL parameter
199
+ annotations are available.
200
+
201
+ The loss is weighted by has_smpl.
202
+ """
203
+ conf = has_smpl.float()
204
+ gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
205
+ loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat)
206
+ loss_regr_betas = self.criterion_regr(pred_betas, gt_betas)
207
+ loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean()
208
+ loss_regr_betas = (conf[:, None] * loss_regr_betas).mean()
209
+ return loss_regr_pose, loss_regr_betas
210
+
211
+ def project_points(self, points_3d, camera):
212
+ """Perform orthographic projection of 3D points using the camera
213
+ parameters, return projected 2D points in image plane.
214
+
215
+ Note:
216
+ - batch size: B
217
+ - point number: N
218
+
219
+ Args:
220
+ points_3d (Tensor([B, N, 3])): 3D points.
221
+ camera (Tensor([B, 3])): camera parameters with the
222
+ 3 channel as (scale, translation_x, translation_y)
223
+
224
+ Returns:
225
+ Tensor([B, N, 2]): projected 2D points \
226
+ in image space.
227
+ """
228
+ batch_size = points_3d.shape[0]
229
+ device = points_3d.device
230
+ cam_t = torch.stack([
231
+ camera[:, 1], camera[:, 2], 2 * self.focal_length /
232
+ (self.img_res * camera[:, 0] + 1e-9)
233
+ ],
234
+ dim=-1)
235
+ camera_center = camera.new_zeros([batch_size, 2])
236
+ rot_t = torch.eye(
237
+ 3, device=device,
238
+ dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1)
239
+ joints_2d = perspective_projection(
240
+ points_3d,
241
+ rotation=rot_t,
242
+ translation=cam_t,
243
+ focal_length=self.focal_length,
244
+ camera_center=camera_center)
245
+ return joints_2d
246
+
247
+ def forward(self, output, target):
248
+ """Forward function.
249
+
250
+ Args:
251
+ output (dict): dict of network predicted results.
252
+ Keys: 'vertices', 'joints_3d', 'camera',
253
+ 'pose'(optional), 'beta'(optional)
254
+ target (dict): dict of ground-truth labels.
255
+ Keys: 'vertices', 'joints_3d', 'joints_3d_visible',
256
+ 'joints_2d', 'joints_2d_visible', 'pose', 'beta',
257
+ 'has_smpl'
258
+
259
+ Returns:
260
+ dict: dict of losses.
261
+ """
262
+ losses = {}
263
+
264
+ # Per-vertex loss for the shape
265
+ pred_vertices = output['vertices']
266
+
267
+ gt_vertices = target['vertices']
268
+ has_smpl = target['has_smpl']
269
+ loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl)
270
+ losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight
271
+
272
+ # Compute loss on SMPL parameters, if available
273
+ if 'pose' in output.keys() and 'beta' in output.keys():
274
+ pred_rotmat = output['pose']
275
+ pred_betas = output['beta']
276
+ gt_pose = target['pose']
277
+ gt_betas = target['beta']
278
+ loss_regr_pose, loss_regr_betas = self.smpl_losses(
279
+ pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl)
280
+ losses['smpl_pose_loss'] = \
281
+ loss_regr_pose * self.smpl_pose_loss_weight
282
+ losses['smpl_beta_loss'] = \
283
+ loss_regr_betas * self.smpl_beta_loss_weight
284
+
285
+ # Compute 3D joints loss
286
+ pred_joints_3d = output['joints_3d']
287
+ gt_joints_3d = target['joints_3d']
288
+ joints_3d_visible = target['joints_3d_visible']
289
+ loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d,
290
+ joints_3d_visible)
291
+ losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight
292
+
293
+ # Compute 2D reprojection loss for the 2D joints
294
+ pred_camera = output['camera']
295
+ gt_joints_2d = target['joints_2d']
296
+ joints_2d_visible = target['joints_2d_visible']
297
+ pred_joints_2d = self.project_points(pred_joints_3d, pred_camera)
298
+
299
+ # Normalize keypoints to [-1,1]
300
+ # The coordinate origin of pred_joints_2d is
301
+ # the center of the input image.
302
+ pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1)
303
+ # The coordinate origin of gt_joints_2d is
304
+ # the top left corner of the input image.
305
+ gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1
306
+ loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d,
307
+ joints_2d_visible)
308
+ losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight
309
+
310
+ return losses
311
+
312
+
313
+ class GANLoss(nn.Module):
314
+ """Define GAN loss.
315
+
316
+ Args:
317
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
318
+ real_label_val (float): The value for real label. Default: 1.0.
319
+ fake_label_val (float): The value for fake label. Default: 0.0.
320
+ loss_weight (float): Loss weight. Default: 1.0.
321
+ Note that loss_weight is only for generators; and it is always 1.0
322
+ for discriminators.
323
+ """
324
+
325
+ def __init__(self,
326
+ gan_type,
327
+ real_label_val=1.0,
328
+ fake_label_val=0.0,
329
+ loss_weight=1.0):
330
+ super().__init__()
331
+ self.gan_type = gan_type
332
+ self.loss_weight = loss_weight
333
+ self.real_label_val = real_label_val
334
+ self.fake_label_val = fake_label_val
335
+
336
+ if self.gan_type == 'vanilla':
337
+ self.loss = nn.BCEWithLogitsLoss()
338
+ elif self.gan_type == 'lsgan':
339
+ self.loss = nn.MSELoss()
340
+ elif self.gan_type == 'wgan':
341
+ self.loss = self._wgan_loss
342
+ elif self.gan_type == 'hinge':
343
+ self.loss = nn.ReLU()
344
+ else:
345
+ raise NotImplementedError(
346
+ f'GAN type {self.gan_type} is not implemented.')
347
+
348
+ @staticmethod
349
+ def _wgan_loss(input, target):
350
+ """wgan loss.
351
+
352
+ Args:
353
+ input (Tensor): Input tensor.
354
+ target (bool): Target label.
355
+
356
+ Returns:
357
+ Tensor: wgan loss.
358
+ """
359
+ return -input.mean() if target else input.mean()
360
+
361
+ def get_target_label(self, input, target_is_real):
362
+ """Get target label.
363
+
364
+ Args:
365
+ input (Tensor): Input tensor.
366
+ target_is_real (bool): Whether the target is real or fake.
367
+
368
+ Returns:
369
+ (bool | Tensor): Target tensor. Return bool for wgan, \
370
+ otherwise, return Tensor.
371
+ """
372
+
373
+ if self.gan_type == 'wgan':
374
+ return target_is_real
375
+ target_val = (
376
+ self.real_label_val if target_is_real else self.fake_label_val)
377
+ return input.new_ones(input.size()) * target_val
378
+
379
+ def forward(self, input, target_is_real, is_disc=False):
380
+ """
381
+ Args:
382
+ input (Tensor): The input for the loss module, i.e., the network
383
+ prediction.
384
+ target_is_real (bool): Whether the targe is real or fake.
385
+ is_disc (bool): Whether the loss for discriminators or not.
386
+ Default: False.
387
+
388
+ Returns:
389
+ Tensor: GAN loss value.
390
+ """
391
+ target_label = self.get_target_label(input, target_is_real)
392
+ if self.gan_type == 'hinge':
393
+ if is_disc: # for discriminators in hinge-gan
394
+ input = -input if target_is_real else input
395
+ loss = self.loss(1 + input).mean()
396
+ else: # for generators in hinge-gan
397
+ loss = -input.mean()
398
+ else: # other gan types
399
+ loss = self.loss(input, target_label)
400
+
401
+ # loss_weight is always 1.0 for discriminators
402
+ return loss if is_disc else loss * self.loss_weight
easy_ViTPose/vit_models/losses/mse_loss.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ __all__ = ['JointsMSELoss', 'JointsOHKMMSELoss',]
7
+
8
+
9
+ class JointsMSELoss(nn.Module):
10
+ """MSE loss for heatmaps.
11
+
12
+ Args:
13
+ use_target_weight (bool): Option to use weighted MSE loss.
14
+ Different joint types may have different target weights.
15
+ loss_weight (float): Weight of the loss. Default: 1.0.
16
+ """
17
+
18
+ def __init__(self, use_target_weight=False, loss_weight=1.):
19
+ super().__init__()
20
+ self.criterion = nn.MSELoss()
21
+ self.use_target_weight = use_target_weight
22
+ self.loss_weight = loss_weight
23
+
24
+ def forward(self, output, target, target_weight):
25
+ """Forward function."""
26
+ batch_size = output.size(0)
27
+ num_joints = output.size(1)
28
+
29
+ heatmaps_pred = output.reshape(
30
+ (batch_size, num_joints, -1)).split(1, 1)
31
+ heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
32
+
33
+ loss = 0.
34
+
35
+ for idx in range(num_joints):
36
+ heatmap_pred = heatmaps_pred[idx].squeeze(1)
37
+ heatmap_gt = heatmaps_gt[idx].squeeze(1)
38
+ if self.use_target_weight:
39
+ loss += self.criterion(heatmap_pred * target_weight[:, idx],
40
+ heatmap_gt * target_weight[:, idx])
41
+ else:
42
+ loss += self.criterion(heatmap_pred, heatmap_gt)
43
+
44
+ return loss / num_joints * self.loss_weight
45
+
46
+
47
+ class CombinedTargetMSELoss(nn.Module):
48
+ """MSE loss for combined target.
49
+ CombinedTarget: The combination of classification target
50
+ (response map) and regression target (offset map).
51
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
52
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
53
+
54
+ Args:
55
+ use_target_weight (bool): Option to use weighted MSE loss.
56
+ Different joint types may have different target weights.
57
+ loss_weight (float): Weight of the loss. Default: 1.0.
58
+ """
59
+
60
+ def __init__(self, use_target_weight, loss_weight=1.):
61
+ super().__init__()
62
+ self.criterion = nn.MSELoss(reduction='mean')
63
+ self.use_target_weight = use_target_weight
64
+ self.loss_weight = loss_weight
65
+
66
+ def forward(self, output, target, target_weight):
67
+ batch_size = output.size(0)
68
+ num_channels = output.size(1)
69
+ heatmaps_pred = output.reshape(
70
+ (batch_size, num_channels, -1)).split(1, 1)
71
+ heatmaps_gt = target.reshape(
72
+ (batch_size, num_channels, -1)).split(1, 1)
73
+ loss = 0.
74
+ num_joints = num_channels // 3
75
+ for idx in range(num_joints):
76
+ heatmap_pred = heatmaps_pred[idx * 3].squeeze()
77
+ heatmap_gt = heatmaps_gt[idx * 3].squeeze()
78
+ offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
79
+ offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
80
+ offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
81
+ offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
82
+ if self.use_target_weight:
83
+ heatmap_pred = heatmap_pred * target_weight[:, idx]
84
+ heatmap_gt = heatmap_gt * target_weight[:, idx]
85
+ # classification loss
86
+ loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
87
+ # regression loss
88
+ loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
89
+ heatmap_gt * offset_x_gt)
90
+ loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
91
+ heatmap_gt * offset_y_gt)
92
+ return loss / num_joints * self.loss_weight
93
+
94
+
95
+ class JointsOHKMMSELoss(nn.Module):
96
+ """MSE loss with online hard keypoint mining.
97
+
98
+ Args:
99
+ use_target_weight (bool): Option to use weighted MSE loss.
100
+ Different joint types may have different target weights.
101
+ topk (int): Only top k joint losses are kept.
102
+ loss_weight (float): Weight of the loss. Default: 1.0.
103
+ """
104
+
105
+ def __init__(self, use_target_weight=False, topk=8, loss_weight=1.):
106
+ super().__init__()
107
+ assert topk > 0
108
+ self.criterion = nn.MSELoss(reduction='none')
109
+ self.use_target_weight = use_target_weight
110
+ self.topk = topk
111
+ self.loss_weight = loss_weight
112
+
113
+ def _ohkm(self, loss):
114
+ """Online hard keypoint mining."""
115
+ ohkm_loss = 0.
116
+ N = len(loss)
117
+ for i in range(N):
118
+ sub_loss = loss[i]
119
+ _, topk_idx = torch.topk(
120
+ sub_loss, k=self.topk, dim=0, sorted=False)
121
+ tmp_loss = torch.gather(sub_loss, 0, topk_idx)
122
+ ohkm_loss += torch.sum(tmp_loss) / self.topk
123
+ ohkm_loss /= N
124
+ return ohkm_loss
125
+
126
+ def forward(self, output, target, target_weight):
127
+ """Forward function."""
128
+ batch_size = output.size(0)
129
+ num_joints = output.size(1)
130
+ if num_joints < self.topk:
131
+ raise ValueError(f'topk ({self.topk}) should not '
132
+ f'larger than num_joints ({num_joints}).')
133
+ heatmaps_pred = output.reshape(
134
+ (batch_size, num_joints, -1)).split(1, 1)
135
+ heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
136
+
137
+ losses = []
138
+ for idx in range(num_joints):
139
+ heatmap_pred = heatmaps_pred[idx].squeeze(1)
140
+ heatmap_gt = heatmaps_gt[idx].squeeze(1)
141
+ if self.use_target_weight:
142
+ losses.append(
143
+ self.criterion(heatmap_pred * target_weight[:, idx],
144
+ heatmap_gt * target_weight[:, idx]))
145
+ else:
146
+ losses.append(self.criterion(heatmap_pred, heatmap_gt))
147
+
148
+ losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses]
149
+ losses = torch.cat(losses, dim=1)
150
+
151
+ return self._ohkm(losses) * self.loss_weight
easy_ViTPose/vit_models/losses/multi_loss_factory.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Adapted from https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation
3
+ # Original licence: Copyright (c) Microsoft, under the MIT License.
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ __all__ = ['HeatmapLoss', 'AELoss', 'MultiLossFactory']
11
+
12
+
13
+ def _make_input(t, requires_grad=False, device=torch.device('cpu')):
14
+ """Make zero inputs for AE loss.
15
+
16
+ Args:
17
+ t (torch.Tensor): input
18
+ requires_grad (bool): Option to use requires_grad.
19
+ device: torch device
20
+
21
+ Returns:
22
+ torch.Tensor: zero input.
23
+ """
24
+ inp = torch.autograd.Variable(t, requires_grad=requires_grad)
25
+ inp = inp.sum()
26
+ inp = inp.to(device)
27
+ return inp
28
+
29
+
30
+ class HeatmapLoss(nn.Module):
31
+ """Accumulate the heatmap loss for each image in the batch.
32
+
33
+ Args:
34
+ supervise_empty (bool): Whether to supervise empty channels.
35
+ """
36
+
37
+ def __init__(self, supervise_empty=True):
38
+ super().__init__()
39
+ self.supervise_empty = supervise_empty
40
+
41
+ def forward(self, pred, gt, mask):
42
+ """Forward function.
43
+
44
+ Note:
45
+ - batch_size: N
46
+ - heatmaps weight: W
47
+ - heatmaps height: H
48
+ - max_num_people: M
49
+ - num_keypoints: K
50
+
51
+ Args:
52
+ pred (torch.Tensor[N,K,H,W]):heatmap of output.
53
+ gt (torch.Tensor[N,K,H,W]): target heatmap.
54
+ mask (torch.Tensor[N,H,W]): mask of target.
55
+ """
56
+ assert pred.size() == gt.size(
57
+ ), f'pred.size() is {pred.size()}, gt.size() is {gt.size()}'
58
+
59
+ if not self.supervise_empty:
60
+ empty_mask = (gt.sum(dim=[2, 3], keepdim=True) > 0).float()
61
+ loss = ((pred - gt)**2) * empty_mask.expand_as(
62
+ pred) * mask[:, None, :, :].expand_as(pred)
63
+ else:
64
+ loss = ((pred - gt)**2) * mask[:, None, :, :].expand_as(pred)
65
+ loss = loss.mean(dim=3).mean(dim=2).mean(dim=1)
66
+ return loss
67
+
68
+
69
+ class AELoss(nn.Module):
70
+ """Associative Embedding loss.
71
+
72
+ `Associative Embedding: End-to-End Learning for Joint Detection and
73
+ Grouping <https://arxiv.org/abs/1611.05424v2>`_.
74
+ """
75
+
76
+ def __init__(self, loss_type):
77
+ super().__init__()
78
+ self.loss_type = loss_type
79
+
80
+ def singleTagLoss(self, pred_tag, joints):
81
+ """Associative embedding loss for one image.
82
+
83
+ Note:
84
+ - heatmaps weight: W
85
+ - heatmaps height: H
86
+ - max_num_people: M
87
+ - num_keypoints: K
88
+
89
+ Args:
90
+ pred_tag (torch.Tensor[KxHxW,1]): tag of output for one image.
91
+ joints (torch.Tensor[M,K,2]): joints information for one image.
92
+ """
93
+ tags = []
94
+ pull = 0
95
+ for joints_per_person in joints:
96
+ tmp = []
97
+ for joint in joints_per_person:
98
+ if joint[1] > 0:
99
+ tmp.append(pred_tag[joint[0]])
100
+ if len(tmp) == 0:
101
+ continue
102
+ tmp = torch.stack(tmp)
103
+ tags.append(torch.mean(tmp, dim=0))
104
+ pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)
105
+
106
+ num_tags = len(tags)
107
+ if num_tags == 0:
108
+ return (
109
+ _make_input(torch.zeros(1).float(), device=pred_tag.device),
110
+ _make_input(torch.zeros(1).float(), device=pred_tag.device))
111
+ elif num_tags == 1:
112
+ return (_make_input(
113
+ torch.zeros(1).float(), device=pred_tag.device), pull)
114
+
115
+ tags = torch.stack(tags)
116
+
117
+ size = (num_tags, num_tags)
118
+ A = tags.expand(*size)
119
+ B = A.permute(1, 0)
120
+
121
+ diff = A - B
122
+
123
+ if self.loss_type == 'exp':
124
+ diff = torch.pow(diff, 2)
125
+ push = torch.exp(-diff)
126
+ push = torch.sum(push) - num_tags
127
+ elif self.loss_type == 'max':
128
+ diff = 1 - torch.abs(diff)
129
+ push = torch.clamp(diff, min=0).sum() - num_tags
130
+ else:
131
+ raise ValueError('Unknown ae loss type')
132
+
133
+ push_loss = push / ((num_tags - 1) * num_tags) * 0.5
134
+ pull_loss = pull / (num_tags)
135
+
136
+ return push_loss, pull_loss
137
+
138
+ def forward(self, tags, joints):
139
+ """Accumulate the tag loss for each image in the batch.
140
+
141
+ Note:
142
+ - batch_size: N
143
+ - heatmaps weight: W
144
+ - heatmaps height: H
145
+ - max_num_people: M
146
+ - num_keypoints: K
147
+
148
+ Args:
149
+ tags (torch.Tensor[N,KxHxW,1]): tag channels of output.
150
+ joints (torch.Tensor[N,M,K,2]): joints information.
151
+ """
152
+ pushes, pulls = [], []
153
+ joints = joints.cpu().data.numpy()
154
+ batch_size = tags.size(0)
155
+ for i in range(batch_size):
156
+ push, pull = self.singleTagLoss(tags[i], joints[i])
157
+ pushes.append(push)
158
+ pulls.append(pull)
159
+ return torch.stack(pushes), torch.stack(pulls)
160
+
161
+
162
+ class MultiLossFactory(nn.Module):
163
+ """Loss for bottom-up models.
164
+
165
+ Args:
166
+ num_joints (int): Number of keypoints.
167
+ num_stages (int): Number of stages.
168
+ ae_loss_type (str): Type of ae loss.
169
+ with_ae_loss (list[bool]): Use ae loss or not in multi-heatmap.
170
+ push_loss_factor (list[float]):
171
+ Parameter of push loss in multi-heatmap.
172
+ pull_loss_factor (list[float]):
173
+ Parameter of pull loss in multi-heatmap.
174
+ with_heatmap_loss (list[bool]):
175
+ Use heatmap loss or not in multi-heatmap.
176
+ heatmaps_loss_factor (list[float]):
177
+ Parameter of heatmap loss in multi-heatmap.
178
+ supervise_empty (bool): Whether to supervise empty channels.
179
+ """
180
+
181
+ def __init__(self,
182
+ num_joints,
183
+ num_stages,
184
+ ae_loss_type,
185
+ with_ae_loss,
186
+ push_loss_factor,
187
+ pull_loss_factor,
188
+ with_heatmaps_loss,
189
+ heatmaps_loss_factor,
190
+ supervise_empty=True):
191
+ super().__init__()
192
+
193
+ assert isinstance(with_heatmaps_loss, (list, tuple)), \
194
+ 'with_heatmaps_loss should be a list or tuple'
195
+ assert isinstance(heatmaps_loss_factor, (list, tuple)), \
196
+ 'heatmaps_loss_factor should be a list or tuple'
197
+ assert isinstance(with_ae_loss, (list, tuple)), \
198
+ 'with_ae_loss should be a list or tuple'
199
+ assert isinstance(push_loss_factor, (list, tuple)), \
200
+ 'push_loss_factor should be a list or tuple'
201
+ assert isinstance(pull_loss_factor, (list, tuple)), \
202
+ 'pull_loss_factor should be a list or tuple'
203
+
204
+ self.num_joints = num_joints
205
+ self.num_stages = num_stages
206
+ self.ae_loss_type = ae_loss_type
207
+ self.with_ae_loss = with_ae_loss
208
+ self.push_loss_factor = push_loss_factor
209
+ self.pull_loss_factor = pull_loss_factor
210
+ self.with_heatmaps_loss = with_heatmaps_loss
211
+ self.heatmaps_loss_factor = heatmaps_loss_factor
212
+
213
+ self.heatmaps_loss = \
214
+ nn.ModuleList(
215
+ [
216
+ HeatmapLoss(supervise_empty)
217
+ if with_heatmaps_loss else None
218
+ for with_heatmaps_loss in self.with_heatmaps_loss
219
+ ]
220
+ )
221
+
222
+ self.ae_loss = \
223
+ nn.ModuleList(
224
+ [
225
+ AELoss(self.ae_loss_type) if with_ae_loss else None
226
+ for with_ae_loss in self.with_ae_loss
227
+ ]
228
+ )
229
+
230
+ def forward(self, outputs, heatmaps, masks, joints):
231
+ """Forward function to calculate losses.
232
+
233
+ Note:
234
+ - batch_size: N
235
+ - heatmaps weight: W
236
+ - heatmaps height: H
237
+ - max_num_people: M
238
+ - num_keypoints: K
239
+ - output_channel: C C=2K if use ae loss else K
240
+
241
+ Args:
242
+ outputs (list(torch.Tensor[N,C,H,W])): outputs of stages.
243
+ heatmaps (list(torch.Tensor[N,K,H,W])): target of heatmaps.
244
+ masks (list(torch.Tensor[N,H,W])): masks of heatmaps.
245
+ joints (list(torch.Tensor[N,M,K,2])): joints of ae loss.
246
+ """
247
+ heatmaps_losses = []
248
+ push_losses = []
249
+ pull_losses = []
250
+ for idx in range(len(outputs)):
251
+ offset_feat = 0
252
+ if self.heatmaps_loss[idx]:
253
+ heatmaps_pred = outputs[idx][:, :self.num_joints]
254
+ offset_feat = self.num_joints
255
+ heatmaps_loss = self.heatmaps_loss[idx](heatmaps_pred,
256
+ heatmaps[idx],
257
+ masks[idx])
258
+ heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor[idx]
259
+ heatmaps_losses.append(heatmaps_loss)
260
+ else:
261
+ heatmaps_losses.append(None)
262
+
263
+ if self.ae_loss[idx]:
264
+ tags_pred = outputs[idx][:, offset_feat:]
265
+ batch_size = tags_pred.size()[0]
266
+ tags_pred = tags_pred.contiguous().view(batch_size, -1, 1)
267
+
268
+ push_loss, pull_loss = self.ae_loss[idx](tags_pred,
269
+ joints[idx])
270
+ push_loss = push_loss * self.push_loss_factor[idx]
271
+ pull_loss = pull_loss * self.pull_loss_factor[idx]
272
+
273
+ push_losses.append(push_loss)
274
+ pull_losses.append(pull_loss)
275
+ else:
276
+ push_losses.append(None)
277
+ pull_losses.append(None)
278
+
279
+ return heatmaps_losses, push_losses, pull_losses
easy_ViTPose/vit_models/losses/regression_loss.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ __all__ = ['SmoothL1Loss', 'SoftWingLoss', 'SoftWingLoss',
10
+ 'L1Loss', 'MPJPELoss', 'MSELoss', 'BoneLoss',
11
+ 'SemiSupervisionLoss']
12
+
13
+
14
+ class SmoothL1Loss(nn.Module):
15
+ """SmoothL1Loss loss.
16
+
17
+ Args:
18
+ use_target_weight (bool): Option to use weighted MSE loss.
19
+ Different joint types may have different target weights.
20
+ loss_weight (float): Weight of the loss. Default: 1.0.
21
+ """
22
+
23
+ def __init__(self, use_target_weight=False, loss_weight=1.):
24
+ super().__init__()
25
+ self.criterion = F.smooth_l1_loss
26
+ self.use_target_weight = use_target_weight
27
+ self.loss_weight = loss_weight
28
+
29
+ def forward(self, output, target, target_weight=None):
30
+ """Forward function.
31
+
32
+ Note:
33
+ - batch_size: N
34
+ - num_keypoints: K
35
+ - dimension of keypoints: D (D=2 or D=3)
36
+
37
+ Args:
38
+ output (torch.Tensor[N, K, D]): Output regression.
39
+ target (torch.Tensor[N, K, D]): Target regression.
40
+ target_weight (torch.Tensor[N, K, D]):
41
+ Weights across different joint types.
42
+ """
43
+ if self.use_target_weight:
44
+ assert target_weight is not None
45
+ loss = self.criterion(output * target_weight,
46
+ target * target_weight)
47
+ else:
48
+ loss = self.criterion(output, target)
49
+
50
+ return loss * self.loss_weight
51
+
52
+
53
+ class WingLoss(nn.Module):
54
+ """Wing Loss. paper ref: 'Wing Loss for Robust Facial Landmark Localisation
55
+ with Convolutional Neural Networks' Feng et al. CVPR'2018.
56
+
57
+ Args:
58
+ omega (float): Also referred to as width.
59
+ epsilon (float): Also referred to as curvature.
60
+ use_target_weight (bool): Option to use weighted MSE loss.
61
+ Different joint types may have different target weights.
62
+ loss_weight (float): Weight of the loss. Default: 1.0.
63
+ """
64
+
65
+ def __init__(self,
66
+ omega=10.0,
67
+ epsilon=2.0,
68
+ use_target_weight=False,
69
+ loss_weight=1.):
70
+ super().__init__()
71
+ self.omega = omega
72
+ self.epsilon = epsilon
73
+ self.use_target_weight = use_target_weight
74
+ self.loss_weight = loss_weight
75
+
76
+ # constant that smoothly links the piecewise-defined linear
77
+ # and nonlinear parts
78
+ self.C = self.omega * (1.0 - math.log(1.0 + self.omega / self.epsilon))
79
+
80
+ def criterion(self, pred, target):
81
+ """Criterion of wingloss.
82
+
83
+ Note:
84
+ - batch_size: N
85
+ - num_keypoints: K
86
+ - dimension of keypoints: D (D=2 or D=3)
87
+
88
+ Args:
89
+ pred (torch.Tensor[N, K, D]): Output regression.
90
+ target (torch.Tensor[N, K, D]): Target regression.
91
+ """
92
+ delta = (target - pred).abs()
93
+ losses = torch.where(
94
+ delta < self.omega,
95
+ self.omega * torch.log(1.0 + delta / self.epsilon), delta - self.C)
96
+ return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
97
+
98
+ def forward(self, output, target, target_weight=None):
99
+ """Forward function.
100
+
101
+ Note:
102
+ - batch_size: N
103
+ - num_keypoints: K
104
+ - dimension of keypoints: D (D=2 or D=3)
105
+
106
+ Args:
107
+ output (torch.Tensor[N, K, D]): Output regression.
108
+ target (torch.Tensor[N, K, D]): Target regression.
109
+ target_weight (torch.Tensor[N,K,D]):
110
+ Weights across different joint types.
111
+ """
112
+ if self.use_target_weight:
113
+ assert target_weight is not None
114
+ loss = self.criterion(output * target_weight,
115
+ target * target_weight)
116
+ else:
117
+ loss = self.criterion(output, target)
118
+
119
+ return loss * self.loss_weight
120
+
121
+
122
+
123
+ class SoftWingLoss(nn.Module):
124
+ """Soft Wing Loss 'Structure-Coherent Deep Feature Learning for Robust Face
125
+ Alignment' Lin et al. TIP'2021.
126
+
127
+ loss =
128
+ 1. |x| , if |x| < omega1
129
+ 2. omega2*ln(1+|x|/epsilon) + B, if |x| >= omega1
130
+
131
+ Args:
132
+ omega1 (float): The first threshold.
133
+ omega2 (float): The second threshold.
134
+ epsilon (float): Also referred to as curvature.
135
+ use_target_weight (bool): Option to use weighted MSE loss.
136
+ Different joint types may have different target weights.
137
+ loss_weight (float): Weight of the loss. Default: 1.0.
138
+ """
139
+
140
+ def __init__(self,
141
+ omega1=2.0,
142
+ omega2=20.0,
143
+ epsilon=0.5,
144
+ use_target_weight=False,
145
+ loss_weight=1.):
146
+ super().__init__()
147
+ self.omega1 = omega1
148
+ self.omega2 = omega2
149
+ self.epsilon = epsilon
150
+ self.use_target_weight = use_target_weight
151
+ self.loss_weight = loss_weight
152
+
153
+ # constant that smoothly links the piecewise-defined linear
154
+ # and nonlinear parts
155
+ self.B = self.omega1 - self.omega2 * math.log(1.0 + self.omega1 /
156
+ self.epsilon)
157
+
158
+ def criterion(self, pred, target):
159
+ """Criterion of wingloss.
160
+
161
+ Note:
162
+ batch_size: N
163
+ num_keypoints: K
164
+ dimension of keypoints: D (D=2 or D=3)
165
+
166
+ Args:
167
+ pred (torch.Tensor[N, K, D]): Output regression.
168
+ target (torch.Tensor[N, K, D]): Target regression.
169
+ """
170
+ delta = (target - pred).abs()
171
+ losses = torch.where(
172
+ delta < self.omega1, delta,
173
+ self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B)
174
+ return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
175
+
176
+ def forward(self, output, target, target_weight=None):
177
+ """Forward function.
178
+
179
+ Note:
180
+ batch_size: N
181
+ num_keypoints: K
182
+ dimension of keypoints: D (D=2 or D=3)
183
+
184
+ Args:
185
+ output (torch.Tensor[N, K, D]): Output regression.
186
+ target (torch.Tensor[N, K, D]): Target regression.
187
+ target_weight (torch.Tensor[N, K, D]):
188
+ Weights across different joint types.
189
+ """
190
+ if self.use_target_weight:
191
+ assert target_weight is not None
192
+ loss = self.criterion(output * target_weight,
193
+ target * target_weight)
194
+ else:
195
+ loss = self.criterion(output, target)
196
+
197
+ return loss * self.loss_weight
198
+
199
+
200
+ class MPJPELoss(nn.Module):
201
+ """MPJPE (Mean Per Joint Position Error) loss.
202
+
203
+ Args:
204
+ use_target_weight (bool): Option to use weighted MSE loss.
205
+ Different joint types may have different target weights.
206
+ loss_weight (float): Weight of the loss. Default: 1.0.
207
+ """
208
+
209
+ def __init__(self, use_target_weight=False, loss_weight=1.):
210
+ super().__init__()
211
+ self.use_target_weight = use_target_weight
212
+ self.loss_weight = loss_weight
213
+
214
+ def forward(self, output, target, target_weight=None):
215
+ """Forward function.
216
+
217
+ Note:
218
+ - batch_size: N
219
+ - num_keypoints: K
220
+ - dimension of keypoints: D (D=2 or D=3)
221
+
222
+ Args:
223
+ output (torch.Tensor[N, K, D]): Output regression.
224
+ target (torch.Tensor[N, K, D]): Target regression.
225
+ target_weight (torch.Tensor[N,K,D]):
226
+ Weights across different joint types.
227
+ """
228
+
229
+ if self.use_target_weight:
230
+ assert target_weight is not None
231
+ loss = torch.mean(
232
+ torch.norm((output - target) * target_weight, dim=-1))
233
+ else:
234
+ loss = torch.mean(torch.norm(output - target, dim=-1))
235
+
236
+ return loss * self.loss_weight
237
+
238
+
239
+ class L1Loss(nn.Module):
240
+ """L1Loss loss ."""
241
+
242
+ def __init__(self, use_target_weight=False, loss_weight=1.):
243
+ super().__init__()
244
+ self.criterion = F.l1_loss
245
+ self.use_target_weight = use_target_weight
246
+ self.loss_weight = loss_weight
247
+
248
+ def forward(self, output, target, target_weight=None):
249
+ """Forward function.
250
+
251
+ Note:
252
+ - batch_size: N
253
+ - num_keypoints: K
254
+
255
+ Args:
256
+ output (torch.Tensor[N, K, 2]): Output regression.
257
+ target (torch.Tensor[N, K, 2]): Target regression.
258
+ target_weight (torch.Tensor[N, K, 2]):
259
+ Weights across different joint types.
260
+ """
261
+ if self.use_target_weight:
262
+ assert target_weight is not None
263
+ loss = self.criterion(output * target_weight,
264
+ target * target_weight)
265
+ else:
266
+ loss = self.criterion(output, target)
267
+
268
+ return loss * self.loss_weight
269
+
270
+
271
+ class MSELoss(nn.Module):
272
+ """MSE loss for coordinate regression."""
273
+
274
+ def __init__(self, use_target_weight=False, loss_weight=1.):
275
+ super().__init__()
276
+ self.criterion = F.mse_loss
277
+ self.use_target_weight = use_target_weight
278
+ self.loss_weight = loss_weight
279
+
280
+ def forward(self, output, target, target_weight=None):
281
+ """Forward function.
282
+
283
+ Note:
284
+ - batch_size: N
285
+ - num_keypoints: K
286
+
287
+ Args:
288
+ output (torch.Tensor[N, K, 2]): Output regression.
289
+ target (torch.Tensor[N, K, 2]): Target regression.
290
+ target_weight (torch.Tensor[N, K, 2]):
291
+ Weights across different joint types.
292
+ """
293
+ if self.use_target_weight:
294
+ assert target_weight is not None
295
+ loss = self.criterion(output * target_weight,
296
+ target * target_weight)
297
+ else:
298
+ loss = self.criterion(output, target)
299
+
300
+ return loss * self.loss_weight
301
+
302
+
303
+ class BoneLoss(nn.Module):
304
+ """Bone length loss.
305
+
306
+ Args:
307
+ joint_parents (list): Indices of each joint's parent joint.
308
+ use_target_weight (bool): Option to use weighted bone loss.
309
+ Different bone types may have different target weights.
310
+ loss_weight (float): Weight of the loss. Default: 1.0.
311
+ """
312
+
313
+ def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
314
+ super().__init__()
315
+ self.joint_parents = joint_parents
316
+ self.use_target_weight = use_target_weight
317
+ self.loss_weight = loss_weight
318
+
319
+ self.non_root_indices = []
320
+ for i in range(len(self.joint_parents)):
321
+ if i != self.joint_parents[i]:
322
+ self.non_root_indices.append(i)
323
+
324
+ def forward(self, output, target, target_weight=None):
325
+ """Forward function.
326
+
327
+ Note:
328
+ - batch_size: N
329
+ - num_keypoints: K
330
+ - dimension of keypoints: D (D=2 or D=3)
331
+
332
+ Args:
333
+ output (torch.Tensor[N, K, D]): Output regression.
334
+ target (torch.Tensor[N, K, D]): Target regression.
335
+ target_weight (torch.Tensor[N, K-1]):
336
+ Weights across different bone types.
337
+ """
338
+ output_bone = torch.norm(
339
+ output - output[:, self.joint_parents, :],
340
+ dim=-1)[:, self.non_root_indices]
341
+ target_bone = torch.norm(
342
+ target - target[:, self.joint_parents, :],
343
+ dim=-1)[:, self.non_root_indices]
344
+ if self.use_target_weight:
345
+ assert target_weight is not None
346
+ loss = torch.mean(
347
+ torch.abs((output_bone * target_weight).mean(dim=0) -
348
+ (target_bone * target_weight).mean(dim=0)))
349
+ else:
350
+ loss = torch.mean(
351
+ torch.abs(output_bone.mean(dim=0) - target_bone.mean(dim=0)))
352
+
353
+ return loss * self.loss_weight
354
+
355
+
356
+ class SemiSupervisionLoss(nn.Module):
357
+ """Semi-supervision loss for unlabeled data. It is composed of projection
358
+ loss and bone loss.
359
+
360
+ Paper ref: `3D human pose estimation in video with temporal convolutions
361
+ and semi-supervised training` Dario Pavllo et al. CVPR'2019.
362
+
363
+ Args:
364
+ joint_parents (list): Indices of each joint's parent joint.
365
+ projection_loss_weight (float): Weight for projection loss.
366
+ bone_loss_weight (float): Weight for bone loss.
367
+ warmup_iterations (int): Number of warmup iterations. In the first
368
+ `warmup_iterations` iterations, the model is trained only on
369
+ labeled data, and semi-supervision loss will be 0.
370
+ This is a workaround since currently we cannot access
371
+ epoch number in loss functions. Note that the iteration number in
372
+ an epoch can be changed due to different GPU numbers in multi-GPU
373
+ settings. So please set this parameter carefully.
374
+ warmup_iterations = dataset_size // samples_per_gpu // gpu_num
375
+ * warmup_epochs
376
+ """
377
+
378
+ def __init__(self,
379
+ joint_parents,
380
+ projection_loss_weight=1.,
381
+ bone_loss_weight=1.,
382
+ warmup_iterations=0):
383
+ super().__init__()
384
+ self.criterion_projection = MPJPELoss(
385
+ loss_weight=projection_loss_weight)
386
+ self.criterion_bone = BoneLoss(
387
+ joint_parents, loss_weight=bone_loss_weight)
388
+ self.warmup_iterations = warmup_iterations
389
+ self.num_iterations = 0
390
+
391
+ @staticmethod
392
+ def project_joints(x, intrinsics):
393
+ """Project 3D joint coordinates to 2D image plane using camera
394
+ intrinsic parameters.
395
+
396
+ Args:
397
+ x (torch.Tensor[N, K, 3]): 3D joint coordinates.
398
+ intrinsics (torch.Tensor[N, 4] | torch.Tensor[N, 9]): Camera
399
+ intrinsics: f (2), c (2), k (3), p (2).
400
+ """
401
+ while intrinsics.dim() < x.dim():
402
+ intrinsics.unsqueeze_(1)
403
+ f = intrinsics[..., :2]
404
+ c = intrinsics[..., 2:4]
405
+ _x = torch.clamp(x[:, :, :2] / x[:, :, 2:], -1, 1)
406
+ if intrinsics.shape[-1] == 9:
407
+ k = intrinsics[..., 4:7]
408
+ p = intrinsics[..., 7:9]
409
+
410
+ r2 = torch.sum(_x[:, :, :2]**2, dim=-1, keepdim=True)
411
+ radial = 1 + torch.sum(
412
+ k * torch.cat((r2, r2**2, r2**3), dim=-1),
413
+ dim=-1,
414
+ keepdim=True)
415
+ tan = torch.sum(p * _x, dim=-1, keepdim=True)
416
+ _x = _x * (radial + tan) + p * r2
417
+ _x = f * _x + c
418
+ return _x
419
+
420
+ def forward(self, output, target):
421
+ losses = dict()
422
+
423
+ self.num_iterations += 1
424
+ if self.num_iterations <= self.warmup_iterations:
425
+ return losses
426
+
427
+ labeled_pose = output['labeled_pose']
428
+ unlabeled_pose = output['unlabeled_pose']
429
+ unlabeled_traj = output['unlabeled_traj']
430
+ unlabeled_target_2d = target['unlabeled_target_2d']
431
+ intrinsics = target['intrinsics']
432
+
433
+ # projection loss
434
+ unlabeled_output = unlabeled_pose + unlabeled_traj
435
+ unlabeled_output_2d = self.project_joints(unlabeled_output, intrinsics)
436
+ loss_proj = self.criterion_projection(unlabeled_output_2d,
437
+ unlabeled_target_2d, None)
438
+ losses['proj_loss'] = loss_proj
439
+
440
+ # bone loss
441
+ loss_bone = self.criterion_bone(unlabeled_pose, labeled_pose, None)
442
+ losses['bone_loss'] = loss_bone
443
+
444
+ return losses
easy_ViTPose/vit_models/model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .backbone.vit import ViT
4
+ from .head.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
5
+
6
+
7
+ __all__ = ['ViTPose']
8
+
9
+
10
+ class ViTPose(nn.Module):
11
+ def __init__(self, cfg: dict) -> None:
12
+ super(ViTPose, self).__init__()
13
+
14
+ backbone_cfg = {k: v for k, v in cfg['backbone'].items() if k != 'type'}
15
+ head_cfg = {k: v for k, v in cfg['keypoint_head'].items() if k != 'type'}
16
+
17
+ self.backbone = ViT(**backbone_cfg)
18
+ self.keypoint_head = TopdownHeatmapSimpleHead(**head_cfg)
19
+
20
+ def forward_features(self, x):
21
+ return self.backbone(x)
22
+
23
+ def forward(self, x):
24
+ return self.keypoint_head(self.backbone(x))
easy_ViTPose/vit_models/optimizer.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+
3
+ class LayerDecayOptimizer:
4
+ def __init__(self, optimizer, layerwise_decay_rate):
5
+ self.optimizer = optimizer
6
+ self.layerwise_decay_rate = layerwise_decay_rate
7
+ self.param_groups = optimizer.param_groups
8
+
9
+ def step(self, *args, **kwargs):
10
+ for i, group in enumerate(self.optimizer.param_groups):
11
+ group['lr'] *= self.layerwise_decay_rate[i]
12
+ self.optimizer.step(*args, **kwargs)
13
+
14
+ def zero_grad(self, *args, **kwargs):
15
+ self.optimizer.zero_grad(*args, **kwargs)
easy_ViTPose/vit_utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .util import *
2
+ from .top_down_eval import *
3
+ from .post_processing import *
4
+ from .visualization import *
5
+ from .dist_util import *
6
+ from .logging import *
easy_ViTPose/vit_utils/dist_util.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ import functools
4
+ import os
5
+ import socket
6
+ import subprocess
7
+ from collections import OrderedDict
8
+ from typing import Callable, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.multiprocessing as mp
12
+ from torch import distributed as dist
13
+ from torch._utils import (_flatten_dense_tensors, _take_tensors,
14
+ _unflatten_dense_tensors)
15
+
16
+
17
+ def is_mps_available() -> bool:
18
+ """Return True if mps devices exist.
19
+
20
+ It's specialized for mac m1 chips and require torch version 1.12 or higher.
21
+ """
22
+ try:
23
+ import torch
24
+ return hasattr(torch.backends,
25
+ 'mps') and torch.backends.mps.is_available()
26
+ except Exception:
27
+ return False
28
+
29
+ def _find_free_port() -> str:
30
+ # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
31
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
32
+ # Binding to port 0 will cause the OS to find an available port for us
33
+ sock.bind(('', 0))
34
+ port = sock.getsockname()[1]
35
+ sock.close()
36
+ # NOTE: there is still a chance the port could be taken by other processes.
37
+ return port
38
+
39
+
40
+ def _is_free_port(port: int) -> bool:
41
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
42
+ ips.append('localhost')
43
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
44
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
45
+
46
+
47
+ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
48
+ if mp.get_start_method(allow_none=True) is None:
49
+ mp.set_start_method('spawn')
50
+ if launcher == 'pytorch':
51
+ _init_dist_pytorch(backend, **kwargs)
52
+ elif launcher == 'mpi':
53
+ _init_dist_mpi(backend, **kwargs)
54
+ elif launcher == 'slurm':
55
+ _init_dist_slurm(backend, **kwargs)
56
+ else:
57
+ raise ValueError(f'Invalid launcher type: {launcher}')
58
+
59
+
60
+ def _init_dist_pytorch(backend: str, **kwargs) -> None:
61
+ # TODO: use local_rank instead of rank % num_gpus
62
+ rank = int(os.environ['RANK'])
63
+ num_gpus = torch.cuda.device_count()
64
+ torch.cuda.set_device(rank % num_gpus)
65
+ dist.init_process_group(backend=backend, **kwargs)
66
+
67
+
68
+ def _init_dist_mpi(backend: str, **kwargs) -> None:
69
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
70
+ torch.cuda.set_device(local_rank)
71
+ if 'MASTER_PORT' not in os.environ:
72
+ # 29500 is torch.distributed default port
73
+ os.environ['MASTER_PORT'] = '29500'
74
+ if 'MASTER_ADDR' not in os.environ:
75
+ raise KeyError('The environment variable MASTER_ADDR is not set')
76
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
77
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
78
+ dist.init_process_group(backend=backend, **kwargs)
79
+
80
+
81
+ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
82
+ """Initialize slurm distributed training environment.
83
+
84
+ If argument ``port`` is not specified, then the master port will be system
85
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
86
+ environment variable, then a default port ``29500`` will be used.
87
+
88
+ Args:
89
+ backend (str): Backend of torch.distributed.
90
+ port (int, optional): Master port. Defaults to None.
91
+ """
92
+ proc_id = int(os.environ['SLURM_PROCID'])
93
+ ntasks = int(os.environ['SLURM_NTASKS'])
94
+ node_list = os.environ['SLURM_NODELIST']
95
+ num_gpus = torch.cuda.device_count()
96
+ torch.cuda.set_device(proc_id % num_gpus)
97
+ addr = subprocess.getoutput(
98
+ f'scontrol show hostname {node_list} | head -n1')
99
+ # specify master port
100
+ if port is not None:
101
+ os.environ['MASTER_PORT'] = str(port)
102
+ elif 'MASTER_PORT' in os.environ:
103
+ pass # use MASTER_PORT in the environment variable
104
+ else:
105
+ # if torch.distributed default port(29500) is available
106
+ # then use it, else find a free port
107
+ if _is_free_port(29500):
108
+ os.environ['MASTER_PORT'] = '29500'
109
+ else:
110
+ os.environ['MASTER_PORT'] = str(_find_free_port())
111
+ # use MASTER_ADDR in the environment variable if it already exists
112
+ if 'MASTER_ADDR' not in os.environ:
113
+ os.environ['MASTER_ADDR'] = addr
114
+ os.environ['WORLD_SIZE'] = str(ntasks)
115
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
116
+ os.environ['RANK'] = str(proc_id)
117
+ dist.init_process_group(backend=backend)
118
+
119
+
120
+ def get_dist_info() -> Tuple[int, int]:
121
+ if dist.is_available() and dist.is_initialized():
122
+ rank = dist.get_rank()
123
+ world_size = dist.get_world_size()
124
+ else:
125
+ rank = 0
126
+ world_size = 1
127
+ return rank, world_size
128
+
129
+
130
+ def master_only(func: Callable) -> Callable:
131
+
132
+ @functools.wraps(func)
133
+ def wrapper(*args, **kwargs):
134
+ rank, _ = get_dist_info()
135
+ if rank == 0:
136
+ return func(*args, **kwargs)
137
+
138
+ return wrapper
139
+
140
+
141
+ def allreduce_params(params: List[torch.nn.Parameter],
142
+ coalesce: bool = True,
143
+ bucket_size_mb: int = -1) -> None:
144
+ """Allreduce parameters.
145
+
146
+ Args:
147
+ params (list[torch.nn.Parameter]): List of parameters or buffers
148
+ of a model.
149
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
150
+ Defaults to True.
151
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
152
+ Defaults to -1.
153
+ """
154
+ _, world_size = get_dist_info()
155
+ if world_size == 1:
156
+ return
157
+ params = [param.data for param in params]
158
+ if coalesce:
159
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
160
+ else:
161
+ for tensor in params:
162
+ dist.all_reduce(tensor.div_(world_size))
163
+
164
+
165
+ def allreduce_grads(params: List[torch.nn.Parameter],
166
+ coalesce: bool = True,
167
+ bucket_size_mb: int = -1) -> None:
168
+ """Allreduce gradients.
169
+
170
+ Args:
171
+ params (list[torch.nn.Parameter]): List of parameters of a model.
172
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
173
+ Defaults to True.
174
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
175
+ Defaults to -1.
176
+ """
177
+ grads = [
178
+ param.grad.data for param in params
179
+ if param.requires_grad and param.grad is not None
180
+ ]
181
+ _, world_size = get_dist_info()
182
+ if world_size == 1:
183
+ return
184
+ if coalesce:
185
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
186
+ else:
187
+ for tensor in grads:
188
+ dist.all_reduce(tensor.div_(world_size))
189
+
190
+
191
+ def _allreduce_coalesced(tensors: torch.Tensor,
192
+ world_size: int,
193
+ bucket_size_mb: int = -1) -> None:
194
+ if bucket_size_mb > 0:
195
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
196
+ buckets = _take_tensors(tensors, bucket_size_bytes)
197
+ else:
198
+ buckets = OrderedDict()
199
+ for tensor in tensors:
200
+ tp = tensor.type()
201
+ if tp not in buckets:
202
+ buckets[tp] = []
203
+ buckets[tp].append(tensor)
204
+ buckets = buckets.values()
205
+
206
+ for bucket in buckets:
207
+ flat_tensors = _flatten_dense_tensors(bucket)
208
+ dist.all_reduce(flat_tensors)
209
+ flat_tensors.div_(world_size)
210
+ for tensor, synced in zip(
211
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
212
+ tensor.copy_(synced)
easy_ViTPose/vit_utils/inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import json
4
+
5
+
6
+ rotation_map = {
7
+ 0: None,
8
+ 90: cv2.ROTATE_90_COUNTERCLOCKWISE,
9
+ 180: cv2.ROTATE_180,
10
+ 270: cv2.ROTATE_90_CLOCKWISE
11
+ }
12
+
13
+ class NumpyEncoder(json.JSONEncoder):
14
+ def default(self, obj):
15
+ if isinstance(obj, np.ndarray):
16
+ return obj.tolist()
17
+ return json.JSONEncoder.default(self, obj)
18
+
19
+ def draw_bboxes(image, bounding_boxes, boxes_id, scores):
20
+ image_with_boxes = image.copy()
21
+
22
+ for bbox, bbox_id, score in zip(bounding_boxes, boxes_id, scores):
23
+ x1, y1, x2, y2 = bbox
24
+ cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), (128, 128, 0), 2)
25
+
26
+ label = f'#{bbox_id}: {score:.2f}'
27
+
28
+ (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
29
+ label_x = x1
30
+ label_y = y1 - 5 if y1 > 20 else y1 + 20
31
+
32
+ # Draw a filled rectangle as the background for the label
33
+ cv2.rectangle(image_with_boxes, (x1, label_y - label_height - 5),
34
+ (x1 + label_width, label_y + 5), (128, 128, 0), cv2.FILLED)
35
+ cv2.putText(image_with_boxes, label, (label_x, label_y),
36
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
37
+
38
+ return image_with_boxes
39
+
40
+
41
+ def pad_image(image: np.ndarray, aspect_ratio: float) -> np.ndarray:
42
+ # Get the current aspect ratio of the image
43
+ image_height, image_width = image.shape[:2]
44
+ current_aspect_ratio = image_width / image_height
45
+
46
+ left_pad = 0
47
+ top_pad = 0
48
+ # Determine whether to pad horizontally or vertically
49
+ if current_aspect_ratio < aspect_ratio:
50
+ # Pad horizontally
51
+ target_width = int(aspect_ratio * image_height)
52
+ pad_width = target_width - image_width
53
+ left_pad = pad_width // 2
54
+ right_pad = pad_width - left_pad
55
+
56
+ padded_image = np.pad(image,
57
+ pad_width=((0, 0), (left_pad, right_pad), (0, 0)),
58
+ mode='constant')
59
+ else:
60
+ # Pad vertically
61
+ target_height = int(image_width / aspect_ratio)
62
+ pad_height = target_height - image_height
63
+ top_pad = pad_height // 2
64
+ bottom_pad = pad_height - top_pad
65
+
66
+ padded_image = np.pad(image,
67
+ pad_width=((top_pad, bottom_pad), (0, 0), (0, 0)),
68
+ mode='constant')
69
+ return padded_image, (left_pad, top_pad)
70
+
71
+
72
+ class VideoReader(object):
73
+ def __init__(self, file_name, rotate=0):
74
+ self.file_name = file_name
75
+ self.rotate = rotation_map[rotate]
76
+ try: # OpenCV needs int to read from webcam
77
+ self.file_name = int(file_name)
78
+ except ValueError:
79
+ pass
80
+
81
+ def __iter__(self):
82
+ self.cap = cv2.VideoCapture(self.file_name)
83
+ if not self.cap.isOpened():
84
+ raise IOError('Video {} cannot be opened'.format(self.file_name))
85
+ return self
86
+
87
+ def __next__(self):
88
+ was_read, img = self.cap.read()
89
+ if not was_read:
90
+ raise StopIteration
91
+ if self.rotate is not None:
92
+ img = cv2.rotate(img, self.rotate)
93
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
easy_ViTPose/vit_utils/logging.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.distributed as dist
5
+
6
+ logger_initialized: dict = {}
7
+
8
+
9
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
10
+ """Initialize and get a logger by name.
11
+
12
+ If the logger has not been initialized, this method will initialize the
13
+ logger by adding one or two handlers, otherwise the initialized logger will
14
+ be directly returned. During initialization, a StreamHandler will always be
15
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
16
+ will also be added.
17
+
18
+ Args:
19
+ name (str): Logger name.
20
+ log_file (str | None): The log filename. If specified, a FileHandler
21
+ will be added to the logger.
22
+ log_level (int): The logger level. Note that only the process of
23
+ rank 0 is affected, and other processes will set the level to
24
+ "Error" thus be silent most of the time.
25
+ file_mode (str): The file mode used in opening log file.
26
+ Defaults to 'w'.
27
+
28
+ Returns:
29
+ logging.Logger: The expected logger.
30
+ """
31
+ logger = logging.getLogger(name)
32
+ if name in logger_initialized:
33
+ return logger
34
+ # handle hierarchical names
35
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
36
+ # initialization since it is a child of "a".
37
+ for logger_name in logger_initialized:
38
+ if name.startswith(logger_name):
39
+ return logger
40
+
41
+ # handle duplicate logs to the console
42
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
43
+ # to the root logger. As logger.propagate is True by default, this root
44
+ # level handler causes logging messages from rank>0 processes to
45
+ # unexpectedly show up on the console, creating much unwanted clutter.
46
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
47
+ # at the ERROR level.
48
+ for handler in logger.root.handlers:
49
+ if type(handler) is logging.StreamHandler:
50
+ handler.setLevel(logging.ERROR)
51
+
52
+ stream_handler = logging.StreamHandler()
53
+ handlers = [stream_handler]
54
+
55
+ if dist.is_available() and dist.is_initialized():
56
+ rank = dist.get_rank()
57
+ else:
58
+ rank = 0
59
+
60
+ # only rank 0 will add a FileHandler
61
+ if rank == 0 and log_file is not None:
62
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
63
+ # provide an interface to change the file mode to the default
64
+ # behaviour.
65
+ file_handler = logging.FileHandler(log_file, file_mode)
66
+ handlers.append(file_handler)
67
+
68
+ formatter = logging.Formatter(
69
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
70
+ for handler in handlers:
71
+ handler.setFormatter(formatter)
72
+ handler.setLevel(log_level)
73
+ logger.addHandler(handler)
74
+
75
+ if rank == 0:
76
+ logger.setLevel(log_level)
77
+ else:
78
+ logger.setLevel(logging.ERROR)
79
+
80
+ logger_initialized[name] = True
81
+
82
+ return logger
83
+
84
+
85
+ def print_log(msg, logger=None, level=logging.INFO):
86
+ """Print a log message.
87
+
88
+ Args:
89
+ msg (str): The message to be logged.
90
+ logger (logging.Logger | str | None): The logger to be used.
91
+ Some special loggers are:
92
+
93
+ - "silent": no message will be printed.
94
+ - other str: the logger obtained with `get_root_logger(logger)`.
95
+ - None: The `print()` method will be used to print log messages.
96
+ level (int): Logging level. Only available when `logger` is a Logger
97
+ object or "root".
98
+ """
99
+ if logger is None:
100
+ print(msg)
101
+ elif isinstance(logger, logging.Logger):
102
+ logger.log(level, msg)
103
+ elif logger == 'silent':
104
+ pass
105
+ elif isinstance(logger, str):
106
+ _logger = get_logger(logger)
107
+ _logger.log(level, msg)
108
+ else:
109
+ raise TypeError(
110
+ 'logger should be either a logging.Logger object, str, '
111
+ f'"silent" or None, but got {type(logger)}')
112
+
113
+
114
+ def get_root_logger(log_file=None, log_level=logging.INFO):
115
+ """Use `get_logger` method in mmcv to get the root logger.
116
+
117
+ The logger will be initialized if it has not been initialized. By default a
118
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
119
+ also be added. The name of the root logger is the top-level package name,
120
+ e.g., "mmpose".
121
+
122
+ Args:
123
+ log_file (str | None): The log filename. If specified, a FileHandler
124
+ will be added to the root logger.
125
+ log_level (int): The root logger level. Note that only the process of
126
+ rank 0 is affected, while other processes will set the level to
127
+ "Error" and be silent most of the time.
128
+
129
+ Returns:
130
+ logging.Logger: The root logger.
131
+ """
132
+ return get_logger(__name__.split('.')[0], log_file, log_level)
133
+
easy_ViTPose/vit_utils/nms/__init__.py ADDED
File without changes
easy_ViTPose/vit_utils/nms/cpu_nms.c ADDED
The diff for this file is too large to render. See raw diff
 
easy_ViTPose/vit_utils/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so ADDED
Binary file (264 kB). View file