Spaces:
Running
Running
Commit
·
e3641b1
1
Parent(s):
3964794
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +8 -0
- .gitignore +5 -0
- Dockerfile +19 -0
- app.py +117 -0
- config.py +373 -0
- easy_ViTPose/__init__.py +5 -0
- easy_ViTPose/config.yaml +14 -0
- easy_ViTPose/configs/ViTPose_aic.py +20 -0
- easy_ViTPose/configs/ViTPose_ap10k.py +22 -0
- easy_ViTPose/configs/ViTPose_apt36k.py +22 -0
- easy_ViTPose/configs/ViTPose_coco.py +18 -0
- easy_ViTPose/configs/ViTPose_coco_25.py +20 -0
- easy_ViTPose/configs/ViTPose_common.py +195 -0
- easy_ViTPose/configs/ViTPose_mpii.py +18 -0
- easy_ViTPose/configs/ViTPose_wholebody.py +20 -0
- easy_ViTPose/configs/__init__.py +0 -0
- easy_ViTPose/datasets/COCO.py +556 -0
- easy_ViTPose/datasets/HumanPoseEstimation.py +17 -0
- easy_ViTPose/datasets/__init__.py +0 -0
- easy_ViTPose/easy_ViTPose.egg-info/PKG-INFO +4 -0
- easy_ViTPose/easy_ViTPose.egg-info/SOURCES.txt +35 -0
- easy_ViTPose/easy_ViTPose.egg-info/dependency_links.txt +1 -0
- easy_ViTPose/easy_ViTPose.egg-info/top_level.txt +2 -0
- easy_ViTPose/inference.py +334 -0
- easy_ViTPose/sort.py +266 -0
- easy_ViTPose/to_onnx.ipynb +0 -0
- easy_ViTPose/to_trt.ipynb +0 -0
- easy_ViTPose/train.py +174 -0
- easy_ViTPose/vit_models/__init__.py +8 -0
- easy_ViTPose/vit_models/backbone/__init__.py +0 -0
- easy_ViTPose/vit_models/backbone/vit.py +394 -0
- easy_ViTPose/vit_models/head/__init__.py +0 -0
- easy_ViTPose/vit_models/head/topdown_heatmap_base_head.py +120 -0
- easy_ViTPose/vit_models/head/topdown_heatmap_simple_head.py +334 -0
- easy_ViTPose/vit_models/losses/__init__.py +16 -0
- easy_ViTPose/vit_models/losses/classfication_loss.py +41 -0
- easy_ViTPose/vit_models/losses/heatmap_loss.py +83 -0
- easy_ViTPose/vit_models/losses/mesh_loss.py +402 -0
- easy_ViTPose/vit_models/losses/mse_loss.py +151 -0
- easy_ViTPose/vit_models/losses/multi_loss_factory.py +279 -0
- easy_ViTPose/vit_models/losses/regression_loss.py +444 -0
- easy_ViTPose/vit_models/model.py +24 -0
- easy_ViTPose/vit_models/optimizer.py +15 -0
- easy_ViTPose/vit_utils/__init__.py +6 -0
- easy_ViTPose/vit_utils/dist_util.py +212 -0
- easy_ViTPose/vit_utils/inference.py +93 -0
- easy_ViTPose/vit_utils/logging.py +133 -0
- easy_ViTPose/vit_utils/nms/__init__.py +0 -0
- easy_ViTPose/vit_utils/nms/cpu_nms.c +0 -0
- easy_ViTPose/vit_utils/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so +0 -0
.dockerignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
pose_env_1
|
3 |
+
testing
|
4 |
+
vit_env
|
5 |
+
vit_test
|
6 |
+
test_vit_model.ipynb
|
7 |
+
models
|
8 |
+
models_2
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
pose_env_1
|
3 |
+
testing
|
4 |
+
vit_env
|
5 |
+
test_vit_model.ipynb
|
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
8 |
+
RUN pip install --upgrade pip
|
9 |
+
|
10 |
+
# --no-cache-dir
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
|
13 |
+
COPY . .
|
14 |
+
|
15 |
+
EXPOSE 7860
|
16 |
+
|
17 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
18 |
+
ENV USE_NNPACK=0
|
19 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from main_func import video_identity
|
3 |
+
|
4 |
+
with gr.Blocks() as demo:
|
5 |
+
|
6 |
+
with gr.Row(variant='compact'):
|
7 |
+
|
8 |
+
with gr.Column():
|
9 |
+
gr.Markdown("#### Dynamic Time Warping:")
|
10 |
+
|
11 |
+
with gr.Row(variant='compact'):
|
12 |
+
dtw_mean = gr.Slider(
|
13 |
+
value=0.5,
|
14 |
+
minimum=0,
|
15 |
+
maximum=1.0,
|
16 |
+
step=0.05,
|
17 |
+
label="Winsorize Mean"
|
18 |
+
)
|
19 |
+
|
20 |
+
dtw_filter = gr.Slider(
|
21 |
+
value=3,
|
22 |
+
minimum=1,
|
23 |
+
maximum=20,
|
24 |
+
step=1,
|
25 |
+
label="Savitzky-Golay Filter"
|
26 |
+
)
|
27 |
+
|
28 |
+
gr.Markdown("#### Thresholds:")
|
29 |
+
|
30 |
+
with gr.Row(variant='compact'):
|
31 |
+
angles_sensitive = gr.Number(
|
32 |
+
value=15,
|
33 |
+
minimum=0,
|
34 |
+
maximum=75,
|
35 |
+
step=1,
|
36 |
+
min_width=100,
|
37 |
+
label="Sensitive"
|
38 |
+
)
|
39 |
+
|
40 |
+
angles_common = gr.Number(
|
41 |
+
value=25,
|
42 |
+
minimum=0,
|
43 |
+
maximum=75,
|
44 |
+
step=1,
|
45 |
+
min_width=100,
|
46 |
+
label="Standart"
|
47 |
+
)
|
48 |
+
|
49 |
+
angles_insensitive = gr.Number(
|
50 |
+
value=45,
|
51 |
+
minimum=0,
|
52 |
+
maximum=75,
|
53 |
+
step=1,
|
54 |
+
min_width=100,
|
55 |
+
label="Insensitive"
|
56 |
+
)
|
57 |
+
|
58 |
+
gr.Markdown("#### Patience:")
|
59 |
+
|
60 |
+
trigger_state = gr.Radio(value="three", choices=["three", "two"], label="Trigger Count")
|
61 |
+
|
62 |
+
input_teacher = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Teacher's Video")
|
63 |
+
input_student = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Student's Video")
|
64 |
+
|
65 |
+
|
66 |
+
with gr.Accordion("Clarifications:", open=True):
|
67 |
+
with gr.Accordion("Dynamic Time Warping:", open=False):
|
68 |
+
gr.Markdown("""
|
69 |
+
Dynamic Time Warping is an algorithm that performs frame-by-frame alignment for videos with different speeds.
|
70 |
+
|
71 |
+
- **Winsorized mean**: Determines the portion of DTW paths, sorted from best to worst, to use for generating the mean DTW alignment. Reasonable values range from 0.25 to 0.6.
|
72 |
+
- **Savitzky-Golay Filter**: Enhances the capabilities of the Winsorized mean, making DTW alignment more similar to a strict line. Reasonable values range from 2 to 10.
|
73 |
+
""")
|
74 |
+
|
75 |
+
with gr.Accordion("Thresholds:", open=False):
|
76 |
+
gr.Markdown("""
|
77 |
+
Thresholds are used to identify student errors in dance. If the difference in angle between the teacher's and student's videos exceeds this threshold, it is counted as an error.
|
78 |
+
|
79 |
+
- **Sensitive**: A threshold that is currently not used.
|
80 |
+
- **Standard**: A threshold for most angles. Reasonable values range from 20 to 40.
|
81 |
+
- **Insensitive**: A threshold for difficult areas, such as hands and toes. Reasonable values range from 35 to 55.
|
82 |
+
""")
|
83 |
+
|
84 |
+
with gr.Accordion("Patience:", open=False):
|
85 |
+
gr.Markdown("""
|
86 |
+
Patience helps prevent model errors by highlighting only errors detected in consecutive frames.
|
87 |
+
|
88 |
+
- **Three**: Utilizes 3 consecutive frames for error detection.
|
89 |
+
- **Two**: Utilizes 2 consecutive frames for error detection.
|
90 |
+
|
91 |
+
Both options can be used interchangeably.
|
92 |
+
""")
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
gr_button = gr.Button("Run Pose Comparison")
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
gr.HTML("<div style='height: 100px;'></div>")
|
101 |
+
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
output_merged = gr.Video(show_download_button=True)
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
general_log = gr.TextArea(lines=10, max_lines=9999, label="Error log")
|
108 |
+
|
109 |
+
gr_button.click(
|
110 |
+
fn=video_identity,
|
111 |
+
inputs=[dtw_mean, dtw_filter, angles_sensitive, angles_common, angles_insensitive, trigger_state, input_teacher, input_student],
|
112 |
+
outputs=[output_merged, general_log]
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONNECTIONS_VIT_FULL = [
|
2 |
+
# head
|
3 |
+
(0, 2),
|
4 |
+
(0, 1),
|
5 |
+
(2, 4),
|
6 |
+
(1, 3),
|
7 |
+
(0, 6),
|
8 |
+
(0, 5),
|
9 |
+
|
10 |
+
# right arm
|
11 |
+
(6, 8),
|
12 |
+
(8, 10),
|
13 |
+
|
14 |
+
# right hand
|
15 |
+
(10, 112),
|
16 |
+
|
17 |
+
# Big toe 1
|
18 |
+
(112, 113),
|
19 |
+
(113, 114),
|
20 |
+
(114, 115),
|
21 |
+
(115, 116),
|
22 |
+
|
23 |
+
# toe 2
|
24 |
+
(112, 117),
|
25 |
+
(117, 118),
|
26 |
+
(118, 119),
|
27 |
+
(119, 120),
|
28 |
+
|
29 |
+
# toe 3
|
30 |
+
(112, 121),
|
31 |
+
(121, 122),
|
32 |
+
(122, 123),
|
33 |
+
(123, 124),
|
34 |
+
|
35 |
+
# toe 4
|
36 |
+
(112, 125),
|
37 |
+
(125, 126),
|
38 |
+
(126, 127),
|
39 |
+
(127, 128),
|
40 |
+
|
41 |
+
# toe 5
|
42 |
+
(112, 129),
|
43 |
+
(129, 130),
|
44 |
+
(130, 131),
|
45 |
+
(131, 132),
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# left arm
|
50 |
+
(5, 7),
|
51 |
+
(7, 9),
|
52 |
+
|
53 |
+
# left hand
|
54 |
+
(9, 91),
|
55 |
+
|
56 |
+
|
57 |
+
# Big toe 1
|
58 |
+
(91, 92),
|
59 |
+
(92, 93),
|
60 |
+
(93, 94),
|
61 |
+
(94, 95),
|
62 |
+
|
63 |
+
# toe 2
|
64 |
+
(91, 96),
|
65 |
+
(96, 97),
|
66 |
+
(97, 98),
|
67 |
+
(98, 99),
|
68 |
+
|
69 |
+
# toe 3
|
70 |
+
(91, 100),
|
71 |
+
(100, 101),
|
72 |
+
(101, 102),
|
73 |
+
(102, 103),
|
74 |
+
|
75 |
+
# toe 4
|
76 |
+
(91, 104),
|
77 |
+
(104, 105),
|
78 |
+
(105, 106),
|
79 |
+
(106, 107),
|
80 |
+
|
81 |
+
# toe 5
|
82 |
+
(91, 108),
|
83 |
+
(108, 109),
|
84 |
+
(109, 110),
|
85 |
+
(110, 111),
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
# torso
|
90 |
+
(6, 5),
|
91 |
+
(12, 11),
|
92 |
+
(6, 12),
|
93 |
+
(5, 11),
|
94 |
+
|
95 |
+
# right leg
|
96 |
+
(12, 14),
|
97 |
+
(14, 16),
|
98 |
+
|
99 |
+
# right foot
|
100 |
+
(16, 22),
|
101 |
+
(22, 21),
|
102 |
+
(22, 20),
|
103 |
+
|
104 |
+
|
105 |
+
# left leg
|
106 |
+
(11, 13),
|
107 |
+
(13, 15),
|
108 |
+
|
109 |
+
# left foot
|
110 |
+
(15, 19),
|
111 |
+
(19, 18),
|
112 |
+
(19, 17),
|
113 |
+
]
|
114 |
+
|
115 |
+
EDGE_GROUPS_FOR_ERRORS = [
|
116 |
+
[0, 2, 4],
|
117 |
+
[0, 1, 3],
|
118 |
+
|
119 |
+
# neck
|
120 |
+
[6, 0, 2],
|
121 |
+
[5, 0, 1],
|
122 |
+
|
123 |
+
# right arm
|
124 |
+
|
125 |
+
# right shoulder
|
126 |
+
[5, 6, 8],
|
127 |
+
|
128 |
+
# right elbow
|
129 |
+
[6, 8, 10],
|
130 |
+
|
131 |
+
# right hand
|
132 |
+
[8, 10, 121],
|
133 |
+
|
134 |
+
[112, 114, 116],
|
135 |
+
[112, 117, 120],
|
136 |
+
[112, 121, 124],
|
137 |
+
[112, 125, 128],
|
138 |
+
[112, 129, 132],
|
139 |
+
|
140 |
+
# left arm
|
141 |
+
|
142 |
+
# left shoulder
|
143 |
+
[6, 5, 7],
|
144 |
+
|
145 |
+
# left elbow
|
146 |
+
[5, 7, 9],
|
147 |
+
|
148 |
+
# left hand
|
149 |
+
[7, 9, 100],
|
150 |
+
|
151 |
+
[91, 93, 95],
|
152 |
+
[91, 96, 99],
|
153 |
+
[91, 100, 103],
|
154 |
+
[91, 104, 107],
|
155 |
+
[91, 108, 111],
|
156 |
+
|
157 |
+
|
158 |
+
# right leg
|
159 |
+
|
160 |
+
# right upper-leg
|
161 |
+
[6, 12, 14],
|
162 |
+
|
163 |
+
# right middle-leg
|
164 |
+
[12, 14, 16],
|
165 |
+
|
166 |
+
# right lower-leg
|
167 |
+
[14, 16, 22],
|
168 |
+
[16, 22, 21],
|
169 |
+
[16, 22, 20],
|
170 |
+
|
171 |
+
# left leg
|
172 |
+
|
173 |
+
# left upper-leg
|
174 |
+
[5, 11, 13],
|
175 |
+
|
176 |
+
# left middle-leg
|
177 |
+
[11, 13, 15],
|
178 |
+
|
179 |
+
# left lower-leg
|
180 |
+
[13, 15, 19],
|
181 |
+
[15, 19, 17],
|
182 |
+
[15, 19, 18],
|
183 |
+
|
184 |
+
]
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
CONNECTIONS_FOR_ERROR = [
|
189 |
+
# head
|
190 |
+
(0, 2),
|
191 |
+
(2, 4),
|
192 |
+
(0, 1),
|
193 |
+
(1, 3),
|
194 |
+
|
195 |
+
# right arm
|
196 |
+
(6, 0),
|
197 |
+
(8, 6),
|
198 |
+
(10, 8),
|
199 |
+
|
200 |
+
# right hand
|
201 |
+
# (121, 10),
|
202 |
+
|
203 |
+
(112, 114),
|
204 |
+
(114, 116),
|
205 |
+
|
206 |
+
(112, 117),
|
207 |
+
(117, 120),
|
208 |
+
|
209 |
+
(112, 121),
|
210 |
+
(121, 124),
|
211 |
+
|
212 |
+
(112, 125),
|
213 |
+
(125, 128),
|
214 |
+
|
215 |
+
(112, 129),
|
216 |
+
(129, 132),
|
217 |
+
|
218 |
+
# left arm
|
219 |
+
(5, 0),
|
220 |
+
(7, 5),
|
221 |
+
(9, 7),
|
222 |
+
|
223 |
+
# left hand
|
224 |
+
# (100, 9),
|
225 |
+
|
226 |
+
(91, 93),
|
227 |
+
(93, 95),
|
228 |
+
|
229 |
+
(91, 96),
|
230 |
+
(96, 99),
|
231 |
+
|
232 |
+
(91, 100),
|
233 |
+
(100, 103),
|
234 |
+
|
235 |
+
(91, 104),
|
236 |
+
(104, 107),
|
237 |
+
|
238 |
+
(91, 108),
|
239 |
+
(108, 111),
|
240 |
+
|
241 |
+
# torso
|
242 |
+
(6, 12),
|
243 |
+
(5, 11),
|
244 |
+
|
245 |
+
# right leg
|
246 |
+
(12, 14),
|
247 |
+
(14, 16),
|
248 |
+
|
249 |
+
(16, 22),
|
250 |
+
(22, 21),
|
251 |
+
(22, 20),
|
252 |
+
|
253 |
+
# left leg
|
254 |
+
(11, 13),
|
255 |
+
(13, 15),
|
256 |
+
|
257 |
+
(15, 19),
|
258 |
+
(19, 17),
|
259 |
+
(19, 18),
|
260 |
+
|
261 |
+
]
|
262 |
+
|
263 |
+
def get_thresholds(sensetive_error, general_error, unsensetive_error):
|
264 |
+
thresholds = [
|
265 |
+
general_error,
|
266 |
+
general_error,
|
267 |
+
general_error,
|
268 |
+
general_error,
|
269 |
+
|
270 |
+
general_error,
|
271 |
+
general_error,
|
272 |
+
|
273 |
+
unsensetive_error,
|
274 |
+
unsensetive_error,
|
275 |
+
unsensetive_error,
|
276 |
+
unsensetive_error,
|
277 |
+
unsensetive_error,
|
278 |
+
unsensetive_error,
|
279 |
+
|
280 |
+
general_error,
|
281 |
+
general_error,
|
282 |
+
unsensetive_error,
|
283 |
+
unsensetive_error,
|
284 |
+
unsensetive_error,
|
285 |
+
unsensetive_error,
|
286 |
+
unsensetive_error,
|
287 |
+
unsensetive_error,
|
288 |
+
|
289 |
+
general_error,
|
290 |
+
general_error,
|
291 |
+
unsensetive_error,
|
292 |
+
unsensetive_error,
|
293 |
+
unsensetive_error,
|
294 |
+
|
295 |
+
general_error,
|
296 |
+
general_error,
|
297 |
+
unsensetive_error,
|
298 |
+
unsensetive_error,
|
299 |
+
unsensetive_error,
|
300 |
+
]
|
301 |
+
|
302 |
+
return thresholds
|
303 |
+
|
304 |
+
|
305 |
+
EDGE_GROUPS_FOR_SUMMARY = {
|
306 |
+
(2, 4): "Head position is incorrect",
|
307 |
+
(1, 3): "Head position is incorrect",
|
308 |
+
|
309 |
+
# neck
|
310 |
+
|
311 |
+
(0, 2): "Head position is incorrect",
|
312 |
+
(0, 1): "Head position is incorrect",
|
313 |
+
|
314 |
+
# right arm
|
315 |
+
|
316 |
+
# right shoulder
|
317 |
+
(6, 8): "Right shoulder position is incorrect",
|
318 |
+
|
319 |
+
# right elbow
|
320 |
+
(8, 10): "Right elbow position is incorrect",
|
321 |
+
|
322 |
+
# right hand
|
323 |
+
(10, 121): "Right hand's palm position is incorrect",
|
324 |
+
|
325 |
+
(114, 116): "Right thumb finger position is incorrect",
|
326 |
+
(117, 120): "Right index finger position is incorrect",
|
327 |
+
(121, 124): "Right middle finger position is incorrect",
|
328 |
+
(125, 128): "Right ring finger position is incorrect",
|
329 |
+
(129, 132): "Right pinky finger position is incorrect",
|
330 |
+
|
331 |
+
# left arm
|
332 |
+
|
333 |
+
# left shoulder
|
334 |
+
(5, 7): "Left shoulder position is incorrect",
|
335 |
+
|
336 |
+
# left elbow
|
337 |
+
(7, 9): "Left elbow position is incorrect",
|
338 |
+
|
339 |
+
# left hand
|
340 |
+
(9, 100): "Left hand palm position is incorrect",
|
341 |
+
|
342 |
+
(93, 95): "Left thumb finger position is incorrect",
|
343 |
+
(96, 99): "Left index finger position is incorrect",
|
344 |
+
(100, 103): "Left middle finger position is incorrect",
|
345 |
+
(104, 107): "Left ring finger position is incorrect",
|
346 |
+
(108, 111): "Left pinky finger position is incorrect",
|
347 |
+
|
348 |
+
# right leg
|
349 |
+
|
350 |
+
# right upper-leg
|
351 |
+
(12, 14): "Right thigh position is incorrect",
|
352 |
+
|
353 |
+
# right middle-leg
|
354 |
+
(14, 16): "Right shin position is incorrect",
|
355 |
+
|
356 |
+
# right lower-leg
|
357 |
+
(16, 22): "Right foot position is incorrect",
|
358 |
+
(22, 21): "Right shin position is incorrect",
|
359 |
+
(22, 20): "Right shin position is incorrect",
|
360 |
+
|
361 |
+
# left leg
|
362 |
+
|
363 |
+
# left upper-leg
|
364 |
+
(11, 13): "Left thigh position is incorrect",
|
365 |
+
|
366 |
+
# left middle-leg
|
367 |
+
(13, 15): "Left shin position is incorrect",
|
368 |
+
|
369 |
+
# left lower-leg
|
370 |
+
(15, 19): "Left foot position is incorrect",
|
371 |
+
(19, 17): "Left shin position is incorrect",
|
372 |
+
(19, 18): "Left shin position is incorrect"
|
373 |
+
}
|
easy_ViTPose/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import VitInference
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'VitInference'
|
5 |
+
]
|
easy_ViTPose/config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Train config ---------------------------------------
|
2 |
+
log_level: logging.INFO
|
3 |
+
seed: 0
|
4 |
+
deterministic: True
|
5 |
+
cudnn_benchmark: True # Use cudnn
|
6 |
+
resume_from: "ckpts/og-vitpose-s.pth" # CKPT path
|
7 |
+
# resume_from: False
|
8 |
+
gpu_ids: [0]
|
9 |
+
launcher: 'none' # When distributed training ['none', 'pytorch', 'slurm', 'mpi']
|
10 |
+
use_amp: True
|
11 |
+
validate: True
|
12 |
+
autoscale_lr: False
|
13 |
+
dist_params:
|
14 |
+
...
|
easy_ViTPose/configs/ViTPose_aic.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=14,
|
6 |
+
dataset_joints=14,
|
7 |
+
dataset_channel=[
|
8 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
|
9 |
+
],
|
10 |
+
inference_channel=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
11 |
+
|
12 |
+
# Set models channels
|
13 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
14 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
15 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
16 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
17 |
+
|
18 |
+
names = ['small', 'base', 'large', 'huge']
|
19 |
+
for name in names:
|
20 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_ap10k.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=17,
|
6 |
+
dataset_joints=17,
|
7 |
+
dataset_channel=[
|
8 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
9 |
+
],
|
10 |
+
inference_channel=[
|
11 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
12 |
+
])
|
13 |
+
|
14 |
+
# Set models channels
|
15 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
16 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
17 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
18 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
19 |
+
|
20 |
+
names = ['small', 'base', 'large', 'huge']
|
21 |
+
for name in names:
|
22 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_apt36k.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=17,
|
6 |
+
dataset_joints=17,
|
7 |
+
dataset_channel=[
|
8 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
9 |
+
],
|
10 |
+
inference_channel=[
|
11 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
12 |
+
])
|
13 |
+
|
14 |
+
# Set models channels
|
15 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
16 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
17 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
18 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
19 |
+
|
20 |
+
names = ['small', 'base', 'large', 'huge']
|
21 |
+
for name in names:
|
22 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_coco.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=17,
|
6 |
+
dataset_joints=17,
|
7 |
+
dataset_channel=list(range(17)),
|
8 |
+
inference_channel=list(range(17)))
|
9 |
+
|
10 |
+
# Set models channels
|
11 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
12 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
13 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
14 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
15 |
+
|
16 |
+
names = ['small', 'base', 'large', 'huge']
|
17 |
+
for name in names:
|
18 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_coco_25.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=25,
|
6 |
+
dataset_joints=25,
|
7 |
+
dataset_channel=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
8 |
+
16, 17, 18, 19, 20, 21, 22, 23, 24], ],
|
9 |
+
inference_channel=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
10 |
+
16, 17, 18, 19, 20, 21, 22, 23, 24])
|
11 |
+
|
12 |
+
# Set models channels
|
13 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
14 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
15 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
16 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
17 |
+
|
18 |
+
names = ['small', 'base', 'large', 'huge']
|
19 |
+
for name in names:
|
20 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_common.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Common configuration
|
2 |
+
optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.1,
|
3 |
+
constructor='LayerDecayOptimizerConstructor',
|
4 |
+
paramwise_cfg=dict(
|
5 |
+
num_layers=12,
|
6 |
+
layer_decay_rate=1 - 2e-4,
|
7 |
+
custom_keys={
|
8 |
+
'bias': dict(decay_multi=0.),
|
9 |
+
'pos_embed': dict(decay_mult=0.),
|
10 |
+
'relative_position_bias_table': dict(decay_mult=0.),
|
11 |
+
'norm': dict(decay_mult=0.)
|
12 |
+
}
|
13 |
+
)
|
14 |
+
)
|
15 |
+
|
16 |
+
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
|
17 |
+
|
18 |
+
# learning policy
|
19 |
+
lr_config = dict(
|
20 |
+
policy='step',
|
21 |
+
warmup='linear',
|
22 |
+
warmup_iters=300,
|
23 |
+
warmup_ratio=0.001,
|
24 |
+
step=[3])
|
25 |
+
|
26 |
+
total_epochs = 4
|
27 |
+
target_type = 'GaussianHeatmap'
|
28 |
+
|
29 |
+
data_cfg = dict(
|
30 |
+
image_size=[192, 256],
|
31 |
+
heatmap_size=[48, 64],
|
32 |
+
soft_nms=False,
|
33 |
+
nms_thr=1.0,
|
34 |
+
oks_thr=0.9,
|
35 |
+
vis_thr=0.2,
|
36 |
+
use_gt_bbox=False,
|
37 |
+
det_bbox_thr=0.0,
|
38 |
+
bbox_file='data/coco/person_detection_results/'
|
39 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
40 |
+
)
|
41 |
+
|
42 |
+
data_root = '/home/adryw/dataset/COCO17'
|
43 |
+
data = dict(
|
44 |
+
samples_per_gpu=64,
|
45 |
+
workers_per_gpu=6,
|
46 |
+
val_dataloader=dict(samples_per_gpu=128),
|
47 |
+
test_dataloader=dict(samples_per_gpu=128),
|
48 |
+
train=dict(
|
49 |
+
type='TopDownCocoDataset',
|
50 |
+
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
|
51 |
+
img_prefix=f'{data_root}/train2017/',
|
52 |
+
data_cfg=data_cfg),
|
53 |
+
val=dict(
|
54 |
+
type='TopDownCocoDataset',
|
55 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
56 |
+
img_prefix=f'{data_root}/val2017/',
|
57 |
+
data_cfg=data_cfg),
|
58 |
+
test=dict(
|
59 |
+
type='TopDownCocoDataset',
|
60 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
61 |
+
img_prefix=f'{data_root}/val2017/',
|
62 |
+
data_cfg=data_cfg)
|
63 |
+
)
|
64 |
+
|
65 |
+
model_small = dict(
|
66 |
+
type='TopDown',
|
67 |
+
pretrained=None,
|
68 |
+
backbone=dict(
|
69 |
+
type='ViT',
|
70 |
+
img_size=(256, 192),
|
71 |
+
patch_size=16,
|
72 |
+
embed_dim=384,
|
73 |
+
depth=12,
|
74 |
+
num_heads=12,
|
75 |
+
ratio=1,
|
76 |
+
use_checkpoint=False,
|
77 |
+
mlp_ratio=4,
|
78 |
+
qkv_bias=True,
|
79 |
+
drop_path_rate=0.1,
|
80 |
+
),
|
81 |
+
keypoint_head=dict(
|
82 |
+
type='TopdownHeatmapSimpleHead',
|
83 |
+
in_channels=384,
|
84 |
+
num_deconv_layers=2,
|
85 |
+
num_deconv_filters=(256, 256),
|
86 |
+
num_deconv_kernels=(4, 4),
|
87 |
+
extra=dict(final_conv_kernel=1, ),
|
88 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
89 |
+
train_cfg=dict(),
|
90 |
+
test_cfg=dict(
|
91 |
+
flip_test=True,
|
92 |
+
post_process='default',
|
93 |
+
shift_heatmap=False,
|
94 |
+
target_type=target_type,
|
95 |
+
modulate_kernel=11,
|
96 |
+
use_udp=True))
|
97 |
+
|
98 |
+
model_base = dict(
|
99 |
+
type='TopDown',
|
100 |
+
pretrained=None,
|
101 |
+
backbone=dict(
|
102 |
+
type='ViT',
|
103 |
+
img_size=(256, 192),
|
104 |
+
patch_size=16,
|
105 |
+
embed_dim=768,
|
106 |
+
depth=12,
|
107 |
+
num_heads=12,
|
108 |
+
ratio=1,
|
109 |
+
use_checkpoint=False,
|
110 |
+
mlp_ratio=4,
|
111 |
+
qkv_bias=True,
|
112 |
+
drop_path_rate=0.3,
|
113 |
+
),
|
114 |
+
keypoint_head=dict(
|
115 |
+
type='TopdownHeatmapSimpleHead',
|
116 |
+
in_channels=768,
|
117 |
+
num_deconv_layers=2,
|
118 |
+
num_deconv_filters=(256, 256),
|
119 |
+
num_deconv_kernels=(4, 4),
|
120 |
+
extra=dict(final_conv_kernel=1, ),
|
121 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
122 |
+
train_cfg=dict(),
|
123 |
+
test_cfg=dict(
|
124 |
+
flip_test=True,
|
125 |
+
post_process='default',
|
126 |
+
shift_heatmap=False,
|
127 |
+
target_type=target_type,
|
128 |
+
modulate_kernel=11,
|
129 |
+
use_udp=True))
|
130 |
+
|
131 |
+
model_large = dict(
|
132 |
+
type='TopDown',
|
133 |
+
pretrained=None,
|
134 |
+
backbone=dict(
|
135 |
+
type='ViT',
|
136 |
+
img_size=(256, 192),
|
137 |
+
patch_size=16,
|
138 |
+
embed_dim=1024,
|
139 |
+
depth=24,
|
140 |
+
num_heads=16,
|
141 |
+
ratio=1,
|
142 |
+
use_checkpoint=False,
|
143 |
+
mlp_ratio=4,
|
144 |
+
qkv_bias=True,
|
145 |
+
drop_path_rate=0.5,
|
146 |
+
),
|
147 |
+
keypoint_head=dict(
|
148 |
+
type='TopdownHeatmapSimpleHead',
|
149 |
+
in_channels=1024,
|
150 |
+
num_deconv_layers=2,
|
151 |
+
num_deconv_filters=(256, 256),
|
152 |
+
num_deconv_kernels=(4, 4),
|
153 |
+
extra=dict(final_conv_kernel=1, ),
|
154 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
155 |
+
train_cfg=dict(),
|
156 |
+
test_cfg=dict(
|
157 |
+
flip_test=True,
|
158 |
+
post_process='default',
|
159 |
+
shift_heatmap=False,
|
160 |
+
target_type=target_type,
|
161 |
+
modulate_kernel=11,
|
162 |
+
use_udp=True))
|
163 |
+
|
164 |
+
model_huge = dict(
|
165 |
+
type='TopDown',
|
166 |
+
pretrained=None,
|
167 |
+
backbone=dict(
|
168 |
+
type='ViT',
|
169 |
+
img_size=(256, 192),
|
170 |
+
patch_size=16,
|
171 |
+
embed_dim=1280,
|
172 |
+
depth=32,
|
173 |
+
num_heads=16,
|
174 |
+
ratio=1,
|
175 |
+
use_checkpoint=False,
|
176 |
+
mlp_ratio=4,
|
177 |
+
qkv_bias=True,
|
178 |
+
drop_path_rate=0.55,
|
179 |
+
),
|
180 |
+
keypoint_head=dict(
|
181 |
+
type='TopdownHeatmapSimpleHead',
|
182 |
+
in_channels=1280,
|
183 |
+
num_deconv_layers=2,
|
184 |
+
num_deconv_filters=(256, 256),
|
185 |
+
num_deconv_kernels=(4, 4),
|
186 |
+
extra=dict(final_conv_kernel=1, ),
|
187 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
188 |
+
train_cfg=dict(),
|
189 |
+
test_cfg=dict(
|
190 |
+
flip_test=True,
|
191 |
+
post_process='default',
|
192 |
+
shift_heatmap=False,
|
193 |
+
target_type=target_type,
|
194 |
+
modulate_kernel=11,
|
195 |
+
use_udp=True))
|
easy_ViTPose/configs/ViTPose_mpii.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=16,
|
6 |
+
dataset_joints=16,
|
7 |
+
dataset_channel=list(range(16)),
|
8 |
+
inference_channel=list(range(16)))
|
9 |
+
|
10 |
+
# Set models channels
|
11 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
12 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
13 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
14 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
15 |
+
|
16 |
+
names = ['small', 'base', 'large', 'huge']
|
17 |
+
for name in names:
|
18 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/ViTPose_wholebody.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ViTPose_common import *
|
2 |
+
|
3 |
+
# Channel configuration
|
4 |
+
channel_cfg = dict(
|
5 |
+
num_output_channels=133,
|
6 |
+
dataset_joints=133,
|
7 |
+
dataset_channel=[
|
8 |
+
list(range(133)),
|
9 |
+
],
|
10 |
+
inference_channel=list(range(133)))
|
11 |
+
|
12 |
+
# Set models channels
|
13 |
+
data_cfg['num_output_channels'] = channel_cfg['num_output_channels']
|
14 |
+
data_cfg['num_joints']= channel_cfg['dataset_joints']
|
15 |
+
data_cfg['dataset_channel']= channel_cfg['dataset_channel']
|
16 |
+
data_cfg['inference_channel']= channel_cfg['inference_channel']
|
17 |
+
|
18 |
+
names = ['small', 'base', 'large', 'huge']
|
19 |
+
for name in names:
|
20 |
+
globals()[f'model_{name}']['keypoint_head']['out_channels'] = channel_cfg['num_output_channels']
|
easy_ViTPose/configs/__init__.py
ADDED
File without changes
|
easy_ViTPose/datasets/COCO.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Part of this code is derived/taken from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import json_tricks as json
|
9 |
+
import numpy as np
|
10 |
+
from pycocotools.coco import COCO
|
11 |
+
from torchvision import transforms
|
12 |
+
import torchvision.transforms.functional as F
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from .HumanPoseEstimation import HumanPoseEstimationDataset as Dataset
|
17 |
+
|
18 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
19 |
+
from vit_utils.transform import fliplr_joints, affine_transform, get_affine_transform
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
|
24 |
+
class COCODataset(Dataset):
|
25 |
+
"""
|
26 |
+
COCODataset class.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, root_path="./datasets/COCO", data_version="train2017",
|
30 |
+
is_train=True, use_gt_bboxes=True, bbox_path="",
|
31 |
+
image_width=288, image_height=384,
|
32 |
+
scale=True, scale_factor=0.35, flip_prob=0.5, rotate_prob=0.5, rotation_factor=45., half_body_prob=0.3,
|
33 |
+
use_different_joints_weight=False, heatmap_sigma=3, soft_nms=False):
|
34 |
+
"""
|
35 |
+
Initializes a new COCODataset object.
|
36 |
+
|
37 |
+
Image and annotation indexes are loaded and stored in memory.
|
38 |
+
Annotations are preprocessed to have a simple list of annotations to iterate over.
|
39 |
+
|
40 |
+
Bounding boxes can be loaded from the ground truth or from a pickle file (in this case, no annotations are
|
41 |
+
provided).
|
42 |
+
|
43 |
+
Args:
|
44 |
+
root_path (str): dataset root path.
|
45 |
+
Default: "./datasets/COCO"
|
46 |
+
data_version (str): desired version/folder of COCO. Possible options are "train2017", "val2017".
|
47 |
+
Default: "train2017"
|
48 |
+
is_train (bool): train or eval mode. If true, train mode is used.
|
49 |
+
Default: True
|
50 |
+
use_gt_bboxes (bool): use ground truth bounding boxes. If False, bbox_path is required.
|
51 |
+
Default: True
|
52 |
+
bbox_path (str): bounding boxes pickle file path.
|
53 |
+
Default: ""
|
54 |
+
image_width (int): image width.
|
55 |
+
Default: 288
|
56 |
+
image_height (int): image height.
|
57 |
+
Default: ``384``
|
58 |
+
color_rgb (bool): rgb or bgr color mode. If True, rgb color mode is used.
|
59 |
+
Default: True
|
60 |
+
scale (bool): scale mode.
|
61 |
+
Default: True
|
62 |
+
scale_factor (float): scale factor.
|
63 |
+
Default: 0.35
|
64 |
+
flip_prob (float): flip probability.
|
65 |
+
Default: 0.5
|
66 |
+
rotate_prob (float): rotate probability.
|
67 |
+
Default: 0.5
|
68 |
+
rotation_factor (float): rotation factor.
|
69 |
+
Default: 45.
|
70 |
+
half_body_prob (float): half body probability.
|
71 |
+
Default: 0.3
|
72 |
+
use_different_joints_weight (bool): use different joints weights.
|
73 |
+
If true, the following joints weights will be used:
|
74 |
+
[1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, 1.5]
|
75 |
+
Default: False
|
76 |
+
heatmap_sigma (float): sigma of the gaussian used to create the heatmap.
|
77 |
+
Default: 3
|
78 |
+
soft_nms (bool): enable soft non-maximum suppression.
|
79 |
+
Default: False
|
80 |
+
"""
|
81 |
+
super(COCODataset, self).__init__()
|
82 |
+
|
83 |
+
self.root_path = root_path
|
84 |
+
self.data_version = data_version
|
85 |
+
self.is_train = is_train
|
86 |
+
self.use_gt_bboxes = use_gt_bboxes
|
87 |
+
self.bbox_path = bbox_path
|
88 |
+
self.scale = scale # ToDo Check
|
89 |
+
self.scale_factor = scale_factor
|
90 |
+
self.flip_prob = flip_prob
|
91 |
+
self.rotate_prob = rotate_prob
|
92 |
+
self.rotation_factor = rotation_factor
|
93 |
+
self.half_body_prob = half_body_prob
|
94 |
+
self.use_different_joints_weight = use_different_joints_weight # ToDo Check
|
95 |
+
self.heatmap_sigma = heatmap_sigma
|
96 |
+
self.soft_nms = soft_nms
|
97 |
+
|
98 |
+
# Image & annotation path
|
99 |
+
self.data_path = f"{root_path}/{data_version}"
|
100 |
+
self.annotation_path = f"{root_path}/annotations/person_keypoints_{data_version}.json"
|
101 |
+
|
102 |
+
self.image_size = (image_width, image_height)
|
103 |
+
self.aspect_ratio = image_width * 1.0 / image_height
|
104 |
+
|
105 |
+
self.heatmap_size = (int(image_width / 4), int(image_height / 4))
|
106 |
+
self.heatmap_type = 'gaussian'
|
107 |
+
self.pixel_std = 200 # I don't understand the meaning of pixel_std (=200) in the original implementation
|
108 |
+
|
109 |
+
self.num_joints = 25
|
110 |
+
self.num_joints_half_body = 15
|
111 |
+
|
112 |
+
# eye, ear, shoulder, elbow, wrist, hip, knee, ankle
|
113 |
+
self.flip_pairs = [[1, 2], [3, 4], [6, 7], [8, 9], [10, 11], [12, 13],
|
114 |
+
[15, 16], [17, 18], [19, 22], [20, 23], [21, 24]]
|
115 |
+
self.upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
116 |
+
self.lower_body_ids = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
|
117 |
+
self.joints_weight = np.array([1., 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2,
|
118 |
+
1.5, 1.5, 1., 1., 1., 1.2, 1.2, 1.5, 1.5,
|
119 |
+
1.5, 1.5, 1.5, 1.5, 1.5,
|
120 |
+
1.5]).reshape((self.num_joints, 1))
|
121 |
+
|
122 |
+
self.transform = transforms.Compose([
|
123 |
+
transforms.ToTensor(),
|
124 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
125 |
+
])
|
126 |
+
|
127 |
+
# Load COCO dataset - Create COCO object then load images and annotations
|
128 |
+
self.coco = COCO(self.annotation_path)
|
129 |
+
|
130 |
+
# Create a list of annotations and the corresponding image (each image can contain more than one detection)
|
131 |
+
|
132 |
+
""" Load bboxes and joints
|
133 |
+
- if self.use_gt_bboxes -> Load GT bboxes and joints
|
134 |
+
- else -> Load pre-predicted bboxes by a detector (as YOLOv3) and null joints
|
135 |
+
"""
|
136 |
+
|
137 |
+
if not self.use_gt_bboxes:
|
138 |
+
"""
|
139 |
+
bboxes must be saved as the original COCO annotations
|
140 |
+
i.e. the format must be:
|
141 |
+
bboxes = {
|
142 |
+
'<imgId>': [
|
143 |
+
{
|
144 |
+
'id': <annId>, # progressive id for debugging
|
145 |
+
'clean_bbox': np.array([<x>, <y>, <w>, <h>])}
|
146 |
+
},
|
147 |
+
...
|
148 |
+
],
|
149 |
+
...
|
150 |
+
}
|
151 |
+
"""
|
152 |
+
with open(self.bbox_path, 'rb') as fd:
|
153 |
+
bboxes = pickle.load(fd)
|
154 |
+
|
155 |
+
self.data = []
|
156 |
+
# load annotations for each image of COCO
|
157 |
+
for imgId in tqdm(self.coco.getImgIds(), desc="Prepare images, annotations ... "):
|
158 |
+
ann_ids = self.coco.getAnnIds(imgIds=imgId, iscrowd=False) # annotation ids
|
159 |
+
img = self.coco.loadImgs(imgId)[0] # load img
|
160 |
+
|
161 |
+
if self.use_gt_bboxes:
|
162 |
+
objs = self.coco.loadAnns(ann_ids)
|
163 |
+
|
164 |
+
# sanitize bboxes
|
165 |
+
valid_objs = []
|
166 |
+
for obj in objs:
|
167 |
+
# Skip non-person objects (it should never happen)
|
168 |
+
if obj['category_id'] != 1:
|
169 |
+
continue
|
170 |
+
|
171 |
+
# ignore objs without keypoints annotation
|
172 |
+
if max(obj['keypoints']) == 0 and max(obj['foot_kpts']) == 0:
|
173 |
+
continue
|
174 |
+
|
175 |
+
x, y, w, h = obj['bbox']
|
176 |
+
x1 = np.max((0, x))
|
177 |
+
y1 = np.max((0, y))
|
178 |
+
x2 = np.min((img['width'] - 1, x1 + np.max((0, w - 1))))
|
179 |
+
y2 = np.min((img['height'] - 1, y1 + np.max((0, h - 1))))
|
180 |
+
|
181 |
+
# Use only valid bounding boxes
|
182 |
+
if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
|
183 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
184 |
+
valid_objs.append(obj)
|
185 |
+
|
186 |
+
objs = valid_objs
|
187 |
+
|
188 |
+
else:
|
189 |
+
objs = bboxes[imgId]
|
190 |
+
|
191 |
+
# for each annotation of this image, add the formatted annotation to self.data
|
192 |
+
for obj in objs:
|
193 |
+
joints = np.zeros((self.num_joints, 2), dtype=np.float)
|
194 |
+
joints_visibility = np.ones((self.num_joints, 2), dtype=np.float)
|
195 |
+
|
196 |
+
# Add foot data to keypoints
|
197 |
+
obj['keypoints'].extend(obj['foot_kpts'])
|
198 |
+
|
199 |
+
if self.use_gt_bboxes:
|
200 |
+
""" COCO pre-processing
|
201 |
+
|
202 |
+
- Moved above
|
203 |
+
- Skip non-person objects (it should never happen)
|
204 |
+
if obj['category_id'] != 1:
|
205 |
+
continue
|
206 |
+
|
207 |
+
# ignore objs without keypoints annotation
|
208 |
+
if max(obj['keypoints']) == 0:
|
209 |
+
continue
|
210 |
+
"""
|
211 |
+
|
212 |
+
# Not all joints are already present, skip them
|
213 |
+
vjoints = list(range(self.num_joints))
|
214 |
+
vjoints.remove(5)
|
215 |
+
vjoints.remove(14)
|
216 |
+
|
217 |
+
for idx, pt in enumerate(vjoints):
|
218 |
+
if pt == 5 or pt == 14:
|
219 |
+
continue # Neck and hip are manually filled
|
220 |
+
joints[pt, 0] = obj['keypoints'][idx * 3 + 0]
|
221 |
+
joints[pt, 1] = obj['keypoints'][idx * 3 + 1]
|
222 |
+
t_vis = int(np.clip(obj['keypoints'][idx * 3 + 2], 0, 1))
|
223 |
+
"""
|
224 |
+
- COCO:
|
225 |
+
if visibility == 0 -> keypoint is not in the image.
|
226 |
+
if visibility == 1 -> keypoint is in the image BUT not visible
|
227 |
+
(e.g. behind an object).
|
228 |
+
if visibility == 2 -> keypoint looks clearly
|
229 |
+
(i.e. it is not hidden).
|
230 |
+
"""
|
231 |
+
joints_visibility[pt, 0] = t_vis
|
232 |
+
joints_visibility[pt, 1] = t_vis
|
233 |
+
|
234 |
+
center, scale = self._box2cs(obj['clean_bbox'][:4])
|
235 |
+
|
236 |
+
# Add neck and c-hip (check utils/visualization.py for keypoints)
|
237 |
+
joints[5, 0] = (joints[6, 0] + joints[7, 0]) / 2
|
238 |
+
joints[5, 1] = (joints[6, 1] + joints[7, 1]) / 2
|
239 |
+
joints_visibility[5, :] = min(joints_visibility[6, 0],
|
240 |
+
joints_visibility[7, 0])
|
241 |
+
joints[14, 0] = (joints[12, 0] + joints[13, 0]) / 2
|
242 |
+
joints[14, 1] = (joints[12, 1] + joints[13, 1]) / 2
|
243 |
+
joints_visibility[14, :] = min(joints_visibility[12, 0],
|
244 |
+
joints_visibility[13, 0])
|
245 |
+
|
246 |
+
self.data.append({
|
247 |
+
'imgId': imgId,
|
248 |
+
'annId': obj['id'],
|
249 |
+
'imgPath': f"{self.root_path}/{self.data_version}/{imgId:012d}.jpg",
|
250 |
+
'center': center,
|
251 |
+
'scale': scale,
|
252 |
+
'joints': joints,
|
253 |
+
'joints_visibility': joints_visibility,
|
254 |
+
})
|
255 |
+
|
256 |
+
# Done check if we need prepare_data -> We should not
|
257 |
+
print('\nCOCO dataset loaded!')
|
258 |
+
|
259 |
+
# Default values
|
260 |
+
self.bbox_thre = 1.0
|
261 |
+
self.image_thre = 0.0
|
262 |
+
self.in_vis_thre = 0.2
|
263 |
+
self.nms_thre = 1.0
|
264 |
+
self.oks_thre = 0.9
|
265 |
+
|
266 |
+
def __len__(self):
|
267 |
+
return len(self.data)
|
268 |
+
|
269 |
+
def __getitem__(self, index):
|
270 |
+
# index = 0
|
271 |
+
joints_data = self.data[index].copy()
|
272 |
+
|
273 |
+
# Load image
|
274 |
+
try:
|
275 |
+
image = np.array(Image.open(joints_data['imgPath']))
|
276 |
+
if image.ndim == 2:
|
277 |
+
# Some images are grayscale and will fail the trasform, convert to RGB
|
278 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
279 |
+
except:
|
280 |
+
raise ValueError(f"Fail to read {joints_data['imgPath']}")
|
281 |
+
|
282 |
+
joints = joints_data['joints']
|
283 |
+
joints_vis = joints_data['joints_visibility']
|
284 |
+
|
285 |
+
c = joints_data['center']
|
286 |
+
s = joints_data['scale']
|
287 |
+
score = joints_data['score'] if 'score' in joints_data else 1
|
288 |
+
r = 0
|
289 |
+
|
290 |
+
# Apply data augmentation
|
291 |
+
if self.is_train:
|
292 |
+
if (self.half_body_prob and random.random() < self.half_body_prob and
|
293 |
+
np.sum(joints_vis[:, 0]) > self.num_joints_half_body):
|
294 |
+
c_half_body, s_half_body = self._half_body_transform(joints, joints_vis)
|
295 |
+
|
296 |
+
if c_half_body is not None and s_half_body is not None:
|
297 |
+
c, s = c_half_body, s_half_body
|
298 |
+
|
299 |
+
sf = self.scale_factor
|
300 |
+
rf = self.rotation_factor
|
301 |
+
|
302 |
+
if self.scale:
|
303 |
+
# A random scale factor in [1 - sf, 1 + sf]
|
304 |
+
s = s * np.clip(random.random() * sf + 1, 1 - sf, 1 + sf)
|
305 |
+
|
306 |
+
if self.rotate_prob and random.random() < self.rotate_prob:
|
307 |
+
# A random rotation factor in [-2 * rf, 2 * rf]
|
308 |
+
r = np.clip(random.random() * rf, -rf * 2, rf * 2)
|
309 |
+
else:
|
310 |
+
r = 0
|
311 |
+
|
312 |
+
if self.flip_prob and random.random() < self.flip_prob:
|
313 |
+
image = image[:, ::-1, :]
|
314 |
+
joints, joints_vis = fliplr_joints(joints, joints_vis,
|
315 |
+
image.shape[1],
|
316 |
+
self.flip_pairs)
|
317 |
+
c[0] = image.shape[1] - c[0] - 1
|
318 |
+
|
319 |
+
# Apply affine transform on joints and image
|
320 |
+
trans = get_affine_transform(c, s, self.pixel_std, r, self.image_size)
|
321 |
+
image = cv2.warpAffine(
|
322 |
+
image,
|
323 |
+
trans,
|
324 |
+
(int(self.image_size[0]), int(self.image_size[1])),
|
325 |
+
flags=cv2.INTER_LINEAR
|
326 |
+
)
|
327 |
+
|
328 |
+
for i in range(self.num_joints):
|
329 |
+
if joints_vis[i, 0] > 0.:
|
330 |
+
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
|
331 |
+
|
332 |
+
# Convert image to tensor and normalize
|
333 |
+
if self.transform is not None: # I could remove this check
|
334 |
+
image = self.transform(image)
|
335 |
+
|
336 |
+
target, target_weight = self._generate_target(joints, joints_vis)
|
337 |
+
|
338 |
+
# Update metadata
|
339 |
+
joints_data['joints'] = joints
|
340 |
+
joints_data['joints_visibility'] = joints_vis
|
341 |
+
joints_data['center'] = c
|
342 |
+
joints_data['scale'] = s
|
343 |
+
joints_data['rotation'] = r
|
344 |
+
joints_data['score'] = score
|
345 |
+
|
346 |
+
# from utils.visualization import draw_points_and_skeleton, joints_dict
|
347 |
+
# image = np.rollaxis(image.detach().cpu().numpy(), 0, 3)
|
348 |
+
# joints = np.hstack((joints[:, ::-1], joints_vis[:, 0][..., None]))
|
349 |
+
# image = draw_points_and_skeleton(image.copy(), joints,
|
350 |
+
# joints_dict()['coco']['skeleton'],
|
351 |
+
# person_index=0,
|
352 |
+
# points_color_palette='gist_rainbow',
|
353 |
+
# skeleton_color_palette='jet',
|
354 |
+
# points_palette_samples=10,
|
355 |
+
# confidence_threshold=0.4)
|
356 |
+
# cv2.imshow('', image)
|
357 |
+
# cv2.waitKey(0)
|
358 |
+
|
359 |
+
return image, target.astype(np.float32), target_weight.astype(np.float32), joints_data
|
360 |
+
|
361 |
+
|
362 |
+
# Private methods
|
363 |
+
def _box2cs(self, box):
|
364 |
+
x, y, w, h = box[:4]
|
365 |
+
return self._xywh2cs(x, y, w, h)
|
366 |
+
|
367 |
+
def _xywh2cs(self, x, y, w, h):
|
368 |
+
center = np.zeros((2,), dtype=np.float32)
|
369 |
+
center[0] = x + w * 0.5
|
370 |
+
center[1] = y + h * 0.5
|
371 |
+
|
372 |
+
if w > self.aspect_ratio * h:
|
373 |
+
h = w * 1.0 / self.aspect_ratio
|
374 |
+
elif w < self.aspect_ratio * h:
|
375 |
+
w = h * self.aspect_ratio
|
376 |
+
scale = np.array(
|
377 |
+
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
|
378 |
+
dtype=np.float32)
|
379 |
+
if center[0] != -1:
|
380 |
+
scale = scale * 1.25
|
381 |
+
|
382 |
+
return center, scale
|
383 |
+
|
384 |
+
def _half_body_transform(self, joints, joints_vis):
|
385 |
+
upper_joints = []
|
386 |
+
lower_joints = []
|
387 |
+
for joint_id in range(self.num_joints):
|
388 |
+
if joints_vis[joint_id][0] > 0:
|
389 |
+
if joint_id in self.upper_body_ids:
|
390 |
+
upper_joints.append(joints[joint_id])
|
391 |
+
else:
|
392 |
+
lower_joints.append(joints[joint_id])
|
393 |
+
|
394 |
+
if random.random() < 0.5 and len(upper_joints) > 2:
|
395 |
+
selected_joints = upper_joints
|
396 |
+
else:
|
397 |
+
selected_joints = lower_joints \
|
398 |
+
if len(lower_joints) > 2 else upper_joints
|
399 |
+
|
400 |
+
if len(selected_joints) < 2:
|
401 |
+
return None, None
|
402 |
+
|
403 |
+
selected_joints = np.array(selected_joints, dtype=np.float32)
|
404 |
+
center = selected_joints.mean(axis=0)[:2]
|
405 |
+
|
406 |
+
left_top = np.amin(selected_joints, axis=0)
|
407 |
+
right_bottom = np.amax(selected_joints, axis=0)
|
408 |
+
|
409 |
+
w = right_bottom[0] - left_top[0]
|
410 |
+
h = right_bottom[1] - left_top[1]
|
411 |
+
|
412 |
+
if w > self.aspect_ratio * h:
|
413 |
+
h = w * 1.0 / self.aspect_ratio
|
414 |
+
elif w < self.aspect_ratio * h:
|
415 |
+
w = h * self.aspect_ratio
|
416 |
+
|
417 |
+
scale = np.array(
|
418 |
+
[
|
419 |
+
w * 1.0 / self.pixel_std,
|
420 |
+
h * 1.0 / self.pixel_std
|
421 |
+
],
|
422 |
+
dtype=np.float32
|
423 |
+
)
|
424 |
+
|
425 |
+
scale = scale * 1.5
|
426 |
+
|
427 |
+
return center, scale
|
428 |
+
|
429 |
+
def _generate_target(self, joints, joints_vis):
|
430 |
+
"""
|
431 |
+
:param joints: [num_joints, 2]
|
432 |
+
:param joints_vis: [num_joints, 2]
|
433 |
+
:return: target, target_weight(1: visible, 0: invisible)
|
434 |
+
"""
|
435 |
+
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
|
436 |
+
target_weight[:, 0] = joints_vis[:, 0]
|
437 |
+
|
438 |
+
if self.heatmap_type == 'gaussian':
|
439 |
+
target = np.zeros((self.num_joints,
|
440 |
+
self.heatmap_size[1],
|
441 |
+
self.heatmap_size[0]),
|
442 |
+
dtype=np.float32)
|
443 |
+
|
444 |
+
tmp_size = self.heatmap_sigma * 3
|
445 |
+
|
446 |
+
for joint_id in range(self.num_joints):
|
447 |
+
feat_stride = np.asarray(self.image_size) / np.asarray(self.heatmap_size)
|
448 |
+
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
|
449 |
+
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
|
450 |
+
# Check that any part of the gaussian is in-bounds
|
451 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
452 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
453 |
+
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
|
454 |
+
or br[0] < 0 or br[1] < 0:
|
455 |
+
# If not, just return the image as is
|
456 |
+
target_weight[joint_id] = 0
|
457 |
+
continue
|
458 |
+
|
459 |
+
# # Generate gaussian
|
460 |
+
size = 2 * tmp_size + 1
|
461 |
+
x = np.arange(0, size, 1, np.float32)
|
462 |
+
y = x[:, np.newaxis]
|
463 |
+
x0 = y0 = size // 2
|
464 |
+
# The gaussian is not normalized, we want the center value to equal 1
|
465 |
+
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.heatmap_sigma ** 2))
|
466 |
+
|
467 |
+
# Usable gaussian range
|
468 |
+
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
|
469 |
+
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
|
470 |
+
# Image range
|
471 |
+
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
|
472 |
+
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
|
473 |
+
|
474 |
+
v = target_weight[joint_id]
|
475 |
+
if v > 0.5:
|
476 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
477 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
478 |
+
else:
|
479 |
+
raise NotImplementedError
|
480 |
+
|
481 |
+
if self.use_different_joints_weight:
|
482 |
+
target_weight = np.multiply(target_weight, self.joints_weight)
|
483 |
+
|
484 |
+
return target, target_weight
|
485 |
+
|
486 |
+
def _write_coco_keypoint_results(self, keypoints, res_file):
|
487 |
+
data_pack = [
|
488 |
+
{
|
489 |
+
'cat_id': 1, # 1 == 'person'
|
490 |
+
'cls': 'person',
|
491 |
+
'ann_type': 'keypoints',
|
492 |
+
'keypoints': keypoints
|
493 |
+
}
|
494 |
+
]
|
495 |
+
|
496 |
+
results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
|
497 |
+
with open(res_file, 'w') as f:
|
498 |
+
json.dump(results, f, sort_keys=True, indent=4)
|
499 |
+
try:
|
500 |
+
json.load(open(res_file))
|
501 |
+
except Exception:
|
502 |
+
content = []
|
503 |
+
with open(res_file, 'r') as f:
|
504 |
+
for line in f:
|
505 |
+
content.append(line)
|
506 |
+
content[-1] = ']'
|
507 |
+
with open(res_file, 'w') as f:
|
508 |
+
for c in content:
|
509 |
+
f.write(c)
|
510 |
+
|
511 |
+
def _coco_keypoint_results_one_category_kernel(self, data_pack):
|
512 |
+
cat_id = data_pack['cat_id']
|
513 |
+
keypoints = data_pack['keypoints']
|
514 |
+
cat_results = []
|
515 |
+
|
516 |
+
for img_kpts in keypoints:
|
517 |
+
if len(img_kpts) == 0:
|
518 |
+
continue
|
519 |
+
|
520 |
+
_key_points = np.array([img_kpts[k]['keypoints'] for k in range(len(img_kpts))], dtype=np.float32)
|
521 |
+
key_points = np.zeros((_key_points.shape[0], self.num_joints * 3), dtype=np.float32)
|
522 |
+
|
523 |
+
for ipt in range(self.num_joints):
|
524 |
+
key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
|
525 |
+
key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
|
526 |
+
key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score.
|
527 |
+
|
528 |
+
result = [
|
529 |
+
{
|
530 |
+
'image_id': img_kpts[k]['image'],
|
531 |
+
'category_id': cat_id,
|
532 |
+
'keypoints': list(key_points[k]),
|
533 |
+
'score': img_kpts[k]['score'].astype(np.float32),
|
534 |
+
'center': list(img_kpts[k]['center']),
|
535 |
+
'scale': list(img_kpts[k]['scale'])
|
536 |
+
}
|
537 |
+
for k in range(len(img_kpts))
|
538 |
+
]
|
539 |
+
cat_results.extend(result)
|
540 |
+
|
541 |
+
return cat_results
|
542 |
+
|
543 |
+
|
544 |
+
if __name__ == '__main__':
|
545 |
+
# from skimage import io
|
546 |
+
coco = COCODataset(root_path=f"{os.path.dirname(__file__)}/COCO", data_version="traincoex", rotate_prob=0., half_body_prob=0.)
|
547 |
+
item = coco[1]
|
548 |
+
# io.imsave("tmp.jpg", item[0].permute(1,2,0).numpy())
|
549 |
+
print()
|
550 |
+
print(item[1].shape)
|
551 |
+
print('ok!!')
|
552 |
+
# img = np.clip(np.transpose(item[0].numpy(), (1, 2, 0))[:, :, ::-1] * np.asarray([0.229, 0.224, 0.225]) +
|
553 |
+
# np.asarray([0.485, 0.456, 0.406]), 0, 1) * 255
|
554 |
+
# cv2.imwrite('./tmp.png', img.astype(np.uint8))
|
555 |
+
# print(item[-1])
|
556 |
+
pass
|
easy_ViTPose/datasets/HumanPoseEstimation.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
|
3 |
+
|
4 |
+
class HumanPoseEstimationDataset(Dataset):
|
5 |
+
"""
|
6 |
+
HumanPoseEstimationDataset class.
|
7 |
+
|
8 |
+
Generic class for HPE datasets.
|
9 |
+
"""
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
pass
|
15 |
+
|
16 |
+
def __getitem__(self, item):
|
17 |
+
pass
|
easy_ViTPose/datasets/__init__.py
ADDED
File without changes
|
easy_ViTPose/easy_ViTPose.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: easy-ViTPose
|
3 |
+
Version: 0.1
|
4 |
+
License-File: LICENSE
|
easy_ViTPose/easy_ViTPose.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
src/easy_ViTPose.egg-info/PKG-INFO
|
5 |
+
src/easy_ViTPose.egg-info/SOURCES.txt
|
6 |
+
src/easy_ViTPose.egg-info/dependency_links.txt
|
7 |
+
src/easy_ViTPose.egg-info/top_level.txt
|
8 |
+
src/vit_models/__init__.py
|
9 |
+
src/vit_models/model.py
|
10 |
+
src/vit_models/optimizer.py
|
11 |
+
src/vit_models/losses/__init__.py
|
12 |
+
src/vit_models/losses/classfication_loss.py
|
13 |
+
src/vit_models/losses/heatmap_loss.py
|
14 |
+
src/vit_models/losses/mesh_loss.py
|
15 |
+
src/vit_models/losses/mse_loss.py
|
16 |
+
src/vit_models/losses/multi_loss_factory.py
|
17 |
+
src/vit_models/losses/regression_loss.py
|
18 |
+
src/vit_utils/__init__.py
|
19 |
+
src/vit_utils/dist_util.py
|
20 |
+
src/vit_utils/inference.py
|
21 |
+
src/vit_utils/logging.py
|
22 |
+
src/vit_utils/top_down_eval.py
|
23 |
+
src/vit_utils/train_valid_fn.py
|
24 |
+
src/vit_utils/transform.py
|
25 |
+
src/vit_utils/util.py
|
26 |
+
src/vit_utils/visualization.py
|
27 |
+
src/vit_utils/nms/__init__.py
|
28 |
+
src/vit_utils/nms/nms.py
|
29 |
+
src/vit_utils/nms/nms_ori.py
|
30 |
+
src/vit_utils/nms/setup_linux.py
|
31 |
+
src/vit_utils/post_processing/__init__.py
|
32 |
+
src/vit_utils/post_processing/group.py
|
33 |
+
src/vit_utils/post_processing/nms.py
|
34 |
+
src/vit_utils/post_processing/one_euro_filter.py
|
35 |
+
src/vit_utils/post_processing/post_transforms.py
|
easy_ViTPose/easy_ViTPose.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
easy_ViTPose/easy_ViTPose.egg-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
vit_models
|
2 |
+
vit_utils
|
easy_ViTPose/inference.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
+
import typing
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from ultralytics import YOLO
|
11 |
+
|
12 |
+
from .configs.ViTPose_common import data_cfg
|
13 |
+
from .sort import Sort
|
14 |
+
from .vit_models.model import ViTPose
|
15 |
+
from .vit_utils.inference import draw_bboxes, pad_image
|
16 |
+
from .vit_utils.top_down_eval import keypoints_from_heatmaps
|
17 |
+
from .vit_utils.util import dyn_model_import, infer_dataset_by_path
|
18 |
+
from .vit_utils.visualization import draw_points_and_skeleton, joints_dict
|
19 |
+
|
20 |
+
try:
|
21 |
+
import torch_tensorrt
|
22 |
+
except ModuleNotFoundError:
|
23 |
+
pass
|
24 |
+
|
25 |
+
try:
|
26 |
+
import onnxruntime
|
27 |
+
except ModuleNotFoundError:
|
28 |
+
pass
|
29 |
+
|
30 |
+
__all__ = ['VitInference']
|
31 |
+
np.bool = np.bool_
|
32 |
+
MEAN = [0.485, 0.456, 0.406]
|
33 |
+
STD = [0.229, 0.224, 0.225]
|
34 |
+
|
35 |
+
|
36 |
+
DETC_TO_YOLO_YOLOC = {
|
37 |
+
'human': [0],
|
38 |
+
'cat': [15],
|
39 |
+
'dog': [16],
|
40 |
+
'horse': [17],
|
41 |
+
'sheep': [18],
|
42 |
+
'cow': [19],
|
43 |
+
'elephant': [20],
|
44 |
+
'bear': [21],
|
45 |
+
'zebra': [22],
|
46 |
+
'giraffe': [23],
|
47 |
+
'animals': [15, 16, 17, 18, 19, 20, 21, 22, 23]
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
class VitInference:
|
52 |
+
"""
|
53 |
+
Class for performing inference using ViTPose models with YOLOv8 human detection and SORT tracking.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
model (str): Path to the ViT model file (.pth, .onnx, .engine).
|
57 |
+
yolo (str): Path of the YOLOv8 model to load.
|
58 |
+
model_name (str, optional): Name of the ViT model architecture to use.
|
59 |
+
Valid values are 's', 'b', 'l', 'h'.
|
60 |
+
Defaults to None, is necessary when using .pth checkpoints.
|
61 |
+
det_class (str, optional): the detection class. if None it is inferred by the dataset.
|
62 |
+
valid values are 'human', 'cat', 'dog', 'horse', 'sheep',
|
63 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
64 |
+
'animals' (which is all previous but human)
|
65 |
+
dataset (str, optional): Name of the dataset. If None it's extracted from the file name.
|
66 |
+
Valid values are 'coco', 'coco_25', 'wholebody', 'mpii',
|
67 |
+
'ap10k', 'apt36k', 'aic'
|
68 |
+
yolo_size (int, optional): Size of the input image for YOLOv8 model. Defaults to 320.
|
69 |
+
device (str, optional): Device to use for inference. Defaults to 'cuda' if available, else 'cpu'.
|
70 |
+
is_video (bool, optional): Flag indicating if the input is video. Defaults to False.
|
71 |
+
single_pose (bool, optional): Flag indicating if the video (on images this flag has no effect)
|
72 |
+
will contain a single pose.
|
73 |
+
In this case the SORT tracker is not used (increasing performance)
|
74 |
+
but people id tracking
|
75 |
+
won't be consistent among frames.
|
76 |
+
yolo_step (int, optional): The tracker can be used to predict the bboxes instead of yolo for performance,
|
77 |
+
this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame).
|
78 |
+
This does not have any effect when is_video is False.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, model: str,
|
82 |
+
yolo: str,
|
83 |
+
model_name: Optional[str] = None,
|
84 |
+
det_class: Optional[str] = None,
|
85 |
+
dataset: Optional[str] = None,
|
86 |
+
yolo_size: Optional[int] = 320,
|
87 |
+
device: Optional[str] = None,
|
88 |
+
is_video: Optional[bool] = False,
|
89 |
+
single_pose: Optional[bool] = False,
|
90 |
+
yolo_step: Optional[int] = 1):
|
91 |
+
assert os.path.isfile(model), f'The model file {model} does not exist'
|
92 |
+
assert os.path.isfile(yolo), f'The YOLOv8 model {yolo} does not exist'
|
93 |
+
|
94 |
+
# Device priority is cuda / mps / cpu
|
95 |
+
if device is None:
|
96 |
+
if torch.cuda.is_available():
|
97 |
+
device = 'cuda'
|
98 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
99 |
+
device = 'mps'
|
100 |
+
else:
|
101 |
+
device = 'cpu'
|
102 |
+
|
103 |
+
self.device = device
|
104 |
+
self.yolo = YOLO(yolo, task='detect')
|
105 |
+
self.yolo_size = yolo_size
|
106 |
+
self.yolo_step = yolo_step
|
107 |
+
self.is_video = is_video
|
108 |
+
self.single_pose = single_pose
|
109 |
+
self.reset()
|
110 |
+
|
111 |
+
# State saving during inference
|
112 |
+
self.save_state = True # Can be disabled manually
|
113 |
+
self._img = None
|
114 |
+
self._yolo_res = None
|
115 |
+
self._tracker_res = None
|
116 |
+
self._keypoints = None
|
117 |
+
|
118 |
+
# Use extension to decide which kind of model has been loaded
|
119 |
+
use_onnx = model.endswith('.onnx')
|
120 |
+
use_trt = model.endswith('.engine')
|
121 |
+
|
122 |
+
|
123 |
+
# Extract dataset name
|
124 |
+
if dataset is None:
|
125 |
+
dataset = infer_dataset_by_path(model)
|
126 |
+
|
127 |
+
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
|
128 |
+
'The specified dataset is not valid'
|
129 |
+
|
130 |
+
# Dataset can now be set for visualization
|
131 |
+
self.dataset = dataset
|
132 |
+
|
133 |
+
# if we picked the dataset switch to correct yolo classes if not set
|
134 |
+
if det_class is None:
|
135 |
+
det_class = 'animals' if dataset in ['ap10k', 'apt36k'] else 'human'
|
136 |
+
self.yolo_classes = DETC_TO_YOLO_YOLOC[det_class]
|
137 |
+
|
138 |
+
assert model_name in [None, 's', 'b', 'l', 'h'], \
|
139 |
+
f'The model name {model_name} is not valid'
|
140 |
+
|
141 |
+
# onnx / trt models do not require model_cfg specification
|
142 |
+
if model_name is None:
|
143 |
+
assert use_onnx or use_trt, \
|
144 |
+
'Specify the model_name if not using onnx / trt'
|
145 |
+
else:
|
146 |
+
# Dynamically import the model class
|
147 |
+
model_cfg = dyn_model_import(self.dataset, model_name)
|
148 |
+
|
149 |
+
self.target_size = data_cfg['image_size']
|
150 |
+
if use_onnx:
|
151 |
+
self._ort_session = onnxruntime.InferenceSession(model,
|
152 |
+
providers=['CUDAExecutionProvider',
|
153 |
+
'CPUExecutionProvider'])
|
154 |
+
inf_fn = self._inference_onnx
|
155 |
+
else:
|
156 |
+
self._vit_pose = ViTPose(model_cfg)
|
157 |
+
self._vit_pose.eval()
|
158 |
+
|
159 |
+
if use_trt:
|
160 |
+
self._vit_pose = torch.jit.load(model)
|
161 |
+
else:
|
162 |
+
ckpt = torch.load(model, map_location='cpu')
|
163 |
+
if 'state_dict' in ckpt:
|
164 |
+
self._vit_pose.load_state_dict(ckpt['state_dict'])
|
165 |
+
else:
|
166 |
+
self._vit_pose.load_state_dict(ckpt)
|
167 |
+
self._vit_pose.to(torch.device(device))
|
168 |
+
|
169 |
+
inf_fn = self._inference_torch
|
170 |
+
|
171 |
+
# Override _inference abstract with selected engine
|
172 |
+
self._inference = inf_fn # type: ignore
|
173 |
+
|
174 |
+
def reset(self):
|
175 |
+
"""
|
176 |
+
Reset the inference class to be ready for a new video.
|
177 |
+
This will reset the internal counter of frames, on videos
|
178 |
+
this is necessary to reset the tracker.
|
179 |
+
"""
|
180 |
+
min_hits = 3 if self.yolo_step == 1 else 1
|
181 |
+
use_tracker = self.is_video and not self.single_pose
|
182 |
+
self.tracker = Sort(max_age=self.yolo_step,
|
183 |
+
min_hits=min_hits,
|
184 |
+
iou_threshold=0.3) if use_tracker else None # TODO: Params
|
185 |
+
self.frame_counter = 0
|
186 |
+
|
187 |
+
@classmethod
|
188 |
+
def postprocess(cls, heatmaps, org_w, org_h):
|
189 |
+
"""
|
190 |
+
Postprocess the heatmaps to obtain keypoints and their probabilities.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
heatmaps (ndarray): Heatmap predictions from the model.
|
194 |
+
org_w (int): Original width of the image.
|
195 |
+
org_h (int): Original height of the image.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
ndarray: Processed keypoints with probabilities.
|
199 |
+
"""
|
200 |
+
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
|
201 |
+
center=np.array([[org_w // 2,
|
202 |
+
org_h // 2]]),
|
203 |
+
scale=np.array([[org_w, org_h]]),
|
204 |
+
unbiased=True, use_udp=True)
|
205 |
+
return np.concatenate([points[:, :, ::-1], prob], axis=2)
|
206 |
+
|
207 |
+
@abc.abstractmethod
|
208 |
+
def _inference(self, img: np.ndarray) -> np.ndarray:
|
209 |
+
"""
|
210 |
+
Abstract method for performing inference on an image.
|
211 |
+
It is overloaded by each inference engine.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
img (ndarray): Input image for inference.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
ndarray: Inference results.
|
218 |
+
"""
|
219 |
+
raise NotImplementedError
|
220 |
+
|
221 |
+
def inference(self, img: np.ndarray) -> dict[typing.Any, typing.Any]:
|
222 |
+
"""
|
223 |
+
Perform inference on the input image.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
img (ndarray): Input image for inference in RGB format.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
dict[typing.Any, typing.Any]: Inference results.
|
230 |
+
"""
|
231 |
+
|
232 |
+
# First use YOLOv8 for detection
|
233 |
+
res_pd = np.empty((0, 5))
|
234 |
+
results = None
|
235 |
+
if (self.tracker is None or
|
236 |
+
(self.frame_counter % self.yolo_step == 0 or self.frame_counter < 3)):
|
237 |
+
results = self.yolo(img, verbose=False, imgsz=self.yolo_size,
|
238 |
+
device=self.device if self.device != 'cuda' else 0,
|
239 |
+
classes=self.yolo_classes)[0]
|
240 |
+
res_pd = np.array([r[:5].tolist() for r in # TODO: Confidence threshold
|
241 |
+
results.boxes.data.cpu().numpy() if r[4] > 0.35]).reshape((-1, 5))
|
242 |
+
self.frame_counter += 1
|
243 |
+
|
244 |
+
frame_keypoints = {}
|
245 |
+
ids = None
|
246 |
+
if self.tracker is not None:
|
247 |
+
res_pd = self.tracker.update(res_pd)
|
248 |
+
ids = res_pd[:, 5].astype(int).tolist()
|
249 |
+
|
250 |
+
# Prepare boxes for inference
|
251 |
+
bboxes = res_pd[:, :4].round().astype(int)
|
252 |
+
scores = res_pd[:, 4].tolist()
|
253 |
+
pad_bbox = 10
|
254 |
+
|
255 |
+
if ids is None:
|
256 |
+
ids = range(len(bboxes))
|
257 |
+
|
258 |
+
for bbox, id in zip(bboxes, ids):
|
259 |
+
# TODO: Slightly bigger bbox
|
260 |
+
bbox[[0, 2]] = np.clip(bbox[[0, 2]] + [-pad_bbox, pad_bbox], 0, img.shape[1])
|
261 |
+
bbox[[1, 3]] = np.clip(bbox[[1, 3]] + [-pad_bbox, pad_bbox], 0, img.shape[0])
|
262 |
+
|
263 |
+
# Crop image and pad to 3/4 aspect ratio
|
264 |
+
img_inf = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
265 |
+
img_inf, (left_pad, top_pad) = pad_image(img_inf, 3 / 4)
|
266 |
+
|
267 |
+
keypoints = self._inference(img_inf)[0]
|
268 |
+
# Transform keypoints to original image
|
269 |
+
keypoints[:, :2] += bbox[:2][::-1] - [top_pad, left_pad]
|
270 |
+
frame_keypoints[id] = keypoints
|
271 |
+
|
272 |
+
if self.save_state:
|
273 |
+
self._img = img
|
274 |
+
self._yolo_res = results
|
275 |
+
self._tracker_res = (bboxes, ids, scores)
|
276 |
+
self._keypoints = frame_keypoints
|
277 |
+
|
278 |
+
return frame_keypoints
|
279 |
+
|
280 |
+
def draw(self, show_yolo=True, show_raw_yolo=False, confidence_threshold=0.5):
|
281 |
+
"""
|
282 |
+
Draw keypoints and bounding boxes on the image.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
show_yolo (bool, optional): Whether to show YOLOv8 bounding boxes. Default is True.
|
286 |
+
show_raw_yolo (bool, optional): Whether to show raw YOLOv8 bounding boxes. Default is False.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
ndarray: Image with keypoints and bounding boxes drawn.
|
290 |
+
"""
|
291 |
+
img = self._img.copy()
|
292 |
+
bboxes, ids, scores = self._tracker_res
|
293 |
+
|
294 |
+
if self._yolo_res is not None and (show_raw_yolo or (self.tracker is None and show_yolo)):
|
295 |
+
img = np.array(self._yolo_res.plot())
|
296 |
+
|
297 |
+
if show_yolo and self.tracker is not None:
|
298 |
+
img = draw_bboxes(img, bboxes, ids, scores)
|
299 |
+
|
300 |
+
img = np.array(img)[..., ::-1] # RGB to BGR for cv2 modules
|
301 |
+
for idx, k in self._keypoints.items():
|
302 |
+
img = draw_points_and_skeleton(img.copy(), k,
|
303 |
+
joints_dict()[self.dataset]['skeleton'],
|
304 |
+
person_index=idx,
|
305 |
+
points_color_palette='gist_rainbow',
|
306 |
+
skeleton_color_palette='jet',
|
307 |
+
points_palette_samples=10,
|
308 |
+
confidence_threshold=confidence_threshold)
|
309 |
+
return img[..., ::-1] # Return RGB as original
|
310 |
+
|
311 |
+
def pre_img(self, img):
|
312 |
+
org_h, org_w = img.shape[:2]
|
313 |
+
img_input = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR) / 255
|
314 |
+
img_input = ((img_input - MEAN) / STD).transpose(2, 0, 1)[None].astype(np.float32)
|
315 |
+
return img_input, org_h, org_w
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
def _inference_torch(self, img: np.ndarray) -> np.ndarray:
|
319 |
+
# Prepare input data
|
320 |
+
img_input, org_h, org_w = self.pre_img(img)
|
321 |
+
img_input = torch.from_numpy(img_input).to(torch.device(self.device))
|
322 |
+
|
323 |
+
# Feed to model
|
324 |
+
heatmaps = self._vit_pose(img_input).detach().cpu().numpy()
|
325 |
+
return self.postprocess(heatmaps, org_w, org_h)
|
326 |
+
|
327 |
+
def _inference_onnx(self, img: np.ndarray) -> np.ndarray:
|
328 |
+
# Prepare input data
|
329 |
+
img_input, org_h, org_w = self.pre_img(img)
|
330 |
+
|
331 |
+
# Feed to model
|
332 |
+
ort_inputs = {self._ort_session.get_inputs()[0].name: img_input}
|
333 |
+
heatmaps = self._ort_session.run(None, ort_inputs)[0]
|
334 |
+
return self.postprocess(heatmaps, org_w, org_h)
|
easy_ViTPose/sort.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SORT: A Simple, Online and Realtime Tracker
|
3 |
+
Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai
|
4 |
+
|
5 |
+
This program is free software: you can redistribute it and/or modify
|
6 |
+
it under the terms of the GNU General Public License as published by
|
7 |
+
the Free Software Foundation, either version 3 of the License, or
|
8 |
+
(at your option) any later version.
|
9 |
+
|
10 |
+
This program is distributed in the hope that it will be useful,
|
11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
GNU General Public License for more details.
|
14 |
+
|
15 |
+
You should have received a copy of the GNU General Public License
|
16 |
+
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
17 |
+
"""
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import os
|
21 |
+
import numpy as np
|
22 |
+
import matplotlib
|
23 |
+
|
24 |
+
import matplotlib.pyplot as plt
|
25 |
+
import matplotlib.patches as patches
|
26 |
+
from skimage import io
|
27 |
+
|
28 |
+
import glob
|
29 |
+
import time
|
30 |
+
import argparse
|
31 |
+
from filterpy.kalman import KalmanFilter
|
32 |
+
|
33 |
+
np.random.seed(0)
|
34 |
+
|
35 |
+
|
36 |
+
def linear_assignment(cost_matrix):
|
37 |
+
try:
|
38 |
+
import lap
|
39 |
+
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
40 |
+
return np.array([[y[i], i] for i in x if i >= 0])
|
41 |
+
except ImportError:
|
42 |
+
from scipy.optimize import linear_sum_assignment
|
43 |
+
x, y = linear_sum_assignment(cost_matrix)
|
44 |
+
return np.array(list(zip(x, y)))
|
45 |
+
|
46 |
+
|
47 |
+
def iou_batch(bb_test, bb_gt):
|
48 |
+
"""
|
49 |
+
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
50 |
+
"""
|
51 |
+
bb_gt = np.expand_dims(bb_gt, 0)
|
52 |
+
bb_test = np.expand_dims(bb_test, 1)
|
53 |
+
|
54 |
+
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
|
55 |
+
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
|
56 |
+
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
|
57 |
+
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
|
58 |
+
w = np.maximum(0., xx2 - xx1)
|
59 |
+
h = np.maximum(0., yy2 - yy1)
|
60 |
+
wh = w * h
|
61 |
+
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
|
62 |
+
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
|
63 |
+
return(o)
|
64 |
+
|
65 |
+
|
66 |
+
def convert_bbox_to_z(bbox):
|
67 |
+
"""
|
68 |
+
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
69 |
+
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
70 |
+
the aspect ratio
|
71 |
+
"""
|
72 |
+
w = bbox[2] - bbox[0]
|
73 |
+
h = bbox[3] - bbox[1]
|
74 |
+
x = bbox[0] + w/2.
|
75 |
+
y = bbox[1] + h/2.
|
76 |
+
s = w * h # scale is just area
|
77 |
+
r = w / float(h)
|
78 |
+
return np.array([x, y, s, r]).reshape((4, 1))
|
79 |
+
|
80 |
+
|
81 |
+
def convert_x_to_bbox(x, score=None):
|
82 |
+
"""
|
83 |
+
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
84 |
+
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
85 |
+
"""
|
86 |
+
w = np.sqrt(x[2] * x[3])
|
87 |
+
h = x[2] / w
|
88 |
+
if(score == None):
|
89 |
+
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
|
90 |
+
else:
|
91 |
+
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
|
92 |
+
|
93 |
+
|
94 |
+
class KalmanBoxTracker(object):
|
95 |
+
"""
|
96 |
+
This class represents the internal state of individual tracked objects observed as bbox.
|
97 |
+
"""
|
98 |
+
count = 0
|
99 |
+
|
100 |
+
def __init__(self, bbox, score):
|
101 |
+
"""
|
102 |
+
Initialises a tracker using initial bounding box.
|
103 |
+
"""
|
104 |
+
# define constant velocity model
|
105 |
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
106 |
+
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [
|
107 |
+
0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
|
108 |
+
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
|
109 |
+
|
110 |
+
self.kf.R[2:, 2:] *= 10.
|
111 |
+
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
|
112 |
+
self.kf.P *= 10.
|
113 |
+
self.kf.Q[-1, -1] *= 0.01
|
114 |
+
self.kf.Q[4:, 4:] *= 0.01
|
115 |
+
|
116 |
+
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
117 |
+
self.time_since_update = 0
|
118 |
+
self.id = KalmanBoxTracker.count
|
119 |
+
KalmanBoxTracker.count += 1
|
120 |
+
self.history = []
|
121 |
+
self.hits = 0
|
122 |
+
self.hit_streak = 0
|
123 |
+
self.age = 0
|
124 |
+
self.score = score
|
125 |
+
|
126 |
+
def update(self, bbox, score):
|
127 |
+
"""
|
128 |
+
Updates the state vector with observed bbox.
|
129 |
+
"""
|
130 |
+
self.time_since_update = 0
|
131 |
+
self.history = []
|
132 |
+
self.hits += 1
|
133 |
+
self.hit_streak += 1
|
134 |
+
self.kf.update(convert_bbox_to_z(bbox))
|
135 |
+
self.score = score
|
136 |
+
|
137 |
+
def predict(self):
|
138 |
+
"""
|
139 |
+
Advances the state vector and returns the predicted bounding box estimate.
|
140 |
+
"""
|
141 |
+
if((self.kf.x[6]+self.kf.x[2]) <= 0):
|
142 |
+
self.kf.x[6] *= 0.0
|
143 |
+
self.kf.predict()
|
144 |
+
self.age += 1
|
145 |
+
if(self.time_since_update > 0):
|
146 |
+
self.hit_streak = 0
|
147 |
+
self.time_since_update += 1
|
148 |
+
self.history.append(convert_x_to_bbox(self.kf.x))
|
149 |
+
return self.history[-1]
|
150 |
+
|
151 |
+
def get_state(self):
|
152 |
+
"""
|
153 |
+
Returns the current bounding box estimate.
|
154 |
+
"""
|
155 |
+
return convert_x_to_bbox(self.kf.x)
|
156 |
+
|
157 |
+
|
158 |
+
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
|
159 |
+
"""
|
160 |
+
Assigns detections to tracked object (both represented as bounding boxes)
|
161 |
+
|
162 |
+
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
163 |
+
"""
|
164 |
+
if(len(trackers) == 0):
|
165 |
+
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
|
166 |
+
|
167 |
+
iou_matrix = iou_batch(detections, trackers)
|
168 |
+
|
169 |
+
if min(iou_matrix.shape) > 0:
|
170 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
171 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
172 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
173 |
+
else:
|
174 |
+
matched_indices = linear_assignment(-iou_matrix)
|
175 |
+
else:
|
176 |
+
matched_indices = np.empty(shape=(0, 2))
|
177 |
+
|
178 |
+
unmatched_detections = []
|
179 |
+
for d, det in enumerate(detections):
|
180 |
+
if(d not in matched_indices[:, 0]):
|
181 |
+
unmatched_detections.append(d)
|
182 |
+
unmatched_trackers = []
|
183 |
+
for t, trk in enumerate(trackers):
|
184 |
+
if(t not in matched_indices[:, 1]):
|
185 |
+
unmatched_trackers.append(t)
|
186 |
+
|
187 |
+
# filter out matched with low IOU
|
188 |
+
matches = []
|
189 |
+
for m in matched_indices:
|
190 |
+
if(iou_matrix[m[0], m[1]] < iou_threshold):
|
191 |
+
unmatched_detections.append(m[0])
|
192 |
+
unmatched_trackers.append(m[1])
|
193 |
+
else:
|
194 |
+
matches.append(m.reshape(1, 2))
|
195 |
+
if(len(matches) == 0):
|
196 |
+
matches = np.empty((0, 2), dtype=int)
|
197 |
+
else:
|
198 |
+
matches = np.concatenate(matches, axis=0)
|
199 |
+
|
200 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
201 |
+
|
202 |
+
|
203 |
+
class Sort(object):
|
204 |
+
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
|
205 |
+
"""
|
206 |
+
Sets key parameters for SORT
|
207 |
+
"""
|
208 |
+
self.max_age = max_age
|
209 |
+
self.min_hits = min_hits
|
210 |
+
self.iou_threshold = iou_threshold
|
211 |
+
self.trackers = []
|
212 |
+
self.frame_count = 0
|
213 |
+
|
214 |
+
def update(self, dets=np.empty((0, 5))):
|
215 |
+
"""
|
216 |
+
Params:
|
217 |
+
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
|
218 |
+
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
|
219 |
+
Returns the a similar array, where the last column is the object ID.
|
220 |
+
|
221 |
+
NOTE: The number of objects returned may differ from the number of detections provided.
|
222 |
+
"""
|
223 |
+
self.frame_count += 1
|
224 |
+
empty_dets = dets.shape[0] == 0
|
225 |
+
|
226 |
+
# get predicted locations from existing trackers.
|
227 |
+
trks = np.zeros((len(self.trackers), 5))
|
228 |
+
to_del = []
|
229 |
+
ret = []
|
230 |
+
for t, trk in enumerate(trks):
|
231 |
+
pos = self.trackers[t].predict()[0]
|
232 |
+
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
233 |
+
if np.any(np.isnan(pos)):
|
234 |
+
to_del.append(t)
|
235 |
+
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
236 |
+
for t in reversed(to_del):
|
237 |
+
self.trackers.pop(t)
|
238 |
+
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
|
239 |
+
|
240 |
+
# update matched trackers with assigned detections
|
241 |
+
for m in matched:
|
242 |
+
self.trackers[m[1]].update(dets[m[0], :], dets[m[0], -1])
|
243 |
+
|
244 |
+
# create and initialise new trackers for unmatched detections
|
245 |
+
for i in unmatched_dets:
|
246 |
+
trk = KalmanBoxTracker(dets[i, :], dets[i, -1])
|
247 |
+
self.trackers.append(trk)
|
248 |
+
|
249 |
+
i = len(self.trackers)
|
250 |
+
unmatched = []
|
251 |
+
for trk in reversed(self.trackers):
|
252 |
+
d = trk.get_state()[0]
|
253 |
+
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
254 |
+
ret.append(np.concatenate((d, [trk.score, trk.id+1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
|
255 |
+
i -= 1
|
256 |
+
# remove dead tracklet
|
257 |
+
if(trk.time_since_update > self.max_age):
|
258 |
+
self.trackers.pop(i)
|
259 |
+
if empty_dets:
|
260 |
+
unmatched.append(np.concatenate((d, [trk.score, trk.id + 1])).reshape(1, -1))
|
261 |
+
|
262 |
+
if len(ret):
|
263 |
+
return np.concatenate(ret)
|
264 |
+
elif empty_dets:
|
265 |
+
return np.concatenate(unmatched) if len(unmatched) else np.empty((0, 6))
|
266 |
+
return np.empty((0, 6))
|
easy_ViTPose/to_onnx.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
easy_ViTPose/to_trt.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
easy_ViTPose/train.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
import click
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
from glob import glob
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
|
16 |
+
from vit_utils.util import init_random_seed, set_random_seed
|
17 |
+
from vit_utils.dist_util import get_dist_info, init_dist
|
18 |
+
from vit_utils.logging import get_root_logger
|
19 |
+
|
20 |
+
import configs.ViTPose_small_coco_256x192 as s_cfg
|
21 |
+
import configs.ViTPose_base_coco_256x192 as b_cfg
|
22 |
+
import configs.ViTPose_large_coco_256x192 as l_cfg
|
23 |
+
import configs.ViTPose_huge_coco_256x192 as h_cfg
|
24 |
+
|
25 |
+
from vit_models.model import ViTPose
|
26 |
+
from datasets.COCO import COCODataset
|
27 |
+
from vit_utils.train_valid_fn import train_model
|
28 |
+
|
29 |
+
CUR_PATH = osp.dirname(__file__)
|
30 |
+
|
31 |
+
@click.command()
|
32 |
+
@click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
|
33 |
+
@click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
|
34 |
+
def main(config_path, model_name):
|
35 |
+
|
36 |
+
cfg = {'b':b_cfg,
|
37 |
+
's':s_cfg,
|
38 |
+
'l':l_cfg,
|
39 |
+
'h':h_cfg}.get(model_name.lower())
|
40 |
+
# Load config.yaml
|
41 |
+
with open(config_path, 'r') as f:
|
42 |
+
cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
|
43 |
+
|
44 |
+
for k, v in cfg_yaml.items():
|
45 |
+
if hasattr(cfg, k):
|
46 |
+
raise ValueError(f"Already exists {k} in config")
|
47 |
+
else:
|
48 |
+
cfg.__setattr__(k, v)
|
49 |
+
|
50 |
+
# set cudnn_benchmark
|
51 |
+
if cfg.cudnn_benchmark:
|
52 |
+
torch.backends.cudnn.benchmark = True
|
53 |
+
|
54 |
+
# Set work directory (session-level)
|
55 |
+
if not hasattr(cfg, 'work_dir'):
|
56 |
+
cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
|
57 |
+
|
58 |
+
if not osp.exists(cfg.work_dir):
|
59 |
+
os.makedirs(cfg.work_dir)
|
60 |
+
session_list = sorted(glob(f"{cfg.work_dir}/*"))
|
61 |
+
if len(session_list) == 0:
|
62 |
+
session = 1
|
63 |
+
else:
|
64 |
+
session = int(os.path.basename(session_list[-1])) + 1
|
65 |
+
session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
|
66 |
+
os.makedirs(session_dir)
|
67 |
+
cfg.__setattr__('work_dir', session_dir)
|
68 |
+
|
69 |
+
|
70 |
+
if cfg.autoscale_lr:
|
71 |
+
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
|
72 |
+
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
|
73 |
+
|
74 |
+
# init distributed env first, since logger depends on the dist info.
|
75 |
+
if cfg.launcher == 'none':
|
76 |
+
distributed = False
|
77 |
+
if len(cfg.gpu_ids) > 1:
|
78 |
+
warnings.warn(
|
79 |
+
f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
|
80 |
+
f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
|
81 |
+
"non-distribute training time.")
|
82 |
+
cfg.gpu_ids = cfg.gpu_ids[0:1]
|
83 |
+
else:
|
84 |
+
distributed = True
|
85 |
+
init_dist(cfg.launcher, **cfg.dist_params)
|
86 |
+
# re-set gpu_ids with distributed training mode
|
87 |
+
_, world_size = get_dist_info()
|
88 |
+
cfg.gpu_ids = range(world_size)
|
89 |
+
|
90 |
+
# init the logger before other steps
|
91 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
92 |
+
log_file = osp.join(session_dir, f'{timestamp}.log')
|
93 |
+
logger = get_root_logger(log_file=log_file)
|
94 |
+
|
95 |
+
# init the meta dict to record some important information such as
|
96 |
+
# environment info and seed, which will be logged
|
97 |
+
meta = dict()
|
98 |
+
|
99 |
+
# log some basic info
|
100 |
+
logger.info(f'Distributed training: {distributed}')
|
101 |
+
|
102 |
+
# set random seeds
|
103 |
+
seed = init_random_seed(cfg.seed)
|
104 |
+
logger.info(f"Set random seed to {seed}, "
|
105 |
+
f"deterministic: {cfg.deterministic}")
|
106 |
+
set_random_seed(seed, deterministic=cfg.deterministic)
|
107 |
+
meta['seed'] = seed
|
108 |
+
|
109 |
+
# Set model
|
110 |
+
model = ViTPose(cfg.model)
|
111 |
+
if cfg.resume_from:
|
112 |
+
# Load ckpt partially
|
113 |
+
ckpt_state = torch.load(cfg.resume_from)['state_dict']
|
114 |
+
ckpt_state.pop('keypoint_head.final_layer.bias')
|
115 |
+
ckpt_state.pop('keypoint_head.final_layer.weight')
|
116 |
+
model.load_state_dict(ckpt_state, strict=False)
|
117 |
+
|
118 |
+
# freeze the backbone, leave the head to be finetuned
|
119 |
+
model.backbone.frozen_stages = model.backbone.depth - 1
|
120 |
+
model.backbone.freeze_ffn = True
|
121 |
+
model.backbone.freeze_attn = True
|
122 |
+
model.backbone._freeze_stages()
|
123 |
+
|
124 |
+
# Set dataset
|
125 |
+
datasets_train = COCODataset(
|
126 |
+
root_path=cfg.data_root,
|
127 |
+
data_version="feet_train",
|
128 |
+
is_train=True,
|
129 |
+
use_gt_bboxes=True,
|
130 |
+
image_width=192,
|
131 |
+
image_height=256,
|
132 |
+
scale=True,
|
133 |
+
scale_factor=0.35,
|
134 |
+
flip_prob=0.5,
|
135 |
+
rotate_prob=0.5,
|
136 |
+
rotation_factor=45.,
|
137 |
+
half_body_prob=0.3,
|
138 |
+
use_different_joints_weight=True,
|
139 |
+
heatmap_sigma=3,
|
140 |
+
soft_nms=False
|
141 |
+
)
|
142 |
+
|
143 |
+
datasets_valid = COCODataset(
|
144 |
+
root_path=cfg.data_root,
|
145 |
+
data_version="feet_val",
|
146 |
+
is_train=False,
|
147 |
+
use_gt_bboxes=True,
|
148 |
+
image_width=192,
|
149 |
+
image_height=256,
|
150 |
+
scale=False,
|
151 |
+
scale_factor=0.35,
|
152 |
+
flip_prob=0.5,
|
153 |
+
rotate_prob=0.5,
|
154 |
+
rotation_factor=45.,
|
155 |
+
half_body_prob=0.3,
|
156 |
+
use_different_joints_weight=True,
|
157 |
+
heatmap_sigma=3,
|
158 |
+
soft_nms=False
|
159 |
+
)
|
160 |
+
|
161 |
+
train_model(
|
162 |
+
model=model,
|
163 |
+
datasets_train=datasets_train,
|
164 |
+
datasets_valid=datasets_valid,
|
165 |
+
cfg=cfg,
|
166 |
+
distributed=distributed,
|
167 |
+
validate=cfg.validate,
|
168 |
+
timestamp=timestamp,
|
169 |
+
meta=meta
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == '__main__':
|
174 |
+
main()
|
easy_ViTPose/vit_models/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
5 |
+
|
6 |
+
from vit_utils.util import load_checkpoint, resize, constant_init, normal_init
|
7 |
+
from vit_utils.top_down_eval import keypoints_from_heatmaps, pose_pck_accuracy
|
8 |
+
from vit_utils.post_processing import *
|
easy_ViTPose/vit_models/backbone/__init__.py
ADDED
File without changes
|
easy_ViTPose/vit_models/backbone/vit.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from itertools import repeat
|
6 |
+
import collections.abc
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from functools import partial
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils.checkpoint as checkpoint
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
# from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
# from .base_backbone import BaseBackbone
|
18 |
+
|
19 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
20 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
21 |
+
|
22 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
23 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
24 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
25 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
26 |
+
'survival rate' as the argument.
|
27 |
+
|
28 |
+
"""
|
29 |
+
if drop_prob == 0. or not training:
|
30 |
+
return x
|
31 |
+
keep_prob = 1 - drop_prob
|
32 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
33 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
34 |
+
if keep_prob > 0.0 and scale_by_keep:
|
35 |
+
random_tensor.div_(keep_prob)
|
36 |
+
return x * random_tensor
|
37 |
+
|
38 |
+
def _ntuple(n):
|
39 |
+
def parse(x):
|
40 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
41 |
+
return x
|
42 |
+
return tuple(repeat(x, n))
|
43 |
+
return parse
|
44 |
+
|
45 |
+
|
46 |
+
to_1tuple = _ntuple(1)
|
47 |
+
to_2tuple = _ntuple(2)
|
48 |
+
to_3tuple = _ntuple(3)
|
49 |
+
to_4tuple = _ntuple(4)
|
50 |
+
to_ntuple = _ntuple
|
51 |
+
|
52 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
53 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
54 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
55 |
+
def norm_cdf(x):
|
56 |
+
# Computes standard normal cumulative distribution function
|
57 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
58 |
+
|
59 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
60 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
61 |
+
"The distribution of values may be incorrect.",
|
62 |
+
stacklevel=2)
|
63 |
+
|
64 |
+
# Values are generated by using a truncated uniform distribution and
|
65 |
+
# then using the inverse CDF for the normal distribution.
|
66 |
+
# Get upper and lower cdf values
|
67 |
+
l = norm_cdf((a - mean) / std)
|
68 |
+
u = norm_cdf((b - mean) / std)
|
69 |
+
|
70 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
71 |
+
# [2l-1, 2u-1].
|
72 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
73 |
+
|
74 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
75 |
+
# standard normal
|
76 |
+
tensor.erfinv_()
|
77 |
+
|
78 |
+
# Transform to proper mean, std
|
79 |
+
tensor.mul_(std * math.sqrt(2.))
|
80 |
+
tensor.add_(mean)
|
81 |
+
|
82 |
+
# Clamp to ensure it's in the proper range
|
83 |
+
tensor.clamp_(min=a, max=b)
|
84 |
+
return tensor
|
85 |
+
|
86 |
+
|
87 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
88 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
89 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
90 |
+
normal distribution. The values are effectively drawn from the
|
91 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
92 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
93 |
+
the bounds. The method used for generating the random values works
|
94 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
95 |
+
|
96 |
+
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
97 |
+
applied while sampling the normal with mean/std applied, therefore a, b args
|
98 |
+
should be adjusted to match the range of mean, std args.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
tensor: an n-dimensional `torch.Tensor`
|
102 |
+
mean: the mean of the normal distribution
|
103 |
+
std: the standard deviation of the normal distribution
|
104 |
+
a: the minimum cutoff value
|
105 |
+
b: the maximum cutoff value
|
106 |
+
Examples:
|
107 |
+
>>> w = torch.empty(3, 5)
|
108 |
+
>>> nn.init.trunc_normal_(w)
|
109 |
+
"""
|
110 |
+
with torch.no_grad():
|
111 |
+
return _trunc_normal_(tensor, mean, std, a, b)
|
112 |
+
|
113 |
+
class DropPath(nn.Module):
|
114 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
115 |
+
"""
|
116 |
+
def __init__(self, drop_prob=None):
|
117 |
+
super(DropPath, self).__init__()
|
118 |
+
self.drop_prob = drop_prob
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
return drop_path(x, self.drop_prob, self.training)
|
122 |
+
|
123 |
+
def extra_repr(self):
|
124 |
+
return 'p={}'.format(self.drop_prob)
|
125 |
+
|
126 |
+
class Mlp(nn.Module):
|
127 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
128 |
+
super().__init__()
|
129 |
+
out_features = out_features or in_features
|
130 |
+
hidden_features = hidden_features or in_features
|
131 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
132 |
+
self.act = act_layer()
|
133 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
134 |
+
self.drop = nn.Dropout(drop)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.fc1(x)
|
138 |
+
x = self.act(x)
|
139 |
+
x = self.fc2(x)
|
140 |
+
x = self.drop(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
class Attention(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
146 |
+
proj_drop=0., attn_head_dim=None,):
|
147 |
+
super().__init__()
|
148 |
+
self.num_heads = num_heads
|
149 |
+
head_dim = dim // num_heads
|
150 |
+
self.dim = dim
|
151 |
+
|
152 |
+
if attn_head_dim is not None:
|
153 |
+
head_dim = attn_head_dim
|
154 |
+
all_head_dim = head_dim * self.num_heads
|
155 |
+
|
156 |
+
self.scale = qk_scale or head_dim ** -0.5
|
157 |
+
|
158 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
159 |
+
|
160 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
161 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
162 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
B, N, C = x.shape
|
166 |
+
qkv = self.qkv(x)
|
167 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
168 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
169 |
+
|
170 |
+
q = q * self.scale
|
171 |
+
attn = (q @ k.transpose(-2, -1))
|
172 |
+
|
173 |
+
attn = attn.softmax(dim=-1)
|
174 |
+
attn = self.attn_drop(attn)
|
175 |
+
|
176 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
177 |
+
x = self.proj(x)
|
178 |
+
x = self.proj_drop(x)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
class Block(nn.Module):
|
183 |
+
|
184 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
185 |
+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
|
186 |
+
norm_layer=nn.LayerNorm, attn_head_dim=None
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
self.norm1 = norm_layer(dim)
|
191 |
+
self.attn = Attention(
|
192 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
193 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
|
194 |
+
)
|
195 |
+
|
196 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
197 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
198 |
+
self.norm2 = norm_layer(dim)
|
199 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
200 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
204 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
205 |
+
return x
|
206 |
+
|
207 |
+
|
208 |
+
class PatchEmbed(nn.Module):
|
209 |
+
""" Image to Patch Embedding
|
210 |
+
"""
|
211 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
212 |
+
super().__init__()
|
213 |
+
img_size = to_2tuple(img_size)
|
214 |
+
patch_size = to_2tuple(patch_size)
|
215 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
|
216 |
+
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
|
217 |
+
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
|
218 |
+
self.img_size = img_size
|
219 |
+
self.patch_size = patch_size
|
220 |
+
self.num_patches = num_patches
|
221 |
+
|
222 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
x = self.proj(x)
|
226 |
+
B, C, Hp, Wp = x.shape
|
227 |
+
x = x.view(B, C, Hp * Wp).transpose(1, 2)
|
228 |
+
return x, (Hp, Wp)
|
229 |
+
|
230 |
+
|
231 |
+
class HybridEmbed(nn.Module):
|
232 |
+
""" CNN Feature Map Embedding
|
233 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
234 |
+
"""
|
235 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
236 |
+
super().__init__()
|
237 |
+
assert isinstance(backbone, nn.Module)
|
238 |
+
img_size = to_2tuple(img_size)
|
239 |
+
self.img_size = img_size
|
240 |
+
self.backbone = backbone
|
241 |
+
if feature_size is None:
|
242 |
+
with torch.no_grad():
|
243 |
+
training = backbone.training
|
244 |
+
if training:
|
245 |
+
backbone.eval()
|
246 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
247 |
+
feature_size = o.shape[-2:]
|
248 |
+
feature_dim = o.shape[1]
|
249 |
+
backbone.train(training)
|
250 |
+
else:
|
251 |
+
feature_size = to_2tuple(feature_size)
|
252 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
253 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
254 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
x = self.backbone(x)[-1]
|
258 |
+
x = x.flatten(2).transpose(1, 2)
|
259 |
+
x = self.proj(x)
|
260 |
+
return x
|
261 |
+
|
262 |
+
|
263 |
+
class ViT(nn.Module):
|
264 |
+
def __init__(self,
|
265 |
+
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
|
266 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
267 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
|
268 |
+
frozen_stages=-1, ratio=1, last_norm=True,
|
269 |
+
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
|
270 |
+
):
|
271 |
+
super(ViT, self).__init__()
|
272 |
+
# Protect mutable default arguments
|
273 |
+
|
274 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
275 |
+
self.num_classes = num_classes
|
276 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
277 |
+
self.frozen_stages = frozen_stages
|
278 |
+
self.use_checkpoint = use_checkpoint
|
279 |
+
self.patch_padding = patch_padding
|
280 |
+
self.freeze_attn = freeze_attn
|
281 |
+
self.freeze_ffn = freeze_ffn
|
282 |
+
self.depth = depth
|
283 |
+
|
284 |
+
if hybrid_backbone is not None:
|
285 |
+
self.patch_embed = HybridEmbed(
|
286 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
287 |
+
else:
|
288 |
+
self.patch_embed = PatchEmbed(
|
289 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
|
290 |
+
num_patches = self.patch_embed.num_patches
|
291 |
+
|
292 |
+
# since the pretraining model has class token
|
293 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
294 |
+
|
295 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
296 |
+
|
297 |
+
self.blocks = nn.ModuleList([
|
298 |
+
Block(
|
299 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
300 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
301 |
+
)
|
302 |
+
for i in range(depth)])
|
303 |
+
|
304 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
305 |
+
|
306 |
+
if self.pos_embed is not None:
|
307 |
+
trunc_normal_(self.pos_embed, std=.02)
|
308 |
+
|
309 |
+
self._freeze_stages()
|
310 |
+
|
311 |
+
def _freeze_stages(self):
|
312 |
+
"""Freeze parameters."""
|
313 |
+
if self.frozen_stages >= 0:
|
314 |
+
self.patch_embed.eval()
|
315 |
+
for param in self.patch_embed.parameters():
|
316 |
+
param.requires_grad = False
|
317 |
+
|
318 |
+
for i in range(1, self.frozen_stages + 1):
|
319 |
+
m = self.blocks[i]
|
320 |
+
m.eval()
|
321 |
+
for param in m.parameters():
|
322 |
+
param.requires_grad = False
|
323 |
+
|
324 |
+
if self.freeze_attn:
|
325 |
+
for i in range(0, self.depth):
|
326 |
+
m = self.blocks[i]
|
327 |
+
m.attn.eval()
|
328 |
+
m.norm1.eval()
|
329 |
+
for param in m.attn.parameters():
|
330 |
+
param.requires_grad = False
|
331 |
+
for param in m.norm1.parameters():
|
332 |
+
param.requires_grad = False
|
333 |
+
|
334 |
+
if self.freeze_ffn:
|
335 |
+
self.pos_embed.requires_grad = False
|
336 |
+
self.patch_embed.eval()
|
337 |
+
for param in self.patch_embed.parameters():
|
338 |
+
param.requires_grad = False
|
339 |
+
for i in range(0, self.depth):
|
340 |
+
m = self.blocks[i]
|
341 |
+
m.mlp.eval()
|
342 |
+
m.norm2.eval()
|
343 |
+
for param in m.mlp.parameters():
|
344 |
+
param.requires_grad = False
|
345 |
+
for param in m.norm2.parameters():
|
346 |
+
param.requires_grad = False
|
347 |
+
|
348 |
+
def init_weights(self, pretrained=None):
|
349 |
+
"""Initialize the weights in backbone.
|
350 |
+
Args:
|
351 |
+
pretrained (str, optional): Path to pre-trained weights.
|
352 |
+
Defaults to None.
|
353 |
+
"""
|
354 |
+
super().init_weights(pretrained, patch_padding=self.patch_padding)
|
355 |
+
|
356 |
+
if pretrained is None:
|
357 |
+
def _init_weights(m):
|
358 |
+
if isinstance(m, nn.Linear):
|
359 |
+
trunc_normal_(m.weight, std=.02)
|
360 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
361 |
+
nn.init.constant_(m.bias, 0)
|
362 |
+
elif isinstance(m, nn.LayerNorm):
|
363 |
+
nn.init.constant_(m.bias, 0)
|
364 |
+
nn.init.constant_(m.weight, 1.0)
|
365 |
+
|
366 |
+
self.apply(_init_weights)
|
367 |
+
|
368 |
+
def get_num_layers(self):
|
369 |
+
return len(self.blocks)
|
370 |
+
|
371 |
+
@torch.jit.ignore
|
372 |
+
def no_weight_decay(self):
|
373 |
+
return {'pos_embed', 'cls_token'}
|
374 |
+
|
375 |
+
def forward(self, x):
|
376 |
+
B, C, H, W = x.shape
|
377 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
378 |
+
|
379 |
+
if self.pos_embed is not None:
|
380 |
+
# fit for multiple GPU training
|
381 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
382 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
383 |
+
|
384 |
+
for blk in self.blocks:
|
385 |
+
x = blk(x)
|
386 |
+
|
387 |
+
x = self.last_norm(x)
|
388 |
+
x = x.permute(0, 2, 1).view(B, -1, Hp, Wp).contiguous()
|
389 |
+
return x
|
390 |
+
|
391 |
+
def train(self, mode=True):
|
392 |
+
"""Convert the model into training mode."""
|
393 |
+
super().train(mode)
|
394 |
+
self._freeze_stages()
|
easy_ViTPose/vit_models/head/__init__.py
ADDED
File without changes
|
easy_ViTPose/vit_models/head/topdown_heatmap_base_head.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .. import keypoints_from_heatmaps
|
8 |
+
|
9 |
+
|
10 |
+
class TopdownHeatmapBaseHead(nn.Module):
|
11 |
+
"""Base class for top-down heatmap heads.
|
12 |
+
|
13 |
+
All top-down heatmap heads should subclass it.
|
14 |
+
All subclass should overwrite:
|
15 |
+
|
16 |
+
Methods:`get_loss`, supporting to calculate loss.
|
17 |
+
Methods:`get_accuracy`, supporting to calculate accuracy.
|
18 |
+
Methods:`forward`, supporting to forward model.
|
19 |
+
Methods:`inference_model`, supporting to inference model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
__metaclass__ = ABCMeta
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def get_loss(self, **kwargs):
|
26 |
+
"""Gets the loss."""
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_accuracy(self, **kwargs):
|
30 |
+
"""Gets the accuracy."""
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def forward(self, **kwargs):
|
34 |
+
"""Forward function."""
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def inference_model(self, **kwargs):
|
38 |
+
"""Inference function."""
|
39 |
+
|
40 |
+
def decode(self, img_metas, output, **kwargs):
|
41 |
+
"""Decode keypoints from heatmaps.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
img_metas (list(dict)): Information about data augmentation
|
45 |
+
By default this includes:
|
46 |
+
|
47 |
+
- "image_file: path to the image file
|
48 |
+
- "center": center of the bbox
|
49 |
+
- "scale": scale of the bbox
|
50 |
+
- "rotation": rotation of the bbox
|
51 |
+
- "bbox_score": score of bbox
|
52 |
+
output (np.ndarray[N, K, H, W]): model predicted heatmaps.
|
53 |
+
"""
|
54 |
+
batch_size = len(img_metas)
|
55 |
+
|
56 |
+
if 'bbox_id' in img_metas[0]:
|
57 |
+
bbox_ids = []
|
58 |
+
else:
|
59 |
+
bbox_ids = None
|
60 |
+
|
61 |
+
c = np.zeros((batch_size, 2), dtype=np.float32)
|
62 |
+
s = np.zeros((batch_size, 2), dtype=np.float32)
|
63 |
+
image_paths = []
|
64 |
+
score = np.ones(batch_size)
|
65 |
+
for i in range(batch_size):
|
66 |
+
c[i, :] = img_metas[i]['center']
|
67 |
+
s[i, :] = img_metas[i]['scale']
|
68 |
+
image_paths.append(img_metas[i]['image_file'])
|
69 |
+
|
70 |
+
if 'bbox_score' in img_metas[i]:
|
71 |
+
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
|
72 |
+
if bbox_ids is not None:
|
73 |
+
bbox_ids.append(img_metas[i]['bbox_id'])
|
74 |
+
|
75 |
+
preds, maxvals = keypoints_from_heatmaps(
|
76 |
+
output,
|
77 |
+
c,
|
78 |
+
s,
|
79 |
+
unbiased=self.test_cfg.get('unbiased_decoding', False),
|
80 |
+
post_process=self.test_cfg.get('post_process', 'default'),
|
81 |
+
kernel=self.test_cfg.get('modulate_kernel', 11),
|
82 |
+
valid_radius_factor=self.test_cfg.get('valid_radius_factor',
|
83 |
+
0.0546875),
|
84 |
+
use_udp=self.test_cfg.get('use_udp', False),
|
85 |
+
target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
|
86 |
+
|
87 |
+
all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
|
88 |
+
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
|
89 |
+
all_preds[:, :, 0:2] = preds[:, :, 0:2]
|
90 |
+
all_preds[:, :, 2:3] = maxvals
|
91 |
+
all_boxes[:, 0:2] = c[:, 0:2]
|
92 |
+
all_boxes[:, 2:4] = s[:, 0:2]
|
93 |
+
all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
|
94 |
+
all_boxes[:, 5] = score
|
95 |
+
|
96 |
+
result = {}
|
97 |
+
|
98 |
+
result['preds'] = all_preds
|
99 |
+
result['boxes'] = all_boxes
|
100 |
+
result['image_paths'] = image_paths
|
101 |
+
result['bbox_ids'] = bbox_ids
|
102 |
+
|
103 |
+
return result
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _get_deconv_cfg(deconv_kernel):
|
107 |
+
"""Get configurations for deconv layers."""
|
108 |
+
if deconv_kernel == 4:
|
109 |
+
padding = 1
|
110 |
+
output_padding = 0
|
111 |
+
elif deconv_kernel == 3:
|
112 |
+
padding = 1
|
113 |
+
output_padding = 1
|
114 |
+
elif deconv_kernel == 2:
|
115 |
+
padding = 0
|
116 |
+
output_padding = 0
|
117 |
+
else:
|
118 |
+
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
119 |
+
|
120 |
+
return deconv_kernel, padding, output_padding
|
easy_ViTPose/vit_models/head/topdown_heatmap_simple_head.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .. import constant_init, normal_init
|
5 |
+
|
6 |
+
from .. import pose_pck_accuracy, flip_back, resize
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
|
9 |
+
|
10 |
+
|
11 |
+
class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead):
|
12 |
+
"""Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple
|
13 |
+
Baselines for Human Pose Estimation and Tracking``.
|
14 |
+
|
15 |
+
TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
|
16 |
+
and a simple conv2d layer.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): Number of input channels
|
20 |
+
out_channels (int): Number of output channels
|
21 |
+
num_deconv_layers (int): Number of deconv layers.
|
22 |
+
num_deconv_layers should >= 0. Note that 0 means
|
23 |
+
no deconv layers.
|
24 |
+
num_deconv_filters (list|tuple): Number of filters.
|
25 |
+
If num_deconv_layers > 0, the length of
|
26 |
+
num_deconv_kernels (list|tuple): Kernel sizes.
|
27 |
+
in_index (int|Sequence[int]): Input feature index. Default: 0
|
28 |
+
input_transform (str|None): Transformation type of input features.
|
29 |
+
Options: 'resize_concat', 'multiple_select', None.
|
30 |
+
Default: None.
|
31 |
+
|
32 |
+
- 'resize_concat': Multiple feature maps will be resized to the
|
33 |
+
same size as the first one and then concat together.
|
34 |
+
Usually used in FCN head of HRNet.
|
35 |
+
- 'multiple_select': Multiple feature maps will be bundle into
|
36 |
+
a list and passed into decode head.
|
37 |
+
- None: Only one select feature map is allowed.
|
38 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
39 |
+
Default: False.
|
40 |
+
loss_keypoint (dict): Config for keypoint loss. Default: None.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
num_deconv_layers=3,
|
47 |
+
num_deconv_filters=(256, 256, 256),
|
48 |
+
num_deconv_kernels=(4, 4, 4),
|
49 |
+
extra=None,
|
50 |
+
in_index=0,
|
51 |
+
input_transform=None,
|
52 |
+
align_corners=False,
|
53 |
+
loss_keypoint=None,
|
54 |
+
train_cfg=None,
|
55 |
+
test_cfg=None,
|
56 |
+
upsample=0,):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.in_channels = in_channels
|
60 |
+
self.loss = loss_keypoint
|
61 |
+
self.upsample = upsample
|
62 |
+
|
63 |
+
self.train_cfg = {} if train_cfg is None else train_cfg
|
64 |
+
self.test_cfg = {} if test_cfg is None else test_cfg
|
65 |
+
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
|
66 |
+
|
67 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
68 |
+
self.in_index = in_index
|
69 |
+
self.align_corners = align_corners
|
70 |
+
|
71 |
+
if extra is not None and not isinstance(extra, dict):
|
72 |
+
raise TypeError('extra should be dict or None.')
|
73 |
+
|
74 |
+
if num_deconv_layers > 0:
|
75 |
+
self.deconv_layers = self._make_deconv_layer(
|
76 |
+
num_deconv_layers,
|
77 |
+
num_deconv_filters,
|
78 |
+
num_deconv_kernels,
|
79 |
+
)
|
80 |
+
elif num_deconv_layers == 0:
|
81 |
+
self.deconv_layers = nn.Identity()
|
82 |
+
else:
|
83 |
+
raise ValueError(
|
84 |
+
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
|
85 |
+
|
86 |
+
identity_final_layer = False
|
87 |
+
if extra is not None and 'final_conv_kernel' in extra:
|
88 |
+
assert extra['final_conv_kernel'] in [0, 1, 3]
|
89 |
+
if extra['final_conv_kernel'] == 3:
|
90 |
+
padding = 1
|
91 |
+
elif extra['final_conv_kernel'] == 1:
|
92 |
+
padding = 0
|
93 |
+
else:
|
94 |
+
# 0 for Identity mapping.
|
95 |
+
identity_final_layer = True
|
96 |
+
kernel_size = extra['final_conv_kernel']
|
97 |
+
else:
|
98 |
+
kernel_size = 1
|
99 |
+
padding = 0
|
100 |
+
|
101 |
+
if identity_final_layer:
|
102 |
+
self.final_layer = nn.Identity()
|
103 |
+
else:
|
104 |
+
conv_channels = num_deconv_filters[
|
105 |
+
-1] if num_deconv_layers > 0 else self.in_channels
|
106 |
+
|
107 |
+
layers = []
|
108 |
+
if extra is not None:
|
109 |
+
num_conv_layers = extra.get('num_conv_layers', 0)
|
110 |
+
num_conv_kernels = extra.get('num_conv_kernels',
|
111 |
+
[1] * num_conv_layers)
|
112 |
+
|
113 |
+
for i in range(num_conv_layers):
|
114 |
+
layers.append(
|
115 |
+
nn.Conv2d(in_channels=conv_channels,
|
116 |
+
out_channels=conv_channels,
|
117 |
+
kernel_size=num_conv_kernels[i],
|
118 |
+
stride=1,
|
119 |
+
padding=(num_conv_kernels[i] - 1) // 2)
|
120 |
+
)
|
121 |
+
layers.append(nn.BatchNorm2d(conv_channels))
|
122 |
+
layers.append(nn.ReLU(inplace=True))
|
123 |
+
|
124 |
+
layers.append(
|
125 |
+
nn.Conv2d(in_channels=conv_channels,
|
126 |
+
out_channels=out_channels,
|
127 |
+
kernel_size=kernel_size,
|
128 |
+
stride=1,
|
129 |
+
padding=padding)
|
130 |
+
)
|
131 |
+
|
132 |
+
if len(layers) > 1:
|
133 |
+
self.final_layer = nn.Sequential(*layers)
|
134 |
+
else:
|
135 |
+
self.final_layer = layers[0]
|
136 |
+
|
137 |
+
def get_loss(self, output, target, target_weight):
|
138 |
+
"""Calculate top-down keypoint loss.
|
139 |
+
|
140 |
+
Note:
|
141 |
+
- batch_size: N
|
142 |
+
- num_keypoints: K
|
143 |
+
- heatmaps height: H
|
144 |
+
- heatmaps weight: W
|
145 |
+
|
146 |
+
Args:
|
147 |
+
output (torch.Tensor[N,K,H,W]): Output heatmaps.
|
148 |
+
target (torch.Tensor[N,K,H,W]): Target heatmaps.
|
149 |
+
target_weight (torch.Tensor[N,K,1]):
|
150 |
+
Weights across different joint types.
|
151 |
+
"""
|
152 |
+
|
153 |
+
losses = dict()
|
154 |
+
|
155 |
+
assert not isinstance(self.loss, nn.Sequential)
|
156 |
+
assert target.dim() == 4 and target_weight.dim() == 3
|
157 |
+
losses['heatmap_loss'] = self.loss(output, target, target_weight)
|
158 |
+
|
159 |
+
return losses
|
160 |
+
|
161 |
+
def get_accuracy(self, output, target, target_weight):
|
162 |
+
"""Calculate accuracy for top-down keypoint loss.
|
163 |
+
|
164 |
+
Note:
|
165 |
+
- batch_size: N
|
166 |
+
- num_keypoints: K
|
167 |
+
- heatmaps height: H
|
168 |
+
- heatmaps weight: W
|
169 |
+
|
170 |
+
Args:
|
171 |
+
output (torch.Tensor[N,K,H,W]): Output heatmaps.
|
172 |
+
target (torch.Tensor[N,K,H,W]): Target heatmaps.
|
173 |
+
target_weight (torch.Tensor[N,K,1]):
|
174 |
+
Weights across different joint types.
|
175 |
+
"""
|
176 |
+
|
177 |
+
accuracy = dict()
|
178 |
+
|
179 |
+
if self.target_type == 'GaussianHeatmap':
|
180 |
+
_, avg_acc, _ = pose_pck_accuracy(
|
181 |
+
output.detach().cpu().numpy(),
|
182 |
+
target.detach().cpu().numpy(),
|
183 |
+
target_weight.detach().cpu().numpy().squeeze(-1) > 0)
|
184 |
+
accuracy['acc_pose'] = float(avg_acc)
|
185 |
+
|
186 |
+
return accuracy
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
"""Forward function."""
|
190 |
+
x = self._transform_inputs(x)
|
191 |
+
x = self.deconv_layers(x)
|
192 |
+
x = self.final_layer(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
def inference_model(self, x, flip_pairs=None):
|
196 |
+
"""Inference function.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
output_heatmap (np.ndarray): Output heatmaps.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
x (torch.Tensor[N,K,H,W]): Input features.
|
203 |
+
flip_pairs (None | list[tuple]):
|
204 |
+
Pairs of keypoints which are mirrored.
|
205 |
+
"""
|
206 |
+
output = self.forward(x)
|
207 |
+
|
208 |
+
if flip_pairs is not None:
|
209 |
+
output_heatmap = flip_back(
|
210 |
+
output.detach().cpu().numpy(),
|
211 |
+
flip_pairs,
|
212 |
+
target_type=self.target_type)
|
213 |
+
# feature is not aligned, shift flipped heatmap for higher accuracy
|
214 |
+
if self.test_cfg.get('shift_heatmap', False):
|
215 |
+
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
|
216 |
+
else:
|
217 |
+
output_heatmap = output.detach().cpu().numpy()
|
218 |
+
return output_heatmap
|
219 |
+
|
220 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
221 |
+
"""Check and initialize input transforms.
|
222 |
+
|
223 |
+
The in_channels, in_index and input_transform must match.
|
224 |
+
Specifically, when input_transform is None, only single feature map
|
225 |
+
will be selected. So in_channels and in_index must be of type int.
|
226 |
+
When input_transform is not None, in_channels and in_index must be
|
227 |
+
list or tuple, with the same length.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
in_channels (int|Sequence[int]): Input channels.
|
231 |
+
in_index (int|Sequence[int]): Input feature index.
|
232 |
+
input_transform (str|None): Transformation type of input features.
|
233 |
+
Options: 'resize_concat', 'multiple_select', None.
|
234 |
+
|
235 |
+
- 'resize_concat': Multiple feature maps will be resize to the
|
236 |
+
same size as first one and than concat together.
|
237 |
+
Usually used in FCN head of HRNet.
|
238 |
+
- 'multiple_select': Multiple feature maps will be bundle into
|
239 |
+
a list and passed into decode head.
|
240 |
+
- None: Only one select feature map is allowed.
|
241 |
+
"""
|
242 |
+
|
243 |
+
if input_transform is not None:
|
244 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
245 |
+
self.input_transform = input_transform
|
246 |
+
self.in_index = in_index
|
247 |
+
if input_transform is not None:
|
248 |
+
assert isinstance(in_channels, (list, tuple))
|
249 |
+
assert isinstance(in_index, (list, tuple))
|
250 |
+
assert len(in_channels) == len(in_index)
|
251 |
+
if input_transform == 'resize_concat':
|
252 |
+
self.in_channels = sum(in_channels)
|
253 |
+
else:
|
254 |
+
self.in_channels = in_channels
|
255 |
+
else:
|
256 |
+
assert isinstance(in_channels, int)
|
257 |
+
assert isinstance(in_index, int)
|
258 |
+
self.in_channels = in_channels
|
259 |
+
|
260 |
+
def _transform_inputs(self, inputs):
|
261 |
+
"""Transform inputs for decoder.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
inputs (list[Tensor] | Tensor): multi-level img features.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Tensor: The transformed inputs
|
268 |
+
"""
|
269 |
+
if not isinstance(inputs, list):
|
270 |
+
if self.upsample > 0:
|
271 |
+
raise NotImplementedError
|
272 |
+
return inputs
|
273 |
+
|
274 |
+
if self.input_transform == 'resize_concat':
|
275 |
+
inputs = [inputs[i] for i in self.in_index]
|
276 |
+
upsampled_inputs = [
|
277 |
+
resize(
|
278 |
+
input=x,
|
279 |
+
size=inputs[0].shape[2:],
|
280 |
+
mode='bilinear',
|
281 |
+
align_corners=self.align_corners) for x in inputs
|
282 |
+
]
|
283 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
284 |
+
elif self.input_transform == 'multiple_select':
|
285 |
+
inputs = [inputs[i] for i in self.in_index]
|
286 |
+
else:
|
287 |
+
inputs = inputs[self.in_index]
|
288 |
+
|
289 |
+
return inputs
|
290 |
+
|
291 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
292 |
+
"""Make deconv layers."""
|
293 |
+
if num_layers != len(num_filters):
|
294 |
+
error_msg = f'num_layers({num_layers}) ' \
|
295 |
+
f'!= length of num_filters({len(num_filters)})'
|
296 |
+
raise ValueError(error_msg)
|
297 |
+
if num_layers != len(num_kernels):
|
298 |
+
error_msg = f'num_layers({num_layers}) ' \
|
299 |
+
f'!= length of num_kernels({len(num_kernels)})'
|
300 |
+
raise ValueError(error_msg)
|
301 |
+
|
302 |
+
layers = []
|
303 |
+
for i in range(num_layers):
|
304 |
+
kernel, padding, output_padding = \
|
305 |
+
self._get_deconv_cfg(num_kernels[i])
|
306 |
+
|
307 |
+
planes = num_filters[i]
|
308 |
+
layers.append(
|
309 |
+
nn.ConvTranspose2d(in_channels=self.in_channels,
|
310 |
+
out_channels=planes,
|
311 |
+
kernel_size=kernel,
|
312 |
+
stride=2,
|
313 |
+
padding=padding,
|
314 |
+
output_padding=output_padding,
|
315 |
+
bias=False)
|
316 |
+
)
|
317 |
+
layers.append(nn.BatchNorm2d(planes))
|
318 |
+
layers.append(nn.ReLU(inplace=True))
|
319 |
+
self.in_channels = planes
|
320 |
+
|
321 |
+
return nn.Sequential(*layers)
|
322 |
+
|
323 |
+
def init_weights(self):
|
324 |
+
"""Initialize model weights."""
|
325 |
+
for _, m in self.deconv_layers.named_modules():
|
326 |
+
if isinstance(m, nn.ConvTranspose2d):
|
327 |
+
normal_init(m, std=0.001)
|
328 |
+
elif isinstance(m, nn.BatchNorm2d):
|
329 |
+
constant_init(m, 1)
|
330 |
+
for m in self.final_layer.modules():
|
331 |
+
if isinstance(m, nn.Conv2d):
|
332 |
+
normal_init(m, std=0.001, bias=0)
|
333 |
+
elif isinstance(m, nn.BatchNorm2d):
|
334 |
+
constant_init(m, 1)
|
easy_ViTPose/vit_models/losses/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .classfication_loss import BCELoss
|
3 |
+
from .heatmap_loss import AdaptiveWingLoss
|
4 |
+
from .mesh_loss import GANLoss, MeshLoss
|
5 |
+
from .mse_loss import JointsMSELoss, JointsOHKMMSELoss
|
6 |
+
from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory
|
7 |
+
from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss,
|
8 |
+
SemiSupervisionLoss, SmoothL1Loss, SoftWingLoss,
|
9 |
+
WingLoss)
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'JointsMSELoss', 'JointsOHKMMSELoss', 'HeatmapLoss', 'AELoss',
|
13 |
+
'MultiLossFactory', 'MeshLoss', 'GANLoss', 'SmoothL1Loss', 'WingLoss',
|
14 |
+
'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss',
|
15 |
+
'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss'
|
16 |
+
]
|
easy_ViTPose/vit_models/losses/classfication_loss.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = ['BCELoss']
|
7 |
+
|
8 |
+
|
9 |
+
class BCELoss(nn.Module):
|
10 |
+
"""Binary Cross Entropy loss."""
|
11 |
+
|
12 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
13 |
+
super().__init__()
|
14 |
+
self.criterion = F.binary_cross_entropy
|
15 |
+
self.use_target_weight = use_target_weight
|
16 |
+
self.loss_weight = loss_weight
|
17 |
+
|
18 |
+
def forward(self, output, target, target_weight=None):
|
19 |
+
"""Forward function.
|
20 |
+
|
21 |
+
Note:
|
22 |
+
- batch_size: N
|
23 |
+
- num_labels: K
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output (torch.Tensor[N, K]): Output classification.
|
27 |
+
target (torch.Tensor[N, K]): Target classification.
|
28 |
+
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
|
29 |
+
Weights across different labels.
|
30 |
+
"""
|
31 |
+
|
32 |
+
if self.use_target_weight:
|
33 |
+
assert target_weight is not None
|
34 |
+
loss = self.criterion(output, target, reduction='none')
|
35 |
+
if target_weight.dim() == 1:
|
36 |
+
target_weight = target_weight[:, None]
|
37 |
+
loss = (loss * target_weight).mean()
|
38 |
+
else:
|
39 |
+
loss = self.criterion(output, target)
|
40 |
+
|
41 |
+
return loss * self.loss_weight
|
easy_ViTPose/vit_models/losses/heatmap_loss.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class AdaptiveWingLoss(nn.Module):
|
7 |
+
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
|
8 |
+
Alignment via Heatmap Regression' Wang et al. ICCV'2019.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
alpha (float), omega (float), epsilon (float), theta (float)
|
12 |
+
are hyper-parameters.
|
13 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
14 |
+
Different joint types may have different target weights.
|
15 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
alpha=2.1,
|
20 |
+
omega=14,
|
21 |
+
epsilon=1,
|
22 |
+
theta=0.5,
|
23 |
+
use_target_weight=False,
|
24 |
+
loss_weight=1.):
|
25 |
+
super().__init__()
|
26 |
+
self.alpha = float(alpha)
|
27 |
+
self.omega = float(omega)
|
28 |
+
self.epsilon = float(epsilon)
|
29 |
+
self.theta = float(theta)
|
30 |
+
self.use_target_weight = use_target_weight
|
31 |
+
self.loss_weight = loss_weight
|
32 |
+
|
33 |
+
def criterion(self, pred, target):
|
34 |
+
"""Criterion of wingloss.
|
35 |
+
|
36 |
+
Note:
|
37 |
+
batch_size: N
|
38 |
+
num_keypoints: K
|
39 |
+
|
40 |
+
Args:
|
41 |
+
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
|
42 |
+
target (torch.Tensor[NxKxHxW]): Target heatmaps.
|
43 |
+
"""
|
44 |
+
H, W = pred.shape[2:4]
|
45 |
+
delta = (target - pred).abs()
|
46 |
+
|
47 |
+
A = self.omega * (
|
48 |
+
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
|
49 |
+
) * (self.alpha - target) * (torch.pow(
|
50 |
+
self.theta / self.epsilon,
|
51 |
+
self.alpha - target - 1)) * (1 / self.epsilon)
|
52 |
+
C = self.theta * A - self.omega * torch.log(
|
53 |
+
1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
|
54 |
+
|
55 |
+
losses = torch.where(
|
56 |
+
delta < self.theta,
|
57 |
+
self.omega *
|
58 |
+
torch.log(1 +
|
59 |
+
torch.pow(delta / self.epsilon, self.alpha - target)),
|
60 |
+
A * delta - C)
|
61 |
+
|
62 |
+
return torch.mean(losses)
|
63 |
+
|
64 |
+
def forward(self, output, target, target_weight):
|
65 |
+
"""Forward function.
|
66 |
+
|
67 |
+
Note:
|
68 |
+
batch_size: N
|
69 |
+
num_keypoints: K
|
70 |
+
|
71 |
+
Args:
|
72 |
+
output (torch.Tensor[NxKxHxW]): Output heatmaps.
|
73 |
+
target (torch.Tensor[NxKxHxW]): Target heatmaps.
|
74 |
+
target_weight (torch.Tensor[NxKx1]):
|
75 |
+
Weights across different joint types.
|
76 |
+
"""
|
77 |
+
if self.use_target_weight:
|
78 |
+
loss = self.criterion(output * target_weight.unsqueeze(-1),
|
79 |
+
target * target_weight.unsqueeze(-1))
|
80 |
+
else:
|
81 |
+
loss = self.criterion(output, target)
|
82 |
+
|
83 |
+
return loss * self.loss_weight
|
easy_ViTPose/vit_models/losses/mesh_loss.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
__all__ = ['MeshLoss', 'GANLoss']
|
6 |
+
|
7 |
+
def rot6d_to_rotmat(x):
|
8 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
9 |
+
|
10 |
+
Based on Zhou et al., "On the Continuity of Rotation
|
11 |
+
Representations in Neural Networks", CVPR 2019
|
12 |
+
Input:
|
13 |
+
(B,6) Batch of 6-D rotation representations
|
14 |
+
Output:
|
15 |
+
(B,3,3) Batch of corresponding rotation matrices
|
16 |
+
"""
|
17 |
+
x = x.view(-1, 3, 2)
|
18 |
+
a1 = x[:, :, 0]
|
19 |
+
a2 = x[:, :, 1]
|
20 |
+
b1 = F.normalize(a1)
|
21 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
22 |
+
b3 = torch.cross(b1, b2)
|
23 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
24 |
+
|
25 |
+
|
26 |
+
def batch_rodrigues(theta):
|
27 |
+
"""Convert axis-angle representation to rotation matrix.
|
28 |
+
Args:
|
29 |
+
theta: size = [B, 3]
|
30 |
+
Returns:
|
31 |
+
Rotation matrix corresponding to the quaternion
|
32 |
+
-- size = [B, 3, 3]
|
33 |
+
"""
|
34 |
+
l2norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
35 |
+
angle = torch.unsqueeze(l2norm, -1)
|
36 |
+
normalized = torch.div(theta, angle)
|
37 |
+
angle = angle * 0.5
|
38 |
+
v_cos = torch.cos(angle)
|
39 |
+
v_sin = torch.sin(angle)
|
40 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
41 |
+
return quat_to_rotmat(quat)
|
42 |
+
|
43 |
+
|
44 |
+
def quat_to_rotmat(quat):
|
45 |
+
"""Convert quaternion coefficients to rotation matrix.
|
46 |
+
Args:
|
47 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
48 |
+
Returns:
|
49 |
+
Rotation matrix corresponding to the quaternion
|
50 |
+
-- size = [B, 3, 3]
|
51 |
+
"""
|
52 |
+
norm_quat = quat
|
53 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
54 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\
|
55 |
+
norm_quat[:, 2], norm_quat[:, 3]
|
56 |
+
|
57 |
+
B = quat.size(0)
|
58 |
+
|
59 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
60 |
+
wx, wy, wz = w * x, w * y, w * z
|
61 |
+
xy, xz, yz = x * y, x * z, y * z
|
62 |
+
|
63 |
+
rotMat = torch.stack([
|
64 |
+
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
|
65 |
+
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
|
66 |
+
w2 - x2 - y2 + z2
|
67 |
+
],
|
68 |
+
dim=1).view(B, 3, 3)
|
69 |
+
return rotMat
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def perspective_projection(points, rotation, translation, focal_length,
|
74 |
+
camera_center):
|
75 |
+
"""This function computes the perspective projection of a set of 3D points.
|
76 |
+
|
77 |
+
Note:
|
78 |
+
- batch size: B
|
79 |
+
- point number: N
|
80 |
+
|
81 |
+
Args:
|
82 |
+
points (Tensor([B, N, 3])): A set of 3D points
|
83 |
+
rotation (Tensor([B, 3, 3])): Camera rotation matrix
|
84 |
+
translation (Tensor([B, 3])): Camera translation
|
85 |
+
focal_length (Tensor([B,])): Focal length
|
86 |
+
camera_center (Tensor([B, 2])): Camera center
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
projected_points (Tensor([B, N, 2])): Projected 2D
|
90 |
+
points in image space.
|
91 |
+
"""
|
92 |
+
|
93 |
+
batch_size = points.shape[0]
|
94 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device)
|
95 |
+
K[:, 0, 0] = focal_length
|
96 |
+
K[:, 1, 1] = focal_length
|
97 |
+
K[:, 2, 2] = 1.
|
98 |
+
K[:, :-1, -1] = camera_center
|
99 |
+
|
100 |
+
# Transform points
|
101 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
102 |
+
points = points + translation.unsqueeze(1)
|
103 |
+
|
104 |
+
# Apply perspective distortion
|
105 |
+
projected_points = points / points[:, :, -1].unsqueeze(-1)
|
106 |
+
|
107 |
+
# Apply camera intrinsics
|
108 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
109 |
+
projected_points = projected_points[:, :, :-1]
|
110 |
+
return projected_points
|
111 |
+
|
112 |
+
|
113 |
+
class MeshLoss(nn.Module):
|
114 |
+
"""Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D
|
115 |
+
joints, mesh vertices and smpl parameters (if any).
|
116 |
+
|
117 |
+
Args:
|
118 |
+
joints_2d_loss_weight (float): Weight for loss on 2D joints.
|
119 |
+
joints_3d_loss_weight (float): Weight for loss on 3D joints.
|
120 |
+
vertex_loss_weight (float): Weight for loss on 3D verteices.
|
121 |
+
smpl_pose_loss_weight (float): Weight for loss on SMPL
|
122 |
+
pose parameters.
|
123 |
+
smpl_beta_loss_weight (float): Weight for loss on SMPL
|
124 |
+
shape parameters.
|
125 |
+
img_res (int): Input image resolution.
|
126 |
+
focal_length (float): Focal length of camera model. Default=5000.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self,
|
130 |
+
joints_2d_loss_weight,
|
131 |
+
joints_3d_loss_weight,
|
132 |
+
vertex_loss_weight,
|
133 |
+
smpl_pose_loss_weight,
|
134 |
+
smpl_beta_loss_weight,
|
135 |
+
img_res,
|
136 |
+
focal_length=5000):
|
137 |
+
|
138 |
+
super().__init__()
|
139 |
+
# Per-vertex loss on the mesh
|
140 |
+
self.criterion_vertex = nn.L1Loss(reduction='none')
|
141 |
+
|
142 |
+
# Joints (2D and 3D) loss
|
143 |
+
self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none')
|
144 |
+
self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none')
|
145 |
+
|
146 |
+
# Loss for SMPL parameter regression
|
147 |
+
self.criterion_regr = nn.MSELoss(reduction='none')
|
148 |
+
|
149 |
+
self.joints_2d_loss_weight = joints_2d_loss_weight
|
150 |
+
self.joints_3d_loss_weight = joints_3d_loss_weight
|
151 |
+
self.vertex_loss_weight = vertex_loss_weight
|
152 |
+
self.smpl_pose_loss_weight = smpl_pose_loss_weight
|
153 |
+
self.smpl_beta_loss_weight = smpl_beta_loss_weight
|
154 |
+
self.focal_length = focal_length
|
155 |
+
self.img_res = img_res
|
156 |
+
|
157 |
+
def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible):
|
158 |
+
"""Compute 2D reprojection loss on the joints.
|
159 |
+
|
160 |
+
The loss is weighted by joints_2d_visible.
|
161 |
+
"""
|
162 |
+
conf = joints_2d_visible.float()
|
163 |
+
loss = (conf *
|
164 |
+
self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean()
|
165 |
+
return loss
|
166 |
+
|
167 |
+
def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible):
|
168 |
+
"""Compute 3D joints loss for the examples that 3D joint annotations
|
169 |
+
are available.
|
170 |
+
|
171 |
+
The loss is weighted by joints_3d_visible.
|
172 |
+
"""
|
173 |
+
conf = joints_3d_visible.float()
|
174 |
+
if len(gt_joints_3d) > 0:
|
175 |
+
gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2
|
176 |
+
gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :]
|
177 |
+
pred_pelvis = (pred_joints_3d[:, 2, :] +
|
178 |
+
pred_joints_3d[:, 3, :]) / 2
|
179 |
+
pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :]
|
180 |
+
return (
|
181 |
+
conf *
|
182 |
+
self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean()
|
183 |
+
return pred_joints_3d.sum() * 0
|
184 |
+
|
185 |
+
def vertex_loss(self, pred_vertices, gt_vertices, has_smpl):
|
186 |
+
"""Compute 3D vertex loss for the examples that 3D human mesh
|
187 |
+
annotations are available.
|
188 |
+
|
189 |
+
The loss is weighted by the has_smpl.
|
190 |
+
"""
|
191 |
+
conf = has_smpl.float()
|
192 |
+
loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices)
|
193 |
+
loss_vertex = (conf[:, None, None] * loss_vertex).mean()
|
194 |
+
return loss_vertex
|
195 |
+
|
196 |
+
def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
|
197 |
+
has_smpl):
|
198 |
+
"""Compute SMPL parameters loss for the examples that SMPL parameter
|
199 |
+
annotations are available.
|
200 |
+
|
201 |
+
The loss is weighted by has_smpl.
|
202 |
+
"""
|
203 |
+
conf = has_smpl.float()
|
204 |
+
gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
|
205 |
+
loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat)
|
206 |
+
loss_regr_betas = self.criterion_regr(pred_betas, gt_betas)
|
207 |
+
loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean()
|
208 |
+
loss_regr_betas = (conf[:, None] * loss_regr_betas).mean()
|
209 |
+
return loss_regr_pose, loss_regr_betas
|
210 |
+
|
211 |
+
def project_points(self, points_3d, camera):
|
212 |
+
"""Perform orthographic projection of 3D points using the camera
|
213 |
+
parameters, return projected 2D points in image plane.
|
214 |
+
|
215 |
+
Note:
|
216 |
+
- batch size: B
|
217 |
+
- point number: N
|
218 |
+
|
219 |
+
Args:
|
220 |
+
points_3d (Tensor([B, N, 3])): 3D points.
|
221 |
+
camera (Tensor([B, 3])): camera parameters with the
|
222 |
+
3 channel as (scale, translation_x, translation_y)
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
Tensor([B, N, 2]): projected 2D points \
|
226 |
+
in image space.
|
227 |
+
"""
|
228 |
+
batch_size = points_3d.shape[0]
|
229 |
+
device = points_3d.device
|
230 |
+
cam_t = torch.stack([
|
231 |
+
camera[:, 1], camera[:, 2], 2 * self.focal_length /
|
232 |
+
(self.img_res * camera[:, 0] + 1e-9)
|
233 |
+
],
|
234 |
+
dim=-1)
|
235 |
+
camera_center = camera.new_zeros([batch_size, 2])
|
236 |
+
rot_t = torch.eye(
|
237 |
+
3, device=device,
|
238 |
+
dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
239 |
+
joints_2d = perspective_projection(
|
240 |
+
points_3d,
|
241 |
+
rotation=rot_t,
|
242 |
+
translation=cam_t,
|
243 |
+
focal_length=self.focal_length,
|
244 |
+
camera_center=camera_center)
|
245 |
+
return joints_2d
|
246 |
+
|
247 |
+
def forward(self, output, target):
|
248 |
+
"""Forward function.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
output (dict): dict of network predicted results.
|
252 |
+
Keys: 'vertices', 'joints_3d', 'camera',
|
253 |
+
'pose'(optional), 'beta'(optional)
|
254 |
+
target (dict): dict of ground-truth labels.
|
255 |
+
Keys: 'vertices', 'joints_3d', 'joints_3d_visible',
|
256 |
+
'joints_2d', 'joints_2d_visible', 'pose', 'beta',
|
257 |
+
'has_smpl'
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
dict: dict of losses.
|
261 |
+
"""
|
262 |
+
losses = {}
|
263 |
+
|
264 |
+
# Per-vertex loss for the shape
|
265 |
+
pred_vertices = output['vertices']
|
266 |
+
|
267 |
+
gt_vertices = target['vertices']
|
268 |
+
has_smpl = target['has_smpl']
|
269 |
+
loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl)
|
270 |
+
losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight
|
271 |
+
|
272 |
+
# Compute loss on SMPL parameters, if available
|
273 |
+
if 'pose' in output.keys() and 'beta' in output.keys():
|
274 |
+
pred_rotmat = output['pose']
|
275 |
+
pred_betas = output['beta']
|
276 |
+
gt_pose = target['pose']
|
277 |
+
gt_betas = target['beta']
|
278 |
+
loss_regr_pose, loss_regr_betas = self.smpl_losses(
|
279 |
+
pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl)
|
280 |
+
losses['smpl_pose_loss'] = \
|
281 |
+
loss_regr_pose * self.smpl_pose_loss_weight
|
282 |
+
losses['smpl_beta_loss'] = \
|
283 |
+
loss_regr_betas * self.smpl_beta_loss_weight
|
284 |
+
|
285 |
+
# Compute 3D joints loss
|
286 |
+
pred_joints_3d = output['joints_3d']
|
287 |
+
gt_joints_3d = target['joints_3d']
|
288 |
+
joints_3d_visible = target['joints_3d_visible']
|
289 |
+
loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d,
|
290 |
+
joints_3d_visible)
|
291 |
+
losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight
|
292 |
+
|
293 |
+
# Compute 2D reprojection loss for the 2D joints
|
294 |
+
pred_camera = output['camera']
|
295 |
+
gt_joints_2d = target['joints_2d']
|
296 |
+
joints_2d_visible = target['joints_2d_visible']
|
297 |
+
pred_joints_2d = self.project_points(pred_joints_3d, pred_camera)
|
298 |
+
|
299 |
+
# Normalize keypoints to [-1,1]
|
300 |
+
# The coordinate origin of pred_joints_2d is
|
301 |
+
# the center of the input image.
|
302 |
+
pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1)
|
303 |
+
# The coordinate origin of gt_joints_2d is
|
304 |
+
# the top left corner of the input image.
|
305 |
+
gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1
|
306 |
+
loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d,
|
307 |
+
joints_2d_visible)
|
308 |
+
losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight
|
309 |
+
|
310 |
+
return losses
|
311 |
+
|
312 |
+
|
313 |
+
class GANLoss(nn.Module):
|
314 |
+
"""Define GAN loss.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
318 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
319 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
320 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
321 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
322 |
+
for discriminators.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self,
|
326 |
+
gan_type,
|
327 |
+
real_label_val=1.0,
|
328 |
+
fake_label_val=0.0,
|
329 |
+
loss_weight=1.0):
|
330 |
+
super().__init__()
|
331 |
+
self.gan_type = gan_type
|
332 |
+
self.loss_weight = loss_weight
|
333 |
+
self.real_label_val = real_label_val
|
334 |
+
self.fake_label_val = fake_label_val
|
335 |
+
|
336 |
+
if self.gan_type == 'vanilla':
|
337 |
+
self.loss = nn.BCEWithLogitsLoss()
|
338 |
+
elif self.gan_type == 'lsgan':
|
339 |
+
self.loss = nn.MSELoss()
|
340 |
+
elif self.gan_type == 'wgan':
|
341 |
+
self.loss = self._wgan_loss
|
342 |
+
elif self.gan_type == 'hinge':
|
343 |
+
self.loss = nn.ReLU()
|
344 |
+
else:
|
345 |
+
raise NotImplementedError(
|
346 |
+
f'GAN type {self.gan_type} is not implemented.')
|
347 |
+
|
348 |
+
@staticmethod
|
349 |
+
def _wgan_loss(input, target):
|
350 |
+
"""wgan loss.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
input (Tensor): Input tensor.
|
354 |
+
target (bool): Target label.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
Tensor: wgan loss.
|
358 |
+
"""
|
359 |
+
return -input.mean() if target else input.mean()
|
360 |
+
|
361 |
+
def get_target_label(self, input, target_is_real):
|
362 |
+
"""Get target label.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
input (Tensor): Input tensor.
|
366 |
+
target_is_real (bool): Whether the target is real or fake.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
(bool | Tensor): Target tensor. Return bool for wgan, \
|
370 |
+
otherwise, return Tensor.
|
371 |
+
"""
|
372 |
+
|
373 |
+
if self.gan_type == 'wgan':
|
374 |
+
return target_is_real
|
375 |
+
target_val = (
|
376 |
+
self.real_label_val if target_is_real else self.fake_label_val)
|
377 |
+
return input.new_ones(input.size()) * target_val
|
378 |
+
|
379 |
+
def forward(self, input, target_is_real, is_disc=False):
|
380 |
+
"""
|
381 |
+
Args:
|
382 |
+
input (Tensor): The input for the loss module, i.e., the network
|
383 |
+
prediction.
|
384 |
+
target_is_real (bool): Whether the targe is real or fake.
|
385 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
386 |
+
Default: False.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
Tensor: GAN loss value.
|
390 |
+
"""
|
391 |
+
target_label = self.get_target_label(input, target_is_real)
|
392 |
+
if self.gan_type == 'hinge':
|
393 |
+
if is_disc: # for discriminators in hinge-gan
|
394 |
+
input = -input if target_is_real else input
|
395 |
+
loss = self.loss(1 + input).mean()
|
396 |
+
else: # for generators in hinge-gan
|
397 |
+
loss = -input.mean()
|
398 |
+
else: # other gan types
|
399 |
+
loss = self.loss(input, target_label)
|
400 |
+
|
401 |
+
# loss_weight is always 1.0 for discriminators
|
402 |
+
return loss if is_disc else loss * self.loss_weight
|
easy_ViTPose/vit_models/losses/mse_loss.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = ['JointsMSELoss', 'JointsOHKMMSELoss',]
|
7 |
+
|
8 |
+
|
9 |
+
class JointsMSELoss(nn.Module):
|
10 |
+
"""MSE loss for heatmaps.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
14 |
+
Different joint types may have different target weights.
|
15 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
19 |
+
super().__init__()
|
20 |
+
self.criterion = nn.MSELoss()
|
21 |
+
self.use_target_weight = use_target_weight
|
22 |
+
self.loss_weight = loss_weight
|
23 |
+
|
24 |
+
def forward(self, output, target, target_weight):
|
25 |
+
"""Forward function."""
|
26 |
+
batch_size = output.size(0)
|
27 |
+
num_joints = output.size(1)
|
28 |
+
|
29 |
+
heatmaps_pred = output.reshape(
|
30 |
+
(batch_size, num_joints, -1)).split(1, 1)
|
31 |
+
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
|
32 |
+
|
33 |
+
loss = 0.
|
34 |
+
|
35 |
+
for idx in range(num_joints):
|
36 |
+
heatmap_pred = heatmaps_pred[idx].squeeze(1)
|
37 |
+
heatmap_gt = heatmaps_gt[idx].squeeze(1)
|
38 |
+
if self.use_target_weight:
|
39 |
+
loss += self.criterion(heatmap_pred * target_weight[:, idx],
|
40 |
+
heatmap_gt * target_weight[:, idx])
|
41 |
+
else:
|
42 |
+
loss += self.criterion(heatmap_pred, heatmap_gt)
|
43 |
+
|
44 |
+
return loss / num_joints * self.loss_weight
|
45 |
+
|
46 |
+
|
47 |
+
class CombinedTargetMSELoss(nn.Module):
|
48 |
+
"""MSE loss for combined target.
|
49 |
+
CombinedTarget: The combination of classification target
|
50 |
+
(response map) and regression target (offset map).
|
51 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
52 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
53 |
+
|
54 |
+
Args:
|
55 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
56 |
+
Different joint types may have different target weights.
|
57 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, use_target_weight, loss_weight=1.):
|
61 |
+
super().__init__()
|
62 |
+
self.criterion = nn.MSELoss(reduction='mean')
|
63 |
+
self.use_target_weight = use_target_weight
|
64 |
+
self.loss_weight = loss_weight
|
65 |
+
|
66 |
+
def forward(self, output, target, target_weight):
|
67 |
+
batch_size = output.size(0)
|
68 |
+
num_channels = output.size(1)
|
69 |
+
heatmaps_pred = output.reshape(
|
70 |
+
(batch_size, num_channels, -1)).split(1, 1)
|
71 |
+
heatmaps_gt = target.reshape(
|
72 |
+
(batch_size, num_channels, -1)).split(1, 1)
|
73 |
+
loss = 0.
|
74 |
+
num_joints = num_channels // 3
|
75 |
+
for idx in range(num_joints):
|
76 |
+
heatmap_pred = heatmaps_pred[idx * 3].squeeze()
|
77 |
+
heatmap_gt = heatmaps_gt[idx * 3].squeeze()
|
78 |
+
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
|
79 |
+
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
|
80 |
+
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
|
81 |
+
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
|
82 |
+
if self.use_target_weight:
|
83 |
+
heatmap_pred = heatmap_pred * target_weight[:, idx]
|
84 |
+
heatmap_gt = heatmap_gt * target_weight[:, idx]
|
85 |
+
# classification loss
|
86 |
+
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
|
87 |
+
# regression loss
|
88 |
+
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
|
89 |
+
heatmap_gt * offset_x_gt)
|
90 |
+
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
|
91 |
+
heatmap_gt * offset_y_gt)
|
92 |
+
return loss / num_joints * self.loss_weight
|
93 |
+
|
94 |
+
|
95 |
+
class JointsOHKMMSELoss(nn.Module):
|
96 |
+
"""MSE loss with online hard keypoint mining.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
100 |
+
Different joint types may have different target weights.
|
101 |
+
topk (int): Only top k joint losses are kept.
|
102 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, use_target_weight=False, topk=8, loss_weight=1.):
|
106 |
+
super().__init__()
|
107 |
+
assert topk > 0
|
108 |
+
self.criterion = nn.MSELoss(reduction='none')
|
109 |
+
self.use_target_weight = use_target_weight
|
110 |
+
self.topk = topk
|
111 |
+
self.loss_weight = loss_weight
|
112 |
+
|
113 |
+
def _ohkm(self, loss):
|
114 |
+
"""Online hard keypoint mining."""
|
115 |
+
ohkm_loss = 0.
|
116 |
+
N = len(loss)
|
117 |
+
for i in range(N):
|
118 |
+
sub_loss = loss[i]
|
119 |
+
_, topk_idx = torch.topk(
|
120 |
+
sub_loss, k=self.topk, dim=0, sorted=False)
|
121 |
+
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
|
122 |
+
ohkm_loss += torch.sum(tmp_loss) / self.topk
|
123 |
+
ohkm_loss /= N
|
124 |
+
return ohkm_loss
|
125 |
+
|
126 |
+
def forward(self, output, target, target_weight):
|
127 |
+
"""Forward function."""
|
128 |
+
batch_size = output.size(0)
|
129 |
+
num_joints = output.size(1)
|
130 |
+
if num_joints < self.topk:
|
131 |
+
raise ValueError(f'topk ({self.topk}) should not '
|
132 |
+
f'larger than num_joints ({num_joints}).')
|
133 |
+
heatmaps_pred = output.reshape(
|
134 |
+
(batch_size, num_joints, -1)).split(1, 1)
|
135 |
+
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
|
136 |
+
|
137 |
+
losses = []
|
138 |
+
for idx in range(num_joints):
|
139 |
+
heatmap_pred = heatmaps_pred[idx].squeeze(1)
|
140 |
+
heatmap_gt = heatmaps_gt[idx].squeeze(1)
|
141 |
+
if self.use_target_weight:
|
142 |
+
losses.append(
|
143 |
+
self.criterion(heatmap_pred * target_weight[:, idx],
|
144 |
+
heatmap_gt * target_weight[:, idx]))
|
145 |
+
else:
|
146 |
+
losses.append(self.criterion(heatmap_pred, heatmap_gt))
|
147 |
+
|
148 |
+
losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses]
|
149 |
+
losses = torch.cat(losses, dim=1)
|
150 |
+
|
151 |
+
return self._ohkm(losses) * self.loss_weight
|
easy_ViTPose/vit_models/losses/multi_loss_factory.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation
|
3 |
+
# Original licence: Copyright (c) Microsoft, under the MIT License.
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ['HeatmapLoss', 'AELoss', 'MultiLossFactory']
|
11 |
+
|
12 |
+
|
13 |
+
def _make_input(t, requires_grad=False, device=torch.device('cpu')):
|
14 |
+
"""Make zero inputs for AE loss.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
t (torch.Tensor): input
|
18 |
+
requires_grad (bool): Option to use requires_grad.
|
19 |
+
device: torch device
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: zero input.
|
23 |
+
"""
|
24 |
+
inp = torch.autograd.Variable(t, requires_grad=requires_grad)
|
25 |
+
inp = inp.sum()
|
26 |
+
inp = inp.to(device)
|
27 |
+
return inp
|
28 |
+
|
29 |
+
|
30 |
+
class HeatmapLoss(nn.Module):
|
31 |
+
"""Accumulate the heatmap loss for each image in the batch.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
supervise_empty (bool): Whether to supervise empty channels.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, supervise_empty=True):
|
38 |
+
super().__init__()
|
39 |
+
self.supervise_empty = supervise_empty
|
40 |
+
|
41 |
+
def forward(self, pred, gt, mask):
|
42 |
+
"""Forward function.
|
43 |
+
|
44 |
+
Note:
|
45 |
+
- batch_size: N
|
46 |
+
- heatmaps weight: W
|
47 |
+
- heatmaps height: H
|
48 |
+
- max_num_people: M
|
49 |
+
- num_keypoints: K
|
50 |
+
|
51 |
+
Args:
|
52 |
+
pred (torch.Tensor[N,K,H,W]):heatmap of output.
|
53 |
+
gt (torch.Tensor[N,K,H,W]): target heatmap.
|
54 |
+
mask (torch.Tensor[N,H,W]): mask of target.
|
55 |
+
"""
|
56 |
+
assert pred.size() == gt.size(
|
57 |
+
), f'pred.size() is {pred.size()}, gt.size() is {gt.size()}'
|
58 |
+
|
59 |
+
if not self.supervise_empty:
|
60 |
+
empty_mask = (gt.sum(dim=[2, 3], keepdim=True) > 0).float()
|
61 |
+
loss = ((pred - gt)**2) * empty_mask.expand_as(
|
62 |
+
pred) * mask[:, None, :, :].expand_as(pred)
|
63 |
+
else:
|
64 |
+
loss = ((pred - gt)**2) * mask[:, None, :, :].expand_as(pred)
|
65 |
+
loss = loss.mean(dim=3).mean(dim=2).mean(dim=1)
|
66 |
+
return loss
|
67 |
+
|
68 |
+
|
69 |
+
class AELoss(nn.Module):
|
70 |
+
"""Associative Embedding loss.
|
71 |
+
|
72 |
+
`Associative Embedding: End-to-End Learning for Joint Detection and
|
73 |
+
Grouping <https://arxiv.org/abs/1611.05424v2>`_.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, loss_type):
|
77 |
+
super().__init__()
|
78 |
+
self.loss_type = loss_type
|
79 |
+
|
80 |
+
def singleTagLoss(self, pred_tag, joints):
|
81 |
+
"""Associative embedding loss for one image.
|
82 |
+
|
83 |
+
Note:
|
84 |
+
- heatmaps weight: W
|
85 |
+
- heatmaps height: H
|
86 |
+
- max_num_people: M
|
87 |
+
- num_keypoints: K
|
88 |
+
|
89 |
+
Args:
|
90 |
+
pred_tag (torch.Tensor[KxHxW,1]): tag of output for one image.
|
91 |
+
joints (torch.Tensor[M,K,2]): joints information for one image.
|
92 |
+
"""
|
93 |
+
tags = []
|
94 |
+
pull = 0
|
95 |
+
for joints_per_person in joints:
|
96 |
+
tmp = []
|
97 |
+
for joint in joints_per_person:
|
98 |
+
if joint[1] > 0:
|
99 |
+
tmp.append(pred_tag[joint[0]])
|
100 |
+
if len(tmp) == 0:
|
101 |
+
continue
|
102 |
+
tmp = torch.stack(tmp)
|
103 |
+
tags.append(torch.mean(tmp, dim=0))
|
104 |
+
pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)
|
105 |
+
|
106 |
+
num_tags = len(tags)
|
107 |
+
if num_tags == 0:
|
108 |
+
return (
|
109 |
+
_make_input(torch.zeros(1).float(), device=pred_tag.device),
|
110 |
+
_make_input(torch.zeros(1).float(), device=pred_tag.device))
|
111 |
+
elif num_tags == 1:
|
112 |
+
return (_make_input(
|
113 |
+
torch.zeros(1).float(), device=pred_tag.device), pull)
|
114 |
+
|
115 |
+
tags = torch.stack(tags)
|
116 |
+
|
117 |
+
size = (num_tags, num_tags)
|
118 |
+
A = tags.expand(*size)
|
119 |
+
B = A.permute(1, 0)
|
120 |
+
|
121 |
+
diff = A - B
|
122 |
+
|
123 |
+
if self.loss_type == 'exp':
|
124 |
+
diff = torch.pow(diff, 2)
|
125 |
+
push = torch.exp(-diff)
|
126 |
+
push = torch.sum(push) - num_tags
|
127 |
+
elif self.loss_type == 'max':
|
128 |
+
diff = 1 - torch.abs(diff)
|
129 |
+
push = torch.clamp(diff, min=0).sum() - num_tags
|
130 |
+
else:
|
131 |
+
raise ValueError('Unknown ae loss type')
|
132 |
+
|
133 |
+
push_loss = push / ((num_tags - 1) * num_tags) * 0.5
|
134 |
+
pull_loss = pull / (num_tags)
|
135 |
+
|
136 |
+
return push_loss, pull_loss
|
137 |
+
|
138 |
+
def forward(self, tags, joints):
|
139 |
+
"""Accumulate the tag loss for each image in the batch.
|
140 |
+
|
141 |
+
Note:
|
142 |
+
- batch_size: N
|
143 |
+
- heatmaps weight: W
|
144 |
+
- heatmaps height: H
|
145 |
+
- max_num_people: M
|
146 |
+
- num_keypoints: K
|
147 |
+
|
148 |
+
Args:
|
149 |
+
tags (torch.Tensor[N,KxHxW,1]): tag channels of output.
|
150 |
+
joints (torch.Tensor[N,M,K,2]): joints information.
|
151 |
+
"""
|
152 |
+
pushes, pulls = [], []
|
153 |
+
joints = joints.cpu().data.numpy()
|
154 |
+
batch_size = tags.size(0)
|
155 |
+
for i in range(batch_size):
|
156 |
+
push, pull = self.singleTagLoss(tags[i], joints[i])
|
157 |
+
pushes.append(push)
|
158 |
+
pulls.append(pull)
|
159 |
+
return torch.stack(pushes), torch.stack(pulls)
|
160 |
+
|
161 |
+
|
162 |
+
class MultiLossFactory(nn.Module):
|
163 |
+
"""Loss for bottom-up models.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
num_joints (int): Number of keypoints.
|
167 |
+
num_stages (int): Number of stages.
|
168 |
+
ae_loss_type (str): Type of ae loss.
|
169 |
+
with_ae_loss (list[bool]): Use ae loss or not in multi-heatmap.
|
170 |
+
push_loss_factor (list[float]):
|
171 |
+
Parameter of push loss in multi-heatmap.
|
172 |
+
pull_loss_factor (list[float]):
|
173 |
+
Parameter of pull loss in multi-heatmap.
|
174 |
+
with_heatmap_loss (list[bool]):
|
175 |
+
Use heatmap loss or not in multi-heatmap.
|
176 |
+
heatmaps_loss_factor (list[float]):
|
177 |
+
Parameter of heatmap loss in multi-heatmap.
|
178 |
+
supervise_empty (bool): Whether to supervise empty channels.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self,
|
182 |
+
num_joints,
|
183 |
+
num_stages,
|
184 |
+
ae_loss_type,
|
185 |
+
with_ae_loss,
|
186 |
+
push_loss_factor,
|
187 |
+
pull_loss_factor,
|
188 |
+
with_heatmaps_loss,
|
189 |
+
heatmaps_loss_factor,
|
190 |
+
supervise_empty=True):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
assert isinstance(with_heatmaps_loss, (list, tuple)), \
|
194 |
+
'with_heatmaps_loss should be a list or tuple'
|
195 |
+
assert isinstance(heatmaps_loss_factor, (list, tuple)), \
|
196 |
+
'heatmaps_loss_factor should be a list or tuple'
|
197 |
+
assert isinstance(with_ae_loss, (list, tuple)), \
|
198 |
+
'with_ae_loss should be a list or tuple'
|
199 |
+
assert isinstance(push_loss_factor, (list, tuple)), \
|
200 |
+
'push_loss_factor should be a list or tuple'
|
201 |
+
assert isinstance(pull_loss_factor, (list, tuple)), \
|
202 |
+
'pull_loss_factor should be a list or tuple'
|
203 |
+
|
204 |
+
self.num_joints = num_joints
|
205 |
+
self.num_stages = num_stages
|
206 |
+
self.ae_loss_type = ae_loss_type
|
207 |
+
self.with_ae_loss = with_ae_loss
|
208 |
+
self.push_loss_factor = push_loss_factor
|
209 |
+
self.pull_loss_factor = pull_loss_factor
|
210 |
+
self.with_heatmaps_loss = with_heatmaps_loss
|
211 |
+
self.heatmaps_loss_factor = heatmaps_loss_factor
|
212 |
+
|
213 |
+
self.heatmaps_loss = \
|
214 |
+
nn.ModuleList(
|
215 |
+
[
|
216 |
+
HeatmapLoss(supervise_empty)
|
217 |
+
if with_heatmaps_loss else None
|
218 |
+
for with_heatmaps_loss in self.with_heatmaps_loss
|
219 |
+
]
|
220 |
+
)
|
221 |
+
|
222 |
+
self.ae_loss = \
|
223 |
+
nn.ModuleList(
|
224 |
+
[
|
225 |
+
AELoss(self.ae_loss_type) if with_ae_loss else None
|
226 |
+
for with_ae_loss in self.with_ae_loss
|
227 |
+
]
|
228 |
+
)
|
229 |
+
|
230 |
+
def forward(self, outputs, heatmaps, masks, joints):
|
231 |
+
"""Forward function to calculate losses.
|
232 |
+
|
233 |
+
Note:
|
234 |
+
- batch_size: N
|
235 |
+
- heatmaps weight: W
|
236 |
+
- heatmaps height: H
|
237 |
+
- max_num_people: M
|
238 |
+
- num_keypoints: K
|
239 |
+
- output_channel: C C=2K if use ae loss else K
|
240 |
+
|
241 |
+
Args:
|
242 |
+
outputs (list(torch.Tensor[N,C,H,W])): outputs of stages.
|
243 |
+
heatmaps (list(torch.Tensor[N,K,H,W])): target of heatmaps.
|
244 |
+
masks (list(torch.Tensor[N,H,W])): masks of heatmaps.
|
245 |
+
joints (list(torch.Tensor[N,M,K,2])): joints of ae loss.
|
246 |
+
"""
|
247 |
+
heatmaps_losses = []
|
248 |
+
push_losses = []
|
249 |
+
pull_losses = []
|
250 |
+
for idx in range(len(outputs)):
|
251 |
+
offset_feat = 0
|
252 |
+
if self.heatmaps_loss[idx]:
|
253 |
+
heatmaps_pred = outputs[idx][:, :self.num_joints]
|
254 |
+
offset_feat = self.num_joints
|
255 |
+
heatmaps_loss = self.heatmaps_loss[idx](heatmaps_pred,
|
256 |
+
heatmaps[idx],
|
257 |
+
masks[idx])
|
258 |
+
heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor[idx]
|
259 |
+
heatmaps_losses.append(heatmaps_loss)
|
260 |
+
else:
|
261 |
+
heatmaps_losses.append(None)
|
262 |
+
|
263 |
+
if self.ae_loss[idx]:
|
264 |
+
tags_pred = outputs[idx][:, offset_feat:]
|
265 |
+
batch_size = tags_pred.size()[0]
|
266 |
+
tags_pred = tags_pred.contiguous().view(batch_size, -1, 1)
|
267 |
+
|
268 |
+
push_loss, pull_loss = self.ae_loss[idx](tags_pred,
|
269 |
+
joints[idx])
|
270 |
+
push_loss = push_loss * self.push_loss_factor[idx]
|
271 |
+
pull_loss = pull_loss * self.pull_loss_factor[idx]
|
272 |
+
|
273 |
+
push_losses.append(push_loss)
|
274 |
+
pull_losses.append(pull_loss)
|
275 |
+
else:
|
276 |
+
push_losses.append(None)
|
277 |
+
pull_losses.append(None)
|
278 |
+
|
279 |
+
return heatmaps_losses, push_losses, pull_losses
|
easy_ViTPose/vit_models/losses/regression_loss.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
__all__ = ['SmoothL1Loss', 'SoftWingLoss', 'SoftWingLoss',
|
10 |
+
'L1Loss', 'MPJPELoss', 'MSELoss', 'BoneLoss',
|
11 |
+
'SemiSupervisionLoss']
|
12 |
+
|
13 |
+
|
14 |
+
class SmoothL1Loss(nn.Module):
|
15 |
+
"""SmoothL1Loss loss.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
19 |
+
Different joint types may have different target weights.
|
20 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
24 |
+
super().__init__()
|
25 |
+
self.criterion = F.smooth_l1_loss
|
26 |
+
self.use_target_weight = use_target_weight
|
27 |
+
self.loss_weight = loss_weight
|
28 |
+
|
29 |
+
def forward(self, output, target, target_weight=None):
|
30 |
+
"""Forward function.
|
31 |
+
|
32 |
+
Note:
|
33 |
+
- batch_size: N
|
34 |
+
- num_keypoints: K
|
35 |
+
- dimension of keypoints: D (D=2 or D=3)
|
36 |
+
|
37 |
+
Args:
|
38 |
+
output (torch.Tensor[N, K, D]): Output regression.
|
39 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
40 |
+
target_weight (torch.Tensor[N, K, D]):
|
41 |
+
Weights across different joint types.
|
42 |
+
"""
|
43 |
+
if self.use_target_weight:
|
44 |
+
assert target_weight is not None
|
45 |
+
loss = self.criterion(output * target_weight,
|
46 |
+
target * target_weight)
|
47 |
+
else:
|
48 |
+
loss = self.criterion(output, target)
|
49 |
+
|
50 |
+
return loss * self.loss_weight
|
51 |
+
|
52 |
+
|
53 |
+
class WingLoss(nn.Module):
|
54 |
+
"""Wing Loss. paper ref: 'Wing Loss for Robust Facial Landmark Localisation
|
55 |
+
with Convolutional Neural Networks' Feng et al. CVPR'2018.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
omega (float): Also referred to as width.
|
59 |
+
epsilon (float): Also referred to as curvature.
|
60 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
61 |
+
Different joint types may have different target weights.
|
62 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self,
|
66 |
+
omega=10.0,
|
67 |
+
epsilon=2.0,
|
68 |
+
use_target_weight=False,
|
69 |
+
loss_weight=1.):
|
70 |
+
super().__init__()
|
71 |
+
self.omega = omega
|
72 |
+
self.epsilon = epsilon
|
73 |
+
self.use_target_weight = use_target_weight
|
74 |
+
self.loss_weight = loss_weight
|
75 |
+
|
76 |
+
# constant that smoothly links the piecewise-defined linear
|
77 |
+
# and nonlinear parts
|
78 |
+
self.C = self.omega * (1.0 - math.log(1.0 + self.omega / self.epsilon))
|
79 |
+
|
80 |
+
def criterion(self, pred, target):
|
81 |
+
"""Criterion of wingloss.
|
82 |
+
|
83 |
+
Note:
|
84 |
+
- batch_size: N
|
85 |
+
- num_keypoints: K
|
86 |
+
- dimension of keypoints: D (D=2 or D=3)
|
87 |
+
|
88 |
+
Args:
|
89 |
+
pred (torch.Tensor[N, K, D]): Output regression.
|
90 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
91 |
+
"""
|
92 |
+
delta = (target - pred).abs()
|
93 |
+
losses = torch.where(
|
94 |
+
delta < self.omega,
|
95 |
+
self.omega * torch.log(1.0 + delta / self.epsilon), delta - self.C)
|
96 |
+
return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
|
97 |
+
|
98 |
+
def forward(self, output, target, target_weight=None):
|
99 |
+
"""Forward function.
|
100 |
+
|
101 |
+
Note:
|
102 |
+
- batch_size: N
|
103 |
+
- num_keypoints: K
|
104 |
+
- dimension of keypoints: D (D=2 or D=3)
|
105 |
+
|
106 |
+
Args:
|
107 |
+
output (torch.Tensor[N, K, D]): Output regression.
|
108 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
109 |
+
target_weight (torch.Tensor[N,K,D]):
|
110 |
+
Weights across different joint types.
|
111 |
+
"""
|
112 |
+
if self.use_target_weight:
|
113 |
+
assert target_weight is not None
|
114 |
+
loss = self.criterion(output * target_weight,
|
115 |
+
target * target_weight)
|
116 |
+
else:
|
117 |
+
loss = self.criterion(output, target)
|
118 |
+
|
119 |
+
return loss * self.loss_weight
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
class SoftWingLoss(nn.Module):
|
124 |
+
"""Soft Wing Loss 'Structure-Coherent Deep Feature Learning for Robust Face
|
125 |
+
Alignment' Lin et al. TIP'2021.
|
126 |
+
|
127 |
+
loss =
|
128 |
+
1. |x| , if |x| < omega1
|
129 |
+
2. omega2*ln(1+|x|/epsilon) + B, if |x| >= omega1
|
130 |
+
|
131 |
+
Args:
|
132 |
+
omega1 (float): The first threshold.
|
133 |
+
omega2 (float): The second threshold.
|
134 |
+
epsilon (float): Also referred to as curvature.
|
135 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
136 |
+
Different joint types may have different target weights.
|
137 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self,
|
141 |
+
omega1=2.0,
|
142 |
+
omega2=20.0,
|
143 |
+
epsilon=0.5,
|
144 |
+
use_target_weight=False,
|
145 |
+
loss_weight=1.):
|
146 |
+
super().__init__()
|
147 |
+
self.omega1 = omega1
|
148 |
+
self.omega2 = omega2
|
149 |
+
self.epsilon = epsilon
|
150 |
+
self.use_target_weight = use_target_weight
|
151 |
+
self.loss_weight = loss_weight
|
152 |
+
|
153 |
+
# constant that smoothly links the piecewise-defined linear
|
154 |
+
# and nonlinear parts
|
155 |
+
self.B = self.omega1 - self.omega2 * math.log(1.0 + self.omega1 /
|
156 |
+
self.epsilon)
|
157 |
+
|
158 |
+
def criterion(self, pred, target):
|
159 |
+
"""Criterion of wingloss.
|
160 |
+
|
161 |
+
Note:
|
162 |
+
batch_size: N
|
163 |
+
num_keypoints: K
|
164 |
+
dimension of keypoints: D (D=2 or D=3)
|
165 |
+
|
166 |
+
Args:
|
167 |
+
pred (torch.Tensor[N, K, D]): Output regression.
|
168 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
169 |
+
"""
|
170 |
+
delta = (target - pred).abs()
|
171 |
+
losses = torch.where(
|
172 |
+
delta < self.omega1, delta,
|
173 |
+
self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B)
|
174 |
+
return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
|
175 |
+
|
176 |
+
def forward(self, output, target, target_weight=None):
|
177 |
+
"""Forward function.
|
178 |
+
|
179 |
+
Note:
|
180 |
+
batch_size: N
|
181 |
+
num_keypoints: K
|
182 |
+
dimension of keypoints: D (D=2 or D=3)
|
183 |
+
|
184 |
+
Args:
|
185 |
+
output (torch.Tensor[N, K, D]): Output regression.
|
186 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
187 |
+
target_weight (torch.Tensor[N, K, D]):
|
188 |
+
Weights across different joint types.
|
189 |
+
"""
|
190 |
+
if self.use_target_weight:
|
191 |
+
assert target_weight is not None
|
192 |
+
loss = self.criterion(output * target_weight,
|
193 |
+
target * target_weight)
|
194 |
+
else:
|
195 |
+
loss = self.criterion(output, target)
|
196 |
+
|
197 |
+
return loss * self.loss_weight
|
198 |
+
|
199 |
+
|
200 |
+
class MPJPELoss(nn.Module):
|
201 |
+
"""MPJPE (Mean Per Joint Position Error) loss.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
use_target_weight (bool): Option to use weighted MSE loss.
|
205 |
+
Different joint types may have different target weights.
|
206 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
210 |
+
super().__init__()
|
211 |
+
self.use_target_weight = use_target_weight
|
212 |
+
self.loss_weight = loss_weight
|
213 |
+
|
214 |
+
def forward(self, output, target, target_weight=None):
|
215 |
+
"""Forward function.
|
216 |
+
|
217 |
+
Note:
|
218 |
+
- batch_size: N
|
219 |
+
- num_keypoints: K
|
220 |
+
- dimension of keypoints: D (D=2 or D=3)
|
221 |
+
|
222 |
+
Args:
|
223 |
+
output (torch.Tensor[N, K, D]): Output regression.
|
224 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
225 |
+
target_weight (torch.Tensor[N,K,D]):
|
226 |
+
Weights across different joint types.
|
227 |
+
"""
|
228 |
+
|
229 |
+
if self.use_target_weight:
|
230 |
+
assert target_weight is not None
|
231 |
+
loss = torch.mean(
|
232 |
+
torch.norm((output - target) * target_weight, dim=-1))
|
233 |
+
else:
|
234 |
+
loss = torch.mean(torch.norm(output - target, dim=-1))
|
235 |
+
|
236 |
+
return loss * self.loss_weight
|
237 |
+
|
238 |
+
|
239 |
+
class L1Loss(nn.Module):
|
240 |
+
"""L1Loss loss ."""
|
241 |
+
|
242 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
243 |
+
super().__init__()
|
244 |
+
self.criterion = F.l1_loss
|
245 |
+
self.use_target_weight = use_target_weight
|
246 |
+
self.loss_weight = loss_weight
|
247 |
+
|
248 |
+
def forward(self, output, target, target_weight=None):
|
249 |
+
"""Forward function.
|
250 |
+
|
251 |
+
Note:
|
252 |
+
- batch_size: N
|
253 |
+
- num_keypoints: K
|
254 |
+
|
255 |
+
Args:
|
256 |
+
output (torch.Tensor[N, K, 2]): Output regression.
|
257 |
+
target (torch.Tensor[N, K, 2]): Target regression.
|
258 |
+
target_weight (torch.Tensor[N, K, 2]):
|
259 |
+
Weights across different joint types.
|
260 |
+
"""
|
261 |
+
if self.use_target_weight:
|
262 |
+
assert target_weight is not None
|
263 |
+
loss = self.criterion(output * target_weight,
|
264 |
+
target * target_weight)
|
265 |
+
else:
|
266 |
+
loss = self.criterion(output, target)
|
267 |
+
|
268 |
+
return loss * self.loss_weight
|
269 |
+
|
270 |
+
|
271 |
+
class MSELoss(nn.Module):
|
272 |
+
"""MSE loss for coordinate regression."""
|
273 |
+
|
274 |
+
def __init__(self, use_target_weight=False, loss_weight=1.):
|
275 |
+
super().__init__()
|
276 |
+
self.criterion = F.mse_loss
|
277 |
+
self.use_target_weight = use_target_weight
|
278 |
+
self.loss_weight = loss_weight
|
279 |
+
|
280 |
+
def forward(self, output, target, target_weight=None):
|
281 |
+
"""Forward function.
|
282 |
+
|
283 |
+
Note:
|
284 |
+
- batch_size: N
|
285 |
+
- num_keypoints: K
|
286 |
+
|
287 |
+
Args:
|
288 |
+
output (torch.Tensor[N, K, 2]): Output regression.
|
289 |
+
target (torch.Tensor[N, K, 2]): Target regression.
|
290 |
+
target_weight (torch.Tensor[N, K, 2]):
|
291 |
+
Weights across different joint types.
|
292 |
+
"""
|
293 |
+
if self.use_target_weight:
|
294 |
+
assert target_weight is not None
|
295 |
+
loss = self.criterion(output * target_weight,
|
296 |
+
target * target_weight)
|
297 |
+
else:
|
298 |
+
loss = self.criterion(output, target)
|
299 |
+
|
300 |
+
return loss * self.loss_weight
|
301 |
+
|
302 |
+
|
303 |
+
class BoneLoss(nn.Module):
|
304 |
+
"""Bone length loss.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
joint_parents (list): Indices of each joint's parent joint.
|
308 |
+
use_target_weight (bool): Option to use weighted bone loss.
|
309 |
+
Different bone types may have different target weights.
|
310 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
|
314 |
+
super().__init__()
|
315 |
+
self.joint_parents = joint_parents
|
316 |
+
self.use_target_weight = use_target_weight
|
317 |
+
self.loss_weight = loss_weight
|
318 |
+
|
319 |
+
self.non_root_indices = []
|
320 |
+
for i in range(len(self.joint_parents)):
|
321 |
+
if i != self.joint_parents[i]:
|
322 |
+
self.non_root_indices.append(i)
|
323 |
+
|
324 |
+
def forward(self, output, target, target_weight=None):
|
325 |
+
"""Forward function.
|
326 |
+
|
327 |
+
Note:
|
328 |
+
- batch_size: N
|
329 |
+
- num_keypoints: K
|
330 |
+
- dimension of keypoints: D (D=2 or D=3)
|
331 |
+
|
332 |
+
Args:
|
333 |
+
output (torch.Tensor[N, K, D]): Output regression.
|
334 |
+
target (torch.Tensor[N, K, D]): Target regression.
|
335 |
+
target_weight (torch.Tensor[N, K-1]):
|
336 |
+
Weights across different bone types.
|
337 |
+
"""
|
338 |
+
output_bone = torch.norm(
|
339 |
+
output - output[:, self.joint_parents, :],
|
340 |
+
dim=-1)[:, self.non_root_indices]
|
341 |
+
target_bone = torch.norm(
|
342 |
+
target - target[:, self.joint_parents, :],
|
343 |
+
dim=-1)[:, self.non_root_indices]
|
344 |
+
if self.use_target_weight:
|
345 |
+
assert target_weight is not None
|
346 |
+
loss = torch.mean(
|
347 |
+
torch.abs((output_bone * target_weight).mean(dim=0) -
|
348 |
+
(target_bone * target_weight).mean(dim=0)))
|
349 |
+
else:
|
350 |
+
loss = torch.mean(
|
351 |
+
torch.abs(output_bone.mean(dim=0) - target_bone.mean(dim=0)))
|
352 |
+
|
353 |
+
return loss * self.loss_weight
|
354 |
+
|
355 |
+
|
356 |
+
class SemiSupervisionLoss(nn.Module):
|
357 |
+
"""Semi-supervision loss for unlabeled data. It is composed of projection
|
358 |
+
loss and bone loss.
|
359 |
+
|
360 |
+
Paper ref: `3D human pose estimation in video with temporal convolutions
|
361 |
+
and semi-supervised training` Dario Pavllo et al. CVPR'2019.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
joint_parents (list): Indices of each joint's parent joint.
|
365 |
+
projection_loss_weight (float): Weight for projection loss.
|
366 |
+
bone_loss_weight (float): Weight for bone loss.
|
367 |
+
warmup_iterations (int): Number of warmup iterations. In the first
|
368 |
+
`warmup_iterations` iterations, the model is trained only on
|
369 |
+
labeled data, and semi-supervision loss will be 0.
|
370 |
+
This is a workaround since currently we cannot access
|
371 |
+
epoch number in loss functions. Note that the iteration number in
|
372 |
+
an epoch can be changed due to different GPU numbers in multi-GPU
|
373 |
+
settings. So please set this parameter carefully.
|
374 |
+
warmup_iterations = dataset_size // samples_per_gpu // gpu_num
|
375 |
+
* warmup_epochs
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(self,
|
379 |
+
joint_parents,
|
380 |
+
projection_loss_weight=1.,
|
381 |
+
bone_loss_weight=1.,
|
382 |
+
warmup_iterations=0):
|
383 |
+
super().__init__()
|
384 |
+
self.criterion_projection = MPJPELoss(
|
385 |
+
loss_weight=projection_loss_weight)
|
386 |
+
self.criterion_bone = BoneLoss(
|
387 |
+
joint_parents, loss_weight=bone_loss_weight)
|
388 |
+
self.warmup_iterations = warmup_iterations
|
389 |
+
self.num_iterations = 0
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def project_joints(x, intrinsics):
|
393 |
+
"""Project 3D joint coordinates to 2D image plane using camera
|
394 |
+
intrinsic parameters.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
x (torch.Tensor[N, K, 3]): 3D joint coordinates.
|
398 |
+
intrinsics (torch.Tensor[N, 4] | torch.Tensor[N, 9]): Camera
|
399 |
+
intrinsics: f (2), c (2), k (3), p (2).
|
400 |
+
"""
|
401 |
+
while intrinsics.dim() < x.dim():
|
402 |
+
intrinsics.unsqueeze_(1)
|
403 |
+
f = intrinsics[..., :2]
|
404 |
+
c = intrinsics[..., 2:4]
|
405 |
+
_x = torch.clamp(x[:, :, :2] / x[:, :, 2:], -1, 1)
|
406 |
+
if intrinsics.shape[-1] == 9:
|
407 |
+
k = intrinsics[..., 4:7]
|
408 |
+
p = intrinsics[..., 7:9]
|
409 |
+
|
410 |
+
r2 = torch.sum(_x[:, :, :2]**2, dim=-1, keepdim=True)
|
411 |
+
radial = 1 + torch.sum(
|
412 |
+
k * torch.cat((r2, r2**2, r2**3), dim=-1),
|
413 |
+
dim=-1,
|
414 |
+
keepdim=True)
|
415 |
+
tan = torch.sum(p * _x, dim=-1, keepdim=True)
|
416 |
+
_x = _x * (radial + tan) + p * r2
|
417 |
+
_x = f * _x + c
|
418 |
+
return _x
|
419 |
+
|
420 |
+
def forward(self, output, target):
|
421 |
+
losses = dict()
|
422 |
+
|
423 |
+
self.num_iterations += 1
|
424 |
+
if self.num_iterations <= self.warmup_iterations:
|
425 |
+
return losses
|
426 |
+
|
427 |
+
labeled_pose = output['labeled_pose']
|
428 |
+
unlabeled_pose = output['unlabeled_pose']
|
429 |
+
unlabeled_traj = output['unlabeled_traj']
|
430 |
+
unlabeled_target_2d = target['unlabeled_target_2d']
|
431 |
+
intrinsics = target['intrinsics']
|
432 |
+
|
433 |
+
# projection loss
|
434 |
+
unlabeled_output = unlabeled_pose + unlabeled_traj
|
435 |
+
unlabeled_output_2d = self.project_joints(unlabeled_output, intrinsics)
|
436 |
+
loss_proj = self.criterion_projection(unlabeled_output_2d,
|
437 |
+
unlabeled_target_2d, None)
|
438 |
+
losses['proj_loss'] = loss_proj
|
439 |
+
|
440 |
+
# bone loss
|
441 |
+
loss_bone = self.criterion_bone(unlabeled_pose, labeled_pose, None)
|
442 |
+
losses['bone_loss'] = loss_bone
|
443 |
+
|
444 |
+
return losses
|
easy_ViTPose/vit_models/model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .backbone.vit import ViT
|
4 |
+
from .head.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = ['ViTPose']
|
8 |
+
|
9 |
+
|
10 |
+
class ViTPose(nn.Module):
|
11 |
+
def __init__(self, cfg: dict) -> None:
|
12 |
+
super(ViTPose, self).__init__()
|
13 |
+
|
14 |
+
backbone_cfg = {k: v for k, v in cfg['backbone'].items() if k != 'type'}
|
15 |
+
head_cfg = {k: v for k, v in cfg['keypoint_head'].items() if k != 'type'}
|
16 |
+
|
17 |
+
self.backbone = ViT(**backbone_cfg)
|
18 |
+
self.keypoint_head = TopdownHeatmapSimpleHead(**head_cfg)
|
19 |
+
|
20 |
+
def forward_features(self, x):
|
21 |
+
return self.backbone(x)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.keypoint_head(self.backbone(x))
|
easy_ViTPose/vit_models/optimizer.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
|
3 |
+
class LayerDecayOptimizer:
|
4 |
+
def __init__(self, optimizer, layerwise_decay_rate):
|
5 |
+
self.optimizer = optimizer
|
6 |
+
self.layerwise_decay_rate = layerwise_decay_rate
|
7 |
+
self.param_groups = optimizer.param_groups
|
8 |
+
|
9 |
+
def step(self, *args, **kwargs):
|
10 |
+
for i, group in enumerate(self.optimizer.param_groups):
|
11 |
+
group['lr'] *= self.layerwise_decay_rate[i]
|
12 |
+
self.optimizer.step(*args, **kwargs)
|
13 |
+
|
14 |
+
def zero_grad(self, *args, **kwargs):
|
15 |
+
self.optimizer.zero_grad(*args, **kwargs)
|
easy_ViTPose/vit_utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .util import *
|
2 |
+
from .top_down_eval import *
|
3 |
+
from .post_processing import *
|
4 |
+
from .visualization import *
|
5 |
+
from .dist_util import *
|
6 |
+
from .logging import *
|
easy_ViTPose/vit_utils/dist_util.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
import functools
|
4 |
+
import os
|
5 |
+
import socket
|
6 |
+
import subprocess
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import Callable, List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
from torch import distributed as dist
|
13 |
+
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
14 |
+
_unflatten_dense_tensors)
|
15 |
+
|
16 |
+
|
17 |
+
def is_mps_available() -> bool:
|
18 |
+
"""Return True if mps devices exist.
|
19 |
+
|
20 |
+
It's specialized for mac m1 chips and require torch version 1.12 or higher.
|
21 |
+
"""
|
22 |
+
try:
|
23 |
+
import torch
|
24 |
+
return hasattr(torch.backends,
|
25 |
+
'mps') and torch.backends.mps.is_available()
|
26 |
+
except Exception:
|
27 |
+
return False
|
28 |
+
|
29 |
+
def _find_free_port() -> str:
|
30 |
+
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
|
31 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
32 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
33 |
+
sock.bind(('', 0))
|
34 |
+
port = sock.getsockname()[1]
|
35 |
+
sock.close()
|
36 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
37 |
+
return port
|
38 |
+
|
39 |
+
|
40 |
+
def _is_free_port(port: int) -> bool:
|
41 |
+
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
|
42 |
+
ips.append('localhost')
|
43 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
44 |
+
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
|
45 |
+
|
46 |
+
|
47 |
+
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
|
48 |
+
if mp.get_start_method(allow_none=True) is None:
|
49 |
+
mp.set_start_method('spawn')
|
50 |
+
if launcher == 'pytorch':
|
51 |
+
_init_dist_pytorch(backend, **kwargs)
|
52 |
+
elif launcher == 'mpi':
|
53 |
+
_init_dist_mpi(backend, **kwargs)
|
54 |
+
elif launcher == 'slurm':
|
55 |
+
_init_dist_slurm(backend, **kwargs)
|
56 |
+
else:
|
57 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
58 |
+
|
59 |
+
|
60 |
+
def _init_dist_pytorch(backend: str, **kwargs) -> None:
|
61 |
+
# TODO: use local_rank instead of rank % num_gpus
|
62 |
+
rank = int(os.environ['RANK'])
|
63 |
+
num_gpus = torch.cuda.device_count()
|
64 |
+
torch.cuda.set_device(rank % num_gpus)
|
65 |
+
dist.init_process_group(backend=backend, **kwargs)
|
66 |
+
|
67 |
+
|
68 |
+
def _init_dist_mpi(backend: str, **kwargs) -> None:
|
69 |
+
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
70 |
+
torch.cuda.set_device(local_rank)
|
71 |
+
if 'MASTER_PORT' not in os.environ:
|
72 |
+
# 29500 is torch.distributed default port
|
73 |
+
os.environ['MASTER_PORT'] = '29500'
|
74 |
+
if 'MASTER_ADDR' not in os.environ:
|
75 |
+
raise KeyError('The environment variable MASTER_ADDR is not set')
|
76 |
+
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
|
77 |
+
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
|
78 |
+
dist.init_process_group(backend=backend, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
|
82 |
+
"""Initialize slurm distributed training environment.
|
83 |
+
|
84 |
+
If argument ``port`` is not specified, then the master port will be system
|
85 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
86 |
+
environment variable, then a default port ``29500`` will be used.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
backend (str): Backend of torch.distributed.
|
90 |
+
port (int, optional): Master port. Defaults to None.
|
91 |
+
"""
|
92 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
93 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
94 |
+
node_list = os.environ['SLURM_NODELIST']
|
95 |
+
num_gpus = torch.cuda.device_count()
|
96 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
97 |
+
addr = subprocess.getoutput(
|
98 |
+
f'scontrol show hostname {node_list} | head -n1')
|
99 |
+
# specify master port
|
100 |
+
if port is not None:
|
101 |
+
os.environ['MASTER_PORT'] = str(port)
|
102 |
+
elif 'MASTER_PORT' in os.environ:
|
103 |
+
pass # use MASTER_PORT in the environment variable
|
104 |
+
else:
|
105 |
+
# if torch.distributed default port(29500) is available
|
106 |
+
# then use it, else find a free port
|
107 |
+
if _is_free_port(29500):
|
108 |
+
os.environ['MASTER_PORT'] = '29500'
|
109 |
+
else:
|
110 |
+
os.environ['MASTER_PORT'] = str(_find_free_port())
|
111 |
+
# use MASTER_ADDR in the environment variable if it already exists
|
112 |
+
if 'MASTER_ADDR' not in os.environ:
|
113 |
+
os.environ['MASTER_ADDR'] = addr
|
114 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
115 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
116 |
+
os.environ['RANK'] = str(proc_id)
|
117 |
+
dist.init_process_group(backend=backend)
|
118 |
+
|
119 |
+
|
120 |
+
def get_dist_info() -> Tuple[int, int]:
|
121 |
+
if dist.is_available() and dist.is_initialized():
|
122 |
+
rank = dist.get_rank()
|
123 |
+
world_size = dist.get_world_size()
|
124 |
+
else:
|
125 |
+
rank = 0
|
126 |
+
world_size = 1
|
127 |
+
return rank, world_size
|
128 |
+
|
129 |
+
|
130 |
+
def master_only(func: Callable) -> Callable:
|
131 |
+
|
132 |
+
@functools.wraps(func)
|
133 |
+
def wrapper(*args, **kwargs):
|
134 |
+
rank, _ = get_dist_info()
|
135 |
+
if rank == 0:
|
136 |
+
return func(*args, **kwargs)
|
137 |
+
|
138 |
+
return wrapper
|
139 |
+
|
140 |
+
|
141 |
+
def allreduce_params(params: List[torch.nn.Parameter],
|
142 |
+
coalesce: bool = True,
|
143 |
+
bucket_size_mb: int = -1) -> None:
|
144 |
+
"""Allreduce parameters.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
params (list[torch.nn.Parameter]): List of parameters or buffers
|
148 |
+
of a model.
|
149 |
+
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
150 |
+
Defaults to True.
|
151 |
+
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
152 |
+
Defaults to -1.
|
153 |
+
"""
|
154 |
+
_, world_size = get_dist_info()
|
155 |
+
if world_size == 1:
|
156 |
+
return
|
157 |
+
params = [param.data for param in params]
|
158 |
+
if coalesce:
|
159 |
+
_allreduce_coalesced(params, world_size, bucket_size_mb)
|
160 |
+
else:
|
161 |
+
for tensor in params:
|
162 |
+
dist.all_reduce(tensor.div_(world_size))
|
163 |
+
|
164 |
+
|
165 |
+
def allreduce_grads(params: List[torch.nn.Parameter],
|
166 |
+
coalesce: bool = True,
|
167 |
+
bucket_size_mb: int = -1) -> None:
|
168 |
+
"""Allreduce gradients.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
params (list[torch.nn.Parameter]): List of parameters of a model.
|
172 |
+
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
173 |
+
Defaults to True.
|
174 |
+
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
175 |
+
Defaults to -1.
|
176 |
+
"""
|
177 |
+
grads = [
|
178 |
+
param.grad.data for param in params
|
179 |
+
if param.requires_grad and param.grad is not None
|
180 |
+
]
|
181 |
+
_, world_size = get_dist_info()
|
182 |
+
if world_size == 1:
|
183 |
+
return
|
184 |
+
if coalesce:
|
185 |
+
_allreduce_coalesced(grads, world_size, bucket_size_mb)
|
186 |
+
else:
|
187 |
+
for tensor in grads:
|
188 |
+
dist.all_reduce(tensor.div_(world_size))
|
189 |
+
|
190 |
+
|
191 |
+
def _allreduce_coalesced(tensors: torch.Tensor,
|
192 |
+
world_size: int,
|
193 |
+
bucket_size_mb: int = -1) -> None:
|
194 |
+
if bucket_size_mb > 0:
|
195 |
+
bucket_size_bytes = bucket_size_mb * 1024 * 1024
|
196 |
+
buckets = _take_tensors(tensors, bucket_size_bytes)
|
197 |
+
else:
|
198 |
+
buckets = OrderedDict()
|
199 |
+
for tensor in tensors:
|
200 |
+
tp = tensor.type()
|
201 |
+
if tp not in buckets:
|
202 |
+
buckets[tp] = []
|
203 |
+
buckets[tp].append(tensor)
|
204 |
+
buckets = buckets.values()
|
205 |
+
|
206 |
+
for bucket in buckets:
|
207 |
+
flat_tensors = _flatten_dense_tensors(bucket)
|
208 |
+
dist.all_reduce(flat_tensors)
|
209 |
+
flat_tensors.div_(world_size)
|
210 |
+
for tensor, synced in zip(
|
211 |
+
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
|
212 |
+
tensor.copy_(synced)
|
easy_ViTPose/vit_utils/inference.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
rotation_map = {
|
7 |
+
0: None,
|
8 |
+
90: cv2.ROTATE_90_COUNTERCLOCKWISE,
|
9 |
+
180: cv2.ROTATE_180,
|
10 |
+
270: cv2.ROTATE_90_CLOCKWISE
|
11 |
+
}
|
12 |
+
|
13 |
+
class NumpyEncoder(json.JSONEncoder):
|
14 |
+
def default(self, obj):
|
15 |
+
if isinstance(obj, np.ndarray):
|
16 |
+
return obj.tolist()
|
17 |
+
return json.JSONEncoder.default(self, obj)
|
18 |
+
|
19 |
+
def draw_bboxes(image, bounding_boxes, boxes_id, scores):
|
20 |
+
image_with_boxes = image.copy()
|
21 |
+
|
22 |
+
for bbox, bbox_id, score in zip(bounding_boxes, boxes_id, scores):
|
23 |
+
x1, y1, x2, y2 = bbox
|
24 |
+
cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), (128, 128, 0), 2)
|
25 |
+
|
26 |
+
label = f'#{bbox_id}: {score:.2f}'
|
27 |
+
|
28 |
+
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
29 |
+
label_x = x1
|
30 |
+
label_y = y1 - 5 if y1 > 20 else y1 + 20
|
31 |
+
|
32 |
+
# Draw a filled rectangle as the background for the label
|
33 |
+
cv2.rectangle(image_with_boxes, (x1, label_y - label_height - 5),
|
34 |
+
(x1 + label_width, label_y + 5), (128, 128, 0), cv2.FILLED)
|
35 |
+
cv2.putText(image_with_boxes, label, (label_x, label_y),
|
36 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
37 |
+
|
38 |
+
return image_with_boxes
|
39 |
+
|
40 |
+
|
41 |
+
def pad_image(image: np.ndarray, aspect_ratio: float) -> np.ndarray:
|
42 |
+
# Get the current aspect ratio of the image
|
43 |
+
image_height, image_width = image.shape[:2]
|
44 |
+
current_aspect_ratio = image_width / image_height
|
45 |
+
|
46 |
+
left_pad = 0
|
47 |
+
top_pad = 0
|
48 |
+
# Determine whether to pad horizontally or vertically
|
49 |
+
if current_aspect_ratio < aspect_ratio:
|
50 |
+
# Pad horizontally
|
51 |
+
target_width = int(aspect_ratio * image_height)
|
52 |
+
pad_width = target_width - image_width
|
53 |
+
left_pad = pad_width // 2
|
54 |
+
right_pad = pad_width - left_pad
|
55 |
+
|
56 |
+
padded_image = np.pad(image,
|
57 |
+
pad_width=((0, 0), (left_pad, right_pad), (0, 0)),
|
58 |
+
mode='constant')
|
59 |
+
else:
|
60 |
+
# Pad vertically
|
61 |
+
target_height = int(image_width / aspect_ratio)
|
62 |
+
pad_height = target_height - image_height
|
63 |
+
top_pad = pad_height // 2
|
64 |
+
bottom_pad = pad_height - top_pad
|
65 |
+
|
66 |
+
padded_image = np.pad(image,
|
67 |
+
pad_width=((top_pad, bottom_pad), (0, 0), (0, 0)),
|
68 |
+
mode='constant')
|
69 |
+
return padded_image, (left_pad, top_pad)
|
70 |
+
|
71 |
+
|
72 |
+
class VideoReader(object):
|
73 |
+
def __init__(self, file_name, rotate=0):
|
74 |
+
self.file_name = file_name
|
75 |
+
self.rotate = rotation_map[rotate]
|
76 |
+
try: # OpenCV needs int to read from webcam
|
77 |
+
self.file_name = int(file_name)
|
78 |
+
except ValueError:
|
79 |
+
pass
|
80 |
+
|
81 |
+
def __iter__(self):
|
82 |
+
self.cap = cv2.VideoCapture(self.file_name)
|
83 |
+
if not self.cap.isOpened():
|
84 |
+
raise IOError('Video {} cannot be opened'.format(self.file_name))
|
85 |
+
return self
|
86 |
+
|
87 |
+
def __next__(self):
|
88 |
+
was_read, img = self.cap.read()
|
89 |
+
if not was_read:
|
90 |
+
raise StopIteration
|
91 |
+
if self.rotate is not None:
|
92 |
+
img = cv2.rotate(img, self.rotate)
|
93 |
+
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
easy_ViTPose/vit_utils/logging.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
logger_initialized: dict = {}
|
7 |
+
|
8 |
+
|
9 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
10 |
+
"""Initialize and get a logger by name.
|
11 |
+
|
12 |
+
If the logger has not been initialized, this method will initialize the
|
13 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
14 |
+
be directly returned. During initialization, a StreamHandler will always be
|
15 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
16 |
+
will also be added.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
name (str): Logger name.
|
20 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
21 |
+
will be added to the logger.
|
22 |
+
log_level (int): The logger level. Note that only the process of
|
23 |
+
rank 0 is affected, and other processes will set the level to
|
24 |
+
"Error" thus be silent most of the time.
|
25 |
+
file_mode (str): The file mode used in opening log file.
|
26 |
+
Defaults to 'w'.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
logging.Logger: The expected logger.
|
30 |
+
"""
|
31 |
+
logger = logging.getLogger(name)
|
32 |
+
if name in logger_initialized:
|
33 |
+
return logger
|
34 |
+
# handle hierarchical names
|
35 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
36 |
+
# initialization since it is a child of "a".
|
37 |
+
for logger_name in logger_initialized:
|
38 |
+
if name.startswith(logger_name):
|
39 |
+
return logger
|
40 |
+
|
41 |
+
# handle duplicate logs to the console
|
42 |
+
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
43 |
+
# to the root logger. As logger.propagate is True by default, this root
|
44 |
+
# level handler causes logging messages from rank>0 processes to
|
45 |
+
# unexpectedly show up on the console, creating much unwanted clutter.
|
46 |
+
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
47 |
+
# at the ERROR level.
|
48 |
+
for handler in logger.root.handlers:
|
49 |
+
if type(handler) is logging.StreamHandler:
|
50 |
+
handler.setLevel(logging.ERROR)
|
51 |
+
|
52 |
+
stream_handler = logging.StreamHandler()
|
53 |
+
handlers = [stream_handler]
|
54 |
+
|
55 |
+
if dist.is_available() and dist.is_initialized():
|
56 |
+
rank = dist.get_rank()
|
57 |
+
else:
|
58 |
+
rank = 0
|
59 |
+
|
60 |
+
# only rank 0 will add a FileHandler
|
61 |
+
if rank == 0 and log_file is not None:
|
62 |
+
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
63 |
+
# provide an interface to change the file mode to the default
|
64 |
+
# behaviour.
|
65 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
66 |
+
handlers.append(file_handler)
|
67 |
+
|
68 |
+
formatter = logging.Formatter(
|
69 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
70 |
+
for handler in handlers:
|
71 |
+
handler.setFormatter(formatter)
|
72 |
+
handler.setLevel(log_level)
|
73 |
+
logger.addHandler(handler)
|
74 |
+
|
75 |
+
if rank == 0:
|
76 |
+
logger.setLevel(log_level)
|
77 |
+
else:
|
78 |
+
logger.setLevel(logging.ERROR)
|
79 |
+
|
80 |
+
logger_initialized[name] = True
|
81 |
+
|
82 |
+
return logger
|
83 |
+
|
84 |
+
|
85 |
+
def print_log(msg, logger=None, level=logging.INFO):
|
86 |
+
"""Print a log message.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
msg (str): The message to be logged.
|
90 |
+
logger (logging.Logger | str | None): The logger to be used.
|
91 |
+
Some special loggers are:
|
92 |
+
|
93 |
+
- "silent": no message will be printed.
|
94 |
+
- other str: the logger obtained with `get_root_logger(logger)`.
|
95 |
+
- None: The `print()` method will be used to print log messages.
|
96 |
+
level (int): Logging level. Only available when `logger` is a Logger
|
97 |
+
object or "root".
|
98 |
+
"""
|
99 |
+
if logger is None:
|
100 |
+
print(msg)
|
101 |
+
elif isinstance(logger, logging.Logger):
|
102 |
+
logger.log(level, msg)
|
103 |
+
elif logger == 'silent':
|
104 |
+
pass
|
105 |
+
elif isinstance(logger, str):
|
106 |
+
_logger = get_logger(logger)
|
107 |
+
_logger.log(level, msg)
|
108 |
+
else:
|
109 |
+
raise TypeError(
|
110 |
+
'logger should be either a logging.Logger object, str, '
|
111 |
+
f'"silent" or None, but got {type(logger)}')
|
112 |
+
|
113 |
+
|
114 |
+
def get_root_logger(log_file=None, log_level=logging.INFO):
|
115 |
+
"""Use `get_logger` method in mmcv to get the root logger.
|
116 |
+
|
117 |
+
The logger will be initialized if it has not been initialized. By default a
|
118 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
119 |
+
also be added. The name of the root logger is the top-level package name,
|
120 |
+
e.g., "mmpose".
|
121 |
+
|
122 |
+
Args:
|
123 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
124 |
+
will be added to the root logger.
|
125 |
+
log_level (int): The root logger level. Note that only the process of
|
126 |
+
rank 0 is affected, while other processes will set the level to
|
127 |
+
"Error" and be silent most of the time.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
logging.Logger: The root logger.
|
131 |
+
"""
|
132 |
+
return get_logger(__name__.split('.')[0], log_file, log_level)
|
133 |
+
|
easy_ViTPose/vit_utils/nms/__init__.py
ADDED
File without changes
|
easy_ViTPose/vit_utils/nms/cpu_nms.c
ADDED
The diff for this file is too large to render.
See raw diff
|
|
easy_ViTPose/vit_utils/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so
ADDED
Binary file (264 kB). View file
|
|