DeclK commited on
Commit
6ed2820
·
1 Parent(s): da31fac
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.pth
2
+ *.pkl
3
+ *.mp4
4
+ *.onnx
5
+ *.ttf
6
+ tempt*
7
+ **pycache**
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference 2 videos and use dtw to match the pose keypoints.
2
+ from tools.inferencer import PoseInferencerV2
3
+ from tools.dtw import DTWForKeypoints
4
+ from tools.visualizer import FastVisualizer
5
+ from tools.utils import convert_video_to_playable_mp4
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ import mmengine
10
+ import numpy as np
11
+ import mmcv
12
+ import cv2
13
+ import gradio as gr
14
+
15
+ def parse_args():
16
+ parser = ArgumentParser()
17
+ parser.add_argument('--config', type=str, default='configs/mark2.py')
18
+ parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
19
+ parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
20
+ return parser.parse_args()
21
+
22
+ def concat(img1, img2, height=1080):
23
+ w1, h1, _ = img1.shape
24
+ w2, h2, _ = img2.shape
25
+
26
+ # Calculate the scaling factor for each image
27
+ scale1 = height / img1.shape[1]
28
+ scale2 = height / img2.shape[1]
29
+
30
+ # Resize the images
31
+ img1 = cv2.resize(img1, (int(h1*scale1), int(w1*scale1)))
32
+ img2 = cv2.resize(img2, (int(h2*scale2), int(w2*scale2)))
33
+
34
+ # Concatenate the images horizontally
35
+ image = cv2.hconcat([img1, img2])
36
+ return image
37
+
38
+ def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm, draw_score_bar=True):
39
+ vis.set_image(img)
40
+ vis.draw_non_transparent_area(box)
41
+ if draw_score_bar:
42
+ vis.draw_score_bar(oks)
43
+ vis.draw_human_keypoints(keypoint, oks_unnorm)
44
+ return vis.get_image()
45
+
46
+ def main(video1, video2):
47
+ # build PoseInferencerV2
48
+ config = 'configs/mark2.py'
49
+ cfg = mmengine.Config.fromfile(config)
50
+ pose_inferencer = PoseInferencerV2(
51
+ cfg.det_cfg,
52
+ cfg.pose_cfg,
53
+ device='cpu')
54
+
55
+ v1 = mmcv.VideoReader(video1)
56
+ v2 = mmcv.VideoReader(video2)
57
+ video_writer = None
58
+
59
+ all_det1, all_pose1 = pose_inferencer.inference_video(video1)
60
+ all_det2, all_pose2 = pose_inferencer.inference_video(video2)
61
+
62
+ keypoints1 = np.stack([p.keypoints[0] for p in all_pose1]) # forced the first pred
63
+ keypoints2 = np.stack([p.keypoints[0] for p in all_pose2])
64
+ boxes1 = np.stack([d.bboxes[0] for d in all_det1])
65
+ boxes2 = np.stack([d.bboxes[0] for d in all_det2])
66
+
67
+ dtw_path, oks, oks_unnorm = DTWForKeypoints(keypoints1, keypoints2).get_dtw_path()
68
+
69
+ vis = FastVisualizer()
70
+
71
+ for i, j in tqdm(dtw_path):
72
+ frame1 = v1[i]
73
+ frame2 = v2[j]
74
+
75
+ frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
76
+ oks[i, j], oks_unnorm[i, j])
77
+ frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
78
+ oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
79
+ # concate two frames
80
+ frame = concat(frame1_, frame2_)
81
+ # draw logo
82
+ vis.set_image(frame)
83
+ frame = vis.draw_logo().get_image()
84
+ # write video
85
+ w, h = frame.shape[1], frame.shape[0]
86
+ if video_writer is None:
87
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
88
+ video_writer = cv2.VideoWriter('dtw_compare.mp4',
89
+ fourcc, v1.fps, (w, h))
90
+ video_writer.write(frame)
91
+ video_writer.release()
92
+ # output video file
93
+ convert_video_to_playable_mp4('dtw_compare.mp4')
94
+ output = str(Path('dtw_compare.mp4').resolve())
95
+ return output
96
+
97
+ if __name__ == '__main__':
98
+ config = 'configs/mark2.py'
99
+ cfg = mmengine.Config.fromfile(config)
100
+
101
+ inputs = [
102
+ gr.Video(label="Input video 1"),
103
+ gr.Video(label="Input video 2")
104
+ ]
105
+
106
+ output = gr.Video(label="Output video")
107
+
108
+ demo = gr.Interface(fn=main, inputs=inputs, outputs=output).queue()
109
+ demo.launch(share=True)
assets/logo.png ADDED
assets/onnx_test.jpg ADDED
configs/mark1.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ det_cfg = dict(
2
+ model_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py',
3
+ model_ckpt='/github/Tennis.ai/model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'
4
+ )
5
+
6
+ pose_cfg = dict(
7
+ model_cfg='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py',
8
+ model_ckpt='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
9
+ )
configs/mark2.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ det_cfg = dict(
2
+ deploy_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py',
3
+ model_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py',
4
+ backend_files=['model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/end2end.onnx']
5
+ )
6
+
7
+ pose_cfg = dict(
8
+ model_cfg='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py',
9
+ model_ckpt='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
10
+ )
convert_det.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ python tools/deploy.py \
2
+ model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py \
3
+ model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py \
4
+ model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth \
5
+ assets/onnx_test.jpg \
6
+ --work-dir model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco \
7
+ --device cpu \
8
+ --show
main.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference 2 videos and use dtw to match the pose keypoints.
2
+ from tools.inferencer import PoseInferencerV2
3
+ from tools.dtw import DTWForKeypoints
4
+ from tools.visualizer import FastVisualizer
5
+ from argparse import ArgumentParser
6
+ from tools.utils import convert_video_to_playable_mp4
7
+ from tqdm import tqdm
8
+ import mmengine
9
+ import numpy as np
10
+ import mmcv
11
+ import cv2
12
+
13
+ def parse_args():
14
+ parser = ArgumentParser()
15
+ parser.add_argument('--config', type=str, default='configs/mark2.py')
16
+ parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
17
+ parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
18
+ return parser.parse_args()
19
+
20
+ def concat(img1, img2, height=1080):
21
+ w1, h1, _ = img1.shape
22
+ w2, h2, _ = img2.shape
23
+
24
+ # Calculate the scaling factor for each image
25
+ scale1 = height / img1.shape[1]
26
+ scale2 = height / img2.shape[1]
27
+
28
+ # Resize the images
29
+ img1 = cv2.resize(img1, (int(h1*scale1), int(w1*scale1)))
30
+ img2 = cv2.resize(img2, (int(h2*scale2), int(w2*scale2)))
31
+
32
+ # Concatenate the images horizontally
33
+ image = cv2.hconcat([img1, img2])
34
+ return image
35
+
36
+ def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm, draw_score_bar=True):
37
+ vis.set_image(img)
38
+ vis.draw_non_transparent_area(box)
39
+ if draw_score_bar:
40
+ vis.draw_score_bar(oks)
41
+ vis.draw_human_keypoints(keypoint, oks_unnorm)
42
+ return vis.get_image()
43
+
44
+ def main(cfg):
45
+ # build PoseInferencerV2
46
+ pose_inferencer = PoseInferencerV2(
47
+ cfg.det_cfg,
48
+ cfg.pose_cfg,
49
+ device='cpu')
50
+
51
+ v1 = mmcv.VideoReader(cfg.video1)
52
+ v2 = mmcv.VideoReader(cfg.video2)
53
+ video_writer = None
54
+
55
+ all_det1, all_pose1 = pose_inferencer.inference_video(cfg.video1)
56
+ all_det2, all_pose2 = pose_inferencer.inference_video(cfg.video2)
57
+
58
+ keypoints1 = np.stack([p.keypoints[0] for p in all_pose1]) # forced the first pred
59
+ keypoints2 = np.stack([p.keypoints[0] for p in all_pose2])
60
+ boxes1 = np.stack([d.bboxes[0] for d in all_det1])
61
+ boxes2 = np.stack([d.bboxes[0] for d in all_det2])
62
+
63
+ dtw_path, oks, oks_unnorm = DTWForKeypoints(keypoints1, keypoints2).get_dtw_path()
64
+
65
+ vis = FastVisualizer()
66
+
67
+ for i, j in tqdm(dtw_path):
68
+ frame1 = v1[i]
69
+ frame2 = v2[j]
70
+
71
+ frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
72
+ oks[i, j], oks_unnorm[i, j])
73
+ frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
74
+ oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
75
+ # concate two frames
76
+ frame = concat(frame1_, frame2_)
77
+ # draw logo
78
+ vis.set_image(frame)
79
+ frame = vis.draw_logo().get_image()
80
+ # write video
81
+ w, h = frame.shape[1], frame.shape[0]
82
+ if video_writer is None:
83
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
84
+ video_writer = cv2.VideoWriter('dtw_compare.mp4',
85
+ fourcc, v1.fps, (w, h))
86
+ video_writer.write(frame)
87
+ video_writer.release()
88
+ convert_video_to_playable_mp4('dtw_compare.mp4')
89
+
90
+ if __name__ == '__main__':
91
+ args = parse_args()
92
+ cfg = mmengine.Config.fromfile(args.config)
93
+ cfg.video1 = args.video1
94
+ cfg.video2 = args.video2
95
+
96
+ main(cfg)
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ onnx_config = dict(
2
+ type='onnx',
3
+ export_params=True,
4
+ keep_initializers_as_inputs=False,
5
+ opset_version=11,
6
+ save_file='end2end.onnx',
7
+ input_names=['input'],
8
+ output_names=['dets', 'labels'],
9
+ input_shape=None,
10
+ optimize=True)
11
+ codebase_config = dict(
12
+ type='mmdet',
13
+ task='ObjectDetection',
14
+ model_type='end2end',
15
+ post_processing=dict(
16
+ score_threshold=0.05,
17
+ confidence_threshold=0.005,
18
+ iou_threshold=0.5,
19
+ max_output_boxes_per_class=200,
20
+ pre_top_k=5000,
21
+ keep_top_k=100,
22
+ background_label_id=-1))
23
+ backend_config = dict(type='onnxruntime')
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmdet'
2
+ default_hooks = dict(
3
+ timer=dict(type='IterTimerHook'),
4
+ logger=dict(type='LoggerHook', interval=50),
5
+ param_scheduler=dict(type='ParamSchedulerHook'),
6
+ checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3),
7
+ sampler_seed=dict(type='DistSamplerSeedHook'),
8
+ visualization=dict(type='DetVisualizationHook'))
9
+ env_cfg = dict(
10
+ cudnn_benchmark=False,
11
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
12
+ dist_cfg=dict(backend='nccl'))
13
+ vis_backends = [dict(type='LocalVisBackend')]
14
+ visualizer = dict(
15
+ type='DetLocalVisualizer',
16
+ vis_backends=[dict(type='LocalVisBackend')],
17
+ name='visualizer')
18
+ log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
19
+ log_level = 'INFO'
20
+ load_from = None
21
+ resume = False
22
+ train_cfg = dict(
23
+ type='EpochBasedTrainLoop',
24
+ max_epochs=300,
25
+ val_interval=10,
26
+ dynamic_intervals=[(280, 1)])
27
+ val_cfg = dict(type='ValLoop')
28
+ test_cfg = dict(type='TestLoop')
29
+ param_scheduler = [
30
+ dict(
31
+ type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
32
+ end=1000),
33
+ dict(
34
+ type='CosineAnnealingLR',
35
+ eta_min=0.0002,
36
+ begin=150,
37
+ end=300,
38
+ T_max=150,
39
+ by_epoch=True,
40
+ convert_to_iter_based=True)
41
+ ]
42
+ optim_wrapper = dict(
43
+ type='OptimWrapper',
44
+ optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.05),
45
+ paramwise_cfg=dict(
46
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
47
+ auto_scale_lr = dict(enable=False, base_batch_size=16)
48
+ dataset_type = 'CocoDataset'
49
+ data_root = 'data/coco/'
50
+ backend_args = None
51
+ train_pipeline = [
52
+ dict(type='LoadImageFromFile', backend_args=None),
53
+ dict(type='LoadAnnotations', with_bbox=True),
54
+ dict(
55
+ type='CachedMosaic',
56
+ img_scale=(640, 640),
57
+ pad_val=114.0,
58
+ max_cached_images=20,
59
+ random_pop=False),
60
+ dict(
61
+ type='RandomResize',
62
+ scale=(1280, 1280),
63
+ ratio_range=(0.5, 2.0),
64
+ keep_ratio=True),
65
+ dict(type='RandomCrop', crop_size=(640, 640)),
66
+ dict(type='YOLOXHSVRandomAug'),
67
+ dict(type='RandomFlip', prob=0.5),
68
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
69
+ dict(
70
+ type='CachedMixUp',
71
+ img_scale=(640, 640),
72
+ ratio_range=(1.0, 1.0),
73
+ max_cached_images=10,
74
+ random_pop=False,
75
+ pad_val=(114, 114, 114),
76
+ prob=0.5),
77
+ dict(type='PackDetInputs')
78
+ ]
79
+ test_pipeline = [
80
+ dict(type='LoadImageFromFile', backend_args=None),
81
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
82
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
83
+ dict(
84
+ type='PackDetInputs',
85
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
86
+ 'scale_factor'))
87
+ ]
88
+ train_dataloader = dict(
89
+ batch_size=32,
90
+ num_workers=10,
91
+ persistent_workers=True,
92
+ sampler=dict(type='DefaultSampler', shuffle=True),
93
+ batch_sampler=None,
94
+ dataset=dict(
95
+ type='CocoDataset',
96
+ data_root='data/coco/',
97
+ ann_file='annotations/instances_train2017.json',
98
+ data_prefix=dict(img='train2017/'),
99
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
100
+ pipeline=[
101
+ dict(type='LoadImageFromFile', backend_args=None),
102
+ dict(type='LoadAnnotations', with_bbox=True),
103
+ dict(
104
+ type='CachedMosaic',
105
+ img_scale=(640, 640),
106
+ pad_val=114.0,
107
+ max_cached_images=20,
108
+ random_pop=False),
109
+ dict(
110
+ type='RandomResize',
111
+ scale=(1280, 1280),
112
+ ratio_range=(0.5, 2.0),
113
+ keep_ratio=True),
114
+ dict(type='RandomCrop', crop_size=(640, 640)),
115
+ dict(type='YOLOXHSVRandomAug'),
116
+ dict(type='RandomFlip', prob=0.5),
117
+ dict(
118
+ type='Pad', size=(640, 640),
119
+ pad_val=dict(img=(114, 114, 114))),
120
+ dict(
121
+ type='CachedMixUp',
122
+ img_scale=(640, 640),
123
+ ratio_range=(1.0, 1.0),
124
+ max_cached_images=10,
125
+ random_pop=False,
126
+ pad_val=(114, 114, 114),
127
+ prob=0.5),
128
+ dict(type='PackDetInputs')
129
+ ],
130
+ backend_args=None),
131
+ pin_memory=True)
132
+ val_dataloader = dict(
133
+ batch_size=5,
134
+ num_workers=10,
135
+ persistent_workers=True,
136
+ drop_last=False,
137
+ sampler=dict(type='DefaultSampler', shuffle=False),
138
+ dataset=dict(
139
+ type='CocoDataset',
140
+ data_root='data/coco/',
141
+ ann_file='annotations/instances_val2017.json',
142
+ data_prefix=dict(img='val2017/'),
143
+ test_mode=True,
144
+ pipeline=[
145
+ dict(type='LoadImageFromFile', backend_args=None),
146
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
147
+ dict(
148
+ type='Pad', size=(640, 640),
149
+ pad_val=dict(img=(114, 114, 114))),
150
+ dict(
151
+ type='PackDetInputs',
152
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
153
+ 'scale_factor'))
154
+ ],
155
+ backend_args=None))
156
+ test_dataloader = dict(
157
+ batch_size=5,
158
+ num_workers=10,
159
+ persistent_workers=True,
160
+ drop_last=False,
161
+ sampler=dict(type='DefaultSampler', shuffle=False),
162
+ dataset=dict(
163
+ type='CocoDataset',
164
+ data_root='data/coco/',
165
+ ann_file='annotations/instances_val2017.json',
166
+ data_prefix=dict(img='val2017/'),
167
+ test_mode=True,
168
+ pipeline=[
169
+ dict(type='LoadImageFromFile', backend_args=None),
170
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
171
+ dict(
172
+ type='Pad', size=(640, 640),
173
+ pad_val=dict(img=(114, 114, 114))),
174
+ dict(
175
+ type='PackDetInputs',
176
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
177
+ 'scale_factor'))
178
+ ],
179
+ backend_args=None))
180
+ val_evaluator = dict(
181
+ type='CocoMetric',
182
+ ann_file='data/coco/annotations/instances_val2017.json',
183
+ metric='bbox',
184
+ format_only=False,
185
+ backend_args=None,
186
+ proposal_nums=(100, 1, 10))
187
+ test_evaluator = dict(
188
+ type='CocoMetric',
189
+ ann_file='data/coco/annotations/instances_val2017.json',
190
+ metric='bbox',
191
+ format_only=False,
192
+ backend_args=None,
193
+ proposal_nums=(100, 1, 10))
194
+ tta_model = dict(
195
+ type='DetTTAModel',
196
+ tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
197
+ img_scales = [(640, 640), (320, 320), (960, 960)]
198
+ tta_pipeline = [
199
+ dict(type='LoadImageFromFile', backend_args=None),
200
+ dict(
201
+ type='TestTimeAug',
202
+ transforms=[[{
203
+ 'type': 'Resize',
204
+ 'scale': (640, 640),
205
+ 'keep_ratio': True
206
+ }, {
207
+ 'type': 'Resize',
208
+ 'scale': (320, 320),
209
+ 'keep_ratio': True
210
+ }, {
211
+ 'type': 'Resize',
212
+ 'scale': (960, 960),
213
+ 'keep_ratio': True
214
+ }],
215
+ [{
216
+ 'type': 'RandomFlip',
217
+ 'prob': 1.0
218
+ }, {
219
+ 'type': 'RandomFlip',
220
+ 'prob': 0.0
221
+ }],
222
+ [{
223
+ 'type': 'Pad',
224
+ 'size': (960, 960),
225
+ 'pad_val': {
226
+ 'img': (114, 114, 114)
227
+ }
228
+ }],
229
+ [{
230
+ 'type':
231
+ 'PackDetInputs',
232
+ 'meta_keys':
233
+ ('img_id', 'img_path', 'ori_shape', 'img_shape',
234
+ 'scale_factor', 'flip', 'flip_direction')
235
+ }]])
236
+ ]
237
+ model = dict(
238
+ type='RTMDet',
239
+ data_preprocessor=dict(
240
+ type='DetDataPreprocessor',
241
+ mean=[103.53, 116.28, 123.675],
242
+ std=[57.375, 57.12, 58.395],
243
+ bgr_to_rgb=False,
244
+ batch_augments=None),
245
+ backbone=dict(
246
+ type='CSPNeXt',
247
+ arch='P5',
248
+ expand_ratio=0.5,
249
+ deepen_factor=0.167,
250
+ widen_factor=0.375,
251
+ channel_attention=True,
252
+ norm_cfg=dict(type='SyncBN'),
253
+ act_cfg=dict(type='SiLU', inplace=True),
254
+ init_cfg=dict(
255
+ type='Pretrained',
256
+ prefix='backbone.',
257
+ checkpoint=
258
+ 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
259
+ )),
260
+ neck=dict(
261
+ type='CSPNeXtPAFPN',
262
+ in_channels=[96, 192, 384],
263
+ out_channels=96,
264
+ num_csp_blocks=1,
265
+ expand_ratio=0.5,
266
+ norm_cfg=dict(type='SyncBN'),
267
+ act_cfg=dict(type='SiLU', inplace=True)),
268
+ bbox_head=dict(
269
+ type='RTMDetSepBNHead',
270
+ num_classes=80,
271
+ in_channels=96,
272
+ stacked_convs=2,
273
+ feat_channels=96,
274
+ anchor_generator=dict(
275
+ type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
276
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
277
+ loss_cls=dict(
278
+ type='QualityFocalLoss',
279
+ use_sigmoid=True,
280
+ beta=2.0,
281
+ loss_weight=1.0),
282
+ loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
283
+ with_objectness=False,
284
+ exp_on_reg=False,
285
+ share_conv=True,
286
+ pred_kernel_size=1,
287
+ norm_cfg=dict(type='SyncBN'),
288
+ act_cfg=dict(type='SiLU', inplace=True)),
289
+ train_cfg=dict(
290
+ assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
291
+ allowed_border=-1,
292
+ pos_weight=-1,
293
+ debug=False),
294
+ test_cfg=dict(
295
+ nms_pre=30000,
296
+ min_bbox_size=0,
297
+ score_thr=0.001,
298
+ nms=dict(type='nms', iou_threshold=0.65),
299
+ max_per_img=300))
300
+ train_pipeline_stage2 = [
301
+ dict(type='LoadImageFromFile', backend_args=None),
302
+ dict(type='LoadAnnotations', with_bbox=True),
303
+ dict(
304
+ type='RandomResize',
305
+ scale=(640, 640),
306
+ ratio_range=(0.5, 2.0),
307
+ keep_ratio=True),
308
+ dict(type='RandomCrop', crop_size=(640, 640)),
309
+ dict(type='YOLOXHSVRandomAug'),
310
+ dict(type='RandomFlip', prob=0.5),
311
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
312
+ dict(type='PackDetInputs')
313
+ ]
314
+ max_epochs = 300
315
+ stage2_num_epochs = 20
316
+ base_lr = 0.004
317
+ interval = 10
318
+ custom_hooks = [
319
+ dict(
320
+ type='EMAHook',
321
+ ema_type='ExpMomentumEMA',
322
+ momentum=0.0002,
323
+ update_buffers=True,
324
+ priority=49),
325
+ dict(
326
+ type='PipelineSwitchHook',
327
+ switch_epoch=280,
328
+ switch_pipeline=[
329
+ dict(type='LoadImageFromFile', backend_args=None),
330
+ dict(type='LoadAnnotations', with_bbox=True),
331
+ dict(
332
+ type='RandomResize',
333
+ scale=(640, 640),
334
+ ratio_range=(0.5, 2.0),
335
+ keep_ratio=True),
336
+ dict(type='RandomCrop', crop_size=(640, 640)),
337
+ dict(type='YOLOXHSVRandomAug'),
338
+ dict(type='RandomFlip', prob=0.5),
339
+ dict(
340
+ type='Pad', size=(640, 640),
341
+ pad_val=dict(img=(114, 114, 114))),
342
+ dict(type='PackDetInputs')
343
+ ])
344
+ ]
345
+ checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
model_zoo/rtmpose/rtmpose-m_8xb256-420e_aic-coco-256x192/rtmpose-m_8xb256-420e_aic-coco-256x192.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmpose'
2
+ default_hooks = dict(
3
+ timer=dict(type='IterTimerHook'),
4
+ logger=dict(type='LoggerHook', interval=50),
5
+ param_scheduler=dict(type='ParamSchedulerHook'),
6
+ checkpoint=dict(
7
+ type='CheckpointHook',
8
+ interval=10,
9
+ save_best='coco/AP',
10
+ rule='greater',
11
+ max_keep_ckpts=1),
12
+ sampler_seed=dict(type='DistSamplerSeedHook'),
13
+ visualization=dict(type='PoseVisualizationHook', enable=False))
14
+ custom_hooks = [
15
+ dict(
16
+ type='EMAHook',
17
+ ema_type='ExpMomentumEMA',
18
+ momentum=0.0002,
19
+ update_buffers=True,
20
+ priority=49),
21
+ dict(
22
+ type='mmdet.PipelineSwitchHook',
23
+ switch_epoch=390,
24
+ switch_pipeline=[
25
+ dict(type='LoadImage', backend_args=dict(backend='local')),
26
+ dict(type='GetBBoxCenterScale'),
27
+ dict(type='RandomFlip', direction='horizontal'),
28
+ dict(type='RandomHalfBody'),
29
+ dict(
30
+ type='RandomBBoxTransform',
31
+ shift_factor=0.0,
32
+ scale_factor=[0.75, 1.25],
33
+ rotate_factor=60),
34
+ dict(type='TopdownAffine', input_size=(192, 256)),
35
+ dict(type='mmdet.YOLOXHSVRandomAug'),
36
+ dict(
37
+ type='Albumentation',
38
+ transforms=[
39
+ dict(type='Blur', p=0.1),
40
+ dict(type='MedianBlur', p=0.1),
41
+ dict(
42
+ type='CoarseDropout',
43
+ max_holes=1,
44
+ max_height=0.4,
45
+ max_width=0.4,
46
+ min_holes=1,
47
+ min_height=0.2,
48
+ min_width=0.2,
49
+ p=0.5)
50
+ ]),
51
+ dict(
52
+ type='GenerateTarget',
53
+ encoder=dict(
54
+ type='SimCCLabel',
55
+ input_size=(192, 256),
56
+ sigma=(4.9, 5.66),
57
+ simcc_split_ratio=2.0,
58
+ normalize=False,
59
+ use_dark=False)),
60
+ dict(type='PackPoseInputs')
61
+ ])
62
+ ]
63
+ env_cfg = dict(
64
+ cudnn_benchmark=False,
65
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
66
+ dist_cfg=dict(backend='nccl'))
67
+ vis_backends = [dict(type='LocalVisBackend')]
68
+ visualizer = dict(
69
+ type='PoseLocalVisualizer',
70
+ vis_backends=[dict(type='LocalVisBackend')],
71
+ name='visualizer')
72
+ log_processor = dict(
73
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
74
+ log_level = 'INFO'
75
+ load_from = None
76
+ resume = False
77
+ backend_args = dict(backend='local')
78
+ train_cfg = dict(by_epoch=True, max_epochs=420, val_interval=10)
79
+ val_cfg = dict()
80
+ test_cfg = dict()
81
+ max_epochs = 420
82
+ stage2_num_epochs = 30
83
+ base_lr = 0.004
84
+ randomness = dict(seed=21)
85
+ optim_wrapper = dict(
86
+ type='OptimWrapper',
87
+ optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.05),
88
+ paramwise_cfg=dict(
89
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
90
+ param_scheduler = [
91
+ dict(
92
+ type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
93
+ end=1000),
94
+ dict(
95
+ type='CosineAnnealingLR',
96
+ eta_min=0.0002,
97
+ begin=210,
98
+ end=420,
99
+ T_max=210,
100
+ by_epoch=True,
101
+ convert_to_iter_based=True)
102
+ ]
103
+ auto_scale_lr = dict(base_batch_size=1024)
104
+ codec = dict(
105
+ type='SimCCLabel',
106
+ input_size=(192, 256),
107
+ sigma=(4.9, 5.66),
108
+ simcc_split_ratio=2.0,
109
+ normalize=False,
110
+ use_dark=False)
111
+ model = dict(
112
+ type='TopdownPoseEstimator',
113
+ data_preprocessor=dict(
114
+ type='PoseDataPreprocessor',
115
+ mean=[123.675, 116.28, 103.53],
116
+ std=[58.395, 57.12, 57.375],
117
+ bgr_to_rgb=True),
118
+ backbone=dict(
119
+ _scope_='mmdet',
120
+ type='CSPNeXt',
121
+ arch='P5',
122
+ expand_ratio=0.5,
123
+ deepen_factor=0.67,
124
+ widen_factor=0.75,
125
+ out_indices=(4, ),
126
+ channel_attention=True,
127
+ norm_cfg=dict(type='SyncBN'),
128
+ act_cfg=dict(type='SiLU'),
129
+ init_cfg=dict(
130
+ type='Pretrained',
131
+ prefix='backbone.',
132
+ checkpoint=
133
+ 'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth'
134
+ )),
135
+ head=dict(
136
+ type='RTMCCHead',
137
+ in_channels=768,
138
+ out_channels=17,
139
+ input_size=(192, 256),
140
+ in_featuremap_size=(6, 8),
141
+ simcc_split_ratio=2.0,
142
+ final_layer_kernel_size=7,
143
+ gau_cfg=dict(
144
+ hidden_dims=256,
145
+ s=128,
146
+ expansion_factor=2,
147
+ dropout_rate=0.0,
148
+ drop_path=0.0,
149
+ act_fn='SiLU',
150
+ use_rel_bias=False,
151
+ pos_enc=False),
152
+ loss=dict(
153
+ type='KLDiscretLoss',
154
+ use_target_weight=True,
155
+ beta=10.0,
156
+ label_softmax=True),
157
+ decoder=dict(
158
+ type='SimCCLabel',
159
+ input_size=(192, 256),
160
+ sigma=(4.9, 5.66),
161
+ simcc_split_ratio=2.0,
162
+ normalize=False,
163
+ use_dark=False)),
164
+ test_cfg=dict(flip_test=True))
165
+ dataset_type = 'CocoDataset'
166
+ data_mode = 'topdown'
167
+ data_root = 'data/'
168
+ train_pipeline = [
169
+ dict(type='LoadImage', backend_args=dict(backend='local')),
170
+ dict(type='GetBBoxCenterScale'),
171
+ dict(type='RandomFlip', direction='horizontal'),
172
+ dict(type='RandomHalfBody'),
173
+ dict(
174
+ type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
175
+ dict(type='TopdownAffine', input_size=(192, 256)),
176
+ dict(type='mmdet.YOLOXHSVRandomAug'),
177
+ dict(
178
+ type='Albumentation',
179
+ transforms=[
180
+ dict(type='Blur', p=0.1),
181
+ dict(type='MedianBlur', p=0.1),
182
+ dict(
183
+ type='CoarseDropout',
184
+ max_holes=1,
185
+ max_height=0.4,
186
+ max_width=0.4,
187
+ min_holes=1,
188
+ min_height=0.2,
189
+ min_width=0.2,
190
+ p=1.0)
191
+ ]),
192
+ dict(
193
+ type='GenerateTarget',
194
+ encoder=dict(
195
+ type='SimCCLabel',
196
+ input_size=(192, 256),
197
+ sigma=(4.9, 5.66),
198
+ simcc_split_ratio=2.0,
199
+ normalize=False,
200
+ use_dark=False)),
201
+ dict(type='PackPoseInputs')
202
+ ]
203
+ val_pipeline = [
204
+ dict(type='LoadImage', backend_args=dict(backend='local')),
205
+ dict(type='GetBBoxCenterScale'),
206
+ dict(type='TopdownAffine', input_size=(192, 256)),
207
+ dict(type='PackPoseInputs')
208
+ ]
209
+ train_pipeline_stage2 = [
210
+ dict(type='LoadImage', backend_args=dict(backend='local')),
211
+ dict(type='GetBBoxCenterScale'),
212
+ dict(type='RandomFlip', direction='horizontal'),
213
+ dict(type='RandomHalfBody'),
214
+ dict(
215
+ type='RandomBBoxTransform',
216
+ shift_factor=0.0,
217
+ scale_factor=[0.75, 1.25],
218
+ rotate_factor=60),
219
+ dict(type='TopdownAffine', input_size=(192, 256)),
220
+ dict(type='mmdet.YOLOXHSVRandomAug'),
221
+ dict(
222
+ type='Albumentation',
223
+ transforms=[
224
+ dict(type='Blur', p=0.1),
225
+ dict(type='MedianBlur', p=0.1),
226
+ dict(
227
+ type='CoarseDropout',
228
+ max_holes=1,
229
+ max_height=0.4,
230
+ max_width=0.4,
231
+ min_holes=1,
232
+ min_height=0.2,
233
+ min_width=0.2,
234
+ p=0.5)
235
+ ]),
236
+ dict(
237
+ type='GenerateTarget',
238
+ encoder=dict(
239
+ type='SimCCLabel',
240
+ input_size=(192, 256),
241
+ sigma=(4.9, 5.66),
242
+ simcc_split_ratio=2.0,
243
+ normalize=False,
244
+ use_dark=False)),
245
+ dict(type='PackPoseInputs')
246
+ ]
247
+ dataset_coco = dict(
248
+ type='RepeatDataset',
249
+ dataset=dict(
250
+ type='CocoDataset',
251
+ data_root='data/',
252
+ data_mode='topdown',
253
+ ann_file='coco/annotations/person_keypoints_train2017.json',
254
+ data_prefix=dict(img='detection/coco/train2017/'),
255
+ pipeline=[]),
256
+ times=3)
257
+ dataset_aic = dict(
258
+ type='AicDataset',
259
+ data_root='data/',
260
+ data_mode='topdown',
261
+ ann_file='aic/annotations/aic_train.json',
262
+ data_prefix=dict(
263
+ img=
264
+ 'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
265
+ ),
266
+ pipeline=[
267
+ dict(
268
+ type='KeypointConverter',
269
+ num_keypoints=17,
270
+ mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
271
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)])
272
+ ])
273
+ train_dataloader = dict(
274
+ batch_size=256,
275
+ num_workers=10,
276
+ persistent_workers=True,
277
+ sampler=dict(type='DefaultSampler', shuffle=True),
278
+ dataset=dict(
279
+ type='CombinedDataset',
280
+ metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
281
+ datasets=[
282
+ dict(
283
+ type='RepeatDataset',
284
+ dataset=dict(
285
+ type='CocoDataset',
286
+ data_root='data/',
287
+ data_mode='topdown',
288
+ ann_file='coco/annotations/person_keypoints_train2017.json',
289
+ data_prefix=dict(img='detection/coco/train2017/'),
290
+ pipeline=[]),
291
+ times=3),
292
+ dict(
293
+ type='AicDataset',
294
+ data_root='data/',
295
+ data_mode='topdown',
296
+ ann_file='aic/annotations/aic_train.json',
297
+ data_prefix=dict(
298
+ img=
299
+ 'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
300
+ ),
301
+ pipeline=[
302
+ dict(
303
+ type='KeypointConverter',
304
+ num_keypoints=17,
305
+ mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7),
306
+ (5, 9), (6, 12), (7, 14), (8, 16), (9, 11),
307
+ (10, 13), (11, 15)])
308
+ ])
309
+ ],
310
+ pipeline=[
311
+ dict(type='LoadImage', backend_args=dict(backend='local')),
312
+ dict(type='GetBBoxCenterScale'),
313
+ dict(type='RandomFlip', direction='horizontal'),
314
+ dict(type='RandomHalfBody'),
315
+ dict(
316
+ type='RandomBBoxTransform',
317
+ scale_factor=[0.6, 1.4],
318
+ rotate_factor=80),
319
+ dict(type='TopdownAffine', input_size=(192, 256)),
320
+ dict(type='mmdet.YOLOXHSVRandomAug'),
321
+ dict(
322
+ type='Albumentation',
323
+ transforms=[
324
+ dict(type='Blur', p=0.1),
325
+ dict(type='MedianBlur', p=0.1),
326
+ dict(
327
+ type='CoarseDropout',
328
+ max_holes=1,
329
+ max_height=0.4,
330
+ max_width=0.4,
331
+ min_holes=1,
332
+ min_height=0.2,
333
+ min_width=0.2,
334
+ p=1.0)
335
+ ]),
336
+ dict(
337
+ type='GenerateTarget',
338
+ encoder=dict(
339
+ type='SimCCLabel',
340
+ input_size=(192, 256),
341
+ sigma=(4.9, 5.66),
342
+ simcc_split_ratio=2.0,
343
+ normalize=False,
344
+ use_dark=False)),
345
+ dict(type='PackPoseInputs')
346
+ ],
347
+ test_mode=False))
348
+ val_dataloader = dict(
349
+ batch_size=64,
350
+ num_workers=10,
351
+ persistent_workers=True,
352
+ drop_last=False,
353
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
354
+ dataset=dict(
355
+ type='CocoDataset',
356
+ data_root='data/',
357
+ data_mode='topdown',
358
+ ann_file='coco/annotations/person_keypoints_val2017.json',
359
+ data_prefix=dict(img='detection/coco/val2017/'),
360
+ test_mode=True,
361
+ pipeline=[
362
+ dict(type='LoadImage', backend_args=dict(backend='local')),
363
+ dict(type='GetBBoxCenterScale'),
364
+ dict(type='TopdownAffine', input_size=(192, 256)),
365
+ dict(type='PackPoseInputs')
366
+ ]))
367
+ test_dataloader = dict(
368
+ batch_size=64,
369
+ num_workers=10,
370
+ persistent_workers=True,
371
+ drop_last=False,
372
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
373
+ dataset=dict(
374
+ type='CocoDataset',
375
+ data_root='data/',
376
+ data_mode='topdown',
377
+ ann_file='coco/annotations/person_keypoints_val2017.json',
378
+ data_prefix=dict(img='detection/coco/val2017/'),
379
+ test_mode=True,
380
+ pipeline=[
381
+ dict(type='LoadImage', backend_args=dict(backend='local')),
382
+ dict(type='GetBBoxCenterScale'),
383
+ dict(type='TopdownAffine', input_size=(192, 256)),
384
+ dict(type='PackPoseInputs')
385
+ ]))
386
+ val_evaluator = dict(
387
+ type='CocoMetric',
388
+ ann_file='data/coco/annotations/person_keypoints_val2017.json')
389
+ test_evaluator = dict(
390
+ type='CocoMetric',
391
+ ann_file='data/coco/annotations/person_keypoints_val2017.json')
model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmpose'
2
+ default_hooks = dict(
3
+ timer=dict(type='IterTimerHook'),
4
+ logger=dict(type='LoggerHook', interval=50),
5
+ param_scheduler=dict(type='ParamSchedulerHook'),
6
+ checkpoint=dict(
7
+ type='CheckpointHook',
8
+ interval=10,
9
+ save_best='coco/AP',
10
+ rule='greater',
11
+ max_keep_ckpts=1),
12
+ sampler_seed=dict(type='DistSamplerSeedHook'),
13
+ visualization=dict(type='PoseVisualizationHook', enable=False))
14
+ custom_hooks = [
15
+ dict(
16
+ type='mmdet.PipelineSwitchHook',
17
+ switch_epoch=390,
18
+ switch_pipeline=[
19
+ dict(type='LoadImage', backend_args=dict(backend='local')),
20
+ dict(type='GetBBoxCenterScale'),
21
+ dict(type='RandomFlip', direction='horizontal'),
22
+ dict(type='RandomHalfBody'),
23
+ dict(
24
+ type='RandomBBoxTransform',
25
+ shift_factor=0.0,
26
+ scale_factor=[0.75, 1.25],
27
+ rotate_factor=60),
28
+ dict(type='TopdownAffine', input_size=(192, 256)),
29
+ dict(type='mmdet.YOLOXHSVRandomAug'),
30
+ dict(
31
+ type='Albumentation',
32
+ transforms=[
33
+ dict(type='Blur', p=0.1),
34
+ dict(type='MedianBlur', p=0.1),
35
+ dict(
36
+ type='CoarseDropout',
37
+ max_holes=1,
38
+ max_height=0.4,
39
+ max_width=0.4,
40
+ min_holes=1,
41
+ min_height=0.2,
42
+ min_width=0.2,
43
+ p=0.5)
44
+ ]),
45
+ dict(
46
+ type='GenerateTarget',
47
+ encoder=dict(
48
+ type='SimCCLabel',
49
+ input_size=(192, 256),
50
+ sigma=(4.9, 5.66),
51
+ simcc_split_ratio=2.0,
52
+ normalize=False,
53
+ use_dark=False)),
54
+ dict(type='PackPoseInputs')
55
+ ])
56
+ ]
57
+ env_cfg = dict(
58
+ cudnn_benchmark=False,
59
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
60
+ dist_cfg=dict(backend='nccl'))
61
+ vis_backends = [dict(type='LocalVisBackend')]
62
+ visualizer = dict(
63
+ type='PoseLocalVisualizer',
64
+ vis_backends=[dict(type='LocalVisBackend')],
65
+ name='visualizer')
66
+ log_processor = dict(
67
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
68
+ log_level = 'INFO'
69
+ load_from = None
70
+ resume = False
71
+ backend_args = dict(backend='local')
72
+ train_cfg = dict(by_epoch=True, max_epochs=420, val_interval=10)
73
+ val_cfg = dict()
74
+ test_cfg = dict()
75
+ max_epochs = 420
76
+ stage2_num_epochs = 30
77
+ base_lr = 0.004
78
+ randomness = dict(seed=21)
79
+ optim_wrapper = dict(
80
+ type='OptimWrapper',
81
+ optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.0),
82
+ paramwise_cfg=dict(
83
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
84
+ param_scheduler = [
85
+ dict(
86
+ type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
87
+ end=1000),
88
+ dict(
89
+ type='CosineAnnealingLR',
90
+ eta_min=0.0002,
91
+ begin=210,
92
+ end=420,
93
+ T_max=210,
94
+ by_epoch=True,
95
+ convert_to_iter_based=True)
96
+ ]
97
+ auto_scale_lr = dict(base_batch_size=1024)
98
+ codec = dict(
99
+ type='SimCCLabel',
100
+ input_size=(192, 256),
101
+ sigma=(4.9, 5.66),
102
+ simcc_split_ratio=2.0,
103
+ normalize=False,
104
+ use_dark=False)
105
+ model = dict(
106
+ type='TopdownPoseEstimator',
107
+ data_preprocessor=dict(
108
+ type='PoseDataPreprocessor',
109
+ mean=[123.675, 116.28, 103.53],
110
+ std=[58.395, 57.12, 57.375],
111
+ bgr_to_rgb=True),
112
+ backbone=dict(
113
+ _scope_='mmdet',
114
+ type='CSPNeXt',
115
+ arch='P5',
116
+ expand_ratio=0.5,
117
+ deepen_factor=0.167,
118
+ widen_factor=0.375,
119
+ out_indices=(4, ),
120
+ channel_attention=True,
121
+ norm_cfg=dict(type='SyncBN'),
122
+ act_cfg=dict(type='SiLU'),
123
+ init_cfg=dict(
124
+ type='Pretrained',
125
+ prefix='backbone.',
126
+ checkpoint=
127
+ 'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth'
128
+ )),
129
+ head=dict(
130
+ type='RTMCCHead',
131
+ in_channels=384,
132
+ out_channels=17,
133
+ input_size=(192, 256),
134
+ in_featuremap_size=(6, 8),
135
+ simcc_split_ratio=2.0,
136
+ final_layer_kernel_size=7,
137
+ gau_cfg=dict(
138
+ hidden_dims=256,
139
+ s=128,
140
+ expansion_factor=2,
141
+ dropout_rate=0.0,
142
+ drop_path=0.0,
143
+ act_fn='SiLU',
144
+ use_rel_bias=False,
145
+ pos_enc=False),
146
+ loss=dict(
147
+ type='KLDiscretLoss',
148
+ use_target_weight=True,
149
+ beta=10.0,
150
+ label_softmax=True),
151
+ decoder=dict(
152
+ type='SimCCLabel',
153
+ input_size=(192, 256),
154
+ sigma=(4.9, 5.66),
155
+ simcc_split_ratio=2.0,
156
+ normalize=False,
157
+ use_dark=False)),
158
+ test_cfg=dict(flip_test=True))
159
+ dataset_type = 'CocoDataset'
160
+ data_mode = 'topdown'
161
+ data_root = 'data/'
162
+ train_pipeline = [
163
+ dict(type='LoadImage', backend_args=dict(backend='local')),
164
+ dict(type='GetBBoxCenterScale'),
165
+ dict(type='RandomFlip', direction='horizontal'),
166
+ dict(type='RandomHalfBody'),
167
+ dict(
168
+ type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
169
+ dict(type='TopdownAffine', input_size=(192, 256)),
170
+ dict(type='mmdet.YOLOXHSVRandomAug'),
171
+ dict(
172
+ type='Albumentation',
173
+ transforms=[
174
+ dict(type='Blur', p=0.1),
175
+ dict(type='MedianBlur', p=0.1),
176
+ dict(
177
+ type='CoarseDropout',
178
+ max_holes=1,
179
+ max_height=0.4,
180
+ max_width=0.4,
181
+ min_holes=1,
182
+ min_height=0.2,
183
+ min_width=0.2,
184
+ p=1.0)
185
+ ]),
186
+ dict(
187
+ type='GenerateTarget',
188
+ encoder=dict(
189
+ type='SimCCLabel',
190
+ input_size=(192, 256),
191
+ sigma=(4.9, 5.66),
192
+ simcc_split_ratio=2.0,
193
+ normalize=False,
194
+ use_dark=False)),
195
+ dict(type='PackPoseInputs')
196
+ ]
197
+ val_pipeline = [
198
+ dict(type='LoadImage', backend_args=dict(backend='local')),
199
+ dict(type='GetBBoxCenterScale'),
200
+ dict(type='TopdownAffine', input_size=(192, 256)),
201
+ dict(type='PackPoseInputs')
202
+ ]
203
+ train_pipeline_stage2 = [
204
+ dict(type='LoadImage', backend_args=dict(backend='local')),
205
+ dict(type='GetBBoxCenterScale'),
206
+ dict(type='RandomFlip', direction='horizontal'),
207
+ dict(type='RandomHalfBody'),
208
+ dict(
209
+ type='RandomBBoxTransform',
210
+ shift_factor=0.0,
211
+ scale_factor=[0.75, 1.25],
212
+ rotate_factor=60),
213
+ dict(type='TopdownAffine', input_size=(192, 256)),
214
+ dict(type='mmdet.YOLOXHSVRandomAug'),
215
+ dict(
216
+ type='Albumentation',
217
+ transforms=[
218
+ dict(type='Blur', p=0.1),
219
+ dict(type='MedianBlur', p=0.1),
220
+ dict(
221
+ type='CoarseDropout',
222
+ max_holes=1,
223
+ max_height=0.4,
224
+ max_width=0.4,
225
+ min_holes=1,
226
+ min_height=0.2,
227
+ min_width=0.2,
228
+ p=0.5)
229
+ ]),
230
+ dict(
231
+ type='GenerateTarget',
232
+ encoder=dict(
233
+ type='SimCCLabel',
234
+ input_size=(192, 256),
235
+ sigma=(4.9, 5.66),
236
+ simcc_split_ratio=2.0,
237
+ normalize=False,
238
+ use_dark=False)),
239
+ dict(type='PackPoseInputs')
240
+ ]
241
+ dataset_coco = dict(
242
+ type='RepeatDataset',
243
+ dataset=dict(
244
+ type='CocoDataset',
245
+ data_root='data/',
246
+ data_mode='topdown',
247
+ ann_file='coco/annotations/person_keypoints_train2017.json',
248
+ data_prefix=dict(img='detection/coco/train2017/'),
249
+ pipeline=[]),
250
+ times=3)
251
+ dataset_aic = dict(
252
+ type='AicDataset',
253
+ data_root='data/',
254
+ data_mode='topdown',
255
+ ann_file='aic/annotations/aic_train.json',
256
+ data_prefix=dict(
257
+ img=
258
+ 'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
259
+ ),
260
+ pipeline=[
261
+ dict(
262
+ type='KeypointConverter',
263
+ num_keypoints=17,
264
+ mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
265
+ (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)])
266
+ ])
267
+ train_dataloader = dict(
268
+ batch_size=256,
269
+ num_workers=10,
270
+ persistent_workers=True,
271
+ sampler=dict(type='DefaultSampler', shuffle=True),
272
+ dataset=dict(
273
+ type='CombinedDataset',
274
+ metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
275
+ datasets=[
276
+ dict(
277
+ type='RepeatDataset',
278
+ dataset=dict(
279
+ type='CocoDataset',
280
+ data_root='data/',
281
+ data_mode='topdown',
282
+ ann_file='coco/annotations/person_keypoints_train2017.json',
283
+ data_prefix=dict(img='detection/coco/train2017/'),
284
+ pipeline=[]),
285
+ times=3),
286
+ dict(
287
+ type='AicDataset',
288
+ data_root='data/',
289
+ data_mode='topdown',
290
+ ann_file='aic/annotations/aic_train.json',
291
+ data_prefix=dict(
292
+ img=
293
+ 'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
294
+ ),
295
+ pipeline=[
296
+ dict(
297
+ type='KeypointConverter',
298
+ num_keypoints=17,
299
+ mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7),
300
+ (5, 9), (6, 12), (7, 14), (8, 16), (9, 11),
301
+ (10, 13), (11, 15)])
302
+ ])
303
+ ],
304
+ pipeline=[
305
+ dict(type='LoadImage', backend_args=dict(backend='local')),
306
+ dict(type='GetBBoxCenterScale'),
307
+ dict(type='RandomFlip', direction='horizontal'),
308
+ dict(type='RandomHalfBody'),
309
+ dict(
310
+ type='RandomBBoxTransform',
311
+ scale_factor=[0.6, 1.4],
312
+ rotate_factor=80),
313
+ dict(type='TopdownAffine', input_size=(192, 256)),
314
+ dict(type='mmdet.YOLOXHSVRandomAug'),
315
+ dict(
316
+ type='Albumentation',
317
+ transforms=[
318
+ dict(type='Blur', p=0.1),
319
+ dict(type='MedianBlur', p=0.1),
320
+ dict(
321
+ type='CoarseDropout',
322
+ max_holes=1,
323
+ max_height=0.4,
324
+ max_width=0.4,
325
+ min_holes=1,
326
+ min_height=0.2,
327
+ min_width=0.2,
328
+ p=1.0)
329
+ ]),
330
+ dict(
331
+ type='GenerateTarget',
332
+ encoder=dict(
333
+ type='SimCCLabel',
334
+ input_size=(192, 256),
335
+ sigma=(4.9, 5.66),
336
+ simcc_split_ratio=2.0,
337
+ normalize=False,
338
+ use_dark=False)),
339
+ dict(type='PackPoseInputs')
340
+ ],
341
+ test_mode=False))
342
+ val_dataloader = dict(
343
+ batch_size=64,
344
+ num_workers=10,
345
+ persistent_workers=True,
346
+ drop_last=False,
347
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
348
+ dataset=dict(
349
+ type='CocoDataset',
350
+ data_root='data/',
351
+ data_mode='topdown',
352
+ ann_file='coco/annotations/person_keypoints_val2017.json',
353
+ data_prefix=dict(img='detection/coco/val2017/'),
354
+ test_mode=True,
355
+ pipeline=[
356
+ dict(type='LoadImage', backend_args=dict(backend='local')),
357
+ dict(type='GetBBoxCenterScale'),
358
+ dict(type='TopdownAffine', input_size=(192, 256)),
359
+ dict(type='PackPoseInputs')
360
+ ]))
361
+ test_dataloader = dict(
362
+ batch_size=64,
363
+ num_workers=10,
364
+ persistent_workers=True,
365
+ drop_last=False,
366
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
367
+ dataset=dict(
368
+ type='CocoDataset',
369
+ data_root='data/',
370
+ data_mode='topdown',
371
+ ann_file='coco/annotations/person_keypoints_val2017.json',
372
+ data_prefix=dict(img='detection/coco/val2017/'),
373
+ test_mode=True,
374
+ pipeline=[
375
+ dict(type='LoadImage', backend_args=dict(backend='local')),
376
+ dict(type='GetBBoxCenterScale'),
377
+ dict(type='TopdownAffine', input_size=(192, 256)),
378
+ dict(type='PackPoseInputs')
379
+ ]))
380
+ val_evaluator = dict(
381
+ type='CocoMetric',
382
+ ann_file='data/coco/annotations/person_keypoints_val2017.json')
383
+ test_evaluator = dict(
384
+ type='CocoMetric',
385
+ ann_file='data/coco/annotations/person_keypoints_val2017.json')
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openmim
2
+ torch
3
+ mmengine
4
+ mmcv
5
+ mmdet
6
+ mmpose
7
+ mmdeploy
8
+ onnxruntime
9
+ tqdm
10
+ scikit-image
11
+ easydict
12
+ gradio
tools/apis.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmengine.registry import MODELS
3
+ from mmengine.dataset import Compose, pseudo_collate
4
+ from mmengine.model.utils import revert_sync_batchnorm
5
+ from mmengine.registry import init_default_scope
6
+ from mmengine.runner import load_checkpoint
7
+ from mmengine.config import Config
8
+
9
+ from mmdeploy.utils import get_input_shape, load_config
10
+ from mmdeploy.apis.utils import build_task_processor
11
+
12
+ def build_model(cfg, checkpoint=None, device='cpu'):
13
+ """ Build model from config and load checkpoint
14
+ checkpoint_meta usually contains dataset classes information
15
+ """
16
+ if isinstance(cfg, str):
17
+ cfg = Config.fromfile(cfg)
18
+ # scope of model, e.g. mmdet, mmseg, mmpose...
19
+ init_default_scope(cfg.default_scope)
20
+ model = MODELS.build(cfg.model)
21
+ model = revert_sync_batchnorm(model)
22
+ if checkpoint is not None:
23
+ ckpt = load_checkpoint(model, checkpoint,
24
+ map_location='cpu')
25
+ checkpoint_meta = ckpt.get('meta', {})
26
+ # usually classes and pallate are in checkpoint_meta
27
+ model.checkpoint_meta = checkpoint_meta
28
+ model.to(device)
29
+ model.eval()
30
+ return model
31
+
32
+ def inference(model, cfg, img):
33
+ """ Given model, config and image, return inference results.
34
+ Models in mmlab does not share the same inference api. So this
35
+ function is just a memo for me...
36
+ """
37
+ if isinstance(cfg, str):
38
+ cfg = Config.fromfile(cfg)
39
+ # process pipline
40
+ test_pipeline = cfg.test_dataloader.dataset.pipeline
41
+ # Use 'LoadImage' to handle both cases of img and img_path
42
+ # This is specially designed for mmdet config, which uses 'LoadImageFromFile'
43
+ for pipeline in test_pipeline:
44
+ if 'LoadImage' in pipeline['type']:
45
+ pipeline['type'] = 'mmpose.LoadImage'
46
+
47
+ init_default_scope(cfg.default_scope)
48
+ pipeline = Compose(test_pipeline)
49
+
50
+ if isinstance(img, str):
51
+ # img_id is useless...but to be compatible with mmdet
52
+ data_info = dict(img_path=img, img_id=0)
53
+ else:
54
+ data_info = dict(img=img, img_id=0)
55
+
56
+ data = pipeline(data_info)
57
+ batch = pseudo_collate([data])
58
+
59
+ with torch.no_grad():
60
+ results = model.test_step(batch)
61
+
62
+ return results
63
+
64
+ def build_onnx_model_and_task_processor(model_cfg, deploy_cfg, backend_files, device):
65
+
66
+ deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
67
+
68
+ task_processor = build_task_processor(model_cfg, deploy_cfg, device)
69
+
70
+ model = task_processor.build_backend_model(
71
+ backend_files, task_processor.update_data_preprocessor)
72
+
73
+ return model, task_processor
74
+
75
+ def inference_onnx_model(model, task_processor, deploy_cfg, img):
76
+ input_shape = get_input_shape(deploy_cfg)
77
+ model_inputs, _ = task_processor.create_input(img, input_shape)
78
+
79
+ with torch.no_grad():
80
+ result = model.test_step(model_inputs)
81
+
82
+ return result
83
+
84
+ if __name__ == '__main__':
85
+ config = '/github/Tennis.ai/model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py'
86
+ ckpt = '/github/Tennis.ai/model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
87
+ img = '/github/Tennis.ai/assets/000000197388.jpg'
88
+
89
+ detector = build_model(config, checkpoint=ckpt)
90
+ result = inference(detector, config, img)
tools/deploy.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Modified from mmdeploy/tools/deploy.py, removed some codes to only focus on ONNX report
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import os.path as osp
7
+ from functools import partial
8
+
9
+ import mmengine
10
+ import torch.multiprocessing as mp
11
+ from torch.multiprocessing import Process, set_start_method
12
+
13
+ from mmdeploy.apis import (create_calib_input_data, extract_model,
14
+ get_predefined_partition_cfg, torch2onnx,
15
+ torch2torchscript, visualize_model)
16
+ from mmdeploy.apis.core import PIPELINE_MANAGER
17
+ from mmdeploy.apis.utils import to_backend
18
+ from mmdeploy.backend.sdk.export_info import export2SDK
19
+ from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
20
+ get_ir_config, get_partition_config,
21
+ get_root_logger, load_config, target_wrapper)
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description='Export model to backends.')
26
+ parser.add_argument('deploy_cfg', help='deploy config path')
27
+ parser.add_argument('model_cfg', help='model config path')
28
+ parser.add_argument('checkpoint', help='model checkpoint path')
29
+ parser.add_argument('img', help='image used to convert model model')
30
+ parser.add_argument(
31
+ '--test-img',
32
+ default=None,
33
+ type=str,
34
+ nargs='+',
35
+ help='image used to test model')
36
+ parser.add_argument(
37
+ '--work-dir',
38
+ default=os.getcwd(),
39
+ help='the dir to save logs and models')
40
+ parser.add_argument(
41
+ '--calib-dataset-cfg',
42
+ help='dataset config path used to calibrate in int8 mode. If not \
43
+ specified, it will use "val" dataset in model config instead.',
44
+ default=None)
45
+ parser.add_argument(
46
+ '--device', help='device used for conversion', default='cpu')
47
+ parser.add_argument(
48
+ '--log-level',
49
+ help='set log level',
50
+ default='INFO',
51
+ choices=list(logging._nameToLevel.keys()))
52
+ parser.add_argument(
53
+ '--show', action='store_true', help='Show detection outputs')
54
+ parser.add_argument(
55
+ '--dump-info', action='store_true', help='Output information for SDK')
56
+ parser.add_argument(
57
+ '--quant-image-dir',
58
+ default=None,
59
+ help='Image directory for quantize model.')
60
+ parser.add_argument(
61
+ '--quant', action='store_true', help='Quantize model to low bit.')
62
+ parser.add_argument(
63
+ '--uri',
64
+ default='192.168.1.1:60000',
65
+ help='Remote ipv4:port or ipv6:port for inference on edge device.')
66
+ args = parser.parse_args()
67
+ return args
68
+
69
+
70
+ def create_process(name, target, args, kwargs, ret_value=None):
71
+ logger = get_root_logger()
72
+ logger.info(f'{name} start.')
73
+ log_level = logger.level
74
+
75
+ wrap_func = partial(target_wrapper, target, log_level, ret_value)
76
+
77
+ process = Process(target=wrap_func, args=args, kwargs=kwargs)
78
+ process.start()
79
+ process.join()
80
+
81
+ if ret_value is not None:
82
+ if ret_value.value != 0:
83
+ logger.error(f'{name} failed.')
84
+ exit(1)
85
+ else:
86
+ logger.info(f'{name} success.')
87
+
88
+
89
+ def torch2ir(ir_type: IR):
90
+ """Return the conversion function from torch to the intermediate
91
+ representation.
92
+
93
+ Args:
94
+ ir_type (IR): The type of the intermediate representation.
95
+ """
96
+ if ir_type == IR.ONNX:
97
+ return torch2onnx
98
+ elif ir_type == IR.TORCHSCRIPT:
99
+ return torch2torchscript
100
+ else:
101
+ raise KeyError(f'Unexpected IR type {ir_type}')
102
+
103
+
104
+ def main():
105
+ args = parse_args()
106
+ set_start_method('spawn', force=True)
107
+ logger = get_root_logger()
108
+ log_level = logging.getLevelName(args.log_level)
109
+ logger.setLevel(log_level)
110
+
111
+ pipeline_funcs = [
112
+ torch2onnx, torch2torchscript, extract_model, create_calib_input_data
113
+ ]
114
+ PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
115
+ PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
116
+
117
+ deploy_cfg_path = args.deploy_cfg
118
+ model_cfg_path = args.model_cfg
119
+ checkpoint_path = args.checkpoint
120
+ quant = args.quant
121
+ quant_image_dir = args.quant_image_dir
122
+
123
+ # load deploy_cfg
124
+ deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
125
+
126
+ # create work_dir if not
127
+ mmengine.mkdir_or_exist(osp.abspath(args.work_dir))
128
+
129
+ if args.dump_info:
130
+ export2SDK(
131
+ deploy_cfg,
132
+ model_cfg,
133
+ args.work_dir,
134
+ pth=checkpoint_path,
135
+ device=args.device)
136
+
137
+ ret_value = mp.Value('d', 0, lock=False)
138
+
139
+ # convert to IR
140
+ ir_config = get_ir_config(deploy_cfg)
141
+ ir_save_file = ir_config['save_file']
142
+ ir_type = IR.get(ir_config['type'])
143
+ torch2ir(ir_type)(
144
+ args.img,
145
+ args.work_dir,
146
+ ir_save_file,
147
+ deploy_cfg_path,
148
+ model_cfg_path,
149
+ checkpoint_path,
150
+ device=args.device)
151
+
152
+ # convert backend
153
+ ir_files = [osp.join(args.work_dir, ir_save_file)]
154
+
155
+ # partition model
156
+ partition_cfgs = get_partition_config(deploy_cfg)
157
+
158
+ if partition_cfgs is not None:
159
+
160
+ if 'partition_cfg' in partition_cfgs:
161
+ partition_cfgs = partition_cfgs.get('partition_cfg', None)
162
+ else:
163
+ assert 'type' in partition_cfgs
164
+ partition_cfgs = get_predefined_partition_cfg(
165
+ deploy_cfg, partition_cfgs['type'])
166
+
167
+ origin_ir_file = ir_files[0]
168
+ ir_files = []
169
+ for partition_cfg in partition_cfgs:
170
+ save_file = partition_cfg['save_file']
171
+ save_path = osp.join(args.work_dir, save_file)
172
+ start = partition_cfg['start']
173
+ end = partition_cfg['end']
174
+ dynamic_axes = partition_cfg.get('dynamic_axes', None)
175
+
176
+ extract_model(
177
+ origin_ir_file,
178
+ start,
179
+ end,
180
+ dynamic_axes=dynamic_axes,
181
+ save_file=save_path)
182
+
183
+ ir_files.append(save_path)
184
+
185
+ backend_files = ir_files
186
+ # convert backend
187
+ backend = get_backend(deploy_cfg)
188
+
189
+ # convert to backend
190
+ PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
191
+ if backend == Backend.TENSORRT:
192
+ PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
193
+ backend_files = to_backend(
194
+ backend,
195
+ ir_files,
196
+ work_dir=args.work_dir,
197
+ deploy_cfg=deploy_cfg,
198
+ log_level=log_level,
199
+ device=args.device,
200
+ uri=args.uri)
201
+
202
+ if args.test_img is None:
203
+ args.test_img = args.img
204
+
205
+ extra = dict(
206
+ backend=backend,
207
+ output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
208
+ show_result=args.show)
209
+ if backend == Backend.SNPE:
210
+ extra['uri'] = args.uri
211
+
212
+ # get backend inference result, try render
213
+ create_process(
214
+ f'visualize {backend.value} model',
215
+ target=visualize_model,
216
+ args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
217
+ args.device),
218
+ kwargs=extra,
219
+ ret_value=ret_value)
220
+
221
+ # get pytorch model inference result, try visualize if possible
222
+ create_process(
223
+ 'visualize pytorch model',
224
+ target=visualize_model,
225
+ args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
226
+ args.test_img, args.device),
227
+ kwargs=dict(
228
+ backend=Backend.PYTORCH,
229
+ output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
230
+ show_result=args.show),
231
+ ret_value=ret_value)
232
+ logger.info('All process success.')
233
+
234
+
235
+ if __name__ == '__main__':
236
+ main()
tools/dtw.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .utils import get_keypoint_weight
3
+
4
+
5
+ class DTWForKeypoints:
6
+ def __init__(self, keypoints1, keypoints2):
7
+ self.keypoints1 = keypoints1
8
+ self.keypoints2 = keypoints2
9
+
10
+ def get_dtw_path(self):
11
+
12
+ norm_kp1 = self.normalize_keypoints(self.keypoints1)
13
+ norm_kp2 = self.normalize_keypoints(self.keypoints2)
14
+
15
+ kp_weight = get_keypoint_weight()
16
+ oks, oks_unnorm = self.object_keypoint_similarity(norm_kp1,
17
+ norm_kp2, keypoint_weights=kp_weight)
18
+ print(f"OKS max {oks.max():.2f} min {oks.min():.2f}")
19
+
20
+ # do the DTW, and return the path
21
+ cost_matrix = 1 - oks
22
+ dtw_dist, dtw_path = self.dynamic_time_warp(cost_matrix)
23
+
24
+ return dtw_path, oks, oks_unnorm
25
+
26
+ def normalize_keypoints(self, keypoints):
27
+ centroid = keypoints.mean(axis=1)[:, None]
28
+ max_distance = np.max(np.sqrt(np.sum((keypoints - centroid) ** 2, axis=2)),
29
+ axis=1) + 1e-6
30
+
31
+ normalized_keypoints = (keypoints - centroid) / max_distance[:, None, None]
32
+ return normalized_keypoints
33
+
34
+ def keypoints_areas(self, keypoints):
35
+ min_coords = np.min(keypoints, axis=1)
36
+ max_coords = np.max(keypoints, axis=1)
37
+ areas = np.prod(max_coords - min_coords, axis=1)
38
+ return areas
39
+
40
+ def object_keypoint_similarity(self, keypoints1,
41
+ keypoints2,
42
+ scale_constant=0.2,
43
+ keypoint_weights=None):
44
+ """ Calculate the Object Keypoint Similarity (OKS) for multiple objects,
45
+ and add weight to each keypoint. Here we choose to normalize the points
46
+ using centroid and max distance instead of bounding box area.
47
+ """
48
+ # Compute squared distances between all pairs of keypoints
49
+ sq_diff = np.sum((keypoints1[:, None] - keypoints2) ** 2, axis=-1)
50
+
51
+ oks = np.exp(-sq_diff / (2 * scale_constant ** 2))
52
+ oks_unnorm = oks.copy()
53
+
54
+ if keypoint_weights is not None:
55
+ oks = oks * keypoint_weights
56
+ oks = np.sum(oks, axis=-1)
57
+ else:
58
+ oks = np.mean(oks, axis=-1)
59
+
60
+ return oks, oks_unnorm
61
+
62
+ def dynamic_time_warp(self, cost_matrix, R=1000):
63
+ """Compute the Dynamic Time Warping distance and path between two time series.
64
+ If the time series is too long, it will use the Sakoe-Chiba Band constraint,
65
+ so time complexity is bounded at O(MR).
66
+ """
67
+
68
+ M = len(self.keypoints1)
69
+ N = len(self.keypoints2)
70
+
71
+ # Initialize the distance matrix with infinity
72
+ D = np.full((M, N), np.inf)
73
+
74
+ # Initialize the first row and column of the matrix
75
+ D[0, 0] = cost_matrix[0, 0]
76
+ for i in range(1, M):
77
+ D[i, 0] = D[i - 1, 0] + cost_matrix[i, 0]
78
+
79
+ for j in range(1, N):
80
+ D[0, j] = D[0, j - 1] + cost_matrix[0, j]
81
+
82
+ # Fill the remaining elements of the matrix within the
83
+ # Sakoe-Chiba Band using dynamic programming
84
+ for i in range(1, M):
85
+ for j in range(max(1, i - R), min(N, i + R + 1)):
86
+ cost = cost_matrix[i, j]
87
+ D[i, j] = cost + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
88
+
89
+ # Backtrack to find the optimal path
90
+ path = [(M - 1, N - 1)]
91
+ i, j = M - 1, N - 1
92
+ while i > 0 or j > 0:
93
+ min_idx = np.argmin([D[i - 1, j], D[i, j - 1], D[i - 1, j - 1]])
94
+ if min_idx == 0:
95
+ i -= 1
96
+ elif min_idx == 1:
97
+ j -= 1
98
+ else:
99
+ i -= 1
100
+ j -= 1
101
+ path.append((i, j))
102
+ path.reverse()
103
+
104
+ return D[-1, -1], path
105
+
106
+ if __name__ == '__main__':
107
+
108
+ from mmengine.fileio import load
109
+
110
+ keypoints1, kp1_scores = load('tennis1.pkl')
111
+ keypoints2, kp2_scores = load('tennis3.pkl')
112
+
113
+ # Normalize the keypoints
114
+ dtw = DTWForKeypoints(keypoints1, keypoints2)
115
+ path = dtw.get_dtw_path()
116
+ print(path)
tools/inferencer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mmcv
3
+ from pathlib import Path
4
+ from collections import namedtuple
5
+ import cv2 as cv
6
+ from tqdm import tqdm
7
+ from mmengine.registry import init_default_scope
8
+ from mmengine.visualization import Visualizer
9
+ from mmpose.apis import inference_topdown, init_model
10
+ from mmdet.apis import inference_detector, init_detector
11
+ from .utils import filter_by_catgory, filter_by_score, Timer
12
+ from .apis import build_onnx_model_and_task_processor, inference_onnx_model
13
+
14
+
15
+ class PoseInferencer:
16
+ def __init__(self,
17
+ det_cfg,
18
+ pose_cfg,
19
+ device='cpu') -> None:
20
+ # init
21
+ self.det_model_cfg = det_cfg.model_cfg
22
+ self.det_model_ckpt = det_cfg.model_ckpt
23
+ self.pose_model_cfg = pose_cfg.model_cfg
24
+ self.pose_model_ckpt = pose_cfg.model_ckpt
25
+
26
+ self.detector = init_detector(self.det_model_cfg,
27
+ self.det_model_ckpt,
28
+ device=device)
29
+ self.pose_model = init_model(self.pose_model_cfg,
30
+ self.pose_model_ckpt,
31
+ device=device)
32
+
33
+ def process_one_image(self, img):
34
+ init_default_scope('mmdet')
35
+ det_result = inference_detector(self.detector, img)
36
+ det_inst = det_result.pred_instances.cpu().numpy()
37
+ bboxes, scores, labels = (det_inst.bboxes,
38
+ det_inst.scores,
39
+ det_inst.labels)
40
+ bboxes, scores, labels = filter_by_score(bboxes, scores,
41
+ labels, 0.5)
42
+ bboxes, scores, labels = filter_by_catgory(bboxes, scores, labels,
43
+ ['person'])
44
+ # inference with pose model
45
+ init_default_scope('mmpose')
46
+ pose_result = inference_topdown(self.pose_model, img, bboxes)
47
+ if len(pose_result) == 0:
48
+ # no detection place holder
49
+ keypoints = np.zeros((1, 17, 2))
50
+ pts_scores = np.zeros((1, 17))
51
+ bboxes = np.zeros((1, 4))
52
+ scores = np.zeros((1, ))
53
+ labels = np.zeros((1, ))
54
+ else:
55
+ keypoints = np.concatenate([r.pred_instances.keypoints
56
+ for r in pose_result])
57
+ pts_scores = np.concatenate([r.pred_instances.keypoint_scores
58
+ for r in pose_result])
59
+
60
+ DetInst = namedtuple('DetInst', ['bboxes', 'scores', 'labels'])
61
+ PoseInst = namedtuple('PoseInst', ['keypoints', 'pts_scores'])
62
+ return DetInst(bboxes, scores, labels), PoseInst(keypoints, pts_scores)
63
+
64
+ def inference_video(self, video_path):
65
+ """ Inference a video with detector and pose model
66
+ Return:
67
+ all_pose: a list of PoseInst, check the namedtuple definition
68
+ all_det: a list of DetInst
69
+ """
70
+ video_reader = mmcv.VideoReader(video_path)
71
+ all_pose, all_det = [], []
72
+
73
+ for frame in tqdm(video_reader):
74
+ # inference with detector
75
+ det, pose = self.process_one_image(frame)
76
+ all_pose.append(pose)
77
+ all_det.append(det)
78
+
79
+ return all_det, all_pose
80
+
81
+ class PoseInferencerV2:
82
+ """ V2 Use onnx for detection model, still use pytorch for pose model.
83
+ """
84
+ def __init__(self,
85
+ det_cfg,
86
+ pose_cfg,
87
+ device='cpu') -> None:
88
+ # init
89
+ self.det_deploy_cfg = det_cfg.deploy_cfg
90
+ self.det_model_cfg = det_cfg.model_cfg
91
+ self.det_backend_files = det_cfg.backend_files
92
+
93
+ self.pose_model_cfg = pose_cfg.model_cfg
94
+ self.pose_model_ckpt = pose_cfg.model_ckpt
95
+
96
+ self.detector, self.task_processor = \
97
+ build_onnx_model_and_task_processor(self.det_model_cfg,
98
+ self.det_deploy_cfg,
99
+ self.det_backend_files,
100
+ device)
101
+ self.pose_model = init_model(self.pose_model_cfg,
102
+ self.pose_model_ckpt,
103
+ device)
104
+
105
+ def process_one_image(self, img):
106
+ init_default_scope('mmdet')
107
+ det_result = inference_onnx_model(self.detector,
108
+ self.task_processor,
109
+ self.det_deploy_cfg,
110
+ img)
111
+ det_inst = det_result[0].pred_instances.cpu().numpy()
112
+ bboxes, scores, labels = (det_inst.bboxes,
113
+ det_inst.scores,
114
+ det_inst.labels)
115
+ bboxes, scores, labels = filter_by_score(bboxes, scores,
116
+ labels, 0.5)
117
+ bboxes, scores, labels = filter_by_catgory(bboxes, scores, labels,
118
+ ['person'])
119
+ # inference with pose model
120
+ init_default_scope('mmpose')
121
+ pose_result = inference_topdown(self.pose_model, img, bboxes)
122
+ if len(pose_result) == 0:
123
+ # no detection place holder
124
+ keypoints = np.zeros((1, 17, 2))
125
+ pts_scores = np.zeros((1, 17))
126
+ bboxes = np.zeros((1, 4))
127
+ scores = np.zeros((1, ))
128
+ labels = np.zeros((1, ))
129
+ else:
130
+ keypoints = np.concatenate([r.pred_instances.keypoints
131
+ for r in pose_result])
132
+ pts_scores = np.concatenate([r.pred_instances.keypoint_scores
133
+ for r in pose_result])
134
+
135
+ DetInst = namedtuple('DetInst', ['bboxes', 'scores', 'labels'])
136
+ PoseInst = namedtuple('PoseInst', ['keypoints', 'pts_scores'])
137
+ return DetInst(bboxes, scores, labels), PoseInst(keypoints, pts_scores)
138
+
139
+ def inference_video(self, video_path):
140
+ """ Inference a video with detector and pose model
141
+ Return:
142
+ all_pose: a list of PoseInst, check the namedtuple definition
143
+ all_det: a list of DetInst
144
+ """
145
+ video_reader = mmcv.VideoReader(video_path)
146
+ all_pose, all_det = [], []
147
+
148
+ for frame in tqdm(video_reader):
149
+ # inference with detector
150
+ det, pose = self.process_one_image(frame)
151
+ all_pose.append(pose)
152
+ all_det.append(det)
153
+
154
+ return all_det, all_pose
tools/manager.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mim
2
+ from pathlib import Path
3
+ from mim.utils import get_installed_path, echo_success
4
+ from mmengine.config import Config
5
+
6
+ class Manager:
7
+
8
+ def __init__(self, path=None) -> None:
9
+ """
10
+ Params:
11
+ - path: root path of projects to save checkpoints and configs
12
+ """
13
+ if path:
14
+ self.path = Path(path)
15
+ else:
16
+ self.path = Path(__file__).parents[1]
17
+ self.keys = ['weight', 'config', 'model', 'training_data']
18
+
19
+ def get_model_infos(self, package_name, keyword: str=None):
20
+ """ because mim search is too strict,
21
+ I want to search by keyword, not a strict match
22
+ """
23
+ model_infos = mim.get_model_info(package_name)
24
+ model_names = model_infos.index
25
+ info_keys = model_infos.columns.tolist()
26
+ keys = self.intersect_keys(info_keys,
27
+ self.keys)
28
+ if keyword is None:
29
+ return model_infos[:, keys]
30
+ # get valid names, which contains the keyword
31
+ valid_names = [name for name in model_names
32
+ if keyword in name]
33
+ filter_infos = model_infos.loc[valid_names, keys]
34
+ return filter_infos
35
+
36
+ def intersect_keys(self, keys1 , keys2):
37
+ return list(set(keys1) & set(keys2))
38
+
39
+ def download(self, package, model, config_only=False):
40
+ """ Use model names to download checkpoints and configs.
41
+ Args:
42
+ - package: package name, e.g. mmdet
43
+ - model: model name, e.g. faster_rcnn or faster_rcnn_r50_fpn_1x_coco
44
+ - config_only: only download configs, which is helpful when you
45
+ already download checkpoints fast through other ways.
46
+ """
47
+ infos = self.get_model_infos(package, model)
48
+
49
+ for model, info in infos.iterrows():
50
+ # get destination path
51
+ hyper_name = info['model']
52
+ dst_path = self.path / 'model_zoo' / hyper_name / model
53
+ dst_path.mkdir(parents=True, exist_ok=True)
54
+
55
+ if config_only:
56
+ # get config path of the package
57
+ installed_path = Path(get_installed_path(package))
58
+ config_path = info['config']
59
+ config_path = installed_path / '.mim' / config_path
60
+ # build and dump config
61
+ config_obj = Config.fromfile(config_path)
62
+ saved_config_path = dst_path / f'{model}.py'
63
+ config_obj.dump(saved_config_path)
64
+ echo_success(
65
+ f'Successfully dumped {model}.py to {dst_path}')
66
+ else:
67
+ mim.download(package, [model], dest_root=dst_path)
68
+
69
+ if __name__ == '__main__':
70
+ m = Manager()
71
+ print(m.get_model_infos('mmdet', 'det'))
72
+ # m.download('mmpose', 'rtmpose-t_8xb256-420e_aic-coco-256x192', config_only=True)
tools/utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmdet.datasets import CocoDataset
2
+ import time
3
+ from pathlib import Path
4
+ from ffmpy import FFmpeg
5
+ import shutil
6
+ import tempfile
7
+ from easydict import EasyDict
8
+ import numpy as np
9
+
10
+ def coco_keypoint_id_table(reverse=False):
11
+ id2name = { 0: 'nose',
12
+ 1: 'left_eye',
13
+ 2: 'right_eye',
14
+ 3: 'left_ear',
15
+ 4: 'right_ear',
16
+ 5: 'left_shoulder',
17
+ 6: 'right_shoulder',
18
+ 7: 'left_elbow',
19
+ 8: 'right_elbow',
20
+ 9: 'left_wrist',
21
+ 10: 'right_wrist',
22
+ 11: 'left_hip',
23
+ 12: 'right_hip',
24
+ 13: 'left_knee',
25
+ 14: 'right_knee',
26
+ 15: 'left_ankle',
27
+ 16: 'right_ankle'}
28
+ if reverse:
29
+ return {v: k for k, v in id2name.items()}
30
+ return id2name
31
+
32
+ def get_skeleton():
33
+ """ My skeleton links, I deleted some links from default coco style.
34
+ """
35
+ SKELETON = EasyDict()
36
+ SKELETON.head = [[0,1], [0,2], [1,3], [2,4]]
37
+ SKELETON.left_arm = [[5, 7], [7, 9]]
38
+ SKELETON.right_arm = [[6, 8], [8, 10]]
39
+ SKELETON.left_leg = [[11, 13], [13, 15]]
40
+ SKELETON.right_leg = [[12, 14], [14, 16]]
41
+ SKELETON.body = [[5, 6], [5, 11], [6, 12], [11, 12]]
42
+ return SKELETON
43
+
44
+ def get_keypoint_weight(low_weight_ratio=0.1, mid_weight_ratio=0.5):
45
+ """ Get keypoint weight, used in object keypoint similarity,
46
+ `low_weight_names` are points I want to pay less attention.
47
+ """
48
+ low_weight_names = ['nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear']
49
+ mid_weight_names = ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip']
50
+
51
+ logtis = np.ones(17)
52
+ name2id = coco_keypoint_id_table(reverse=True)
53
+
54
+ low_weight_id = [name2id[n] for n in low_weight_names]
55
+ mid_weight_id = [name2id[n] for n in mid_weight_names]
56
+ logtis[low_weight_id] = low_weight_ratio
57
+ logtis[mid_weight_id] = mid_weight_ratio
58
+
59
+ weights = logtis / np.sum(logtis)
60
+ return weights
61
+
62
+ def coco_cat_id_table():
63
+ classes = CocoDataset.METAINFO['classes']
64
+ id2name = {i: name for i, name in enumerate(classes)}
65
+
66
+ return id2name
67
+
68
+ def filter_by_catgory(bboxes, scores, labels, names):
69
+ """ Filter labels by classes
70
+ Args:
71
+ - labels: list of labels, each label is a dict
72
+ - classes: list of class names
73
+ """
74
+ id2name = coco_cat_id_table()
75
+ # names of labels
76
+ label_names = [id2name[id] for id in labels]
77
+ # filter by class names
78
+ mask = np.isin(label_names, names)
79
+ return bboxes[mask], scores[mask], labels[mask]
80
+
81
+ def filter_by_score(bboxes, scores, labels, score_thr):
82
+ """ Filter bboxes by score threshold
83
+ Args:
84
+ - bboxes: list of bboxes, each bbox is a dict
85
+ - score_thr: score threshold
86
+ """
87
+ mask = scores > score_thr
88
+ return bboxes[mask], scores[mask], labels[mask]
89
+
90
+ def convert_video_to_playable_mp4(video_path: str) -> str:
91
+ """ Copied from gradio
92
+ Convert the video to mp4. If something goes wrong return the original video.
93
+ """
94
+ try:
95
+ output_path = Path(video_path).with_suffix(".mp4")
96
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
97
+ shutil.copy2(video_path, tmp_file.name)
98
+ # ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
99
+ ff = FFmpeg(
100
+ inputs={str(tmp_file.name): None},
101
+ outputs={str(output_path): None},
102
+ global_options="-y -loglevel quiet",
103
+ )
104
+ ff.run()
105
+ except:
106
+ print(f"Error converting video to browser-playable format {str(e)}")
107
+ output_path = video_path
108
+ return str(output_path)
109
+
110
+ class Timer:
111
+ def __init__(self):
112
+ self.start_time = time.time()
113
+
114
+ def click(self):
115
+ used_time = time.time() - self.start_time
116
+ self.start_time = time.time()
117
+ return used_time
118
+
119
+ def start(self):
120
+ self.start_time = time.time()
tools/visualizer.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from skimage import draw, io
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from easydict import EasyDict
6
+ from typing import Union
7
+ from .utils import get_skeleton, Timer
8
+
9
+ class FastVisualizer:
10
+ """ Use skimage to draw, which is much faster than matplotlib, and
11
+ more beatiful than opencv.😎
12
+ """
13
+ # TODO: modify color input parameter
14
+ def __init__(self, image=None) -> None:
15
+ self.set_image(image)
16
+ self.colors = self.get_pallete()
17
+ self.skeleton = get_skeleton()
18
+ self.lvl_tresh = self.set_level([0.3, 0.6, 0.8])
19
+
20
+ def set_image(self, image: Union[str, np.ndarray]):
21
+ if isinstance(image, str):
22
+ self.image = cv2.imread(image)
23
+ elif isinstance(image, np.ndarray) or image is None:
24
+ self.image = image
25
+ else:
26
+ raise TypeError(f"Type {type(image)} is not supported")
27
+
28
+ def get_image(self):
29
+ return self.image
30
+
31
+ def draw_box(self, box_coord, color=(25, 113, 194), alpha=1.0):
32
+ """ Draw a box on the image
33
+ Args:
34
+ box_coord: a list of [xmin, ymin, xmax, ymax]
35
+ alpha: the alpha of the box
36
+ color: the edge color of the box
37
+ """
38
+ xmin, ymin, xmax, ymax = box_coord
39
+ rr, cc = draw.rectangle_perimeter((ymin, xmin), (ymax, xmax))
40
+ draw.set_color(self.image, (rr, cc), color, alpha=alpha)
41
+ return self
42
+
43
+ def draw_rectangle(self, box_coord, color=(25, 113, 194), alpha=1.0):
44
+ xmin, ymin, xmax, ymax = box_coord
45
+ rr, cc = draw.rectangle((ymin, xmin), (ymax, xmax))
46
+ draw.set_color(self.image, (rr, cc), color, alpha=alpha)
47
+ return self
48
+
49
+ def draw_point(self, point_coord, radius=5, color=(25, 113, 194), alpha=1.0):
50
+ """ Coord in (x, y) format, but will be converted to (y, x)
51
+ """
52
+ x, y = point_coord
53
+ rr, cc = draw.disk((y, x), radius=radius)
54
+ draw.set_color(self.image, (rr, cc), color, alpha=alpha)
55
+ return self
56
+
57
+ def draw_line(self, start_point, end_point, color=(25, 113, 194), alpha=1.0):
58
+ """ Not used, because I can't produce smooth line.
59
+ """
60
+ cv2.line(self.image, start_point, end_point, color.tolist(), 2,
61
+ cv2.LINE_AA)
62
+ return self
63
+
64
+ def draw_line_aa(self, start_point, end_point, color=(25, 113, 194), alpha=1.0):
65
+ """ Not used, because I can't produce smooth line.
66
+ """
67
+ x1, y1 = start_point
68
+ x2, y2 = end_point
69
+ rr, cc, val = draw.line_aa(y1, x1, y2, x2)
70
+ draw.set_color(self.image, (rr, cc), color, alpha=alpha)
71
+ return self
72
+
73
+ def draw_thick_line(self, start_point, end_point, thickness=1, color=(25, 113, 194), alpha=1.0):
74
+ """ Not used, because I can't produce smooth line.
75
+ """
76
+ x1, y1 = start_point
77
+ x2, y2 = end_point
78
+ dx, dy = x2 - x1, y2 - y1
79
+ length = np.sqrt(dx * dx + dy * dy)
80
+ cos, sin = dx / length, dy / length
81
+
82
+ half_t = thickness / 2.0
83
+ # Calculate the polygon vertices
84
+ vertices_x = [x1 - half_t * sin, x1 + half_t * sin,
85
+ x2 + half_t * sin, x2 - half_t * sin]
86
+ vertices_y = [y1 + half_t * cos, y1 - half_t * cos,
87
+ y2 - half_t * cos, y2 + half_t * cos]
88
+ rr, cc = draw.polygon(vertices_y, vertices_x)
89
+ draw.set_color(self.image, (rr, cc), color, alpha)
90
+
91
+ return self
92
+
93
+ def draw_text(self, text, position,
94
+ font_path='assets/SmileySans/SmileySans-Oblique.ttf',
95
+ font_size=20,
96
+ text_color=(255, 255, 255)):
97
+ """ Position is the left top corner of the text
98
+ """
99
+ # Convert the NumPy array to a PIL image
100
+ pil_image = Image.fromarray(np.uint8(self.image))
101
+ # Load the font (default is Arial)
102
+ font = ImageFont.truetype(font_path, font_size)
103
+ # Create a drawing object
104
+ draw = ImageDraw.Draw(pil_image)
105
+ # Add the text to the image
106
+ draw.text(position, text, font=font, fill=text_color)
107
+ # Convert the PIL image back to a NumPy array
108
+ result = np.array(pil_image)
109
+
110
+ self.image = result
111
+ return self
112
+
113
+ def xyhw_to_xyxy(self, box):
114
+ hw = box[2:]
115
+ x1y1 = box[:2] - hw / 2
116
+ x2y2 = box[:2] + hw / 2
117
+ return np.concatenate([x1y1, x2y2]).astype(np.int32)
118
+
119
+ def draw_line_in_discrete_style(self, start_point, end_point, size=2, sample_points=3,
120
+ color=(25, 113, 194), alpha=1.0):
121
+ """ When drawing continous line, it is super fuzzy, and I can't handle them
122
+ very well even tried OpneCV & PIL all kinds of ways. This is a workaround.
123
+ The discrete line will be represented with few sampled cubes along the line,
124
+ and it is exclusive with start & end points.
125
+ """
126
+ # sample points
127
+ points = np.linspace(start_point, end_point, sample_points + 2)[1:-1]
128
+ for p in points:
129
+ rectangle_xyhw = np.array((p[0], p[1], size, size))
130
+ rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
131
+ self.draw_rectangle(rectangle_xyxy, color, alpha)
132
+ return self
133
+
134
+ def draw_human_keypoints(self, keypoints, scores=None, factor=20, draw_skeleton=False):
135
+ """ Draw skeleton on the image, and give different color according
136
+ to similarity scores.
137
+ """
138
+ # get max length of skeleton
139
+ max_x, max_y = np.max(keypoints, axis=0)
140
+ min_x, min_y = np.min(keypoints, axis=0)
141
+ max_length = max(max_x - min_x, max_y - min_y)
142
+ if max_length < 1: return self
143
+ cube_size = max_length // factor
144
+ line_cube_size = cube_size // 2
145
+ # draw skeleton in discrete style
146
+ if draw_skeleton:
147
+ for key, links in self.skeleton.items():
148
+ links = np.array(links)
149
+ start_points = keypoints[links[:, 0]]
150
+ end_points = keypoints[links[:, 1]]
151
+ for s, e in zip(start_points, end_points):
152
+ self.draw_line_in_discrete_style(s, e, line_cube_size,
153
+ color=self.colors[key], alpha=0.9)
154
+ # draw points
155
+ if scores is None: # use vamos color
156
+ lvl_names = ['vamos'] * len(keypoints)
157
+ else: lvl_names = self.score_level_names(scores)
158
+
159
+ for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
160
+ if idx in set((1, 2, 3, 4)):
161
+ continue # do not draw eyes and years
162
+ rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
163
+ rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
164
+ self.draw_rectangle(rectangle_xyxy,
165
+ color=self.colors[lvl_name],
166
+ alpha=0.8)
167
+ return self
168
+
169
+ def draw_score_bar(self, score, factor=50, bar_ratio=7):
170
+ """ Draw a score bar on the left top of the image.
171
+ factor: the value of image longer edge divided by the bar height
172
+ bar_ratio: the ratio of bar width to bar height
173
+ """
174
+ # calculate bar's height and width
175
+ long_edge = np.max(self.image.shape[:2])
176
+ short_edge = np.min(self.image.shape[:2])
177
+ bar_h = long_edge // factor
178
+ bar_w = bar_h * bar_ratio
179
+ if bar_w * 3 > short_edge:
180
+ # when the image width is not enough
181
+ bar_w = short_edge // 4
182
+ bar_h = bar_w // bar_ratio
183
+ cube_size = bar_h
184
+ # bar's base position
185
+ bar_start_point = (2*bar_h, 2*bar_h)
186
+ # draw bar horizontally, and record the position of each word
187
+ word_positions = []
188
+ box_coords = []
189
+ colors = [self.colors.bad, self.colors.good, self.colors.vamos]
190
+ for i, color in enumerate(colors):
191
+ x0, y0 = bar_start_point[0] + i*bar_w, bar_start_point[1]
192
+ x1, y1 = x0 + bar_w - 1, y0 + bar_h
193
+ box_coord = np.array((x0, y0, x1, y1), dtype=np.int32)
194
+ self.draw_rectangle(box_coord, color=color)
195
+
196
+ box_coords.append(box_coord)
197
+ word_positions.append(np.array((x0, y1 + bar_h // 2)))
198
+ # calculate cube position according to score
199
+ lvl, lvl_ratio, lvl_name = self.score_level(score)
200
+ # the first level start point is the first bar
201
+ cube_lvl_start_x0 = [box_coord[0] - cube_size // 2 if i != 0
202
+ else box_coord[0]
203
+ for i, box_coord in enumerate(box_coords)]
204
+ # process the last level, I want the cube stays in the bar
205
+ level_length = bar_w if lvl == 1 else bar_w - cube_size // 2
206
+ cube_x0 = cube_lvl_start_x0[lvl] + lvl_ratio * level_length
207
+ cube_y0 = bar_start_point[1] - bar_h // 2 - cube_size
208
+ cube_x1 = cube_x0 + cube_size
209
+ cube_y1 = cube_y0 + cube_size
210
+ # draw cube
211
+ self.draw_rectangle((cube_x0, cube_y0, cube_x1, cube_y1),
212
+ color=self.colors.cube)
213
+ # enlarge the box, to emphasize the level
214
+ enlarged_box = box_coords[lvl].copy()
215
+ enlarged_box[:2] = enlarged_box[:2] - bar_h // 8
216
+ enlarged_box[2:] = enlarged_box[2:] + bar_h // 8
217
+ self.draw_rectangle(enlarged_box, color=self.colors[lvl_name])
218
+
219
+ # draw text
220
+ if lvl_name == 'vamos':
221
+ lvl_name = 'vamos!!' # exciting!
222
+ self.draw_text(lvl_name.capitalize(),
223
+ word_positions[lvl],
224
+ font_size=bar_h * 2,
225
+ text_color=tuple(colors[lvl].tolist()))
226
+
227
+ return self
228
+
229
+ def draw_non_transparent_area(self, box_coord, alpha=0.2, extend_ratio=0.1):
230
+ """ Make image outside the box transparent using alpha blend
231
+ """
232
+ x1, y1, x2, y2 = box_coord.astype(np.int32)
233
+ # enlarge the box for 10%
234
+ max_len = max((x2 - x1), (y2 - y1))
235
+ extend_len = int(max_len * extend_ratio)
236
+ x1, y1 = x1 - extend_len, y1 - extend_len
237
+ x2, y2 = x2 + extend_len, y2 + extend_len
238
+ # clip the box
239
+ h, w = self.image.shape[:2]
240
+ x1, y1, x2, y2 = np.clip((x1,y1,x2,y2), a_min=0,
241
+ a_max=(w,h,w,h))
242
+ # Create a white background color
243
+ bg_color = np.ones_like(self.image) * 255
244
+ # Copy the box region from the image
245
+ bg_color[y1:y2, x1:x2] = self.image[y1:y2, x1:x2]
246
+ # Alpha blend inplace
247
+ self.image[:] = self.image * alpha + bg_color * (1 - alpha)
248
+ return self
249
+
250
+ def draw_logo(self, logo='assets/logo.png', factor=30, shift=20):
251
+ """ Draw logo on the right bottom of the image.
252
+ """
253
+ H, W = self.image.shape[:2]
254
+ # load logo
255
+ logo_img = Image.open(logo)
256
+ # scale logo
257
+ logo_h = self.image.shape[0] // factor
258
+ scale_size = logo_h / logo_img.size[1]
259
+ logo_w = int(logo_img.size[0] * scale_size)
260
+ logo_img = logo_img.resize((logo_w, logo_h))
261
+ # convert to RGBA
262
+ image = Image.fromarray(self.image).convert("RGBA")
263
+ # alpha blend
264
+ image.alpha_composite(logo_img, (W - logo_w - shift,
265
+ H - logo_h - shift))
266
+ self.image = np.array(image.convert("RGB"))
267
+ return self
268
+
269
+ def score_level(self, score):
270
+ """ Return the level according to level thresh.
271
+ """
272
+ t = self.lvl_tresh
273
+ if score < t[1]: # t[0] might bigger than 0
274
+ ratio = (score - t[0]) / (t[1] - t[0])
275
+ ratio = np.clip(ratio, a_min=0, a_max=1)
276
+ return 0, ratio, 'bad'
277
+ elif score < t[2]:
278
+ ratio = (score - t[1]) / (t[2] - t[1])
279
+ return 1, ratio, 'good'
280
+ else:
281
+ ratio = (score - t[2]) / (1 - t[2])
282
+ return 2, ratio, 'vamos'
283
+
284
+ def score_level_names(self, scores):
285
+ """ Get multiple score level, return numpy array.
286
+ np.vectorize does not speed up loop, but it is convenient.
287
+ """
288
+ t = self.lvl_tresh
289
+ func_lvl_name = lambda x: 'bad' if x < t[1] else 'good' \
290
+ if x < t[2] else 'vamos'
291
+ lvl_names = np.vectorize(func_lvl_name)(scores)
292
+ return lvl_names
293
+
294
+ def set_level(self, thresh):
295
+ """ Set level thresh for bad, good, vamos.
296
+ """
297
+ from collections import namedtuple
298
+ Level = namedtuple('Level', ['zero', 'good', 'vamos'])
299
+ return Level(thresh[0], thresh[1], thresh[2])
300
+
301
+ def get_pallete(self):
302
+ PALLETE = EasyDict()
303
+
304
+ # light set
305
+ # PALLETE.bad = np.array([253, 138, 138])
306
+ # PALLETE.good = np.array([168, 209, 209])
307
+ # PALLETE.vamos = np.array([241, 247, 181])
308
+ # PALLETE.cube = np.array([158, 161, 212])
309
+
310
+ # dark set, set 80% brightness
311
+ PALLETE.bad = np.array([204, 111, 111])
312
+ PALLETE.good = np.array([143, 179, 179])
313
+ PALLETE.vamos = np.array([196, 204, 124])
314
+ PALLETE.vamos = np.array([109, 169, 228])
315
+ PALLETE.cube = np.array([152, 155, 204])
316
+
317
+ PALLETE.left_arm = np.array([218, 119, 242])
318
+ PALLETE.right_arm = np.array([151, 117, 250])
319
+ PALLETE.left_leg = np.array([255, 212, 59])
320
+ PALLETE.right_leg = np.array([255, 169, 77])
321
+
322
+ PALLETE.head = np.array([134, 142, 150])
323
+ PALLETE.body = np.array([134, 142, 150])
324
+
325
+ # convert rgb to bgr
326
+ for k, v in PALLETE.items():
327
+ PALLETE[k] = v[::-1]
328
+ return PALLETE
329
+
330
+ if __name__ == '__main__':
331
+ vis = FastVisualizer()
332
+
333
+ image = '/github/Tennis.ai/assets/tempt_test.png'
334
+ vis.set_image(image)
335
+ np.random.seed(0)
336
+ keypoints = np.random.randint(300, 600, (17, 2))
337
+ from utils import Timer
338
+ t= Timer()
339
+ t.start()
340
+ vis.draw_score_bar(0.94)
341
+ # vis.draw_skeleton(keypoints)
342
+ # vis.draw_non_transparent_area((0, 0, 100, 100), alpha=0.2)
343
+ vis.draw_logo()
344
+ cv2.imshow('test', vis.image)
345
+ cv2.waitKey(0)
346
+ cv2.destroyAllWindows()