update
Browse files- .gitignore +7 -0
- app.py +109 -0
- assets/logo.png +0 -0
- assets/onnx_test.jpg +0 -0
- configs/mark1.py +9 -0
- configs/mark2.py +10 -0
- convert_det.sh +8 -0
- main.py +96 -0
- model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py +23 -0
- model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py +345 -0
- model_zoo/rtmpose/rtmpose-m_8xb256-420e_aic-coco-256x192/rtmpose-m_8xb256-420e_aic-coco-256x192.py +391 -0
- model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py +385 -0
- requirements.txt +12 -0
- tools/apis.py +90 -0
- tools/deploy.py +236 -0
- tools/dtw.py +116 -0
- tools/inferencer.py +154 -0
- tools/manager.py +72 -0
- tools/utils.py +120 -0
- tools/visualizer.py +346 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pth
|
2 |
+
*.pkl
|
3 |
+
*.mp4
|
4 |
+
*.onnx
|
5 |
+
*.ttf
|
6 |
+
tempt*
|
7 |
+
**pycache**
|
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference 2 videos and use dtw to match the pose keypoints.
|
2 |
+
from tools.inferencer import PoseInferencerV2
|
3 |
+
from tools.dtw import DTWForKeypoints
|
4 |
+
from tools.visualizer import FastVisualizer
|
5 |
+
from tools.utils import convert_video_to_playable_mp4
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pathlib import Path
|
8 |
+
from tqdm import tqdm
|
9 |
+
import mmengine
|
10 |
+
import numpy as np
|
11 |
+
import mmcv
|
12 |
+
import cv2
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = ArgumentParser()
|
17 |
+
parser.add_argument('--config', type=str, default='configs/mark2.py')
|
18 |
+
parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
|
19 |
+
parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
def concat(img1, img2, height=1080):
|
23 |
+
w1, h1, _ = img1.shape
|
24 |
+
w2, h2, _ = img2.shape
|
25 |
+
|
26 |
+
# Calculate the scaling factor for each image
|
27 |
+
scale1 = height / img1.shape[1]
|
28 |
+
scale2 = height / img2.shape[1]
|
29 |
+
|
30 |
+
# Resize the images
|
31 |
+
img1 = cv2.resize(img1, (int(h1*scale1), int(w1*scale1)))
|
32 |
+
img2 = cv2.resize(img2, (int(h2*scale2), int(w2*scale2)))
|
33 |
+
|
34 |
+
# Concatenate the images horizontally
|
35 |
+
image = cv2.hconcat([img1, img2])
|
36 |
+
return image
|
37 |
+
|
38 |
+
def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm, draw_score_bar=True):
|
39 |
+
vis.set_image(img)
|
40 |
+
vis.draw_non_transparent_area(box)
|
41 |
+
if draw_score_bar:
|
42 |
+
vis.draw_score_bar(oks)
|
43 |
+
vis.draw_human_keypoints(keypoint, oks_unnorm)
|
44 |
+
return vis.get_image()
|
45 |
+
|
46 |
+
def main(video1, video2):
|
47 |
+
# build PoseInferencerV2
|
48 |
+
config = 'configs/mark2.py'
|
49 |
+
cfg = mmengine.Config.fromfile(config)
|
50 |
+
pose_inferencer = PoseInferencerV2(
|
51 |
+
cfg.det_cfg,
|
52 |
+
cfg.pose_cfg,
|
53 |
+
device='cpu')
|
54 |
+
|
55 |
+
v1 = mmcv.VideoReader(video1)
|
56 |
+
v2 = mmcv.VideoReader(video2)
|
57 |
+
video_writer = None
|
58 |
+
|
59 |
+
all_det1, all_pose1 = pose_inferencer.inference_video(video1)
|
60 |
+
all_det2, all_pose2 = pose_inferencer.inference_video(video2)
|
61 |
+
|
62 |
+
keypoints1 = np.stack([p.keypoints[0] for p in all_pose1]) # forced the first pred
|
63 |
+
keypoints2 = np.stack([p.keypoints[0] for p in all_pose2])
|
64 |
+
boxes1 = np.stack([d.bboxes[0] for d in all_det1])
|
65 |
+
boxes2 = np.stack([d.bboxes[0] for d in all_det2])
|
66 |
+
|
67 |
+
dtw_path, oks, oks_unnorm = DTWForKeypoints(keypoints1, keypoints2).get_dtw_path()
|
68 |
+
|
69 |
+
vis = FastVisualizer()
|
70 |
+
|
71 |
+
for i, j in tqdm(dtw_path):
|
72 |
+
frame1 = v1[i]
|
73 |
+
frame2 = v2[j]
|
74 |
+
|
75 |
+
frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
|
76 |
+
oks[i, j], oks_unnorm[i, j])
|
77 |
+
frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
|
78 |
+
oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
|
79 |
+
# concate two frames
|
80 |
+
frame = concat(frame1_, frame2_)
|
81 |
+
# draw logo
|
82 |
+
vis.set_image(frame)
|
83 |
+
frame = vis.draw_logo().get_image()
|
84 |
+
# write video
|
85 |
+
w, h = frame.shape[1], frame.shape[0]
|
86 |
+
if video_writer is None:
|
87 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
88 |
+
video_writer = cv2.VideoWriter('dtw_compare.mp4',
|
89 |
+
fourcc, v1.fps, (w, h))
|
90 |
+
video_writer.write(frame)
|
91 |
+
video_writer.release()
|
92 |
+
# output video file
|
93 |
+
convert_video_to_playable_mp4('dtw_compare.mp4')
|
94 |
+
output = str(Path('dtw_compare.mp4').resolve())
|
95 |
+
return output
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
config = 'configs/mark2.py'
|
99 |
+
cfg = mmengine.Config.fromfile(config)
|
100 |
+
|
101 |
+
inputs = [
|
102 |
+
gr.Video(label="Input video 1"),
|
103 |
+
gr.Video(label="Input video 2")
|
104 |
+
]
|
105 |
+
|
106 |
+
output = gr.Video(label="Output video")
|
107 |
+
|
108 |
+
demo = gr.Interface(fn=main, inputs=inputs, outputs=output).queue()
|
109 |
+
demo.launch(share=True)
|
assets/logo.png
ADDED
![]() |
assets/onnx_test.jpg
ADDED
![]() |
configs/mark1.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
det_cfg = dict(
|
2 |
+
model_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py',
|
3 |
+
model_ckpt='/github/Tennis.ai/model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'
|
4 |
+
)
|
5 |
+
|
6 |
+
pose_cfg = dict(
|
7 |
+
model_cfg='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py',
|
8 |
+
model_ckpt='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
|
9 |
+
)
|
configs/mark2.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
det_cfg = dict(
|
2 |
+
deploy_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py',
|
3 |
+
model_cfg='model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py',
|
4 |
+
backend_files=['model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/end2end.onnx']
|
5 |
+
)
|
6 |
+
|
7 |
+
pose_cfg = dict(
|
8 |
+
model_cfg='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py',
|
9 |
+
model_ckpt='model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
|
10 |
+
)
|
convert_det.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python tools/deploy.py \
|
2 |
+
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py \
|
3 |
+
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py \
|
4 |
+
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth \
|
5 |
+
assets/onnx_test.jpg \
|
6 |
+
--work-dir model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco \
|
7 |
+
--device cpu \
|
8 |
+
--show
|
main.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference 2 videos and use dtw to match the pose keypoints.
|
2 |
+
from tools.inferencer import PoseInferencerV2
|
3 |
+
from tools.dtw import DTWForKeypoints
|
4 |
+
from tools.visualizer import FastVisualizer
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from tools.utils import convert_video_to_playable_mp4
|
7 |
+
from tqdm import tqdm
|
8 |
+
import mmengine
|
9 |
+
import numpy as np
|
10 |
+
import mmcv
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument('--config', type=str, default='configs/mark2.py')
|
16 |
+
parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
|
17 |
+
parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
|
18 |
+
return parser.parse_args()
|
19 |
+
|
20 |
+
def concat(img1, img2, height=1080):
|
21 |
+
w1, h1, _ = img1.shape
|
22 |
+
w2, h2, _ = img2.shape
|
23 |
+
|
24 |
+
# Calculate the scaling factor for each image
|
25 |
+
scale1 = height / img1.shape[1]
|
26 |
+
scale2 = height / img2.shape[1]
|
27 |
+
|
28 |
+
# Resize the images
|
29 |
+
img1 = cv2.resize(img1, (int(h1*scale1), int(w1*scale1)))
|
30 |
+
img2 = cv2.resize(img2, (int(h2*scale2), int(w2*scale2)))
|
31 |
+
|
32 |
+
# Concatenate the images horizontally
|
33 |
+
image = cv2.hconcat([img1, img2])
|
34 |
+
return image
|
35 |
+
|
36 |
+
def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm, draw_score_bar=True):
|
37 |
+
vis.set_image(img)
|
38 |
+
vis.draw_non_transparent_area(box)
|
39 |
+
if draw_score_bar:
|
40 |
+
vis.draw_score_bar(oks)
|
41 |
+
vis.draw_human_keypoints(keypoint, oks_unnorm)
|
42 |
+
return vis.get_image()
|
43 |
+
|
44 |
+
def main(cfg):
|
45 |
+
# build PoseInferencerV2
|
46 |
+
pose_inferencer = PoseInferencerV2(
|
47 |
+
cfg.det_cfg,
|
48 |
+
cfg.pose_cfg,
|
49 |
+
device='cpu')
|
50 |
+
|
51 |
+
v1 = mmcv.VideoReader(cfg.video1)
|
52 |
+
v2 = mmcv.VideoReader(cfg.video2)
|
53 |
+
video_writer = None
|
54 |
+
|
55 |
+
all_det1, all_pose1 = pose_inferencer.inference_video(cfg.video1)
|
56 |
+
all_det2, all_pose2 = pose_inferencer.inference_video(cfg.video2)
|
57 |
+
|
58 |
+
keypoints1 = np.stack([p.keypoints[0] for p in all_pose1]) # forced the first pred
|
59 |
+
keypoints2 = np.stack([p.keypoints[0] for p in all_pose2])
|
60 |
+
boxes1 = np.stack([d.bboxes[0] for d in all_det1])
|
61 |
+
boxes2 = np.stack([d.bboxes[0] for d in all_det2])
|
62 |
+
|
63 |
+
dtw_path, oks, oks_unnorm = DTWForKeypoints(keypoints1, keypoints2).get_dtw_path()
|
64 |
+
|
65 |
+
vis = FastVisualizer()
|
66 |
+
|
67 |
+
for i, j in tqdm(dtw_path):
|
68 |
+
frame1 = v1[i]
|
69 |
+
frame2 = v2[j]
|
70 |
+
|
71 |
+
frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
|
72 |
+
oks[i, j], oks_unnorm[i, j])
|
73 |
+
frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
|
74 |
+
oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
|
75 |
+
# concate two frames
|
76 |
+
frame = concat(frame1_, frame2_)
|
77 |
+
# draw logo
|
78 |
+
vis.set_image(frame)
|
79 |
+
frame = vis.draw_logo().get_image()
|
80 |
+
# write video
|
81 |
+
w, h = frame.shape[1], frame.shape[0]
|
82 |
+
if video_writer is None:
|
83 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
84 |
+
video_writer = cv2.VideoWriter('dtw_compare.mp4',
|
85 |
+
fourcc, v1.fps, (w, h))
|
86 |
+
video_writer.write(frame)
|
87 |
+
video_writer.release()
|
88 |
+
convert_video_to_playable_mp4('dtw_compare.mp4')
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
args = parse_args()
|
92 |
+
cfg = mmengine.Config.fromfile(args.config)
|
93 |
+
cfg.video1 = args.video1
|
94 |
+
cfg.video2 = args.video2
|
95 |
+
|
96 |
+
main(cfg)
|
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/detection_onnxruntime_static.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
onnx_config = dict(
|
2 |
+
type='onnx',
|
3 |
+
export_params=True,
|
4 |
+
keep_initializers_as_inputs=False,
|
5 |
+
opset_version=11,
|
6 |
+
save_file='end2end.onnx',
|
7 |
+
input_names=['input'],
|
8 |
+
output_names=['dets', 'labels'],
|
9 |
+
input_shape=None,
|
10 |
+
optimize=True)
|
11 |
+
codebase_config = dict(
|
12 |
+
type='mmdet',
|
13 |
+
task='ObjectDetection',
|
14 |
+
model_type='end2end',
|
15 |
+
post_processing=dict(
|
16 |
+
score_threshold=0.05,
|
17 |
+
confidence_threshold=0.005,
|
18 |
+
iou_threshold=0.5,
|
19 |
+
max_output_boxes_per_class=200,
|
20 |
+
pre_top_k=5000,
|
21 |
+
keep_top_k=100,
|
22 |
+
background_label_id=-1))
|
23 |
+
backend_config = dict(type='onnxruntime')
|
model_zoo/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_scope = 'mmdet'
|
2 |
+
default_hooks = dict(
|
3 |
+
timer=dict(type='IterTimerHook'),
|
4 |
+
logger=dict(type='LoggerHook', interval=50),
|
5 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
6 |
+
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3),
|
7 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
8 |
+
visualization=dict(type='DetVisualizationHook'))
|
9 |
+
env_cfg = dict(
|
10 |
+
cudnn_benchmark=False,
|
11 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
12 |
+
dist_cfg=dict(backend='nccl'))
|
13 |
+
vis_backends = [dict(type='LocalVisBackend')]
|
14 |
+
visualizer = dict(
|
15 |
+
type='DetLocalVisualizer',
|
16 |
+
vis_backends=[dict(type='LocalVisBackend')],
|
17 |
+
name='visualizer')
|
18 |
+
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
|
19 |
+
log_level = 'INFO'
|
20 |
+
load_from = None
|
21 |
+
resume = False
|
22 |
+
train_cfg = dict(
|
23 |
+
type='EpochBasedTrainLoop',
|
24 |
+
max_epochs=300,
|
25 |
+
val_interval=10,
|
26 |
+
dynamic_intervals=[(280, 1)])
|
27 |
+
val_cfg = dict(type='ValLoop')
|
28 |
+
test_cfg = dict(type='TestLoop')
|
29 |
+
param_scheduler = [
|
30 |
+
dict(
|
31 |
+
type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
|
32 |
+
end=1000),
|
33 |
+
dict(
|
34 |
+
type='CosineAnnealingLR',
|
35 |
+
eta_min=0.0002,
|
36 |
+
begin=150,
|
37 |
+
end=300,
|
38 |
+
T_max=150,
|
39 |
+
by_epoch=True,
|
40 |
+
convert_to_iter_based=True)
|
41 |
+
]
|
42 |
+
optim_wrapper = dict(
|
43 |
+
type='OptimWrapper',
|
44 |
+
optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.05),
|
45 |
+
paramwise_cfg=dict(
|
46 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
47 |
+
auto_scale_lr = dict(enable=False, base_batch_size=16)
|
48 |
+
dataset_type = 'CocoDataset'
|
49 |
+
data_root = 'data/coco/'
|
50 |
+
backend_args = None
|
51 |
+
train_pipeline = [
|
52 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
53 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
54 |
+
dict(
|
55 |
+
type='CachedMosaic',
|
56 |
+
img_scale=(640, 640),
|
57 |
+
pad_val=114.0,
|
58 |
+
max_cached_images=20,
|
59 |
+
random_pop=False),
|
60 |
+
dict(
|
61 |
+
type='RandomResize',
|
62 |
+
scale=(1280, 1280),
|
63 |
+
ratio_range=(0.5, 2.0),
|
64 |
+
keep_ratio=True),
|
65 |
+
dict(type='RandomCrop', crop_size=(640, 640)),
|
66 |
+
dict(type='YOLOXHSVRandomAug'),
|
67 |
+
dict(type='RandomFlip', prob=0.5),
|
68 |
+
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
|
69 |
+
dict(
|
70 |
+
type='CachedMixUp',
|
71 |
+
img_scale=(640, 640),
|
72 |
+
ratio_range=(1.0, 1.0),
|
73 |
+
max_cached_images=10,
|
74 |
+
random_pop=False,
|
75 |
+
pad_val=(114, 114, 114),
|
76 |
+
prob=0.5),
|
77 |
+
dict(type='PackDetInputs')
|
78 |
+
]
|
79 |
+
test_pipeline = [
|
80 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
81 |
+
dict(type='Resize', scale=(640, 640), keep_ratio=True),
|
82 |
+
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
|
83 |
+
dict(
|
84 |
+
type='PackDetInputs',
|
85 |
+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
86 |
+
'scale_factor'))
|
87 |
+
]
|
88 |
+
train_dataloader = dict(
|
89 |
+
batch_size=32,
|
90 |
+
num_workers=10,
|
91 |
+
persistent_workers=True,
|
92 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
93 |
+
batch_sampler=None,
|
94 |
+
dataset=dict(
|
95 |
+
type='CocoDataset',
|
96 |
+
data_root='data/coco/',
|
97 |
+
ann_file='annotations/instances_train2017.json',
|
98 |
+
data_prefix=dict(img='train2017/'),
|
99 |
+
filter_cfg=dict(filter_empty_gt=True, min_size=32),
|
100 |
+
pipeline=[
|
101 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
102 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
103 |
+
dict(
|
104 |
+
type='CachedMosaic',
|
105 |
+
img_scale=(640, 640),
|
106 |
+
pad_val=114.0,
|
107 |
+
max_cached_images=20,
|
108 |
+
random_pop=False),
|
109 |
+
dict(
|
110 |
+
type='RandomResize',
|
111 |
+
scale=(1280, 1280),
|
112 |
+
ratio_range=(0.5, 2.0),
|
113 |
+
keep_ratio=True),
|
114 |
+
dict(type='RandomCrop', crop_size=(640, 640)),
|
115 |
+
dict(type='YOLOXHSVRandomAug'),
|
116 |
+
dict(type='RandomFlip', prob=0.5),
|
117 |
+
dict(
|
118 |
+
type='Pad', size=(640, 640),
|
119 |
+
pad_val=dict(img=(114, 114, 114))),
|
120 |
+
dict(
|
121 |
+
type='CachedMixUp',
|
122 |
+
img_scale=(640, 640),
|
123 |
+
ratio_range=(1.0, 1.0),
|
124 |
+
max_cached_images=10,
|
125 |
+
random_pop=False,
|
126 |
+
pad_val=(114, 114, 114),
|
127 |
+
prob=0.5),
|
128 |
+
dict(type='PackDetInputs')
|
129 |
+
],
|
130 |
+
backend_args=None),
|
131 |
+
pin_memory=True)
|
132 |
+
val_dataloader = dict(
|
133 |
+
batch_size=5,
|
134 |
+
num_workers=10,
|
135 |
+
persistent_workers=True,
|
136 |
+
drop_last=False,
|
137 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
138 |
+
dataset=dict(
|
139 |
+
type='CocoDataset',
|
140 |
+
data_root='data/coco/',
|
141 |
+
ann_file='annotations/instances_val2017.json',
|
142 |
+
data_prefix=dict(img='val2017/'),
|
143 |
+
test_mode=True,
|
144 |
+
pipeline=[
|
145 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
146 |
+
dict(type='Resize', scale=(640, 640), keep_ratio=True),
|
147 |
+
dict(
|
148 |
+
type='Pad', size=(640, 640),
|
149 |
+
pad_val=dict(img=(114, 114, 114))),
|
150 |
+
dict(
|
151 |
+
type='PackDetInputs',
|
152 |
+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
153 |
+
'scale_factor'))
|
154 |
+
],
|
155 |
+
backend_args=None))
|
156 |
+
test_dataloader = dict(
|
157 |
+
batch_size=5,
|
158 |
+
num_workers=10,
|
159 |
+
persistent_workers=True,
|
160 |
+
drop_last=False,
|
161 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
162 |
+
dataset=dict(
|
163 |
+
type='CocoDataset',
|
164 |
+
data_root='data/coco/',
|
165 |
+
ann_file='annotations/instances_val2017.json',
|
166 |
+
data_prefix=dict(img='val2017/'),
|
167 |
+
test_mode=True,
|
168 |
+
pipeline=[
|
169 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
170 |
+
dict(type='Resize', scale=(640, 640), keep_ratio=True),
|
171 |
+
dict(
|
172 |
+
type='Pad', size=(640, 640),
|
173 |
+
pad_val=dict(img=(114, 114, 114))),
|
174 |
+
dict(
|
175 |
+
type='PackDetInputs',
|
176 |
+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
177 |
+
'scale_factor'))
|
178 |
+
],
|
179 |
+
backend_args=None))
|
180 |
+
val_evaluator = dict(
|
181 |
+
type='CocoMetric',
|
182 |
+
ann_file='data/coco/annotations/instances_val2017.json',
|
183 |
+
metric='bbox',
|
184 |
+
format_only=False,
|
185 |
+
backend_args=None,
|
186 |
+
proposal_nums=(100, 1, 10))
|
187 |
+
test_evaluator = dict(
|
188 |
+
type='CocoMetric',
|
189 |
+
ann_file='data/coco/annotations/instances_val2017.json',
|
190 |
+
metric='bbox',
|
191 |
+
format_only=False,
|
192 |
+
backend_args=None,
|
193 |
+
proposal_nums=(100, 1, 10))
|
194 |
+
tta_model = dict(
|
195 |
+
type='DetTTAModel',
|
196 |
+
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
|
197 |
+
img_scales = [(640, 640), (320, 320), (960, 960)]
|
198 |
+
tta_pipeline = [
|
199 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
200 |
+
dict(
|
201 |
+
type='TestTimeAug',
|
202 |
+
transforms=[[{
|
203 |
+
'type': 'Resize',
|
204 |
+
'scale': (640, 640),
|
205 |
+
'keep_ratio': True
|
206 |
+
}, {
|
207 |
+
'type': 'Resize',
|
208 |
+
'scale': (320, 320),
|
209 |
+
'keep_ratio': True
|
210 |
+
}, {
|
211 |
+
'type': 'Resize',
|
212 |
+
'scale': (960, 960),
|
213 |
+
'keep_ratio': True
|
214 |
+
}],
|
215 |
+
[{
|
216 |
+
'type': 'RandomFlip',
|
217 |
+
'prob': 1.0
|
218 |
+
}, {
|
219 |
+
'type': 'RandomFlip',
|
220 |
+
'prob': 0.0
|
221 |
+
}],
|
222 |
+
[{
|
223 |
+
'type': 'Pad',
|
224 |
+
'size': (960, 960),
|
225 |
+
'pad_val': {
|
226 |
+
'img': (114, 114, 114)
|
227 |
+
}
|
228 |
+
}],
|
229 |
+
[{
|
230 |
+
'type':
|
231 |
+
'PackDetInputs',
|
232 |
+
'meta_keys':
|
233 |
+
('img_id', 'img_path', 'ori_shape', 'img_shape',
|
234 |
+
'scale_factor', 'flip', 'flip_direction')
|
235 |
+
}]])
|
236 |
+
]
|
237 |
+
model = dict(
|
238 |
+
type='RTMDet',
|
239 |
+
data_preprocessor=dict(
|
240 |
+
type='DetDataPreprocessor',
|
241 |
+
mean=[103.53, 116.28, 123.675],
|
242 |
+
std=[57.375, 57.12, 58.395],
|
243 |
+
bgr_to_rgb=False,
|
244 |
+
batch_augments=None),
|
245 |
+
backbone=dict(
|
246 |
+
type='CSPNeXt',
|
247 |
+
arch='P5',
|
248 |
+
expand_ratio=0.5,
|
249 |
+
deepen_factor=0.167,
|
250 |
+
widen_factor=0.375,
|
251 |
+
channel_attention=True,
|
252 |
+
norm_cfg=dict(type='SyncBN'),
|
253 |
+
act_cfg=dict(type='SiLU', inplace=True),
|
254 |
+
init_cfg=dict(
|
255 |
+
type='Pretrained',
|
256 |
+
prefix='backbone.',
|
257 |
+
checkpoint=
|
258 |
+
'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
|
259 |
+
)),
|
260 |
+
neck=dict(
|
261 |
+
type='CSPNeXtPAFPN',
|
262 |
+
in_channels=[96, 192, 384],
|
263 |
+
out_channels=96,
|
264 |
+
num_csp_blocks=1,
|
265 |
+
expand_ratio=0.5,
|
266 |
+
norm_cfg=dict(type='SyncBN'),
|
267 |
+
act_cfg=dict(type='SiLU', inplace=True)),
|
268 |
+
bbox_head=dict(
|
269 |
+
type='RTMDetSepBNHead',
|
270 |
+
num_classes=80,
|
271 |
+
in_channels=96,
|
272 |
+
stacked_convs=2,
|
273 |
+
feat_channels=96,
|
274 |
+
anchor_generator=dict(
|
275 |
+
type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
|
276 |
+
bbox_coder=dict(type='DistancePointBBoxCoder'),
|
277 |
+
loss_cls=dict(
|
278 |
+
type='QualityFocalLoss',
|
279 |
+
use_sigmoid=True,
|
280 |
+
beta=2.0,
|
281 |
+
loss_weight=1.0),
|
282 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
|
283 |
+
with_objectness=False,
|
284 |
+
exp_on_reg=False,
|
285 |
+
share_conv=True,
|
286 |
+
pred_kernel_size=1,
|
287 |
+
norm_cfg=dict(type='SyncBN'),
|
288 |
+
act_cfg=dict(type='SiLU', inplace=True)),
|
289 |
+
train_cfg=dict(
|
290 |
+
assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
|
291 |
+
allowed_border=-1,
|
292 |
+
pos_weight=-1,
|
293 |
+
debug=False),
|
294 |
+
test_cfg=dict(
|
295 |
+
nms_pre=30000,
|
296 |
+
min_bbox_size=0,
|
297 |
+
score_thr=0.001,
|
298 |
+
nms=dict(type='nms', iou_threshold=0.65),
|
299 |
+
max_per_img=300))
|
300 |
+
train_pipeline_stage2 = [
|
301 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
302 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
303 |
+
dict(
|
304 |
+
type='RandomResize',
|
305 |
+
scale=(640, 640),
|
306 |
+
ratio_range=(0.5, 2.0),
|
307 |
+
keep_ratio=True),
|
308 |
+
dict(type='RandomCrop', crop_size=(640, 640)),
|
309 |
+
dict(type='YOLOXHSVRandomAug'),
|
310 |
+
dict(type='RandomFlip', prob=0.5),
|
311 |
+
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
|
312 |
+
dict(type='PackDetInputs')
|
313 |
+
]
|
314 |
+
max_epochs = 300
|
315 |
+
stage2_num_epochs = 20
|
316 |
+
base_lr = 0.004
|
317 |
+
interval = 10
|
318 |
+
custom_hooks = [
|
319 |
+
dict(
|
320 |
+
type='EMAHook',
|
321 |
+
ema_type='ExpMomentumEMA',
|
322 |
+
momentum=0.0002,
|
323 |
+
update_buffers=True,
|
324 |
+
priority=49),
|
325 |
+
dict(
|
326 |
+
type='PipelineSwitchHook',
|
327 |
+
switch_epoch=280,
|
328 |
+
switch_pipeline=[
|
329 |
+
dict(type='LoadImageFromFile', backend_args=None),
|
330 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
331 |
+
dict(
|
332 |
+
type='RandomResize',
|
333 |
+
scale=(640, 640),
|
334 |
+
ratio_range=(0.5, 2.0),
|
335 |
+
keep_ratio=True),
|
336 |
+
dict(type='RandomCrop', crop_size=(640, 640)),
|
337 |
+
dict(type='YOLOXHSVRandomAug'),
|
338 |
+
dict(type='RandomFlip', prob=0.5),
|
339 |
+
dict(
|
340 |
+
type='Pad', size=(640, 640),
|
341 |
+
pad_val=dict(img=(114, 114, 114))),
|
342 |
+
dict(type='PackDetInputs')
|
343 |
+
])
|
344 |
+
]
|
345 |
+
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
|
model_zoo/rtmpose/rtmpose-m_8xb256-420e_aic-coco-256x192/rtmpose-m_8xb256-420e_aic-coco-256x192.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_scope = 'mmpose'
|
2 |
+
default_hooks = dict(
|
3 |
+
timer=dict(type='IterTimerHook'),
|
4 |
+
logger=dict(type='LoggerHook', interval=50),
|
5 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
6 |
+
checkpoint=dict(
|
7 |
+
type='CheckpointHook',
|
8 |
+
interval=10,
|
9 |
+
save_best='coco/AP',
|
10 |
+
rule='greater',
|
11 |
+
max_keep_ckpts=1),
|
12 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
13 |
+
visualization=dict(type='PoseVisualizationHook', enable=False))
|
14 |
+
custom_hooks = [
|
15 |
+
dict(
|
16 |
+
type='EMAHook',
|
17 |
+
ema_type='ExpMomentumEMA',
|
18 |
+
momentum=0.0002,
|
19 |
+
update_buffers=True,
|
20 |
+
priority=49),
|
21 |
+
dict(
|
22 |
+
type='mmdet.PipelineSwitchHook',
|
23 |
+
switch_epoch=390,
|
24 |
+
switch_pipeline=[
|
25 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
26 |
+
dict(type='GetBBoxCenterScale'),
|
27 |
+
dict(type='RandomFlip', direction='horizontal'),
|
28 |
+
dict(type='RandomHalfBody'),
|
29 |
+
dict(
|
30 |
+
type='RandomBBoxTransform',
|
31 |
+
shift_factor=0.0,
|
32 |
+
scale_factor=[0.75, 1.25],
|
33 |
+
rotate_factor=60),
|
34 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
35 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
36 |
+
dict(
|
37 |
+
type='Albumentation',
|
38 |
+
transforms=[
|
39 |
+
dict(type='Blur', p=0.1),
|
40 |
+
dict(type='MedianBlur', p=0.1),
|
41 |
+
dict(
|
42 |
+
type='CoarseDropout',
|
43 |
+
max_holes=1,
|
44 |
+
max_height=0.4,
|
45 |
+
max_width=0.4,
|
46 |
+
min_holes=1,
|
47 |
+
min_height=0.2,
|
48 |
+
min_width=0.2,
|
49 |
+
p=0.5)
|
50 |
+
]),
|
51 |
+
dict(
|
52 |
+
type='GenerateTarget',
|
53 |
+
encoder=dict(
|
54 |
+
type='SimCCLabel',
|
55 |
+
input_size=(192, 256),
|
56 |
+
sigma=(4.9, 5.66),
|
57 |
+
simcc_split_ratio=2.0,
|
58 |
+
normalize=False,
|
59 |
+
use_dark=False)),
|
60 |
+
dict(type='PackPoseInputs')
|
61 |
+
])
|
62 |
+
]
|
63 |
+
env_cfg = dict(
|
64 |
+
cudnn_benchmark=False,
|
65 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
66 |
+
dist_cfg=dict(backend='nccl'))
|
67 |
+
vis_backends = [dict(type='LocalVisBackend')]
|
68 |
+
visualizer = dict(
|
69 |
+
type='PoseLocalVisualizer',
|
70 |
+
vis_backends=[dict(type='LocalVisBackend')],
|
71 |
+
name='visualizer')
|
72 |
+
log_processor = dict(
|
73 |
+
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
74 |
+
log_level = 'INFO'
|
75 |
+
load_from = None
|
76 |
+
resume = False
|
77 |
+
backend_args = dict(backend='local')
|
78 |
+
train_cfg = dict(by_epoch=True, max_epochs=420, val_interval=10)
|
79 |
+
val_cfg = dict()
|
80 |
+
test_cfg = dict()
|
81 |
+
max_epochs = 420
|
82 |
+
stage2_num_epochs = 30
|
83 |
+
base_lr = 0.004
|
84 |
+
randomness = dict(seed=21)
|
85 |
+
optim_wrapper = dict(
|
86 |
+
type='OptimWrapper',
|
87 |
+
optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.05),
|
88 |
+
paramwise_cfg=dict(
|
89 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
90 |
+
param_scheduler = [
|
91 |
+
dict(
|
92 |
+
type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
|
93 |
+
end=1000),
|
94 |
+
dict(
|
95 |
+
type='CosineAnnealingLR',
|
96 |
+
eta_min=0.0002,
|
97 |
+
begin=210,
|
98 |
+
end=420,
|
99 |
+
T_max=210,
|
100 |
+
by_epoch=True,
|
101 |
+
convert_to_iter_based=True)
|
102 |
+
]
|
103 |
+
auto_scale_lr = dict(base_batch_size=1024)
|
104 |
+
codec = dict(
|
105 |
+
type='SimCCLabel',
|
106 |
+
input_size=(192, 256),
|
107 |
+
sigma=(4.9, 5.66),
|
108 |
+
simcc_split_ratio=2.0,
|
109 |
+
normalize=False,
|
110 |
+
use_dark=False)
|
111 |
+
model = dict(
|
112 |
+
type='TopdownPoseEstimator',
|
113 |
+
data_preprocessor=dict(
|
114 |
+
type='PoseDataPreprocessor',
|
115 |
+
mean=[123.675, 116.28, 103.53],
|
116 |
+
std=[58.395, 57.12, 57.375],
|
117 |
+
bgr_to_rgb=True),
|
118 |
+
backbone=dict(
|
119 |
+
_scope_='mmdet',
|
120 |
+
type='CSPNeXt',
|
121 |
+
arch='P5',
|
122 |
+
expand_ratio=0.5,
|
123 |
+
deepen_factor=0.67,
|
124 |
+
widen_factor=0.75,
|
125 |
+
out_indices=(4, ),
|
126 |
+
channel_attention=True,
|
127 |
+
norm_cfg=dict(type='SyncBN'),
|
128 |
+
act_cfg=dict(type='SiLU'),
|
129 |
+
init_cfg=dict(
|
130 |
+
type='Pretrained',
|
131 |
+
prefix='backbone.',
|
132 |
+
checkpoint=
|
133 |
+
'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth'
|
134 |
+
)),
|
135 |
+
head=dict(
|
136 |
+
type='RTMCCHead',
|
137 |
+
in_channels=768,
|
138 |
+
out_channels=17,
|
139 |
+
input_size=(192, 256),
|
140 |
+
in_featuremap_size=(6, 8),
|
141 |
+
simcc_split_ratio=2.0,
|
142 |
+
final_layer_kernel_size=7,
|
143 |
+
gau_cfg=dict(
|
144 |
+
hidden_dims=256,
|
145 |
+
s=128,
|
146 |
+
expansion_factor=2,
|
147 |
+
dropout_rate=0.0,
|
148 |
+
drop_path=0.0,
|
149 |
+
act_fn='SiLU',
|
150 |
+
use_rel_bias=False,
|
151 |
+
pos_enc=False),
|
152 |
+
loss=dict(
|
153 |
+
type='KLDiscretLoss',
|
154 |
+
use_target_weight=True,
|
155 |
+
beta=10.0,
|
156 |
+
label_softmax=True),
|
157 |
+
decoder=dict(
|
158 |
+
type='SimCCLabel',
|
159 |
+
input_size=(192, 256),
|
160 |
+
sigma=(4.9, 5.66),
|
161 |
+
simcc_split_ratio=2.0,
|
162 |
+
normalize=False,
|
163 |
+
use_dark=False)),
|
164 |
+
test_cfg=dict(flip_test=True))
|
165 |
+
dataset_type = 'CocoDataset'
|
166 |
+
data_mode = 'topdown'
|
167 |
+
data_root = 'data/'
|
168 |
+
train_pipeline = [
|
169 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
170 |
+
dict(type='GetBBoxCenterScale'),
|
171 |
+
dict(type='RandomFlip', direction='horizontal'),
|
172 |
+
dict(type='RandomHalfBody'),
|
173 |
+
dict(
|
174 |
+
type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
|
175 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
176 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
177 |
+
dict(
|
178 |
+
type='Albumentation',
|
179 |
+
transforms=[
|
180 |
+
dict(type='Blur', p=0.1),
|
181 |
+
dict(type='MedianBlur', p=0.1),
|
182 |
+
dict(
|
183 |
+
type='CoarseDropout',
|
184 |
+
max_holes=1,
|
185 |
+
max_height=0.4,
|
186 |
+
max_width=0.4,
|
187 |
+
min_holes=1,
|
188 |
+
min_height=0.2,
|
189 |
+
min_width=0.2,
|
190 |
+
p=1.0)
|
191 |
+
]),
|
192 |
+
dict(
|
193 |
+
type='GenerateTarget',
|
194 |
+
encoder=dict(
|
195 |
+
type='SimCCLabel',
|
196 |
+
input_size=(192, 256),
|
197 |
+
sigma=(4.9, 5.66),
|
198 |
+
simcc_split_ratio=2.0,
|
199 |
+
normalize=False,
|
200 |
+
use_dark=False)),
|
201 |
+
dict(type='PackPoseInputs')
|
202 |
+
]
|
203 |
+
val_pipeline = [
|
204 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
205 |
+
dict(type='GetBBoxCenterScale'),
|
206 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
207 |
+
dict(type='PackPoseInputs')
|
208 |
+
]
|
209 |
+
train_pipeline_stage2 = [
|
210 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
211 |
+
dict(type='GetBBoxCenterScale'),
|
212 |
+
dict(type='RandomFlip', direction='horizontal'),
|
213 |
+
dict(type='RandomHalfBody'),
|
214 |
+
dict(
|
215 |
+
type='RandomBBoxTransform',
|
216 |
+
shift_factor=0.0,
|
217 |
+
scale_factor=[0.75, 1.25],
|
218 |
+
rotate_factor=60),
|
219 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
220 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
221 |
+
dict(
|
222 |
+
type='Albumentation',
|
223 |
+
transforms=[
|
224 |
+
dict(type='Blur', p=0.1),
|
225 |
+
dict(type='MedianBlur', p=0.1),
|
226 |
+
dict(
|
227 |
+
type='CoarseDropout',
|
228 |
+
max_holes=1,
|
229 |
+
max_height=0.4,
|
230 |
+
max_width=0.4,
|
231 |
+
min_holes=1,
|
232 |
+
min_height=0.2,
|
233 |
+
min_width=0.2,
|
234 |
+
p=0.5)
|
235 |
+
]),
|
236 |
+
dict(
|
237 |
+
type='GenerateTarget',
|
238 |
+
encoder=dict(
|
239 |
+
type='SimCCLabel',
|
240 |
+
input_size=(192, 256),
|
241 |
+
sigma=(4.9, 5.66),
|
242 |
+
simcc_split_ratio=2.0,
|
243 |
+
normalize=False,
|
244 |
+
use_dark=False)),
|
245 |
+
dict(type='PackPoseInputs')
|
246 |
+
]
|
247 |
+
dataset_coco = dict(
|
248 |
+
type='RepeatDataset',
|
249 |
+
dataset=dict(
|
250 |
+
type='CocoDataset',
|
251 |
+
data_root='data/',
|
252 |
+
data_mode='topdown',
|
253 |
+
ann_file='coco/annotations/person_keypoints_train2017.json',
|
254 |
+
data_prefix=dict(img='detection/coco/train2017/'),
|
255 |
+
pipeline=[]),
|
256 |
+
times=3)
|
257 |
+
dataset_aic = dict(
|
258 |
+
type='AicDataset',
|
259 |
+
data_root='data/',
|
260 |
+
data_mode='topdown',
|
261 |
+
ann_file='aic/annotations/aic_train.json',
|
262 |
+
data_prefix=dict(
|
263 |
+
img=
|
264 |
+
'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
|
265 |
+
),
|
266 |
+
pipeline=[
|
267 |
+
dict(
|
268 |
+
type='KeypointConverter',
|
269 |
+
num_keypoints=17,
|
270 |
+
mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
|
271 |
+
(7, 14), (8, 16), (9, 11), (10, 13), (11, 15)])
|
272 |
+
])
|
273 |
+
train_dataloader = dict(
|
274 |
+
batch_size=256,
|
275 |
+
num_workers=10,
|
276 |
+
persistent_workers=True,
|
277 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
278 |
+
dataset=dict(
|
279 |
+
type='CombinedDataset',
|
280 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
|
281 |
+
datasets=[
|
282 |
+
dict(
|
283 |
+
type='RepeatDataset',
|
284 |
+
dataset=dict(
|
285 |
+
type='CocoDataset',
|
286 |
+
data_root='data/',
|
287 |
+
data_mode='topdown',
|
288 |
+
ann_file='coco/annotations/person_keypoints_train2017.json',
|
289 |
+
data_prefix=dict(img='detection/coco/train2017/'),
|
290 |
+
pipeline=[]),
|
291 |
+
times=3),
|
292 |
+
dict(
|
293 |
+
type='AicDataset',
|
294 |
+
data_root='data/',
|
295 |
+
data_mode='topdown',
|
296 |
+
ann_file='aic/annotations/aic_train.json',
|
297 |
+
data_prefix=dict(
|
298 |
+
img=
|
299 |
+
'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
|
300 |
+
),
|
301 |
+
pipeline=[
|
302 |
+
dict(
|
303 |
+
type='KeypointConverter',
|
304 |
+
num_keypoints=17,
|
305 |
+
mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7),
|
306 |
+
(5, 9), (6, 12), (7, 14), (8, 16), (9, 11),
|
307 |
+
(10, 13), (11, 15)])
|
308 |
+
])
|
309 |
+
],
|
310 |
+
pipeline=[
|
311 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
312 |
+
dict(type='GetBBoxCenterScale'),
|
313 |
+
dict(type='RandomFlip', direction='horizontal'),
|
314 |
+
dict(type='RandomHalfBody'),
|
315 |
+
dict(
|
316 |
+
type='RandomBBoxTransform',
|
317 |
+
scale_factor=[0.6, 1.4],
|
318 |
+
rotate_factor=80),
|
319 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
320 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
321 |
+
dict(
|
322 |
+
type='Albumentation',
|
323 |
+
transforms=[
|
324 |
+
dict(type='Blur', p=0.1),
|
325 |
+
dict(type='MedianBlur', p=0.1),
|
326 |
+
dict(
|
327 |
+
type='CoarseDropout',
|
328 |
+
max_holes=1,
|
329 |
+
max_height=0.4,
|
330 |
+
max_width=0.4,
|
331 |
+
min_holes=1,
|
332 |
+
min_height=0.2,
|
333 |
+
min_width=0.2,
|
334 |
+
p=1.0)
|
335 |
+
]),
|
336 |
+
dict(
|
337 |
+
type='GenerateTarget',
|
338 |
+
encoder=dict(
|
339 |
+
type='SimCCLabel',
|
340 |
+
input_size=(192, 256),
|
341 |
+
sigma=(4.9, 5.66),
|
342 |
+
simcc_split_ratio=2.0,
|
343 |
+
normalize=False,
|
344 |
+
use_dark=False)),
|
345 |
+
dict(type='PackPoseInputs')
|
346 |
+
],
|
347 |
+
test_mode=False))
|
348 |
+
val_dataloader = dict(
|
349 |
+
batch_size=64,
|
350 |
+
num_workers=10,
|
351 |
+
persistent_workers=True,
|
352 |
+
drop_last=False,
|
353 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
354 |
+
dataset=dict(
|
355 |
+
type='CocoDataset',
|
356 |
+
data_root='data/',
|
357 |
+
data_mode='topdown',
|
358 |
+
ann_file='coco/annotations/person_keypoints_val2017.json',
|
359 |
+
data_prefix=dict(img='detection/coco/val2017/'),
|
360 |
+
test_mode=True,
|
361 |
+
pipeline=[
|
362 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
363 |
+
dict(type='GetBBoxCenterScale'),
|
364 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
365 |
+
dict(type='PackPoseInputs')
|
366 |
+
]))
|
367 |
+
test_dataloader = dict(
|
368 |
+
batch_size=64,
|
369 |
+
num_workers=10,
|
370 |
+
persistent_workers=True,
|
371 |
+
drop_last=False,
|
372 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
373 |
+
dataset=dict(
|
374 |
+
type='CocoDataset',
|
375 |
+
data_root='data/',
|
376 |
+
data_mode='topdown',
|
377 |
+
ann_file='coco/annotations/person_keypoints_val2017.json',
|
378 |
+
data_prefix=dict(img='detection/coco/val2017/'),
|
379 |
+
test_mode=True,
|
380 |
+
pipeline=[
|
381 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
382 |
+
dict(type='GetBBoxCenterScale'),
|
383 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
384 |
+
dict(type='PackPoseInputs')
|
385 |
+
]))
|
386 |
+
val_evaluator = dict(
|
387 |
+
type='CocoMetric',
|
388 |
+
ann_file='data/coco/annotations/person_keypoints_val2017.json')
|
389 |
+
test_evaluator = dict(
|
390 |
+
type='CocoMetric',
|
391 |
+
ann_file='data/coco/annotations/person_keypoints_val2017.json')
|
model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_scope = 'mmpose'
|
2 |
+
default_hooks = dict(
|
3 |
+
timer=dict(type='IterTimerHook'),
|
4 |
+
logger=dict(type='LoggerHook', interval=50),
|
5 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
6 |
+
checkpoint=dict(
|
7 |
+
type='CheckpointHook',
|
8 |
+
interval=10,
|
9 |
+
save_best='coco/AP',
|
10 |
+
rule='greater',
|
11 |
+
max_keep_ckpts=1),
|
12 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
13 |
+
visualization=dict(type='PoseVisualizationHook', enable=False))
|
14 |
+
custom_hooks = [
|
15 |
+
dict(
|
16 |
+
type='mmdet.PipelineSwitchHook',
|
17 |
+
switch_epoch=390,
|
18 |
+
switch_pipeline=[
|
19 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
20 |
+
dict(type='GetBBoxCenterScale'),
|
21 |
+
dict(type='RandomFlip', direction='horizontal'),
|
22 |
+
dict(type='RandomHalfBody'),
|
23 |
+
dict(
|
24 |
+
type='RandomBBoxTransform',
|
25 |
+
shift_factor=0.0,
|
26 |
+
scale_factor=[0.75, 1.25],
|
27 |
+
rotate_factor=60),
|
28 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
29 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
30 |
+
dict(
|
31 |
+
type='Albumentation',
|
32 |
+
transforms=[
|
33 |
+
dict(type='Blur', p=0.1),
|
34 |
+
dict(type='MedianBlur', p=0.1),
|
35 |
+
dict(
|
36 |
+
type='CoarseDropout',
|
37 |
+
max_holes=1,
|
38 |
+
max_height=0.4,
|
39 |
+
max_width=0.4,
|
40 |
+
min_holes=1,
|
41 |
+
min_height=0.2,
|
42 |
+
min_width=0.2,
|
43 |
+
p=0.5)
|
44 |
+
]),
|
45 |
+
dict(
|
46 |
+
type='GenerateTarget',
|
47 |
+
encoder=dict(
|
48 |
+
type='SimCCLabel',
|
49 |
+
input_size=(192, 256),
|
50 |
+
sigma=(4.9, 5.66),
|
51 |
+
simcc_split_ratio=2.0,
|
52 |
+
normalize=False,
|
53 |
+
use_dark=False)),
|
54 |
+
dict(type='PackPoseInputs')
|
55 |
+
])
|
56 |
+
]
|
57 |
+
env_cfg = dict(
|
58 |
+
cudnn_benchmark=False,
|
59 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
60 |
+
dist_cfg=dict(backend='nccl'))
|
61 |
+
vis_backends = [dict(type='LocalVisBackend')]
|
62 |
+
visualizer = dict(
|
63 |
+
type='PoseLocalVisualizer',
|
64 |
+
vis_backends=[dict(type='LocalVisBackend')],
|
65 |
+
name='visualizer')
|
66 |
+
log_processor = dict(
|
67 |
+
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
68 |
+
log_level = 'INFO'
|
69 |
+
load_from = None
|
70 |
+
resume = False
|
71 |
+
backend_args = dict(backend='local')
|
72 |
+
train_cfg = dict(by_epoch=True, max_epochs=420, val_interval=10)
|
73 |
+
val_cfg = dict()
|
74 |
+
test_cfg = dict()
|
75 |
+
max_epochs = 420
|
76 |
+
stage2_num_epochs = 30
|
77 |
+
base_lr = 0.004
|
78 |
+
randomness = dict(seed=21)
|
79 |
+
optim_wrapper = dict(
|
80 |
+
type='OptimWrapper',
|
81 |
+
optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.0),
|
82 |
+
paramwise_cfg=dict(
|
83 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
84 |
+
param_scheduler = [
|
85 |
+
dict(
|
86 |
+
type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
|
87 |
+
end=1000),
|
88 |
+
dict(
|
89 |
+
type='CosineAnnealingLR',
|
90 |
+
eta_min=0.0002,
|
91 |
+
begin=210,
|
92 |
+
end=420,
|
93 |
+
T_max=210,
|
94 |
+
by_epoch=True,
|
95 |
+
convert_to_iter_based=True)
|
96 |
+
]
|
97 |
+
auto_scale_lr = dict(base_batch_size=1024)
|
98 |
+
codec = dict(
|
99 |
+
type='SimCCLabel',
|
100 |
+
input_size=(192, 256),
|
101 |
+
sigma=(4.9, 5.66),
|
102 |
+
simcc_split_ratio=2.0,
|
103 |
+
normalize=False,
|
104 |
+
use_dark=False)
|
105 |
+
model = dict(
|
106 |
+
type='TopdownPoseEstimator',
|
107 |
+
data_preprocessor=dict(
|
108 |
+
type='PoseDataPreprocessor',
|
109 |
+
mean=[123.675, 116.28, 103.53],
|
110 |
+
std=[58.395, 57.12, 57.375],
|
111 |
+
bgr_to_rgb=True),
|
112 |
+
backbone=dict(
|
113 |
+
_scope_='mmdet',
|
114 |
+
type='CSPNeXt',
|
115 |
+
arch='P5',
|
116 |
+
expand_ratio=0.5,
|
117 |
+
deepen_factor=0.167,
|
118 |
+
widen_factor=0.375,
|
119 |
+
out_indices=(4, ),
|
120 |
+
channel_attention=True,
|
121 |
+
norm_cfg=dict(type='SyncBN'),
|
122 |
+
act_cfg=dict(type='SiLU'),
|
123 |
+
init_cfg=dict(
|
124 |
+
type='Pretrained',
|
125 |
+
prefix='backbone.',
|
126 |
+
checkpoint=
|
127 |
+
'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth'
|
128 |
+
)),
|
129 |
+
head=dict(
|
130 |
+
type='RTMCCHead',
|
131 |
+
in_channels=384,
|
132 |
+
out_channels=17,
|
133 |
+
input_size=(192, 256),
|
134 |
+
in_featuremap_size=(6, 8),
|
135 |
+
simcc_split_ratio=2.0,
|
136 |
+
final_layer_kernel_size=7,
|
137 |
+
gau_cfg=dict(
|
138 |
+
hidden_dims=256,
|
139 |
+
s=128,
|
140 |
+
expansion_factor=2,
|
141 |
+
dropout_rate=0.0,
|
142 |
+
drop_path=0.0,
|
143 |
+
act_fn='SiLU',
|
144 |
+
use_rel_bias=False,
|
145 |
+
pos_enc=False),
|
146 |
+
loss=dict(
|
147 |
+
type='KLDiscretLoss',
|
148 |
+
use_target_weight=True,
|
149 |
+
beta=10.0,
|
150 |
+
label_softmax=True),
|
151 |
+
decoder=dict(
|
152 |
+
type='SimCCLabel',
|
153 |
+
input_size=(192, 256),
|
154 |
+
sigma=(4.9, 5.66),
|
155 |
+
simcc_split_ratio=2.0,
|
156 |
+
normalize=False,
|
157 |
+
use_dark=False)),
|
158 |
+
test_cfg=dict(flip_test=True))
|
159 |
+
dataset_type = 'CocoDataset'
|
160 |
+
data_mode = 'topdown'
|
161 |
+
data_root = 'data/'
|
162 |
+
train_pipeline = [
|
163 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
164 |
+
dict(type='GetBBoxCenterScale'),
|
165 |
+
dict(type='RandomFlip', direction='horizontal'),
|
166 |
+
dict(type='RandomHalfBody'),
|
167 |
+
dict(
|
168 |
+
type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
|
169 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
170 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
171 |
+
dict(
|
172 |
+
type='Albumentation',
|
173 |
+
transforms=[
|
174 |
+
dict(type='Blur', p=0.1),
|
175 |
+
dict(type='MedianBlur', p=0.1),
|
176 |
+
dict(
|
177 |
+
type='CoarseDropout',
|
178 |
+
max_holes=1,
|
179 |
+
max_height=0.4,
|
180 |
+
max_width=0.4,
|
181 |
+
min_holes=1,
|
182 |
+
min_height=0.2,
|
183 |
+
min_width=0.2,
|
184 |
+
p=1.0)
|
185 |
+
]),
|
186 |
+
dict(
|
187 |
+
type='GenerateTarget',
|
188 |
+
encoder=dict(
|
189 |
+
type='SimCCLabel',
|
190 |
+
input_size=(192, 256),
|
191 |
+
sigma=(4.9, 5.66),
|
192 |
+
simcc_split_ratio=2.0,
|
193 |
+
normalize=False,
|
194 |
+
use_dark=False)),
|
195 |
+
dict(type='PackPoseInputs')
|
196 |
+
]
|
197 |
+
val_pipeline = [
|
198 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
199 |
+
dict(type='GetBBoxCenterScale'),
|
200 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
201 |
+
dict(type='PackPoseInputs')
|
202 |
+
]
|
203 |
+
train_pipeline_stage2 = [
|
204 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
205 |
+
dict(type='GetBBoxCenterScale'),
|
206 |
+
dict(type='RandomFlip', direction='horizontal'),
|
207 |
+
dict(type='RandomHalfBody'),
|
208 |
+
dict(
|
209 |
+
type='RandomBBoxTransform',
|
210 |
+
shift_factor=0.0,
|
211 |
+
scale_factor=[0.75, 1.25],
|
212 |
+
rotate_factor=60),
|
213 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
214 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
215 |
+
dict(
|
216 |
+
type='Albumentation',
|
217 |
+
transforms=[
|
218 |
+
dict(type='Blur', p=0.1),
|
219 |
+
dict(type='MedianBlur', p=0.1),
|
220 |
+
dict(
|
221 |
+
type='CoarseDropout',
|
222 |
+
max_holes=1,
|
223 |
+
max_height=0.4,
|
224 |
+
max_width=0.4,
|
225 |
+
min_holes=1,
|
226 |
+
min_height=0.2,
|
227 |
+
min_width=0.2,
|
228 |
+
p=0.5)
|
229 |
+
]),
|
230 |
+
dict(
|
231 |
+
type='GenerateTarget',
|
232 |
+
encoder=dict(
|
233 |
+
type='SimCCLabel',
|
234 |
+
input_size=(192, 256),
|
235 |
+
sigma=(4.9, 5.66),
|
236 |
+
simcc_split_ratio=2.0,
|
237 |
+
normalize=False,
|
238 |
+
use_dark=False)),
|
239 |
+
dict(type='PackPoseInputs')
|
240 |
+
]
|
241 |
+
dataset_coco = dict(
|
242 |
+
type='RepeatDataset',
|
243 |
+
dataset=dict(
|
244 |
+
type='CocoDataset',
|
245 |
+
data_root='data/',
|
246 |
+
data_mode='topdown',
|
247 |
+
ann_file='coco/annotations/person_keypoints_train2017.json',
|
248 |
+
data_prefix=dict(img='detection/coco/train2017/'),
|
249 |
+
pipeline=[]),
|
250 |
+
times=3)
|
251 |
+
dataset_aic = dict(
|
252 |
+
type='AicDataset',
|
253 |
+
data_root='data/',
|
254 |
+
data_mode='topdown',
|
255 |
+
ann_file='aic/annotations/aic_train.json',
|
256 |
+
data_prefix=dict(
|
257 |
+
img=
|
258 |
+
'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
|
259 |
+
),
|
260 |
+
pipeline=[
|
261 |
+
dict(
|
262 |
+
type='KeypointConverter',
|
263 |
+
num_keypoints=17,
|
264 |
+
mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12),
|
265 |
+
(7, 14), (8, 16), (9, 11), (10, 13), (11, 15)])
|
266 |
+
])
|
267 |
+
train_dataloader = dict(
|
268 |
+
batch_size=256,
|
269 |
+
num_workers=10,
|
270 |
+
persistent_workers=True,
|
271 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
272 |
+
dataset=dict(
|
273 |
+
type='CombinedDataset',
|
274 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
|
275 |
+
datasets=[
|
276 |
+
dict(
|
277 |
+
type='RepeatDataset',
|
278 |
+
dataset=dict(
|
279 |
+
type='CocoDataset',
|
280 |
+
data_root='data/',
|
281 |
+
data_mode='topdown',
|
282 |
+
ann_file='coco/annotations/person_keypoints_train2017.json',
|
283 |
+
data_prefix=dict(img='detection/coco/train2017/'),
|
284 |
+
pipeline=[]),
|
285 |
+
times=3),
|
286 |
+
dict(
|
287 |
+
type='AicDataset',
|
288 |
+
data_root='data/',
|
289 |
+
data_mode='topdown',
|
290 |
+
ann_file='aic/annotations/aic_train.json',
|
291 |
+
data_prefix=dict(
|
292 |
+
img=
|
293 |
+
'pose/ai_challenge/ai_challenger_keypoint_train_20170902/keypoint_train_images_20170902/'
|
294 |
+
),
|
295 |
+
pipeline=[
|
296 |
+
dict(
|
297 |
+
type='KeypointConverter',
|
298 |
+
num_keypoints=17,
|
299 |
+
mapping=[(0, 6), (1, 8), (2, 10), (3, 5), (4, 7),
|
300 |
+
(5, 9), (6, 12), (7, 14), (8, 16), (9, 11),
|
301 |
+
(10, 13), (11, 15)])
|
302 |
+
])
|
303 |
+
],
|
304 |
+
pipeline=[
|
305 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
306 |
+
dict(type='GetBBoxCenterScale'),
|
307 |
+
dict(type='RandomFlip', direction='horizontal'),
|
308 |
+
dict(type='RandomHalfBody'),
|
309 |
+
dict(
|
310 |
+
type='RandomBBoxTransform',
|
311 |
+
scale_factor=[0.6, 1.4],
|
312 |
+
rotate_factor=80),
|
313 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
314 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
315 |
+
dict(
|
316 |
+
type='Albumentation',
|
317 |
+
transforms=[
|
318 |
+
dict(type='Blur', p=0.1),
|
319 |
+
dict(type='MedianBlur', p=0.1),
|
320 |
+
dict(
|
321 |
+
type='CoarseDropout',
|
322 |
+
max_holes=1,
|
323 |
+
max_height=0.4,
|
324 |
+
max_width=0.4,
|
325 |
+
min_holes=1,
|
326 |
+
min_height=0.2,
|
327 |
+
min_width=0.2,
|
328 |
+
p=1.0)
|
329 |
+
]),
|
330 |
+
dict(
|
331 |
+
type='GenerateTarget',
|
332 |
+
encoder=dict(
|
333 |
+
type='SimCCLabel',
|
334 |
+
input_size=(192, 256),
|
335 |
+
sigma=(4.9, 5.66),
|
336 |
+
simcc_split_ratio=2.0,
|
337 |
+
normalize=False,
|
338 |
+
use_dark=False)),
|
339 |
+
dict(type='PackPoseInputs')
|
340 |
+
],
|
341 |
+
test_mode=False))
|
342 |
+
val_dataloader = dict(
|
343 |
+
batch_size=64,
|
344 |
+
num_workers=10,
|
345 |
+
persistent_workers=True,
|
346 |
+
drop_last=False,
|
347 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
348 |
+
dataset=dict(
|
349 |
+
type='CocoDataset',
|
350 |
+
data_root='data/',
|
351 |
+
data_mode='topdown',
|
352 |
+
ann_file='coco/annotations/person_keypoints_val2017.json',
|
353 |
+
data_prefix=dict(img='detection/coco/val2017/'),
|
354 |
+
test_mode=True,
|
355 |
+
pipeline=[
|
356 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
357 |
+
dict(type='GetBBoxCenterScale'),
|
358 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
359 |
+
dict(type='PackPoseInputs')
|
360 |
+
]))
|
361 |
+
test_dataloader = dict(
|
362 |
+
batch_size=64,
|
363 |
+
num_workers=10,
|
364 |
+
persistent_workers=True,
|
365 |
+
drop_last=False,
|
366 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
367 |
+
dataset=dict(
|
368 |
+
type='CocoDataset',
|
369 |
+
data_root='data/',
|
370 |
+
data_mode='topdown',
|
371 |
+
ann_file='coco/annotations/person_keypoints_val2017.json',
|
372 |
+
data_prefix=dict(img='detection/coco/val2017/'),
|
373 |
+
test_mode=True,
|
374 |
+
pipeline=[
|
375 |
+
dict(type='LoadImage', backend_args=dict(backend='local')),
|
376 |
+
dict(type='GetBBoxCenterScale'),
|
377 |
+
dict(type='TopdownAffine', input_size=(192, 256)),
|
378 |
+
dict(type='PackPoseInputs')
|
379 |
+
]))
|
380 |
+
val_evaluator = dict(
|
381 |
+
type='CocoMetric',
|
382 |
+
ann_file='data/coco/annotations/person_keypoints_val2017.json')
|
383 |
+
test_evaluator = dict(
|
384 |
+
type='CocoMetric',
|
385 |
+
ann_file='data/coco/annotations/person_keypoints_val2017.json')
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openmim
|
2 |
+
torch
|
3 |
+
mmengine
|
4 |
+
mmcv
|
5 |
+
mmdet
|
6 |
+
mmpose
|
7 |
+
mmdeploy
|
8 |
+
onnxruntime
|
9 |
+
tqdm
|
10 |
+
scikit-image
|
11 |
+
easydict
|
12 |
+
gradio
|
tools/apis.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from mmengine.registry import MODELS
|
3 |
+
from mmengine.dataset import Compose, pseudo_collate
|
4 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
5 |
+
from mmengine.registry import init_default_scope
|
6 |
+
from mmengine.runner import load_checkpoint
|
7 |
+
from mmengine.config import Config
|
8 |
+
|
9 |
+
from mmdeploy.utils import get_input_shape, load_config
|
10 |
+
from mmdeploy.apis.utils import build_task_processor
|
11 |
+
|
12 |
+
def build_model(cfg, checkpoint=None, device='cpu'):
|
13 |
+
""" Build model from config and load checkpoint
|
14 |
+
checkpoint_meta usually contains dataset classes information
|
15 |
+
"""
|
16 |
+
if isinstance(cfg, str):
|
17 |
+
cfg = Config.fromfile(cfg)
|
18 |
+
# scope of model, e.g. mmdet, mmseg, mmpose...
|
19 |
+
init_default_scope(cfg.default_scope)
|
20 |
+
model = MODELS.build(cfg.model)
|
21 |
+
model = revert_sync_batchnorm(model)
|
22 |
+
if checkpoint is not None:
|
23 |
+
ckpt = load_checkpoint(model, checkpoint,
|
24 |
+
map_location='cpu')
|
25 |
+
checkpoint_meta = ckpt.get('meta', {})
|
26 |
+
# usually classes and pallate are in checkpoint_meta
|
27 |
+
model.checkpoint_meta = checkpoint_meta
|
28 |
+
model.to(device)
|
29 |
+
model.eval()
|
30 |
+
return model
|
31 |
+
|
32 |
+
def inference(model, cfg, img):
|
33 |
+
""" Given model, config and image, return inference results.
|
34 |
+
Models in mmlab does not share the same inference api. So this
|
35 |
+
function is just a memo for me...
|
36 |
+
"""
|
37 |
+
if isinstance(cfg, str):
|
38 |
+
cfg = Config.fromfile(cfg)
|
39 |
+
# process pipline
|
40 |
+
test_pipeline = cfg.test_dataloader.dataset.pipeline
|
41 |
+
# Use 'LoadImage' to handle both cases of img and img_path
|
42 |
+
# This is specially designed for mmdet config, which uses 'LoadImageFromFile'
|
43 |
+
for pipeline in test_pipeline:
|
44 |
+
if 'LoadImage' in pipeline['type']:
|
45 |
+
pipeline['type'] = 'mmpose.LoadImage'
|
46 |
+
|
47 |
+
init_default_scope(cfg.default_scope)
|
48 |
+
pipeline = Compose(test_pipeline)
|
49 |
+
|
50 |
+
if isinstance(img, str):
|
51 |
+
# img_id is useless...but to be compatible with mmdet
|
52 |
+
data_info = dict(img_path=img, img_id=0)
|
53 |
+
else:
|
54 |
+
data_info = dict(img=img, img_id=0)
|
55 |
+
|
56 |
+
data = pipeline(data_info)
|
57 |
+
batch = pseudo_collate([data])
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
results = model.test_step(batch)
|
61 |
+
|
62 |
+
return results
|
63 |
+
|
64 |
+
def build_onnx_model_and_task_processor(model_cfg, deploy_cfg, backend_files, device):
|
65 |
+
|
66 |
+
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
67 |
+
|
68 |
+
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
69 |
+
|
70 |
+
model = task_processor.build_backend_model(
|
71 |
+
backend_files, task_processor.update_data_preprocessor)
|
72 |
+
|
73 |
+
return model, task_processor
|
74 |
+
|
75 |
+
def inference_onnx_model(model, task_processor, deploy_cfg, img):
|
76 |
+
input_shape = get_input_shape(deploy_cfg)
|
77 |
+
model_inputs, _ = task_processor.create_input(img, input_shape)
|
78 |
+
|
79 |
+
with torch.no_grad():
|
80 |
+
result = model.test_step(model_inputs)
|
81 |
+
|
82 |
+
return result
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
config = '/github/Tennis.ai/model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-t_8xb256-420e_aic-coco-256x192.py'
|
86 |
+
ckpt = '/github/Tennis.ai/model_zoo/rtmpose/rtmpose-t_8xb256-420e_aic-coco-256x192/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth'
|
87 |
+
img = '/github/Tennis.ai/assets/000000197388.jpg'
|
88 |
+
|
89 |
+
detector = build_model(config, checkpoint=ckpt)
|
90 |
+
result = inference(detector, config, img)
|
tools/deploy.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# Modified from mmdeploy/tools/deploy.py, removed some codes to only focus on ONNX report
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import mmengine
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
from torch.multiprocessing import Process, set_start_method
|
12 |
+
|
13 |
+
from mmdeploy.apis import (create_calib_input_data, extract_model,
|
14 |
+
get_predefined_partition_cfg, torch2onnx,
|
15 |
+
torch2torchscript, visualize_model)
|
16 |
+
from mmdeploy.apis.core import PIPELINE_MANAGER
|
17 |
+
from mmdeploy.apis.utils import to_backend
|
18 |
+
from mmdeploy.backend.sdk.export_info import export2SDK
|
19 |
+
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
|
20 |
+
get_ir_config, get_partition_config,
|
21 |
+
get_root_logger, load_config, target_wrapper)
|
22 |
+
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser(description='Export model to backends.')
|
26 |
+
parser.add_argument('deploy_cfg', help='deploy config path')
|
27 |
+
parser.add_argument('model_cfg', help='model config path')
|
28 |
+
parser.add_argument('checkpoint', help='model checkpoint path')
|
29 |
+
parser.add_argument('img', help='image used to convert model model')
|
30 |
+
parser.add_argument(
|
31 |
+
'--test-img',
|
32 |
+
default=None,
|
33 |
+
type=str,
|
34 |
+
nargs='+',
|
35 |
+
help='image used to test model')
|
36 |
+
parser.add_argument(
|
37 |
+
'--work-dir',
|
38 |
+
default=os.getcwd(),
|
39 |
+
help='the dir to save logs and models')
|
40 |
+
parser.add_argument(
|
41 |
+
'--calib-dataset-cfg',
|
42 |
+
help='dataset config path used to calibrate in int8 mode. If not \
|
43 |
+
specified, it will use "val" dataset in model config instead.',
|
44 |
+
default=None)
|
45 |
+
parser.add_argument(
|
46 |
+
'--device', help='device used for conversion', default='cpu')
|
47 |
+
parser.add_argument(
|
48 |
+
'--log-level',
|
49 |
+
help='set log level',
|
50 |
+
default='INFO',
|
51 |
+
choices=list(logging._nameToLevel.keys()))
|
52 |
+
parser.add_argument(
|
53 |
+
'--show', action='store_true', help='Show detection outputs')
|
54 |
+
parser.add_argument(
|
55 |
+
'--dump-info', action='store_true', help='Output information for SDK')
|
56 |
+
parser.add_argument(
|
57 |
+
'--quant-image-dir',
|
58 |
+
default=None,
|
59 |
+
help='Image directory for quantize model.')
|
60 |
+
parser.add_argument(
|
61 |
+
'--quant', action='store_true', help='Quantize model to low bit.')
|
62 |
+
parser.add_argument(
|
63 |
+
'--uri',
|
64 |
+
default='192.168.1.1:60000',
|
65 |
+
help='Remote ipv4:port or ipv6:port for inference on edge device.')
|
66 |
+
args = parser.parse_args()
|
67 |
+
return args
|
68 |
+
|
69 |
+
|
70 |
+
def create_process(name, target, args, kwargs, ret_value=None):
|
71 |
+
logger = get_root_logger()
|
72 |
+
logger.info(f'{name} start.')
|
73 |
+
log_level = logger.level
|
74 |
+
|
75 |
+
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
76 |
+
|
77 |
+
process = Process(target=wrap_func, args=args, kwargs=kwargs)
|
78 |
+
process.start()
|
79 |
+
process.join()
|
80 |
+
|
81 |
+
if ret_value is not None:
|
82 |
+
if ret_value.value != 0:
|
83 |
+
logger.error(f'{name} failed.')
|
84 |
+
exit(1)
|
85 |
+
else:
|
86 |
+
logger.info(f'{name} success.')
|
87 |
+
|
88 |
+
|
89 |
+
def torch2ir(ir_type: IR):
|
90 |
+
"""Return the conversion function from torch to the intermediate
|
91 |
+
representation.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
ir_type (IR): The type of the intermediate representation.
|
95 |
+
"""
|
96 |
+
if ir_type == IR.ONNX:
|
97 |
+
return torch2onnx
|
98 |
+
elif ir_type == IR.TORCHSCRIPT:
|
99 |
+
return torch2torchscript
|
100 |
+
else:
|
101 |
+
raise KeyError(f'Unexpected IR type {ir_type}')
|
102 |
+
|
103 |
+
|
104 |
+
def main():
|
105 |
+
args = parse_args()
|
106 |
+
set_start_method('spawn', force=True)
|
107 |
+
logger = get_root_logger()
|
108 |
+
log_level = logging.getLevelName(args.log_level)
|
109 |
+
logger.setLevel(log_level)
|
110 |
+
|
111 |
+
pipeline_funcs = [
|
112 |
+
torch2onnx, torch2torchscript, extract_model, create_calib_input_data
|
113 |
+
]
|
114 |
+
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
|
115 |
+
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
|
116 |
+
|
117 |
+
deploy_cfg_path = args.deploy_cfg
|
118 |
+
model_cfg_path = args.model_cfg
|
119 |
+
checkpoint_path = args.checkpoint
|
120 |
+
quant = args.quant
|
121 |
+
quant_image_dir = args.quant_image_dir
|
122 |
+
|
123 |
+
# load deploy_cfg
|
124 |
+
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
125 |
+
|
126 |
+
# create work_dir if not
|
127 |
+
mmengine.mkdir_or_exist(osp.abspath(args.work_dir))
|
128 |
+
|
129 |
+
if args.dump_info:
|
130 |
+
export2SDK(
|
131 |
+
deploy_cfg,
|
132 |
+
model_cfg,
|
133 |
+
args.work_dir,
|
134 |
+
pth=checkpoint_path,
|
135 |
+
device=args.device)
|
136 |
+
|
137 |
+
ret_value = mp.Value('d', 0, lock=False)
|
138 |
+
|
139 |
+
# convert to IR
|
140 |
+
ir_config = get_ir_config(deploy_cfg)
|
141 |
+
ir_save_file = ir_config['save_file']
|
142 |
+
ir_type = IR.get(ir_config['type'])
|
143 |
+
torch2ir(ir_type)(
|
144 |
+
args.img,
|
145 |
+
args.work_dir,
|
146 |
+
ir_save_file,
|
147 |
+
deploy_cfg_path,
|
148 |
+
model_cfg_path,
|
149 |
+
checkpoint_path,
|
150 |
+
device=args.device)
|
151 |
+
|
152 |
+
# convert backend
|
153 |
+
ir_files = [osp.join(args.work_dir, ir_save_file)]
|
154 |
+
|
155 |
+
# partition model
|
156 |
+
partition_cfgs = get_partition_config(deploy_cfg)
|
157 |
+
|
158 |
+
if partition_cfgs is not None:
|
159 |
+
|
160 |
+
if 'partition_cfg' in partition_cfgs:
|
161 |
+
partition_cfgs = partition_cfgs.get('partition_cfg', None)
|
162 |
+
else:
|
163 |
+
assert 'type' in partition_cfgs
|
164 |
+
partition_cfgs = get_predefined_partition_cfg(
|
165 |
+
deploy_cfg, partition_cfgs['type'])
|
166 |
+
|
167 |
+
origin_ir_file = ir_files[0]
|
168 |
+
ir_files = []
|
169 |
+
for partition_cfg in partition_cfgs:
|
170 |
+
save_file = partition_cfg['save_file']
|
171 |
+
save_path = osp.join(args.work_dir, save_file)
|
172 |
+
start = partition_cfg['start']
|
173 |
+
end = partition_cfg['end']
|
174 |
+
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
175 |
+
|
176 |
+
extract_model(
|
177 |
+
origin_ir_file,
|
178 |
+
start,
|
179 |
+
end,
|
180 |
+
dynamic_axes=dynamic_axes,
|
181 |
+
save_file=save_path)
|
182 |
+
|
183 |
+
ir_files.append(save_path)
|
184 |
+
|
185 |
+
backend_files = ir_files
|
186 |
+
# convert backend
|
187 |
+
backend = get_backend(deploy_cfg)
|
188 |
+
|
189 |
+
# convert to backend
|
190 |
+
PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
|
191 |
+
if backend == Backend.TENSORRT:
|
192 |
+
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
|
193 |
+
backend_files = to_backend(
|
194 |
+
backend,
|
195 |
+
ir_files,
|
196 |
+
work_dir=args.work_dir,
|
197 |
+
deploy_cfg=deploy_cfg,
|
198 |
+
log_level=log_level,
|
199 |
+
device=args.device,
|
200 |
+
uri=args.uri)
|
201 |
+
|
202 |
+
if args.test_img is None:
|
203 |
+
args.test_img = args.img
|
204 |
+
|
205 |
+
extra = dict(
|
206 |
+
backend=backend,
|
207 |
+
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
|
208 |
+
show_result=args.show)
|
209 |
+
if backend == Backend.SNPE:
|
210 |
+
extra['uri'] = args.uri
|
211 |
+
|
212 |
+
# get backend inference result, try render
|
213 |
+
create_process(
|
214 |
+
f'visualize {backend.value} model',
|
215 |
+
target=visualize_model,
|
216 |
+
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
|
217 |
+
args.device),
|
218 |
+
kwargs=extra,
|
219 |
+
ret_value=ret_value)
|
220 |
+
|
221 |
+
# get pytorch model inference result, try visualize if possible
|
222 |
+
create_process(
|
223 |
+
'visualize pytorch model',
|
224 |
+
target=visualize_model,
|
225 |
+
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
|
226 |
+
args.test_img, args.device),
|
227 |
+
kwargs=dict(
|
228 |
+
backend=Backend.PYTORCH,
|
229 |
+
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
|
230 |
+
show_result=args.show),
|
231 |
+
ret_value=ret_value)
|
232 |
+
logger.info('All process success.')
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == '__main__':
|
236 |
+
main()
|
tools/dtw.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .utils import get_keypoint_weight
|
3 |
+
|
4 |
+
|
5 |
+
class DTWForKeypoints:
|
6 |
+
def __init__(self, keypoints1, keypoints2):
|
7 |
+
self.keypoints1 = keypoints1
|
8 |
+
self.keypoints2 = keypoints2
|
9 |
+
|
10 |
+
def get_dtw_path(self):
|
11 |
+
|
12 |
+
norm_kp1 = self.normalize_keypoints(self.keypoints1)
|
13 |
+
norm_kp2 = self.normalize_keypoints(self.keypoints2)
|
14 |
+
|
15 |
+
kp_weight = get_keypoint_weight()
|
16 |
+
oks, oks_unnorm = self.object_keypoint_similarity(norm_kp1,
|
17 |
+
norm_kp2, keypoint_weights=kp_weight)
|
18 |
+
print(f"OKS max {oks.max():.2f} min {oks.min():.2f}")
|
19 |
+
|
20 |
+
# do the DTW, and return the path
|
21 |
+
cost_matrix = 1 - oks
|
22 |
+
dtw_dist, dtw_path = self.dynamic_time_warp(cost_matrix)
|
23 |
+
|
24 |
+
return dtw_path, oks, oks_unnorm
|
25 |
+
|
26 |
+
def normalize_keypoints(self, keypoints):
|
27 |
+
centroid = keypoints.mean(axis=1)[:, None]
|
28 |
+
max_distance = np.max(np.sqrt(np.sum((keypoints - centroid) ** 2, axis=2)),
|
29 |
+
axis=1) + 1e-6
|
30 |
+
|
31 |
+
normalized_keypoints = (keypoints - centroid) / max_distance[:, None, None]
|
32 |
+
return normalized_keypoints
|
33 |
+
|
34 |
+
def keypoints_areas(self, keypoints):
|
35 |
+
min_coords = np.min(keypoints, axis=1)
|
36 |
+
max_coords = np.max(keypoints, axis=1)
|
37 |
+
areas = np.prod(max_coords - min_coords, axis=1)
|
38 |
+
return areas
|
39 |
+
|
40 |
+
def object_keypoint_similarity(self, keypoints1,
|
41 |
+
keypoints2,
|
42 |
+
scale_constant=0.2,
|
43 |
+
keypoint_weights=None):
|
44 |
+
""" Calculate the Object Keypoint Similarity (OKS) for multiple objects,
|
45 |
+
and add weight to each keypoint. Here we choose to normalize the points
|
46 |
+
using centroid and max distance instead of bounding box area.
|
47 |
+
"""
|
48 |
+
# Compute squared distances between all pairs of keypoints
|
49 |
+
sq_diff = np.sum((keypoints1[:, None] - keypoints2) ** 2, axis=-1)
|
50 |
+
|
51 |
+
oks = np.exp(-sq_diff / (2 * scale_constant ** 2))
|
52 |
+
oks_unnorm = oks.copy()
|
53 |
+
|
54 |
+
if keypoint_weights is not None:
|
55 |
+
oks = oks * keypoint_weights
|
56 |
+
oks = np.sum(oks, axis=-1)
|
57 |
+
else:
|
58 |
+
oks = np.mean(oks, axis=-1)
|
59 |
+
|
60 |
+
return oks, oks_unnorm
|
61 |
+
|
62 |
+
def dynamic_time_warp(self, cost_matrix, R=1000):
|
63 |
+
"""Compute the Dynamic Time Warping distance and path between two time series.
|
64 |
+
If the time series is too long, it will use the Sakoe-Chiba Band constraint,
|
65 |
+
so time complexity is bounded at O(MR).
|
66 |
+
"""
|
67 |
+
|
68 |
+
M = len(self.keypoints1)
|
69 |
+
N = len(self.keypoints2)
|
70 |
+
|
71 |
+
# Initialize the distance matrix with infinity
|
72 |
+
D = np.full((M, N), np.inf)
|
73 |
+
|
74 |
+
# Initialize the first row and column of the matrix
|
75 |
+
D[0, 0] = cost_matrix[0, 0]
|
76 |
+
for i in range(1, M):
|
77 |
+
D[i, 0] = D[i - 1, 0] + cost_matrix[i, 0]
|
78 |
+
|
79 |
+
for j in range(1, N):
|
80 |
+
D[0, j] = D[0, j - 1] + cost_matrix[0, j]
|
81 |
+
|
82 |
+
# Fill the remaining elements of the matrix within the
|
83 |
+
# Sakoe-Chiba Band using dynamic programming
|
84 |
+
for i in range(1, M):
|
85 |
+
for j in range(max(1, i - R), min(N, i + R + 1)):
|
86 |
+
cost = cost_matrix[i, j]
|
87 |
+
D[i, j] = cost + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
|
88 |
+
|
89 |
+
# Backtrack to find the optimal path
|
90 |
+
path = [(M - 1, N - 1)]
|
91 |
+
i, j = M - 1, N - 1
|
92 |
+
while i > 0 or j > 0:
|
93 |
+
min_idx = np.argmin([D[i - 1, j], D[i, j - 1], D[i - 1, j - 1]])
|
94 |
+
if min_idx == 0:
|
95 |
+
i -= 1
|
96 |
+
elif min_idx == 1:
|
97 |
+
j -= 1
|
98 |
+
else:
|
99 |
+
i -= 1
|
100 |
+
j -= 1
|
101 |
+
path.append((i, j))
|
102 |
+
path.reverse()
|
103 |
+
|
104 |
+
return D[-1, -1], path
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
|
108 |
+
from mmengine.fileio import load
|
109 |
+
|
110 |
+
keypoints1, kp1_scores = load('tennis1.pkl')
|
111 |
+
keypoints2, kp2_scores = load('tennis3.pkl')
|
112 |
+
|
113 |
+
# Normalize the keypoints
|
114 |
+
dtw = DTWForKeypoints(keypoints1, keypoints2)
|
115 |
+
path = dtw.get_dtw_path()
|
116 |
+
print(path)
|
tools/inferencer.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import mmcv
|
3 |
+
from pathlib import Path
|
4 |
+
from collections import namedtuple
|
5 |
+
import cv2 as cv
|
6 |
+
from tqdm import tqdm
|
7 |
+
from mmengine.registry import init_default_scope
|
8 |
+
from mmengine.visualization import Visualizer
|
9 |
+
from mmpose.apis import inference_topdown, init_model
|
10 |
+
from mmdet.apis import inference_detector, init_detector
|
11 |
+
from .utils import filter_by_catgory, filter_by_score, Timer
|
12 |
+
from .apis import build_onnx_model_and_task_processor, inference_onnx_model
|
13 |
+
|
14 |
+
|
15 |
+
class PoseInferencer:
|
16 |
+
def __init__(self,
|
17 |
+
det_cfg,
|
18 |
+
pose_cfg,
|
19 |
+
device='cpu') -> None:
|
20 |
+
# init
|
21 |
+
self.det_model_cfg = det_cfg.model_cfg
|
22 |
+
self.det_model_ckpt = det_cfg.model_ckpt
|
23 |
+
self.pose_model_cfg = pose_cfg.model_cfg
|
24 |
+
self.pose_model_ckpt = pose_cfg.model_ckpt
|
25 |
+
|
26 |
+
self.detector = init_detector(self.det_model_cfg,
|
27 |
+
self.det_model_ckpt,
|
28 |
+
device=device)
|
29 |
+
self.pose_model = init_model(self.pose_model_cfg,
|
30 |
+
self.pose_model_ckpt,
|
31 |
+
device=device)
|
32 |
+
|
33 |
+
def process_one_image(self, img):
|
34 |
+
init_default_scope('mmdet')
|
35 |
+
det_result = inference_detector(self.detector, img)
|
36 |
+
det_inst = det_result.pred_instances.cpu().numpy()
|
37 |
+
bboxes, scores, labels = (det_inst.bboxes,
|
38 |
+
det_inst.scores,
|
39 |
+
det_inst.labels)
|
40 |
+
bboxes, scores, labels = filter_by_score(bboxes, scores,
|
41 |
+
labels, 0.5)
|
42 |
+
bboxes, scores, labels = filter_by_catgory(bboxes, scores, labels,
|
43 |
+
['person'])
|
44 |
+
# inference with pose model
|
45 |
+
init_default_scope('mmpose')
|
46 |
+
pose_result = inference_topdown(self.pose_model, img, bboxes)
|
47 |
+
if len(pose_result) == 0:
|
48 |
+
# no detection place holder
|
49 |
+
keypoints = np.zeros((1, 17, 2))
|
50 |
+
pts_scores = np.zeros((1, 17))
|
51 |
+
bboxes = np.zeros((1, 4))
|
52 |
+
scores = np.zeros((1, ))
|
53 |
+
labels = np.zeros((1, ))
|
54 |
+
else:
|
55 |
+
keypoints = np.concatenate([r.pred_instances.keypoints
|
56 |
+
for r in pose_result])
|
57 |
+
pts_scores = np.concatenate([r.pred_instances.keypoint_scores
|
58 |
+
for r in pose_result])
|
59 |
+
|
60 |
+
DetInst = namedtuple('DetInst', ['bboxes', 'scores', 'labels'])
|
61 |
+
PoseInst = namedtuple('PoseInst', ['keypoints', 'pts_scores'])
|
62 |
+
return DetInst(bboxes, scores, labels), PoseInst(keypoints, pts_scores)
|
63 |
+
|
64 |
+
def inference_video(self, video_path):
|
65 |
+
""" Inference a video with detector and pose model
|
66 |
+
Return:
|
67 |
+
all_pose: a list of PoseInst, check the namedtuple definition
|
68 |
+
all_det: a list of DetInst
|
69 |
+
"""
|
70 |
+
video_reader = mmcv.VideoReader(video_path)
|
71 |
+
all_pose, all_det = [], []
|
72 |
+
|
73 |
+
for frame in tqdm(video_reader):
|
74 |
+
# inference with detector
|
75 |
+
det, pose = self.process_one_image(frame)
|
76 |
+
all_pose.append(pose)
|
77 |
+
all_det.append(det)
|
78 |
+
|
79 |
+
return all_det, all_pose
|
80 |
+
|
81 |
+
class PoseInferencerV2:
|
82 |
+
""" V2 Use onnx for detection model, still use pytorch for pose model.
|
83 |
+
"""
|
84 |
+
def __init__(self,
|
85 |
+
det_cfg,
|
86 |
+
pose_cfg,
|
87 |
+
device='cpu') -> None:
|
88 |
+
# init
|
89 |
+
self.det_deploy_cfg = det_cfg.deploy_cfg
|
90 |
+
self.det_model_cfg = det_cfg.model_cfg
|
91 |
+
self.det_backend_files = det_cfg.backend_files
|
92 |
+
|
93 |
+
self.pose_model_cfg = pose_cfg.model_cfg
|
94 |
+
self.pose_model_ckpt = pose_cfg.model_ckpt
|
95 |
+
|
96 |
+
self.detector, self.task_processor = \
|
97 |
+
build_onnx_model_and_task_processor(self.det_model_cfg,
|
98 |
+
self.det_deploy_cfg,
|
99 |
+
self.det_backend_files,
|
100 |
+
device)
|
101 |
+
self.pose_model = init_model(self.pose_model_cfg,
|
102 |
+
self.pose_model_ckpt,
|
103 |
+
device)
|
104 |
+
|
105 |
+
def process_one_image(self, img):
|
106 |
+
init_default_scope('mmdet')
|
107 |
+
det_result = inference_onnx_model(self.detector,
|
108 |
+
self.task_processor,
|
109 |
+
self.det_deploy_cfg,
|
110 |
+
img)
|
111 |
+
det_inst = det_result[0].pred_instances.cpu().numpy()
|
112 |
+
bboxes, scores, labels = (det_inst.bboxes,
|
113 |
+
det_inst.scores,
|
114 |
+
det_inst.labels)
|
115 |
+
bboxes, scores, labels = filter_by_score(bboxes, scores,
|
116 |
+
labels, 0.5)
|
117 |
+
bboxes, scores, labels = filter_by_catgory(bboxes, scores, labels,
|
118 |
+
['person'])
|
119 |
+
# inference with pose model
|
120 |
+
init_default_scope('mmpose')
|
121 |
+
pose_result = inference_topdown(self.pose_model, img, bboxes)
|
122 |
+
if len(pose_result) == 0:
|
123 |
+
# no detection place holder
|
124 |
+
keypoints = np.zeros((1, 17, 2))
|
125 |
+
pts_scores = np.zeros((1, 17))
|
126 |
+
bboxes = np.zeros((1, 4))
|
127 |
+
scores = np.zeros((1, ))
|
128 |
+
labels = np.zeros((1, ))
|
129 |
+
else:
|
130 |
+
keypoints = np.concatenate([r.pred_instances.keypoints
|
131 |
+
for r in pose_result])
|
132 |
+
pts_scores = np.concatenate([r.pred_instances.keypoint_scores
|
133 |
+
for r in pose_result])
|
134 |
+
|
135 |
+
DetInst = namedtuple('DetInst', ['bboxes', 'scores', 'labels'])
|
136 |
+
PoseInst = namedtuple('PoseInst', ['keypoints', 'pts_scores'])
|
137 |
+
return DetInst(bboxes, scores, labels), PoseInst(keypoints, pts_scores)
|
138 |
+
|
139 |
+
def inference_video(self, video_path):
|
140 |
+
""" Inference a video with detector and pose model
|
141 |
+
Return:
|
142 |
+
all_pose: a list of PoseInst, check the namedtuple definition
|
143 |
+
all_det: a list of DetInst
|
144 |
+
"""
|
145 |
+
video_reader = mmcv.VideoReader(video_path)
|
146 |
+
all_pose, all_det = [], []
|
147 |
+
|
148 |
+
for frame in tqdm(video_reader):
|
149 |
+
# inference with detector
|
150 |
+
det, pose = self.process_one_image(frame)
|
151 |
+
all_pose.append(pose)
|
152 |
+
all_det.append(det)
|
153 |
+
|
154 |
+
return all_det, all_pose
|
tools/manager.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mim
|
2 |
+
from pathlib import Path
|
3 |
+
from mim.utils import get_installed_path, echo_success
|
4 |
+
from mmengine.config import Config
|
5 |
+
|
6 |
+
class Manager:
|
7 |
+
|
8 |
+
def __init__(self, path=None) -> None:
|
9 |
+
"""
|
10 |
+
Params:
|
11 |
+
- path: root path of projects to save checkpoints and configs
|
12 |
+
"""
|
13 |
+
if path:
|
14 |
+
self.path = Path(path)
|
15 |
+
else:
|
16 |
+
self.path = Path(__file__).parents[1]
|
17 |
+
self.keys = ['weight', 'config', 'model', 'training_data']
|
18 |
+
|
19 |
+
def get_model_infos(self, package_name, keyword: str=None):
|
20 |
+
""" because mim search is too strict,
|
21 |
+
I want to search by keyword, not a strict match
|
22 |
+
"""
|
23 |
+
model_infos = mim.get_model_info(package_name)
|
24 |
+
model_names = model_infos.index
|
25 |
+
info_keys = model_infos.columns.tolist()
|
26 |
+
keys = self.intersect_keys(info_keys,
|
27 |
+
self.keys)
|
28 |
+
if keyword is None:
|
29 |
+
return model_infos[:, keys]
|
30 |
+
# get valid names, which contains the keyword
|
31 |
+
valid_names = [name for name in model_names
|
32 |
+
if keyword in name]
|
33 |
+
filter_infos = model_infos.loc[valid_names, keys]
|
34 |
+
return filter_infos
|
35 |
+
|
36 |
+
def intersect_keys(self, keys1 , keys2):
|
37 |
+
return list(set(keys1) & set(keys2))
|
38 |
+
|
39 |
+
def download(self, package, model, config_only=False):
|
40 |
+
""" Use model names to download checkpoints and configs.
|
41 |
+
Args:
|
42 |
+
- package: package name, e.g. mmdet
|
43 |
+
- model: model name, e.g. faster_rcnn or faster_rcnn_r50_fpn_1x_coco
|
44 |
+
- config_only: only download configs, which is helpful when you
|
45 |
+
already download checkpoints fast through other ways.
|
46 |
+
"""
|
47 |
+
infos = self.get_model_infos(package, model)
|
48 |
+
|
49 |
+
for model, info in infos.iterrows():
|
50 |
+
# get destination path
|
51 |
+
hyper_name = info['model']
|
52 |
+
dst_path = self.path / 'model_zoo' / hyper_name / model
|
53 |
+
dst_path.mkdir(parents=True, exist_ok=True)
|
54 |
+
|
55 |
+
if config_only:
|
56 |
+
# get config path of the package
|
57 |
+
installed_path = Path(get_installed_path(package))
|
58 |
+
config_path = info['config']
|
59 |
+
config_path = installed_path / '.mim' / config_path
|
60 |
+
# build and dump config
|
61 |
+
config_obj = Config.fromfile(config_path)
|
62 |
+
saved_config_path = dst_path / f'{model}.py'
|
63 |
+
config_obj.dump(saved_config_path)
|
64 |
+
echo_success(
|
65 |
+
f'Successfully dumped {model}.py to {dst_path}')
|
66 |
+
else:
|
67 |
+
mim.download(package, [model], dest_root=dst_path)
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
m = Manager()
|
71 |
+
print(m.get_model_infos('mmdet', 'det'))
|
72 |
+
# m.download('mmpose', 'rtmpose-t_8xb256-420e_aic-coco-256x192', config_only=True)
|
tools/utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmdet.datasets import CocoDataset
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from ffmpy import FFmpeg
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
from easydict import EasyDict
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def coco_keypoint_id_table(reverse=False):
|
11 |
+
id2name = { 0: 'nose',
|
12 |
+
1: 'left_eye',
|
13 |
+
2: 'right_eye',
|
14 |
+
3: 'left_ear',
|
15 |
+
4: 'right_ear',
|
16 |
+
5: 'left_shoulder',
|
17 |
+
6: 'right_shoulder',
|
18 |
+
7: 'left_elbow',
|
19 |
+
8: 'right_elbow',
|
20 |
+
9: 'left_wrist',
|
21 |
+
10: 'right_wrist',
|
22 |
+
11: 'left_hip',
|
23 |
+
12: 'right_hip',
|
24 |
+
13: 'left_knee',
|
25 |
+
14: 'right_knee',
|
26 |
+
15: 'left_ankle',
|
27 |
+
16: 'right_ankle'}
|
28 |
+
if reverse:
|
29 |
+
return {v: k for k, v in id2name.items()}
|
30 |
+
return id2name
|
31 |
+
|
32 |
+
def get_skeleton():
|
33 |
+
""" My skeleton links, I deleted some links from default coco style.
|
34 |
+
"""
|
35 |
+
SKELETON = EasyDict()
|
36 |
+
SKELETON.head = [[0,1], [0,2], [1,3], [2,4]]
|
37 |
+
SKELETON.left_arm = [[5, 7], [7, 9]]
|
38 |
+
SKELETON.right_arm = [[6, 8], [8, 10]]
|
39 |
+
SKELETON.left_leg = [[11, 13], [13, 15]]
|
40 |
+
SKELETON.right_leg = [[12, 14], [14, 16]]
|
41 |
+
SKELETON.body = [[5, 6], [5, 11], [6, 12], [11, 12]]
|
42 |
+
return SKELETON
|
43 |
+
|
44 |
+
def get_keypoint_weight(low_weight_ratio=0.1, mid_weight_ratio=0.5):
|
45 |
+
""" Get keypoint weight, used in object keypoint similarity,
|
46 |
+
`low_weight_names` are points I want to pay less attention.
|
47 |
+
"""
|
48 |
+
low_weight_names = ['nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear']
|
49 |
+
mid_weight_names = ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip']
|
50 |
+
|
51 |
+
logtis = np.ones(17)
|
52 |
+
name2id = coco_keypoint_id_table(reverse=True)
|
53 |
+
|
54 |
+
low_weight_id = [name2id[n] for n in low_weight_names]
|
55 |
+
mid_weight_id = [name2id[n] for n in mid_weight_names]
|
56 |
+
logtis[low_weight_id] = low_weight_ratio
|
57 |
+
logtis[mid_weight_id] = mid_weight_ratio
|
58 |
+
|
59 |
+
weights = logtis / np.sum(logtis)
|
60 |
+
return weights
|
61 |
+
|
62 |
+
def coco_cat_id_table():
|
63 |
+
classes = CocoDataset.METAINFO['classes']
|
64 |
+
id2name = {i: name for i, name in enumerate(classes)}
|
65 |
+
|
66 |
+
return id2name
|
67 |
+
|
68 |
+
def filter_by_catgory(bboxes, scores, labels, names):
|
69 |
+
""" Filter labels by classes
|
70 |
+
Args:
|
71 |
+
- labels: list of labels, each label is a dict
|
72 |
+
- classes: list of class names
|
73 |
+
"""
|
74 |
+
id2name = coco_cat_id_table()
|
75 |
+
# names of labels
|
76 |
+
label_names = [id2name[id] for id in labels]
|
77 |
+
# filter by class names
|
78 |
+
mask = np.isin(label_names, names)
|
79 |
+
return bboxes[mask], scores[mask], labels[mask]
|
80 |
+
|
81 |
+
def filter_by_score(bboxes, scores, labels, score_thr):
|
82 |
+
""" Filter bboxes by score threshold
|
83 |
+
Args:
|
84 |
+
- bboxes: list of bboxes, each bbox is a dict
|
85 |
+
- score_thr: score threshold
|
86 |
+
"""
|
87 |
+
mask = scores > score_thr
|
88 |
+
return bboxes[mask], scores[mask], labels[mask]
|
89 |
+
|
90 |
+
def convert_video_to_playable_mp4(video_path: str) -> str:
|
91 |
+
""" Copied from gradio
|
92 |
+
Convert the video to mp4. If something goes wrong return the original video.
|
93 |
+
"""
|
94 |
+
try:
|
95 |
+
output_path = Path(video_path).with_suffix(".mp4")
|
96 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
97 |
+
shutil.copy2(video_path, tmp_file.name)
|
98 |
+
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
|
99 |
+
ff = FFmpeg(
|
100 |
+
inputs={str(tmp_file.name): None},
|
101 |
+
outputs={str(output_path): None},
|
102 |
+
global_options="-y -loglevel quiet",
|
103 |
+
)
|
104 |
+
ff.run()
|
105 |
+
except:
|
106 |
+
print(f"Error converting video to browser-playable format {str(e)}")
|
107 |
+
output_path = video_path
|
108 |
+
return str(output_path)
|
109 |
+
|
110 |
+
class Timer:
|
111 |
+
def __init__(self):
|
112 |
+
self.start_time = time.time()
|
113 |
+
|
114 |
+
def click(self):
|
115 |
+
used_time = time.time() - self.start_time
|
116 |
+
self.start_time = time.time()
|
117 |
+
return used_time
|
118 |
+
|
119 |
+
def start(self):
|
120 |
+
self.start_time = time.time()
|
tools/visualizer.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from skimage import draw, io
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
from easydict import EasyDict
|
6 |
+
from typing import Union
|
7 |
+
from .utils import get_skeleton, Timer
|
8 |
+
|
9 |
+
class FastVisualizer:
|
10 |
+
""" Use skimage to draw, which is much faster than matplotlib, and
|
11 |
+
more beatiful than opencv.😎
|
12 |
+
"""
|
13 |
+
# TODO: modify color input parameter
|
14 |
+
def __init__(self, image=None) -> None:
|
15 |
+
self.set_image(image)
|
16 |
+
self.colors = self.get_pallete()
|
17 |
+
self.skeleton = get_skeleton()
|
18 |
+
self.lvl_tresh = self.set_level([0.3, 0.6, 0.8])
|
19 |
+
|
20 |
+
def set_image(self, image: Union[str, np.ndarray]):
|
21 |
+
if isinstance(image, str):
|
22 |
+
self.image = cv2.imread(image)
|
23 |
+
elif isinstance(image, np.ndarray) or image is None:
|
24 |
+
self.image = image
|
25 |
+
else:
|
26 |
+
raise TypeError(f"Type {type(image)} is not supported")
|
27 |
+
|
28 |
+
def get_image(self):
|
29 |
+
return self.image
|
30 |
+
|
31 |
+
def draw_box(self, box_coord, color=(25, 113, 194), alpha=1.0):
|
32 |
+
""" Draw a box on the image
|
33 |
+
Args:
|
34 |
+
box_coord: a list of [xmin, ymin, xmax, ymax]
|
35 |
+
alpha: the alpha of the box
|
36 |
+
color: the edge color of the box
|
37 |
+
"""
|
38 |
+
xmin, ymin, xmax, ymax = box_coord
|
39 |
+
rr, cc = draw.rectangle_perimeter((ymin, xmin), (ymax, xmax))
|
40 |
+
draw.set_color(self.image, (rr, cc), color, alpha=alpha)
|
41 |
+
return self
|
42 |
+
|
43 |
+
def draw_rectangle(self, box_coord, color=(25, 113, 194), alpha=1.0):
|
44 |
+
xmin, ymin, xmax, ymax = box_coord
|
45 |
+
rr, cc = draw.rectangle((ymin, xmin), (ymax, xmax))
|
46 |
+
draw.set_color(self.image, (rr, cc), color, alpha=alpha)
|
47 |
+
return self
|
48 |
+
|
49 |
+
def draw_point(self, point_coord, radius=5, color=(25, 113, 194), alpha=1.0):
|
50 |
+
""" Coord in (x, y) format, but will be converted to (y, x)
|
51 |
+
"""
|
52 |
+
x, y = point_coord
|
53 |
+
rr, cc = draw.disk((y, x), radius=radius)
|
54 |
+
draw.set_color(self.image, (rr, cc), color, alpha=alpha)
|
55 |
+
return self
|
56 |
+
|
57 |
+
def draw_line(self, start_point, end_point, color=(25, 113, 194), alpha=1.0):
|
58 |
+
""" Not used, because I can't produce smooth line.
|
59 |
+
"""
|
60 |
+
cv2.line(self.image, start_point, end_point, color.tolist(), 2,
|
61 |
+
cv2.LINE_AA)
|
62 |
+
return self
|
63 |
+
|
64 |
+
def draw_line_aa(self, start_point, end_point, color=(25, 113, 194), alpha=1.0):
|
65 |
+
""" Not used, because I can't produce smooth line.
|
66 |
+
"""
|
67 |
+
x1, y1 = start_point
|
68 |
+
x2, y2 = end_point
|
69 |
+
rr, cc, val = draw.line_aa(y1, x1, y2, x2)
|
70 |
+
draw.set_color(self.image, (rr, cc), color, alpha=alpha)
|
71 |
+
return self
|
72 |
+
|
73 |
+
def draw_thick_line(self, start_point, end_point, thickness=1, color=(25, 113, 194), alpha=1.0):
|
74 |
+
""" Not used, because I can't produce smooth line.
|
75 |
+
"""
|
76 |
+
x1, y1 = start_point
|
77 |
+
x2, y2 = end_point
|
78 |
+
dx, dy = x2 - x1, y2 - y1
|
79 |
+
length = np.sqrt(dx * dx + dy * dy)
|
80 |
+
cos, sin = dx / length, dy / length
|
81 |
+
|
82 |
+
half_t = thickness / 2.0
|
83 |
+
# Calculate the polygon vertices
|
84 |
+
vertices_x = [x1 - half_t * sin, x1 + half_t * sin,
|
85 |
+
x2 + half_t * sin, x2 - half_t * sin]
|
86 |
+
vertices_y = [y1 + half_t * cos, y1 - half_t * cos,
|
87 |
+
y2 - half_t * cos, y2 + half_t * cos]
|
88 |
+
rr, cc = draw.polygon(vertices_y, vertices_x)
|
89 |
+
draw.set_color(self.image, (rr, cc), color, alpha)
|
90 |
+
|
91 |
+
return self
|
92 |
+
|
93 |
+
def draw_text(self, text, position,
|
94 |
+
font_path='assets/SmileySans/SmileySans-Oblique.ttf',
|
95 |
+
font_size=20,
|
96 |
+
text_color=(255, 255, 255)):
|
97 |
+
""" Position is the left top corner of the text
|
98 |
+
"""
|
99 |
+
# Convert the NumPy array to a PIL image
|
100 |
+
pil_image = Image.fromarray(np.uint8(self.image))
|
101 |
+
# Load the font (default is Arial)
|
102 |
+
font = ImageFont.truetype(font_path, font_size)
|
103 |
+
# Create a drawing object
|
104 |
+
draw = ImageDraw.Draw(pil_image)
|
105 |
+
# Add the text to the image
|
106 |
+
draw.text(position, text, font=font, fill=text_color)
|
107 |
+
# Convert the PIL image back to a NumPy array
|
108 |
+
result = np.array(pil_image)
|
109 |
+
|
110 |
+
self.image = result
|
111 |
+
return self
|
112 |
+
|
113 |
+
def xyhw_to_xyxy(self, box):
|
114 |
+
hw = box[2:]
|
115 |
+
x1y1 = box[:2] - hw / 2
|
116 |
+
x2y2 = box[:2] + hw / 2
|
117 |
+
return np.concatenate([x1y1, x2y2]).astype(np.int32)
|
118 |
+
|
119 |
+
def draw_line_in_discrete_style(self, start_point, end_point, size=2, sample_points=3,
|
120 |
+
color=(25, 113, 194), alpha=1.0):
|
121 |
+
""" When drawing continous line, it is super fuzzy, and I can't handle them
|
122 |
+
very well even tried OpneCV & PIL all kinds of ways. This is a workaround.
|
123 |
+
The discrete line will be represented with few sampled cubes along the line,
|
124 |
+
and it is exclusive with start & end points.
|
125 |
+
"""
|
126 |
+
# sample points
|
127 |
+
points = np.linspace(start_point, end_point, sample_points + 2)[1:-1]
|
128 |
+
for p in points:
|
129 |
+
rectangle_xyhw = np.array((p[0], p[1], size, size))
|
130 |
+
rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
|
131 |
+
self.draw_rectangle(rectangle_xyxy, color, alpha)
|
132 |
+
return self
|
133 |
+
|
134 |
+
def draw_human_keypoints(self, keypoints, scores=None, factor=20, draw_skeleton=False):
|
135 |
+
""" Draw skeleton on the image, and give different color according
|
136 |
+
to similarity scores.
|
137 |
+
"""
|
138 |
+
# get max length of skeleton
|
139 |
+
max_x, max_y = np.max(keypoints, axis=0)
|
140 |
+
min_x, min_y = np.min(keypoints, axis=0)
|
141 |
+
max_length = max(max_x - min_x, max_y - min_y)
|
142 |
+
if max_length < 1: return self
|
143 |
+
cube_size = max_length // factor
|
144 |
+
line_cube_size = cube_size // 2
|
145 |
+
# draw skeleton in discrete style
|
146 |
+
if draw_skeleton:
|
147 |
+
for key, links in self.skeleton.items():
|
148 |
+
links = np.array(links)
|
149 |
+
start_points = keypoints[links[:, 0]]
|
150 |
+
end_points = keypoints[links[:, 1]]
|
151 |
+
for s, e in zip(start_points, end_points):
|
152 |
+
self.draw_line_in_discrete_style(s, e, line_cube_size,
|
153 |
+
color=self.colors[key], alpha=0.9)
|
154 |
+
# draw points
|
155 |
+
if scores is None: # use vamos color
|
156 |
+
lvl_names = ['vamos'] * len(keypoints)
|
157 |
+
else: lvl_names = self.score_level_names(scores)
|
158 |
+
|
159 |
+
for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
|
160 |
+
if idx in set((1, 2, 3, 4)):
|
161 |
+
continue # do not draw eyes and years
|
162 |
+
rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
|
163 |
+
rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
|
164 |
+
self.draw_rectangle(rectangle_xyxy,
|
165 |
+
color=self.colors[lvl_name],
|
166 |
+
alpha=0.8)
|
167 |
+
return self
|
168 |
+
|
169 |
+
def draw_score_bar(self, score, factor=50, bar_ratio=7):
|
170 |
+
""" Draw a score bar on the left top of the image.
|
171 |
+
factor: the value of image longer edge divided by the bar height
|
172 |
+
bar_ratio: the ratio of bar width to bar height
|
173 |
+
"""
|
174 |
+
# calculate bar's height and width
|
175 |
+
long_edge = np.max(self.image.shape[:2])
|
176 |
+
short_edge = np.min(self.image.shape[:2])
|
177 |
+
bar_h = long_edge // factor
|
178 |
+
bar_w = bar_h * bar_ratio
|
179 |
+
if bar_w * 3 > short_edge:
|
180 |
+
# when the image width is not enough
|
181 |
+
bar_w = short_edge // 4
|
182 |
+
bar_h = bar_w // bar_ratio
|
183 |
+
cube_size = bar_h
|
184 |
+
# bar's base position
|
185 |
+
bar_start_point = (2*bar_h, 2*bar_h)
|
186 |
+
# draw bar horizontally, and record the position of each word
|
187 |
+
word_positions = []
|
188 |
+
box_coords = []
|
189 |
+
colors = [self.colors.bad, self.colors.good, self.colors.vamos]
|
190 |
+
for i, color in enumerate(colors):
|
191 |
+
x0, y0 = bar_start_point[0] + i*bar_w, bar_start_point[1]
|
192 |
+
x1, y1 = x0 + bar_w - 1, y0 + bar_h
|
193 |
+
box_coord = np.array((x0, y0, x1, y1), dtype=np.int32)
|
194 |
+
self.draw_rectangle(box_coord, color=color)
|
195 |
+
|
196 |
+
box_coords.append(box_coord)
|
197 |
+
word_positions.append(np.array((x0, y1 + bar_h // 2)))
|
198 |
+
# calculate cube position according to score
|
199 |
+
lvl, lvl_ratio, lvl_name = self.score_level(score)
|
200 |
+
# the first level start point is the first bar
|
201 |
+
cube_lvl_start_x0 = [box_coord[0] - cube_size // 2 if i != 0
|
202 |
+
else box_coord[0]
|
203 |
+
for i, box_coord in enumerate(box_coords)]
|
204 |
+
# process the last level, I want the cube stays in the bar
|
205 |
+
level_length = bar_w if lvl == 1 else bar_w - cube_size // 2
|
206 |
+
cube_x0 = cube_lvl_start_x0[lvl] + lvl_ratio * level_length
|
207 |
+
cube_y0 = bar_start_point[1] - bar_h // 2 - cube_size
|
208 |
+
cube_x1 = cube_x0 + cube_size
|
209 |
+
cube_y1 = cube_y0 + cube_size
|
210 |
+
# draw cube
|
211 |
+
self.draw_rectangle((cube_x0, cube_y0, cube_x1, cube_y1),
|
212 |
+
color=self.colors.cube)
|
213 |
+
# enlarge the box, to emphasize the level
|
214 |
+
enlarged_box = box_coords[lvl].copy()
|
215 |
+
enlarged_box[:2] = enlarged_box[:2] - bar_h // 8
|
216 |
+
enlarged_box[2:] = enlarged_box[2:] + bar_h // 8
|
217 |
+
self.draw_rectangle(enlarged_box, color=self.colors[lvl_name])
|
218 |
+
|
219 |
+
# draw text
|
220 |
+
if lvl_name == 'vamos':
|
221 |
+
lvl_name = 'vamos!!' # exciting!
|
222 |
+
self.draw_text(lvl_name.capitalize(),
|
223 |
+
word_positions[lvl],
|
224 |
+
font_size=bar_h * 2,
|
225 |
+
text_color=tuple(colors[lvl].tolist()))
|
226 |
+
|
227 |
+
return self
|
228 |
+
|
229 |
+
def draw_non_transparent_area(self, box_coord, alpha=0.2, extend_ratio=0.1):
|
230 |
+
""" Make image outside the box transparent using alpha blend
|
231 |
+
"""
|
232 |
+
x1, y1, x2, y2 = box_coord.astype(np.int32)
|
233 |
+
# enlarge the box for 10%
|
234 |
+
max_len = max((x2 - x1), (y2 - y1))
|
235 |
+
extend_len = int(max_len * extend_ratio)
|
236 |
+
x1, y1 = x1 - extend_len, y1 - extend_len
|
237 |
+
x2, y2 = x2 + extend_len, y2 + extend_len
|
238 |
+
# clip the box
|
239 |
+
h, w = self.image.shape[:2]
|
240 |
+
x1, y1, x2, y2 = np.clip((x1,y1,x2,y2), a_min=0,
|
241 |
+
a_max=(w,h,w,h))
|
242 |
+
# Create a white background color
|
243 |
+
bg_color = np.ones_like(self.image) * 255
|
244 |
+
# Copy the box region from the image
|
245 |
+
bg_color[y1:y2, x1:x2] = self.image[y1:y2, x1:x2]
|
246 |
+
# Alpha blend inplace
|
247 |
+
self.image[:] = self.image * alpha + bg_color * (1 - alpha)
|
248 |
+
return self
|
249 |
+
|
250 |
+
def draw_logo(self, logo='assets/logo.png', factor=30, shift=20):
|
251 |
+
""" Draw logo on the right bottom of the image.
|
252 |
+
"""
|
253 |
+
H, W = self.image.shape[:2]
|
254 |
+
# load logo
|
255 |
+
logo_img = Image.open(logo)
|
256 |
+
# scale logo
|
257 |
+
logo_h = self.image.shape[0] // factor
|
258 |
+
scale_size = logo_h / logo_img.size[1]
|
259 |
+
logo_w = int(logo_img.size[0] * scale_size)
|
260 |
+
logo_img = logo_img.resize((logo_w, logo_h))
|
261 |
+
# convert to RGBA
|
262 |
+
image = Image.fromarray(self.image).convert("RGBA")
|
263 |
+
# alpha blend
|
264 |
+
image.alpha_composite(logo_img, (W - logo_w - shift,
|
265 |
+
H - logo_h - shift))
|
266 |
+
self.image = np.array(image.convert("RGB"))
|
267 |
+
return self
|
268 |
+
|
269 |
+
def score_level(self, score):
|
270 |
+
""" Return the level according to level thresh.
|
271 |
+
"""
|
272 |
+
t = self.lvl_tresh
|
273 |
+
if score < t[1]: # t[0] might bigger than 0
|
274 |
+
ratio = (score - t[0]) / (t[1] - t[0])
|
275 |
+
ratio = np.clip(ratio, a_min=0, a_max=1)
|
276 |
+
return 0, ratio, 'bad'
|
277 |
+
elif score < t[2]:
|
278 |
+
ratio = (score - t[1]) / (t[2] - t[1])
|
279 |
+
return 1, ratio, 'good'
|
280 |
+
else:
|
281 |
+
ratio = (score - t[2]) / (1 - t[2])
|
282 |
+
return 2, ratio, 'vamos'
|
283 |
+
|
284 |
+
def score_level_names(self, scores):
|
285 |
+
""" Get multiple score level, return numpy array.
|
286 |
+
np.vectorize does not speed up loop, but it is convenient.
|
287 |
+
"""
|
288 |
+
t = self.lvl_tresh
|
289 |
+
func_lvl_name = lambda x: 'bad' if x < t[1] else 'good' \
|
290 |
+
if x < t[2] else 'vamos'
|
291 |
+
lvl_names = np.vectorize(func_lvl_name)(scores)
|
292 |
+
return lvl_names
|
293 |
+
|
294 |
+
def set_level(self, thresh):
|
295 |
+
""" Set level thresh for bad, good, vamos.
|
296 |
+
"""
|
297 |
+
from collections import namedtuple
|
298 |
+
Level = namedtuple('Level', ['zero', 'good', 'vamos'])
|
299 |
+
return Level(thresh[0], thresh[1], thresh[2])
|
300 |
+
|
301 |
+
def get_pallete(self):
|
302 |
+
PALLETE = EasyDict()
|
303 |
+
|
304 |
+
# light set
|
305 |
+
# PALLETE.bad = np.array([253, 138, 138])
|
306 |
+
# PALLETE.good = np.array([168, 209, 209])
|
307 |
+
# PALLETE.vamos = np.array([241, 247, 181])
|
308 |
+
# PALLETE.cube = np.array([158, 161, 212])
|
309 |
+
|
310 |
+
# dark set, set 80% brightness
|
311 |
+
PALLETE.bad = np.array([204, 111, 111])
|
312 |
+
PALLETE.good = np.array([143, 179, 179])
|
313 |
+
PALLETE.vamos = np.array([196, 204, 124])
|
314 |
+
PALLETE.vamos = np.array([109, 169, 228])
|
315 |
+
PALLETE.cube = np.array([152, 155, 204])
|
316 |
+
|
317 |
+
PALLETE.left_arm = np.array([218, 119, 242])
|
318 |
+
PALLETE.right_arm = np.array([151, 117, 250])
|
319 |
+
PALLETE.left_leg = np.array([255, 212, 59])
|
320 |
+
PALLETE.right_leg = np.array([255, 169, 77])
|
321 |
+
|
322 |
+
PALLETE.head = np.array([134, 142, 150])
|
323 |
+
PALLETE.body = np.array([134, 142, 150])
|
324 |
+
|
325 |
+
# convert rgb to bgr
|
326 |
+
for k, v in PALLETE.items():
|
327 |
+
PALLETE[k] = v[::-1]
|
328 |
+
return PALLETE
|
329 |
+
|
330 |
+
if __name__ == '__main__':
|
331 |
+
vis = FastVisualizer()
|
332 |
+
|
333 |
+
image = '/github/Tennis.ai/assets/tempt_test.png'
|
334 |
+
vis.set_image(image)
|
335 |
+
np.random.seed(0)
|
336 |
+
keypoints = np.random.randint(300, 600, (17, 2))
|
337 |
+
from utils import Timer
|
338 |
+
t= Timer()
|
339 |
+
t.start()
|
340 |
+
vis.draw_score_bar(0.94)
|
341 |
+
# vis.draw_skeleton(keypoints)
|
342 |
+
# vis.draw_non_transparent_area((0, 0, 100, 100), alpha=0.2)
|
343 |
+
vis.draw_logo()
|
344 |
+
cv2.imshow('test', vis.image)
|
345 |
+
cv2.waitKey(0)
|
346 |
+
cv2.destroyAllWindows()
|