onescotch commited on
Commit
2de1f98
1 Parent(s): 5e4861d

add huggingface implementation

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -1
  2. app.py +130 -0
  3. assets/conversions.py +523 -0
  4. common/base.py +86 -0
  5. common/logger.py +50 -0
  6. common/nets/layer.py +53 -0
  7. common/nets/loss.py +30 -0
  8. common/nets/smpler_x.py +172 -0
  9. common/timer.py +38 -0
  10. common/utils/__init__.py +0 -0
  11. common/utils/dir.py +10 -0
  12. common/utils/distribute_utils.py +217 -0
  13. common/utils/human_model_files/smpl/SMPL_FEMALE.pkl +3 -0
  14. common/utils/human_model_files/smpl/SMPL_MALE.pkl +3 -0
  15. common/utils/human_model_files/smpl/SMPL_NEUTRAL.pkl +3 -0
  16. common/utils/human_model_files/smpl/smpl_uv.npz +3 -0
  17. common/utils/human_model_files/smplx/MANO_SMPLX_vertex_ids.pkl +3 -0
  18. common/utils/human_model_files/smplx/SMPL-X__FLAME_vertex_ids.npy +3 -0
  19. common/utils/human_model_files/smplx/SMPLX_FEMALE.npz +3 -0
  20. common/utils/human_model_files/smplx/SMPLX_MALE.npz +3 -0
  21. common/utils/human_model_files/smplx/SMPLX_NEUTRAL.npz +3 -0
  22. common/utils/human_model_files/smplx/SMPLX_NEUTRAL.pkl +3 -0
  23. common/utils/human_model_files/smplx/SMPLX_to_J14.pkl +3 -0
  24. common/utils/human_models.py +176 -0
  25. common/utils/inference_utils.py +153 -0
  26. common/utils/preprocessing.py +541 -0
  27. common/utils/smplx/LICENSE +58 -0
  28. common/utils/smplx/README.md +186 -0
  29. common/utils/smplx/examples/demo.py +180 -0
  30. common/utils/smplx/examples/demo_layers.py +181 -0
  31. common/utils/smplx/examples/vis_flame_vertices.py +92 -0
  32. common/utils/smplx/examples/vis_mano_vertices.py +99 -0
  33. common/utils/smplx/setup.py +79 -0
  34. common/utils/smplx/smplx/__init__.py +30 -0
  35. common/utils/smplx/smplx/body_models.py +2331 -0
  36. common/utils/smplx/smplx/joint_names.py +163 -0
  37. common/utils/smplx/smplx/lbs.py +404 -0
  38. common/utils/smplx/smplx/utils.py +125 -0
  39. common/utils/smplx/smplx/vertex_ids.py +77 -0
  40. common/utils/smplx/smplx/vertex_joint_selector.py +77 -0
  41. common/utils/smplx/tools/README.md +20 -0
  42. common/utils/smplx/tools/__init__.py +19 -0
  43. common/utils/smplx/tools/clean_ch.py +68 -0
  44. common/utils/smplx/tools/merge_smplh_mano.py +89 -0
  45. common/utils/transforms.py +172 -0
  46. common/utils/vis.py +183 -0
  47. main/SMPLer_X.py +468 -0
  48. main/_base_/datasets/300w.py +384 -0
  49. main/_base_/datasets/aflw.py +83 -0
  50. main/_base_/datasets/aic.py +140 -0
README.md CHANGED
@@ -4,7 +4,8 @@ emoji: ⚡
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.19.2
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ python_version: 3.8
8
+ sdk_version: 4.16.0
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import os.path as osp
4
+ from pathlib import Path
5
+ import cv2
6
+ import gradio as gr
7
+ import torch
8
+ import math
9
+
10
+ try:
11
+ import mmpose
12
+ except:
13
+ os.system('pip install /home/user/app/main/transformer_utils')
14
+
15
+ os.system('cp -rf /home/user/app/assets/conversions.py /home/user/.pyenv/versions/3.8.18/lib/python3.8/site-packages/torchgeometry/core/conversions.py')
16
+ DEFAULT_MODEL='smpler_x_h32'
17
+ OUT_FOLDER = '/home/user/app/demo_out'
18
+ os.makedirs(OUT_FOLDER, exist_ok=True)
19
+ num_gpus = 1 if torch.cuda.is_available() else -1
20
+ print("!!!", torch.cuda.is_available())
21
+ print(torch.cuda.device_count())
22
+ print(torch.version.cuda)
23
+ index = torch.cuda.current_device()
24
+ print(index)
25
+ print(torch.cuda.get_device_name(index))
26
+ from main.inference import Inferer
27
+ inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
28
+
29
+ def infer(video_input, in_threshold=0.5, num_people="Single person", render_mesh=False):
30
+ os.system(f'rm -rf {OUT_FOLDER}/*')
31
+ multi_person = False if (num_people == "Single person") else True
32
+ cap = cv2.VideoCapture(video_input)
33
+ fps = math.ceil(cap.get(5))
34
+ width = int(cap.get(3))
35
+ height = int(cap.get(4))
36
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
37
+ video_path = osp.join(OUT_FOLDER, f'out.m4v')
38
+ final_video_path = osp.join(OUT_FOLDER, f'out.mp4')
39
+ video_output = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
40
+ success = 1
41
+ frame = 0
42
+ while success:
43
+ success, original_img = cap.read()
44
+ if not success:
45
+ break
46
+ frame += 1
47
+ img, mesh_paths, smplx_paths = inferer.infer(original_img, in_threshold, frame, multi_person, not(render_mesh))
48
+ video_output.write(img)
49
+ cap.release()
50
+ video_output.release()
51
+ cv2.destroyAllWindows()
52
+ os.system(f'ffmpeg -i {video_path} -c copy {final_video_path}')
53
+
54
+ #Compress mesh and smplx files
55
+ save_path_mesh = os.path.join(OUT_FOLDER, 'mesh')
56
+ save_mesh_file = os.path.join(OUT_FOLDER, 'mesh.zip')
57
+ os.makedirs(save_path_mesh, exist_ok= True)
58
+ save_path_smplx = os.path.join(OUT_FOLDER, 'smplx')
59
+ save_smplx_file = os.path.join(OUT_FOLDER, 'smplx.zip')
60
+ os.makedirs(save_path_smplx, exist_ok= True)
61
+ os.system(f'zip -r {save_mesh_file} {save_path_mesh}')
62
+ os.system(f'zip -r {save_smplx_file} {save_path_smplx}')
63
+ return video_path, save_mesh_file, save_smplx_file
64
+
65
+ TITLE = '''<h1 align="center">SMPLer-X: Scaling Up Expressive Human Pose and Shape Estimation</h1>'''
66
+ VIDEO = '''
67
+ <center><iframe width="960" height="540"
68
+ src="https://www.youtube.com/embed/DepTqbPpVzY?si=qSeQuX-bgm_rON7E"title="SMPLer-X: Scaling Up Expressive Human Pose and Shape Estimation" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen>
69
+ </iframe>
70
+ </center><br>'''
71
+ DESCRIPTION = '''
72
+ <b>Official Gradio demo</b> for <a href="https://caizhongang.com/projects/SMPLer-X/"><b>SMPLer-X: Scaling Up Expressive Human Pose and Shape Estimation</b></a>.<br>
73
+ <p>
74
+ Note: You can drop a video at the panel (or select one of the examples)
75
+ then you will get the 3D reconstructions of the detected human. ).
76
+ </p>
77
+ '''
78
+
79
+ with gr.Blocks(title="SMPLer-X", css=".gradio-container") as demo:
80
+
81
+ gr.Markdown(TITLE)
82
+ gr.HTML(VIDEO)
83
+ gr.Markdown(DESCRIPTION)
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ video_input = gr.Video(label="Input video", elem_classes="video")
88
+ threshold = gr.Slider(0, 1.0, value=0.5, label='BBox detection threshold')
89
+ with gr.Column(scale=2):
90
+ num_people = gr.Radio(
91
+ choices=["Single person", "Multiple people"],
92
+ value="Single person",
93
+ label="Number of people",
94
+ info="Choose how many people are there in the video. Choose 'single person' for faster inference.",
95
+ interactive=True,
96
+ scale=1,)
97
+ gr.HTML("""<br/>""")
98
+ mesh_as_vertices = gr.Checkbox(
99
+ label="Render mesh",
100
+ info="By default, the reconstructions of human bodies are expressed as vertices for faster inference. Check this option if you want to render the human body with mesh.",
101
+ interactive=True,
102
+ scale=1,)
103
+
104
+ send_button = gr.Button("Infer")
105
+ gr.HTML("""<br/>""")
106
+
107
+ with gr.Row():
108
+ with gr.Column():
109
+ video_output = gr.Video(elem_classes="video")
110
+ with gr.Column():
111
+ meshes_output = gr.File(label="3D meshes")
112
+ smplx_output = gr.File(label= "SMPL-X models")
113
+ # example_images = gr.Examples([])
114
+ send_button.click(fn=infer, inputs=[video_input, threshold, num_people, mesh_as_vertices], outputs=[video_output, meshes_output, smplx_output])
115
+ # with gr.Row():
116
+ example_videos = gr.Examples([
117
+ ['/home/user/app/assets/01.mp4'],
118
+ ['/home/user/app/assets/02.mp4'],
119
+ ['/home/user/app/assets/03.mp4'],
120
+ ['/home/user/app/assets/04.mp4'],
121
+ ['/home/user/app/assets/05.mp4'],
122
+ ['/home/user/app/assets/06.mp4'],
123
+ ['/home/user/app/assets/07.mp4'],
124
+ ['/home/user/app/assets/08.mp4'],
125
+ ['/home/user/app/assets/09.mp4'],
126
+ ],
127
+ inputs=[video_input, 0.5])
128
+
129
+ #demo.queue()
130
+ demo.launch(debug=True)
assets/conversions.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torchgeometry as tgm
5
+
6
+ __all__ = [
7
+ # functional api
8
+ "pi",
9
+ "rad2deg",
10
+ "deg2rad",
11
+ "convert_points_from_homogeneous",
12
+ "convert_points_to_homogeneous",
13
+ "angle_axis_to_rotation_matrix",
14
+ "rotation_matrix_to_angle_axis",
15
+ "rotation_matrix_to_quaternion",
16
+ "quaternion_to_angle_axis",
17
+ "angle_axis_to_quaternion",
18
+ "rtvec_to_pose",
19
+ # layer api
20
+ "RadToDeg",
21
+ "DegToRad",
22
+ "ConvertPointsFromHomogeneous",
23
+ "ConvertPointsToHomogeneous",
24
+ ]
25
+
26
+
27
+ """Constant with number pi
28
+ """
29
+ pi = torch.Tensor([3.14159265358979323846])
30
+
31
+
32
+ def rad2deg(tensor):
33
+ r"""Function that converts angles from radians to degrees.
34
+
35
+ See :class:`~torchgeometry.RadToDeg` for details.
36
+
37
+ Args:
38
+ tensor (Tensor): Tensor of arbitrary shape.
39
+
40
+ Returns:
41
+ Tensor: Tensor with same shape as input.
42
+
43
+ Example:
44
+ >>> input = tgm.pi * torch.rand(1, 3, 3)
45
+ >>> output = tgm.rad2deg(input)
46
+ """
47
+ if not torch.is_tensor(tensor):
48
+ raise TypeError("Input type is not a torch.Tensor. Got {}"
49
+ .format(type(tensor)))
50
+
51
+ return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
52
+
53
+
54
+ def deg2rad(tensor):
55
+ r"""Function that converts angles from degrees to radians.
56
+
57
+ See :class:`~torchgeometry.DegToRad` for details.
58
+
59
+ Args:
60
+ tensor (Tensor): Tensor of arbitrary shape.
61
+
62
+ Returns:
63
+ Tensor: Tensor with same shape as input.
64
+
65
+ Examples::
66
+
67
+ >>> input = 360. * torch.rand(1, 3, 3)
68
+ >>> output = tgm.deg2rad(input)
69
+ """
70
+ if not torch.is_tensor(tensor):
71
+ raise TypeError("Input type is not a torch.Tensor. Got {}"
72
+ .format(type(tensor)))
73
+
74
+ return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
75
+
76
+
77
+ def convert_points_from_homogeneous(points):
78
+ r"""Function that converts points from homogeneous to Euclidean space.
79
+
80
+ See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details.
81
+
82
+ Examples::
83
+
84
+ >>> input = torch.rand(2, 4, 3) # BxNx3
85
+ >>> output = tgm.convert_points_from_homogeneous(input) # BxNx2
86
+ """
87
+ if not torch.is_tensor(points):
88
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
89
+ type(points)))
90
+ if len(points.shape) < 2:
91
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
92
+ points.shape))
93
+
94
+ return points[..., :-1] / points[..., -1:]
95
+
96
+
97
+ def convert_points_to_homogeneous(points):
98
+ r"""Function that converts points from Euclidean to homogeneous space.
99
+
100
+ See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details.
101
+
102
+ Examples::
103
+
104
+ >>> input = torch.rand(2, 4, 3) # BxNx3
105
+ >>> output = tgm.convert_points_to_homogeneous(input) # BxNx4
106
+ """
107
+ if not torch.is_tensor(points):
108
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
109
+ type(points)))
110
+ if len(points.shape) < 2:
111
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
112
+ points.shape))
113
+
114
+ return nn.functional.pad(points, (0, 1), "constant", 1.0)
115
+
116
+
117
+ def angle_axis_to_rotation_matrix(angle_axis):
118
+ """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
119
+
120
+ Args:
121
+ angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
122
+
123
+ Returns:
124
+ Tensor: tensor of 4x4 rotation matrices.
125
+
126
+ Shape:
127
+ - Input: :math:`(N, 3)`
128
+ - Output: :math:`(N, 4, 4)`
129
+
130
+ Example:
131
+ >>> input = torch.rand(1, 3) # Nx3
132
+ >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4
133
+ """
134
+ def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
135
+ # We want to be careful to only evaluate the square root if the
136
+ # norm of the angle_axis vector is greater than zero. Otherwise
137
+ # we get a division by zero.
138
+ k_one = 1.0
139
+ theta = torch.sqrt(theta2)
140
+ wxyz = angle_axis / (theta + eps)
141
+ wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
142
+ cos_theta = torch.cos(theta)
143
+ sin_theta = torch.sin(theta)
144
+
145
+ r00 = cos_theta + wx * wx * (k_one - cos_theta)
146
+ r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
147
+ r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
148
+ r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
149
+ r11 = cos_theta + wy * wy * (k_one - cos_theta)
150
+ r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
151
+ r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
152
+ r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
153
+ r22 = cos_theta + wz * wz * (k_one - cos_theta)
154
+ rotation_matrix = torch.cat(
155
+ [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
156
+ return rotation_matrix.view(-1, 3, 3)
157
+
158
+ def _compute_rotation_matrix_taylor(angle_axis):
159
+ rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
160
+ k_one = torch.ones_like(rx)
161
+ rotation_matrix = torch.cat(
162
+ [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
163
+ return rotation_matrix.view(-1, 3, 3)
164
+
165
+ # stolen from ceres/rotation.h
166
+
167
+ _angle_axis = torch.unsqueeze(angle_axis, dim=1)
168
+ theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
169
+ theta2 = torch.squeeze(theta2, dim=1)
170
+
171
+ # compute rotation matrices
172
+ rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
173
+ rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
174
+
175
+ # create mask to handle both cases
176
+ eps = 1e-6
177
+ mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
178
+ mask_pos = (mask).type_as(theta2)
179
+ mask_neg = (mask == False).type_as(theta2) # noqa
180
+
181
+ # create output pose matrix
182
+ batch_size = angle_axis.shape[0]
183
+ rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis)
184
+ rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1)
185
+ # fill output matrix with masked values
186
+ rotation_matrix[..., :3, :3] = \
187
+ mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
188
+ return rotation_matrix # Nx4x4
189
+
190
+
191
+ def rtvec_to_pose(rtvec):
192
+ """
193
+ Convert axis-angle rotation and translation vector to 4x4 pose matrix
194
+
195
+ Args:
196
+ rtvec (Tensor): Rodrigues vector transformations
197
+
198
+ Returns:
199
+ Tensor: transformation matrices
200
+
201
+ Shape:
202
+ - Input: :math:`(N, 6)`
203
+ - Output: :math:`(N, 4, 4)`
204
+
205
+ Example:
206
+ >>> input = torch.rand(3, 6) # Nx6
207
+ >>> output = tgm.rtvec_to_pose(input) # Nx4x4
208
+ """
209
+ assert rtvec.shape[-1] == 6, 'rtvec=[rx, ry, rz, tx, ty, tz]'
210
+ pose = angle_axis_to_rotation_matrix(rtvec[..., :3])
211
+ pose[..., :3, 3] = rtvec[..., 3:]
212
+ return pose
213
+
214
+
215
+ def rotation_matrix_to_angle_axis(rotation_matrix):
216
+ """Convert 3x4 rotation matrix to Rodrigues vector
217
+
218
+ Args:
219
+ rotation_matrix (Tensor): rotation matrix.
220
+
221
+ Returns:
222
+ Tensor: Rodrigues vector transformation.
223
+
224
+ Shape:
225
+ - Input: :math:`(N, 3, 4)`
226
+ - Output: :math:`(N, 3)`
227
+
228
+ Example:
229
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
230
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
231
+ """
232
+ # todo add check that matrix is a valid rotation matrix
233
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
234
+ return quaternion_to_angle_axis(quaternion)
235
+
236
+
237
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
238
+ """Convert 3x4 rotation matrix to 4d quaternion vector
239
+
240
+ This algorithm is based on algorithm described in
241
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
242
+
243
+ Args:
244
+ rotation_matrix (Tensor): the rotation matrix to convert.
245
+
246
+ Return:
247
+ Tensor: the rotation in quaternion
248
+
249
+ Shape:
250
+ - Input: :math:`(N, 3, 4)`
251
+ - Output: :math:`(N, 4)`
252
+
253
+ Example:
254
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
255
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
256
+ """
257
+ if not torch.is_tensor(rotation_matrix):
258
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
259
+ type(rotation_matrix)))
260
+
261
+ if len(rotation_matrix.shape) > 3:
262
+ raise ValueError(
263
+ "Input size must be a three dimensional tensor. Got {}".format(
264
+ rotation_matrix.shape))
265
+ if not rotation_matrix.shape[-2:] == (3, 4):
266
+ raise ValueError(
267
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
268
+ rotation_matrix.shape))
269
+
270
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
271
+
272
+ mask_d2 = rmat_t[:, 2, 2] < eps
273
+
274
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
275
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
276
+
277
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
278
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
279
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
280
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
281
+ t0_rep = t0.repeat(4, 1).t()
282
+
283
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
284
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
285
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
286
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
287
+ t1_rep = t1.repeat(4, 1).t()
288
+
289
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
290
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
291
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
292
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
293
+ t2_rep = t2.repeat(4, 1).t()
294
+
295
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
296
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
297
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
298
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
299
+ t3_rep = t3.repeat(4, 1).t()
300
+
301
+ mask_c0 = mask_d2 * mask_d0_d1
302
+ mask_c1 = mask_d2 * ~(mask_d0_d1)
303
+ mask_c2 = ~(mask_d2) * mask_d0_nd1
304
+ mask_c3 = ~(mask_d2) * ~(mask_d0_nd1)
305
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
306
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
307
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
308
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
309
+
310
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
311
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
312
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
313
+ q *= 0.5
314
+ return q
315
+
316
+
317
+ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
318
+ """Convert quaternion vector to angle axis of rotation.
319
+
320
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
321
+
322
+ Args:
323
+ quaternion (torch.Tensor): tensor with quaternions.
324
+
325
+ Return:
326
+ torch.Tensor: tensor with angle axis of rotation.
327
+
328
+ Shape:
329
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
330
+ - Output: :math:`(*, 3)`
331
+
332
+ Example:
333
+ >>> quaternion = torch.rand(2, 4) # Nx4
334
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
335
+ """
336
+ if not torch.is_tensor(quaternion):
337
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
338
+ type(quaternion)))
339
+
340
+ if not quaternion.shape[-1] == 4:
341
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
342
+ .format(quaternion.shape))
343
+ # unpack input and compute conversion
344
+ q1: torch.Tensor = quaternion[..., 1]
345
+ q2: torch.Tensor = quaternion[..., 2]
346
+ q3: torch.Tensor = quaternion[..., 3]
347
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
348
+
349
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
350
+ cos_theta: torch.Tensor = quaternion[..., 0]
351
+ two_theta: torch.Tensor = 2.0 * torch.where(
352
+ cos_theta < 0.0,
353
+ torch.atan2(-sin_theta, -cos_theta),
354
+ torch.atan2(sin_theta, cos_theta))
355
+
356
+ k_pos: torch.Tensor = two_theta / sin_theta
357
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
358
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
359
+
360
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
361
+ angle_axis[..., 0] += q1 * k
362
+ angle_axis[..., 1] += q2 * k
363
+ angle_axis[..., 2] += q3 * k
364
+ return angle_axis
365
+
366
+ # based on:
367
+ # https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
368
+
369
+
370
+ def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor:
371
+ """Convert an angle axis to a quaternion.
372
+
373
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
374
+
375
+ Args:
376
+ angle_axis (torch.Tensor): tensor with angle axis.
377
+
378
+ Return:
379
+ torch.Tensor: tensor with quaternion.
380
+
381
+ Shape:
382
+ - Input: :math:`(*, 3)` where `*` means, any number of dimensions
383
+ - Output: :math:`(*, 4)`
384
+
385
+ Example:
386
+ >>> angle_axis = torch.rand(2, 4) # Nx4
387
+ >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3
388
+ """
389
+ if not torch.is_tensor(angle_axis):
390
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
391
+ type(angle_axis)))
392
+
393
+ if not angle_axis.shape[-1] == 3:
394
+ raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}"
395
+ .format(angle_axis.shape))
396
+ # unpack input and compute conversion
397
+ a0: torch.Tensor = angle_axis[..., 0:1]
398
+ a1: torch.Tensor = angle_axis[..., 1:2]
399
+ a2: torch.Tensor = angle_axis[..., 2:3]
400
+ theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2
401
+
402
+ theta: torch.Tensor = torch.sqrt(theta_squared)
403
+ half_theta: torch.Tensor = theta * 0.5
404
+
405
+ mask: torch.Tensor = theta_squared > 0.0
406
+ ones: torch.Tensor = torch.ones_like(half_theta)
407
+
408
+ k_neg: torch.Tensor = 0.5 * ones
409
+ k_pos: torch.Tensor = torch.sin(half_theta) / theta
410
+ k: torch.Tensor = torch.where(mask, k_pos, k_neg)
411
+ w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones)
412
+
413
+ quaternion: torch.Tensor = torch.zeros_like(angle_axis)
414
+ quaternion[..., 0:1] += a0 * k
415
+ quaternion[..., 1:2] += a1 * k
416
+ quaternion[..., 2:3] += a2 * k
417
+ return torch.cat([w, quaternion], dim=-1)
418
+
419
+ # TODO: add below funtionalities
420
+ # - pose_to_rtvec
421
+
422
+
423
+ # layer api
424
+
425
+
426
+ class RadToDeg(nn.Module):
427
+ r"""Creates an object that converts angles from radians to degrees.
428
+
429
+ Args:
430
+ tensor (Tensor): Tensor of arbitrary shape.
431
+
432
+ Returns:
433
+ Tensor: Tensor with same shape as input.
434
+
435
+ Examples::
436
+
437
+ >>> input = tgm.pi * torch.rand(1, 3, 3)
438
+ >>> output = tgm.RadToDeg()(input)
439
+ """
440
+
441
+ def __init__(self):
442
+ super(RadToDeg, self).__init__()
443
+
444
+ def forward(self, input):
445
+ return rad2deg(input)
446
+
447
+
448
+ class DegToRad(nn.Module):
449
+ r"""Function that converts angles from degrees to radians.
450
+
451
+ Args:
452
+ tensor (Tensor): Tensor of arbitrary shape.
453
+
454
+ Returns:
455
+ Tensor: Tensor with same shape as input.
456
+
457
+ Examples::
458
+
459
+ >>> input = 360. * torch.rand(1, 3, 3)
460
+ >>> output = tgm.DegToRad()(input)
461
+ """
462
+
463
+ def __init__(self):
464
+ super(DegToRad, self).__init__()
465
+
466
+ def forward(self, input):
467
+ return deg2rad(input)
468
+
469
+
470
+ class ConvertPointsFromHomogeneous(nn.Module):
471
+ r"""Creates a transformation that converts points from homogeneous to
472
+ Euclidean space.
473
+
474
+ Args:
475
+ points (Tensor): tensor of N-dimensional points.
476
+
477
+ Returns:
478
+ Tensor: tensor of N-1-dimensional points.
479
+
480
+ Shape:
481
+ - Input: :math:`(B, D, N)` or :math:`(D, N)`
482
+ - Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)`
483
+
484
+ Examples::
485
+
486
+ >>> input = torch.rand(2, 4, 3) # BxNx3
487
+ >>> transform = tgm.ConvertPointsFromHomogeneous()
488
+ >>> output = transform(input) # BxNx2
489
+ """
490
+
491
+ def __init__(self):
492
+ super(ConvertPointsFromHomogeneous, self).__init__()
493
+
494
+ def forward(self, input):
495
+ return convert_points_from_homogeneous(input)
496
+
497
+
498
+ class ConvertPointsToHomogeneous(nn.Module):
499
+ r"""Creates a transformation to convert points from Euclidean to
500
+ homogeneous space.
501
+
502
+ Args:
503
+ points (Tensor): tensor of N-dimensional points.
504
+
505
+ Returns:
506
+ Tensor: tensor of N+1-dimensional points.
507
+
508
+ Shape:
509
+ - Input: :math:`(B, D, N)` or :math:`(D, N)`
510
+ - Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)`
511
+
512
+ Examples::
513
+
514
+ >>> input = torch.rand(2, 4, 3) # BxNx3
515
+ >>> transform = tgm.ConvertPointsToHomogeneous()
516
+ >>> output = transform(input) # BxNx4
517
+ """
518
+
519
+ def __init__(self):
520
+ super(ConvertPointsToHomogeneous, self).__init__()
521
+
522
+ def forward(self, input):
523
+ return convert_points_to_homogeneous(input)
common/base.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import math
3
+ import abc
4
+ from torch.utils.data import DataLoader
5
+ import torch.optim
6
+ import torchvision.transforms as transforms
7
+ from timer import Timer
8
+ from logger import colorlogger
9
+ from torch.nn.parallel.data_parallel import DataParallel
10
+ from config import cfg
11
+ from SMPLer_X import get_model
12
+
13
+ # ddp
14
+ import torch.distributed as dist
15
+ from torch.utils.data import DistributedSampler
16
+ import torch.utils.data.distributed
17
+ from utils.distribute_utils import (
18
+ get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
19
+ )
20
+ from mmcv.runner import get_dist_info
21
+
22
+ class Base(object):
23
+ __metaclass__ = abc.ABCMeta
24
+
25
+ def __init__(self, log_name='logs.txt'):
26
+ self.cur_epoch = 0
27
+
28
+ # timer
29
+ self.tot_timer = Timer()
30
+ self.gpu_timer = Timer()
31
+ self.read_timer = Timer()
32
+
33
+ # logger
34
+ self.logger = colorlogger(cfg.log_dir, log_name=log_name)
35
+
36
+ @abc.abstractmethod
37
+ def _make_batch_generator(self):
38
+ return
39
+
40
+ @abc.abstractmethod
41
+ def _make_model(self):
42
+ return
43
+
44
+ class Demoer(Base):
45
+ def __init__(self, test_epoch=None):
46
+ if test_epoch is not None:
47
+ self.test_epoch = int(test_epoch)
48
+ super(Demoer, self).__init__(log_name='test_logs.txt')
49
+
50
+ def _make_batch_generator(self, demo_scene):
51
+ # data load and construct batch generator
52
+ self.logger.info("Creating dataset...")
53
+ from data.UBody.UBody import UBody
54
+ testset_loader = UBody(transforms.ToTensor(), "demo", demo_scene) # eval(demoset)(transforms.ToTensor(), "demo")
55
+ batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus * cfg.test_batch_size,
56
+ shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
57
+
58
+ self.testset = testset_loader
59
+ self.batch_generator = batch_generator
60
+
61
+ def _make_model(self):
62
+ self.logger.info('Load checkpoint from {}'.format(cfg.pretrained_model_path))
63
+
64
+ # prepare network
65
+ self.logger.info("Creating graph...")
66
+ model = get_model('test')
67
+ model = DataParallel(model).to(cfg.device)
68
+ ckpt = torch.load(cfg.pretrained_model_path, map_location=cfg.device)
69
+
70
+ from collections import OrderedDict
71
+ new_state_dict = OrderedDict()
72
+ for k, v in ckpt['network'].items():
73
+ if 'module' not in k:
74
+ k = 'module.' + k
75
+ k = k.replace('module.backbone', 'module.encoder').replace('body_rotation_net', 'body_regressor').replace(
76
+ 'hand_rotation_net', 'hand_regressor')
77
+ new_state_dict[k] = v
78
+ model.load_state_dict(new_state_dict, strict=False)
79
+ model.eval()
80
+
81
+ self.model = model
82
+
83
+ def _evaluate(self, outs, cur_sample_idx):
84
+ eval_result = self.testset.evaluate(outs, cur_sample_idx)
85
+ return eval_result
86
+
common/logger.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ OK = '\033[92m'
5
+ WARNING = '\033[93m'
6
+ FAIL = '\033[91m'
7
+ END = '\033[0m'
8
+
9
+ PINK = '\033[95m'
10
+ BLUE = '\033[94m'
11
+ GREEN = OK
12
+ RED = FAIL
13
+ WHITE = END
14
+ YELLOW = WARNING
15
+
16
+ class colorlogger():
17
+ def __init__(self, log_dir, log_name='train_logs.txt'):
18
+ # set log
19
+ self._logger = logging.getLogger(log_name)
20
+ self._logger.setLevel(logging.INFO)
21
+ log_file = os.path.join(log_dir, log_name)
22
+ if not os.path.exists(log_dir):
23
+ os.makedirs(log_dir)
24
+ file_log = logging.FileHandler(log_file, mode='a')
25
+ file_log.setLevel(logging.INFO)
26
+ console_log = logging.StreamHandler()
27
+ console_log.setLevel(logging.INFO)
28
+ formatter = logging.Formatter(
29
+ "{}%(asctime)s{} %(message)s".format(GREEN, END),
30
+ "%m-%d %H:%M:%S")
31
+ file_log.setFormatter(formatter)
32
+ console_log.setFormatter(formatter)
33
+ self._logger.addHandler(file_log)
34
+ self._logger.addHandler(console_log)
35
+
36
+ def debug(self, msg):
37
+ self._logger.debug(str(msg))
38
+
39
+ def info(self, msg):
40
+ self._logger.info(str(msg))
41
+
42
+ def warning(self, msg):
43
+ self._logger.warning(WARNING + 'WRN: ' + str(msg) + END)
44
+
45
+ def critical(self, msg):
46
+ self._logger.critical(RED + 'CRI: ' + str(msg) + END)
47
+
48
+ def error(self, msg):
49
+ self._logger.error(RED + 'ERR: ' + str(msg) + END)
50
+
common/nets/layer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def make_linear_layers(feat_dims, relu_final=True, use_bn=False):
4
+ layers = []
5
+ for i in range(len(feat_dims)-1):
6
+ layers.append(nn.Linear(feat_dims[i], feat_dims[i+1]))
7
+
8
+ # Do not use ReLU for final estimation
9
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final):
10
+ if use_bn:
11
+ layers.append(nn.BatchNorm1d(feat_dims[i+1]))
12
+ layers.append(nn.ReLU(inplace=True))
13
+
14
+ return nn.Sequential(*layers)
15
+
16
+ def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
17
+ layers = []
18
+ for i in range(len(feat_dims)-1):
19
+ layers.append(
20
+ nn.Conv2d(
21
+ in_channels=feat_dims[i],
22
+ out_channels=feat_dims[i+1],
23
+ kernel_size=kernel,
24
+ stride=stride,
25
+ padding=padding
26
+ ))
27
+ # Do not use BN and ReLU for final estimation
28
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
29
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
30
+ layers.append(nn.ReLU(inplace=True))
31
+
32
+ return nn.Sequential(*layers)
33
+
34
+ def make_deconv_layers(feat_dims, bnrelu_final=True):
35
+ layers = []
36
+ for i in range(len(feat_dims)-1):
37
+ layers.append(
38
+ nn.ConvTranspose2d(
39
+ in_channels=feat_dims[i],
40
+ out_channels=feat_dims[i+1],
41
+ kernel_size=4,
42
+ stride=2,
43
+ padding=1,
44
+ output_padding=0,
45
+ bias=False))
46
+
47
+ # Do not use BN and ReLU for final estimation
48
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
49
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
50
+ layers.append(nn.ReLU(inplace=True))
51
+
52
+ return nn.Sequential(*layers)
53
+
common/nets/loss.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CoordLoss(nn.Module):
5
+ def __init__(self):
6
+ super(CoordLoss, self).__init__()
7
+
8
+ def forward(self, coord_out, coord_gt, valid, is_3D=None):
9
+ loss = torch.abs(coord_out - coord_gt) * valid
10
+ if is_3D is not None:
11
+ loss_z = loss[:,:,2:] * is_3D[:,None,None].float()
12
+ loss = torch.cat((loss[:,:,:2], loss_z),2)
13
+ return loss
14
+
15
+ class ParamLoss(nn.Module):
16
+ def __init__(self):
17
+ super(ParamLoss, self).__init__()
18
+
19
+ def forward(self, param_out, param_gt, valid):
20
+ loss = torch.abs(param_out - param_gt) * valid
21
+ return loss
22
+
23
+ class CELoss(nn.Module):
24
+ def __init__(self):
25
+ super(CELoss, self).__init__()
26
+ self.ce_loss = nn.CrossEntropyLoss(reduction='none')
27
+
28
+ def forward(self, out, gt_index):
29
+ loss = self.ce_loss(out, gt_index)
30
+ return loss
common/nets/smpler_x.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from nets.layer import make_conv_layers, make_linear_layers, make_deconv_layers
5
+ from utils.transforms import sample_joint_features, soft_argmax_2d, soft_argmax_3d
6
+ from utils.human_models import smpl_x
7
+ from config import cfg
8
+ from mmcv.ops.roi_align import roi_align
9
+
10
+ class PositionNet(nn.Module):
11
+ def __init__(self, part, feat_dim=768):
12
+ super(PositionNet, self).__init__()
13
+ if part == 'body':
14
+ self.joint_num = len(smpl_x.pos_joint_part['body'])
15
+ self.hm_shape = cfg.output_hm_shape
16
+ elif part == 'hand':
17
+ self.joint_num = len(smpl_x.pos_joint_part['rhand'])
18
+ self.hm_shape = cfg.output_hand_hm_shape
19
+ self.conv = make_conv_layers([feat_dim, self.joint_num * self.hm_shape[0]], kernel=1, stride=1, padding=0, bnrelu_final=False)
20
+
21
+ def forward(self, img_feat):
22
+ joint_hm = self.conv(img_feat).view(-1, self.joint_num, self.hm_shape[0], self.hm_shape[1], self.hm_shape[2])
23
+ joint_coord = soft_argmax_3d(joint_hm)
24
+ joint_hm = F.softmax(joint_hm.view(-1, self.joint_num, self.hm_shape[0] * self.hm_shape[1] * self.hm_shape[2]), 2)
25
+ joint_hm = joint_hm.view(-1, self.joint_num, self.hm_shape[0], self.hm_shape[1], self.hm_shape[2])
26
+ return joint_hm, joint_coord
27
+
28
+ class HandRotationNet(nn.Module):
29
+ def __init__(self, part, feat_dim = 768):
30
+ super(HandRotationNet, self).__init__()
31
+ self.part = part
32
+ self.joint_num = len(smpl_x.pos_joint_part['rhand'])
33
+ self.hand_conv = make_conv_layers([feat_dim, 512], kernel=1, stride=1, padding=0)
34
+ self.hand_pose_out = make_linear_layers([self.joint_num * 515, len(smpl_x.orig_joint_part['rhand']) * 6], relu_final=False)
35
+ self.feat_dim = feat_dim
36
+
37
+ def forward(self, img_feat, joint_coord_img):
38
+ batch_size = img_feat.shape[0]
39
+ img_feat = self.hand_conv(img_feat)
40
+ img_feat_joints = sample_joint_features(img_feat, joint_coord_img[:, :, :2])
41
+ feat = torch.cat((img_feat_joints, joint_coord_img), 2) # batch_size, joint_num, 512+3
42
+ hand_pose = self.hand_pose_out(feat.view(batch_size, -1))
43
+ return hand_pose
44
+
45
+ class BodyRotationNet(nn.Module):
46
+ def __init__(self, feat_dim = 768):
47
+ super(BodyRotationNet, self).__init__()
48
+ self.joint_num = len(smpl_x.pos_joint_part['body'])
49
+ self.body_conv = make_linear_layers([feat_dim, 512], relu_final=False)
50
+ self.root_pose_out = make_linear_layers([self.joint_num * (512+3), 6], relu_final=False)
51
+ self.body_pose_out = make_linear_layers(
52
+ [self.joint_num * (512+3), (len(smpl_x.orig_joint_part['body']) - 1) * 6], relu_final=False) # without root
53
+ self.shape_out = make_linear_layers([feat_dim, smpl_x.shape_param_dim], relu_final=False)
54
+ self.cam_out = make_linear_layers([feat_dim, 3], relu_final=False)
55
+ self.feat_dim = feat_dim
56
+
57
+ def forward(self, body_pose_token, shape_token, cam_token, body_joint_img):
58
+ batch_size = body_pose_token.shape[0]
59
+
60
+ # shape parameter
61
+ shape_param = self.shape_out(shape_token)
62
+
63
+ # camera parameter
64
+ cam_param = self.cam_out(cam_token)
65
+
66
+ # body pose parameter
67
+ body_pose_token = self.body_conv(body_pose_token)
68
+ body_pose_token = torch.cat((body_pose_token, body_joint_img), 2)
69
+ root_pose = self.root_pose_out(body_pose_token.view(batch_size, -1))
70
+ body_pose = self.body_pose_out(body_pose_token.view(batch_size, -1))
71
+
72
+ return root_pose, body_pose, shape_param, cam_param
73
+
74
+ class FaceRegressor(nn.Module):
75
+ def __init__(self, feat_dim=768):
76
+ super(FaceRegressor, self).__init__()
77
+ self.expr_out = make_linear_layers([feat_dim, smpl_x.expr_code_dim], relu_final=False)
78
+ self.jaw_pose_out = make_linear_layers([feat_dim, 6], relu_final=False)
79
+
80
+ def forward(self, expr_token, jaw_pose_token):
81
+ expr_param = self.expr_out(expr_token) # expression parameter
82
+ jaw_pose = self.jaw_pose_out(jaw_pose_token) # jaw pose parameter
83
+ return expr_param, jaw_pose
84
+
85
+ class BoxNet(nn.Module):
86
+ def __init__(self, feat_dim=768):
87
+ super(BoxNet, self).__init__()
88
+ self.joint_num = len(smpl_x.pos_joint_part['body'])
89
+ self.deconv = make_deconv_layers([feat_dim + self.joint_num * cfg.output_hm_shape[0], 256, 256, 256])
90
+ self.bbox_center = make_conv_layers([256, 3], kernel=1, stride=1, padding=0, bnrelu_final=False)
91
+ self.lhand_size = make_linear_layers([256, 256, 2], relu_final=False)
92
+ self.rhand_size = make_linear_layers([256, 256, 2], relu_final=False)
93
+ self.face_size = make_linear_layers([256, 256, 2], relu_final=False)
94
+
95
+ def forward(self, img_feat, joint_hm):
96
+ joint_hm = joint_hm.view(joint_hm.shape[0], joint_hm.shape[1] * cfg.output_hm_shape[0], cfg.output_hm_shape[1], cfg.output_hm_shape[2])
97
+ img_feat = torch.cat((img_feat, joint_hm), 1)
98
+ img_feat = self.deconv(img_feat)
99
+
100
+ # bbox center
101
+ bbox_center_hm = self.bbox_center(img_feat)
102
+ bbox_center = soft_argmax_2d(bbox_center_hm)
103
+ lhand_center, rhand_center, face_center = bbox_center[:, 0, :], bbox_center[:, 1, :], bbox_center[:, 2, :]
104
+
105
+ # bbox size
106
+ lhand_feat = sample_joint_features(img_feat, lhand_center[:, None, :].detach())[:, 0, :]
107
+ lhand_size = self.lhand_size(lhand_feat)
108
+ rhand_feat = sample_joint_features(img_feat, rhand_center[:, None, :].detach())[:, 0, :]
109
+ rhand_size = self.rhand_size(rhand_feat)
110
+ face_feat = sample_joint_features(img_feat, face_center[:, None, :].detach())[:, 0, :]
111
+ face_size = self.face_size(face_feat)
112
+
113
+ lhand_center = lhand_center / 8
114
+ rhand_center = rhand_center / 8
115
+ face_center = face_center / 8
116
+ return lhand_center, lhand_size, rhand_center, rhand_size, face_center, face_size
117
+
118
+ class BoxSizeNet(nn.Module):
119
+ def __init__(self):
120
+ super(BoxSizeNet, self).__init__()
121
+ self.lhand_size = make_linear_layers([256, 256, 2], relu_final=False)
122
+ self.rhand_size = make_linear_layers([256, 256, 2], relu_final=False)
123
+ self.face_size = make_linear_layers([256, 256, 2], relu_final=False)
124
+
125
+ def forward(self, box_fea):
126
+ # box_fea: [bs, 3, C]
127
+ lhand_size = self.lhand_size(box_fea[:, 0])
128
+ rhand_size = self.rhand_size(box_fea[:, 1])
129
+ face_size = self.face_size(box_fea[:, 2])
130
+ return lhand_size, rhand_size, face_size
131
+
132
+ class HandRoI(nn.Module):
133
+ def __init__(self, feat_dim=768, upscale=4):
134
+ super(HandRoI, self).__init__()
135
+ self.upscale = upscale
136
+ if upscale==1:
137
+ self.deconv = make_conv_layers([feat_dim, feat_dim], kernel=1, stride=1, padding=0, bnrelu_final=False)
138
+ self.conv = make_conv_layers([feat_dim, feat_dim], kernel=1, stride=1, padding=0, bnrelu_final=False)
139
+ elif upscale==2:
140
+ self.deconv = make_deconv_layers([feat_dim, feat_dim//2])
141
+ self.conv = make_conv_layers([feat_dim//2, feat_dim], kernel=1, stride=1, padding=0, bnrelu_final=False)
142
+ elif upscale==4:
143
+ self.deconv = make_deconv_layers([feat_dim, feat_dim//2, feat_dim//4])
144
+ self.conv = make_conv_layers([feat_dim//4, feat_dim], kernel=1, stride=1, padding=0, bnrelu_final=False)
145
+ elif upscale==8:
146
+ self.deconv = make_deconv_layers([feat_dim, feat_dim//2, feat_dim//4, feat_dim//8])
147
+ self.conv = make_conv_layers([feat_dim//8, feat_dim], kernel=1, stride=1, padding=0, bnrelu_final=False)
148
+
149
+ def forward(self, img_feat, lhand_bbox, rhand_bbox):
150
+ lhand_bbox = torch.cat((torch.arange(lhand_bbox.shape[0]).float().to(cfg.device)[:, None], lhand_bbox),
151
+ 1) # batch_idx, xmin, ymin, xmax, ymax
152
+ rhand_bbox = torch.cat((torch.arange(rhand_bbox.shape[0]).float().to(cfg.device)[:, None], rhand_bbox),
153
+ 1) # batch_idx, xmin, ymin, xmax, ymax
154
+ img_feat = self.deconv(img_feat)
155
+ lhand_bbox_roi = lhand_bbox.clone()
156
+ lhand_bbox_roi[:, 1] = lhand_bbox_roi[:, 1] / cfg.input_body_shape[1] * cfg.output_hm_shape[2] * self.upscale
157
+ lhand_bbox_roi[:, 2] = lhand_bbox_roi[:, 2] / cfg.input_body_shape[0] * cfg.output_hm_shape[1] * self.upscale
158
+ lhand_bbox_roi[:, 3] = lhand_bbox_roi[:, 3] / cfg.input_body_shape[1] * cfg.output_hm_shape[2] * self.upscale
159
+ lhand_bbox_roi[:, 4] = lhand_bbox_roi[:, 4] / cfg.input_body_shape[0] * cfg.output_hm_shape[1] * self.upscale
160
+ assert (cfg.output_hm_shape[1]*self.upscale, cfg.output_hm_shape[2]*self.upscale) == (img_feat.shape[2], img_feat.shape[3])
161
+ lhand_img_feat = roi_align(img_feat, lhand_bbox_roi, (cfg.output_hand_hm_shape[1], cfg.output_hand_hm_shape[2]), 1.0, 0, 'avg', False)
162
+ lhand_img_feat = torch.flip(lhand_img_feat, [3]) # flip to the right hand
163
+
164
+ rhand_bbox_roi = rhand_bbox.clone()
165
+ rhand_bbox_roi[:, 1] = rhand_bbox_roi[:, 1] / cfg.input_body_shape[1] * cfg.output_hm_shape[2] * self.upscale
166
+ rhand_bbox_roi[:, 2] = rhand_bbox_roi[:, 2] / cfg.input_body_shape[0] * cfg.output_hm_shape[1] * self.upscale
167
+ rhand_bbox_roi[:, 3] = rhand_bbox_roi[:, 3] / cfg.input_body_shape[1] * cfg.output_hm_shape[2] * self.upscale
168
+ rhand_bbox_roi[:, 4] = rhand_bbox_roi[:, 4] / cfg.input_body_shape[0] * cfg.output_hm_shape[1] * self.upscale
169
+ rhand_img_feat = roi_align(img_feat, rhand_bbox_roi, (cfg.output_hand_hm_shape[1], cfg.output_hand_hm_shape[2]), 1.0, 0, 'avg', False)
170
+ hand_img_feat = torch.cat((lhand_img_feat, rhand_img_feat)) # [bs, c, cfg.output_hand_hm_shape[2]*scale, cfg.output_hand_hm_shape[1]*scale]
171
+ hand_img_feat = self.conv(hand_img_feat)
172
+ return hand_img_feat
common/timer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import time
9
+
10
+ class Timer(object):
11
+ """A simple timer."""
12
+ def __init__(self):
13
+ self.total_time = 0.
14
+ self.calls = 0
15
+ self.start_time = 0.
16
+ self.diff = 0.
17
+ self.average_time = 0.
18
+ self.warm_up = 0
19
+
20
+ def tic(self):
21
+ # using time.time instead of time.clock because time time.clock
22
+ # does not normalize for multithreading
23
+ self.start_time = time.time()
24
+
25
+ def toc(self, average=True):
26
+ self.diff = time.time() - self.start_time
27
+ if self.warm_up < 10:
28
+ self.warm_up += 1
29
+ return self.diff
30
+ else:
31
+ self.total_time += self.diff
32
+ self.calls += 1
33
+ self.average_time = self.total_time / self.calls
34
+
35
+ if average:
36
+ return self.average_time
37
+ else:
38
+ return self.diff
common/utils/__init__.py ADDED
File without changes
common/utils/dir.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ def make_folder(folder_name):
5
+ os.makedirs(folder_name, exist_ok=True)
6
+
7
+ def add_pypath(path):
8
+ if path not in sys.path:
9
+ sys.path.insert(0, path)
10
+
common/utils/distribute_utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import os
3
+ import os.path as osp
4
+ import pickle
5
+ import shutil
6
+ import tempfile
7
+ import time
8
+ import torch
9
+ import torch.distributed as dist
10
+ from mmcv.runner import get_dist_info
11
+ import random
12
+ import numpy as np
13
+ import subprocess
14
+
15
+ def set_seed(seed):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ # torch.set_deterministic(True)
21
+
22
+
23
+ def time_synchronized():
24
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
25
+ return time.time()
26
+
27
+
28
+ def setup_for_distributed(is_master):
29
+ """This function disables printing when not in master process."""
30
+ import builtins as __builtin__
31
+ builtin_print = __builtin__.print
32
+
33
+ def print(*args, **kwargs):
34
+ force = kwargs.pop('force', False)
35
+ if is_master or force:
36
+ builtin_print(*args, **kwargs)
37
+
38
+ __builtin__.print = print
39
+
40
+
41
+ def init_distributed_mode(port = None, master_port=29500):
42
+ """Initialize slurm distributed training environment.
43
+
44
+ If argument ``port`` is not specified, then the master port will be system
45
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
46
+ environment variable, then a default port ``29500`` will be used.
47
+
48
+ Args:
49
+ backend (str): Backend of torch.distributed.
50
+ port (int, optional): Master port. Defaults to None.
51
+ """
52
+ dist_backend = 'nccl'
53
+ proc_id = int(os.environ['SLURM_PROCID'])
54
+ ntasks = int(os.environ['SLURM_NTASKS'])
55
+ node_list = os.environ['SLURM_NODELIST']
56
+ num_gpus = torch.cuda.device_count()
57
+ torch.cuda.set_device(proc_id % num_gpus)
58
+ addr = subprocess.getoutput(
59
+ f'scontrol show hostname {node_list} | head -n1')
60
+ # specify master port
61
+ if port is not None:
62
+ os.environ['MASTER_PORT'] = str(port)
63
+ elif 'MASTER_PORT' in os.environ:
64
+ pass # use MASTER_PORT in the environment variable
65
+ else:
66
+ # 29500 is torch.distributed default port
67
+ os.environ['MASTER_PORT'] = str(master_port)
68
+ # use MASTER_ADDR in the environment variable if it already exists
69
+ if 'MASTER_ADDR' not in os.environ:
70
+ os.environ['MASTER_ADDR'] = addr
71
+ os.environ['WORLD_SIZE'] = str(ntasks)
72
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
73
+ os.environ['RANK'] = str(proc_id)
74
+ dist.init_process_group(backend=dist_backend)
75
+
76
+ distributed = True
77
+ gpu_idx = proc_id % num_gpus
78
+
79
+ return distributed, gpu_idx
80
+
81
+
82
+ def is_dist_avail_and_initialized():
83
+ if not dist.is_available():
84
+ return False
85
+ if not dist.is_initialized():
86
+ return False
87
+ return True
88
+
89
+
90
+ def get_world_size():
91
+ if not is_dist_avail_and_initialized():
92
+ return 1
93
+ return dist.get_world_size()
94
+
95
+
96
+ def get_rank():
97
+ if not is_dist_avail_and_initialized():
98
+ return 0
99
+ return dist.get_rank()
100
+
101
+ def get_process_groups():
102
+ world_size = int(os.environ['WORLD_SIZE'])
103
+ ranks = list(range(world_size))
104
+ num_gpus = torch.cuda.device_count()
105
+ num_nodes = world_size // num_gpus
106
+ if world_size % num_gpus != 0:
107
+ raise NotImplementedError('Not implemented for node not fully used.')
108
+
109
+ groups = []
110
+ for node_idx in range(num_nodes):
111
+ groups.append(ranks[node_idx*num_gpus : (node_idx+1)*num_gpus])
112
+ process_groups = [torch.distributed.new_group(group) for group in groups]
113
+
114
+ return process_groups
115
+
116
+ def get_group_idx():
117
+ num_gpus = torch.cuda.device_count()
118
+ proc_id = get_rank()
119
+ group_idx = proc_id // num_gpus
120
+
121
+ return group_idx
122
+
123
+
124
+ def is_main_process():
125
+ return get_rank() == 0
126
+
127
+ def cleanup():
128
+ dist.destroy_process_group()
129
+
130
+
131
+ def collect_results(result_part, size, tmpdir=None):
132
+ rank, world_size = get_dist_info()
133
+ # create a tmp dir if it is not specified
134
+ if tmpdir is None:
135
+ MAX_LEN = 512
136
+ # 32 is whitespace
137
+ dir_tensor = torch.full((MAX_LEN, ),
138
+ 32,
139
+ dtype=torch.uint8,
140
+ device='cuda')
141
+ if rank == 0:
142
+ tmpdir = tempfile.mkdtemp()
143
+ tmpdir = torch.tensor(
144
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
145
+ dir_tensor[:len(tmpdir)] = tmpdir
146
+ dist.broadcast(dir_tensor, 0)
147
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
148
+ else:
149
+ mmcv.mkdir_or_exist(tmpdir)
150
+ # dump the part result to the dir
151
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
152
+ dist.barrier()
153
+ # collect all parts
154
+ if rank != 0:
155
+ return None
156
+ else:
157
+ # load results of all parts from tmp dir
158
+ part_list = []
159
+ for i in range(world_size):
160
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
161
+ part_list.append(mmcv.load(part_file))
162
+ # sort the results
163
+ ordered_results = []
164
+ for res in zip(*part_list):
165
+ ordered_results.extend(list(res))
166
+ # the dataloader may pad some samples
167
+ ordered_results = ordered_results[:size]
168
+ # remove tmp dir
169
+ shutil.rmtree(tmpdir)
170
+ return ordered_results
171
+
172
+
173
+ def all_gather(data):
174
+ """
175
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
176
+ Args:
177
+ data:
178
+ Any picklable object
179
+ Returns:
180
+ data_list(list):
181
+ List of data gathered from each rank
182
+ """
183
+ world_size = get_world_size()
184
+ if world_size == 1:
185
+ return [data]
186
+
187
+ # serialized to a Tensor
188
+ buffer = pickle.dumps(data)
189
+ storage = torch.ByteStorage.from_buffer(buffer)
190
+ tensor = torch.ByteTensor(storage).to('cuda')
191
+
192
+ # obtain Tensor size of each rank
193
+ local_size = torch.tensor([tensor.numel()], device='cuda')
194
+ size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)]
195
+ dist.all_gather(size_list, local_size)
196
+ size_list = [int(size.item()) for size in size_list]
197
+ max_size = max(size_list)
198
+
199
+ # receiving Tensor from all ranks
200
+ # we pad the tensor because torch all_gather does not support
201
+ # gathering tensors of different shapes
202
+ tensor_list = []
203
+ for _ in size_list:
204
+ tensor_list.append(
205
+ torch.empty((max_size, ), dtype=torch.uint8, device='cuda'))
206
+ if local_size != max_size:
207
+ padding = torch.empty(
208
+ size=(max_size - local_size, ), dtype=torch.uint8, device='cuda')
209
+ tensor = torch.cat((tensor, padding), dim=0)
210
+ dist.all_gather(tensor_list, tensor)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+
217
+ return data_list
common/utils/human_model_files/smpl/SMPL_FEMALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d4a1791b6b94880397e1a3a4539b703a228d2150c57de7b288389a8115f4ef0
3
+ size 247530000
common/utils/human_model_files/smpl/SMPL_MALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed4d55bb3041fefc6f73b70694d6c8edc1020c0d07340be5cc651cae2c6a6ae3
3
+ size 247101031
common/utils/human_model_files/smpl/SMPL_NEUTRAL.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4924f235e63f7c5d5b690acedf736419c2edb846a2d69fc0956169615fa75688
3
+ size 247186228
common/utils/human_model_files/smpl/smpl_uv.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb2a1aaf8be2091ebc4344daefae0622cc09252b33d4f6c36ea2c6541a01d469
3
+ size 1524004
common/utils/human_model_files/smplx/MANO_SMPLX_vertex_ids.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5abe70b6574de25470475091e8008314a5b90127eb48c3e63bfa0adf8c04dcf
3
+ size 13535
common/utils/human_model_files/smplx/SMPL-X__FLAME_vertex_ids.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e70cdc3659aae699b9732e8dd4af49106310c69b90dc83d9f73e96dbf871e49
3
+ size 40312
common/utils/human_model_files/smplx/SMPLX_FEMALE.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2a3686c9d6d218ff6822fba411c607a3c8125a70af340f384ce68bebecabe0e
3
+ size 108794146
common/utils/human_model_files/smplx/SMPLX_MALE.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab318e3f37d2bfaae26abf4e6fab445c2a610e1d63714794d60379cc263bc2a5
3
+ size 108753445
common/utils/human_model_files/smplx/SMPLX_NEUTRAL.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:376021446ddc86e99acacd795182bbef903e61d33b76b9d8b359c2b0865bd992
3
+ size 108752058
common/utils/human_model_files/smplx/SMPLX_NEUTRAL.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:381c808965deb4f5e845f8c3eddb0cd69930cc72e5774ce4f34c4ce3cf058361
3
+ size 544173380
common/utils/human_model_files/smplx/SMPLX_to_J14.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df844ddea85b0a400a2e8dbe63d09d19f2b1b7ec0e0e952daeae08f83d82d61
3
+ size 4692193
common/utils/human_models.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os.path as osp
4
+ from config import cfg
5
+ from utils.smplx import smplx
6
+ import pickle
7
+
8
+ class SMPLX(object):
9
+ def __init__(self):
10
+ self.layer_arg = {'create_global_orient': False, 'create_body_pose': False, 'create_left_hand_pose': False, 'create_right_hand_pose': False, 'create_jaw_pose': False, 'create_leye_pose': False, 'create_reye_pose': False, 'create_betas': False, 'create_expression': False, 'create_transl': False}
11
+ self.layer = {'neutral': smplx.create(cfg.human_model_path, 'smplx', gender='NEUTRAL', use_pca=False, use_face_contour=True, **self.layer_arg),
12
+ 'male': smplx.create(cfg.human_model_path, 'smplx', gender='MALE', use_pca=False, use_face_contour=True, **self.layer_arg),
13
+ 'female': smplx.create(cfg.human_model_path, 'smplx', gender='FEMALE', use_pca=False, use_face_contour=True, **self.layer_arg)
14
+ }
15
+ self.vertex_num = 10475
16
+ self.face = self.layer['neutral'].faces
17
+ self.shape_param_dim = 10
18
+ self.expr_code_dim = 10
19
+ with open(osp.join(cfg.human_model_path, 'smplx', 'SMPLX_to_J14.pkl'), 'rb') as f:
20
+ self.j14_regressor = pickle.load(f, encoding='latin1')
21
+ with open(osp.join(cfg.human_model_path, 'smplx', 'MANO_SMPLX_vertex_ids.pkl'), 'rb') as f:
22
+ self.hand_vertex_idx = pickle.load(f, encoding='latin1')
23
+ self.face_vertex_idx = np.load(osp.join(cfg.human_model_path, 'smplx', 'SMPL-X__FLAME_vertex_ids.npy'))
24
+ self.J_regressor = self.layer['neutral'].J_regressor.numpy()
25
+ self.J_regressor_idx = {'pelvis': 0, 'lwrist': 20, 'rwrist': 21, 'neck': 12}
26
+ self.orig_hand_regressor = self.make_hand_regressor()
27
+ #self.orig_hand_regressor = {'left': self.layer.J_regressor.numpy()[[20,37,38,39,25,26,27,28,29,30,34,35,36,31,32,33],:], 'right': self.layer.J_regressor.numpy()[[21,52,53,54,40,41,42,43,44,45,49,50,51,46,47,48],:]}
28
+
29
+ # original SMPLX joint set
30
+ self.orig_joint_num = 53 # 22 (body joints) + 30 (hand joints) + 1 (face jaw joint)
31
+ self.orig_joints_name = \
32
+ ('Pelvis', 'L_Hip', 'R_Hip', 'Spine_1', 'L_Knee', 'R_Knee', 'Spine_2', 'L_Ankle', 'R_Ankle', 'Spine_3', 'L_Foot', 'R_Foot', 'Neck', 'L_Collar', 'R_Collar', 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', # body joints
33
+ 'L_Index_1', 'L_Index_2', 'L_Index_3', 'L_Middle_1', 'L_Middle_2', 'L_Middle_3', 'L_Pinky_1', 'L_Pinky_2', 'L_Pinky_3', 'L_Ring_1', 'L_Ring_2', 'L_Ring_3', 'L_Thumb_1', 'L_Thumb_2', 'L_Thumb_3', # left hand joints
34
+ 'R_Index_1', 'R_Index_2', 'R_Index_3', 'R_Middle_1', 'R_Middle_2', 'R_Middle_3', 'R_Pinky_1', 'R_Pinky_2', 'R_Pinky_3', 'R_Ring_1', 'R_Ring_2', 'R_Ring_3', 'R_Thumb_1', 'R_Thumb_2', 'R_Thumb_3', # right hand joints
35
+ 'Jaw' # face jaw joint
36
+ )
37
+ self.orig_flip_pairs = \
38
+ ( (1,2), (4,5), (7,8), (10,11), (13,14), (16,17), (18,19), (20,21), # body joints
39
+ (22,37), (23,38), (24,39), (25,40), (26,41), (27,42), (28,43), (29,44), (30,45), (31,46), (32,47), (33,48), (34,49), (35,50), (36,51) # hand joints
40
+ )
41
+ self.orig_root_joint_idx = self.orig_joints_name.index('Pelvis')
42
+ self.orig_joint_part = \
43
+ {'body': range(self.orig_joints_name.index('Pelvis'), self.orig_joints_name.index('R_Wrist')+1),
44
+ 'lhand': range(self.orig_joints_name.index('L_Index_1'), self.orig_joints_name.index('L_Thumb_3')+1),
45
+ 'rhand': range(self.orig_joints_name.index('R_Index_1'), self.orig_joints_name.index('R_Thumb_3')+1),
46
+ 'face': range(self.orig_joints_name.index('Jaw'), self.orig_joints_name.index('Jaw')+1)}
47
+
48
+ # changed SMPLX joint set for the supervision
49
+ self.joint_num = 137 # 25 (body joints) + 40 (hand joints) + 72 (face keypoints)
50
+ self.joints_name = \
51
+ ('Pelvis', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Neck', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Big_toe', 'L_Small_toe', 'L_Heel', 'R_Big_toe', 'R_Small_toe', 'R_Heel', 'L_Ear', 'R_Ear', 'L_Eye', 'R_Eye', 'Nose',# body joints
52
+ 'L_Thumb_1', 'L_Thumb_2', 'L_Thumb_3', 'L_Thumb_4', 'L_Index_1', 'L_Index_2', 'L_Index_3', 'L_Index_4', 'L_Middle_1', 'L_Middle_2', 'L_Middle_3', 'L_Middle_4', 'L_Ring_1', 'L_Ring_2', 'L_Ring_3', 'L_Ring_4', 'L_Pinky_1', 'L_Pinky_2', 'L_Pinky_3', 'L_Pinky_4', # left hand joints
53
+ 'R_Thumb_1', 'R_Thumb_2', 'R_Thumb_3', 'R_Thumb_4', 'R_Index_1', 'R_Index_2', 'R_Index_3', 'R_Index_4', 'R_Middle_1', 'R_Middle_2', 'R_Middle_3', 'R_Middle_4', 'R_Ring_1', 'R_Ring_2', 'R_Ring_3', 'R_Ring_4', 'R_Pinky_1', 'R_Pinky_2', 'R_Pinky_3', 'R_Pinky_4', # right hand joints
54
+ *['Face_' + str(i) for i in range(1,73)] # face keypoints (too many keypoints... omit real names. have same name of keypoints defined in FLAME class)
55
+ )
56
+ self.root_joint_idx = self.joints_name.index('Pelvis')
57
+ self.lwrist_idx = self.joints_name.index('L_Wrist')
58
+ self.rwrist_idx = self.joints_name.index('R_Wrist')
59
+ self.neck_idx = self.joints_name.index('Neck')
60
+ self.flip_pairs = \
61
+ ( (1,2), (3,4), (5,6), (8,9), (10,11), (12,13), (14,17), (15,18), (16,19), (20,21), (22,23), # body joints
62
+ (25,45), (26,46), (27,47), (28,48), (29,49), (30,50), (31,51), (32,52), (33,53), (34,54), (35,55), (36,56), (37,57), (38,58), (39,59), (40,60), (41,61), (42,62), (43,63), (44,64), # hand joints
63
+ (67,68), # face eyeballs
64
+ (69,78), (70,77), (71,76), (72,75), (73,74), # face eyebrow
65
+ (83,87), (84,86), # face below nose
66
+ (88,97), (89,96), (90,95), (91,94), (92,99), (93,98), # face eyes
67
+ (100,106), (101,105), (102,104), (107,111), (108,110), # face mouth
68
+ (112,116), (113,115), (117,119), # face lip
69
+ (120,136), (121,135), (122,134), (123,133), (124,132), (125,131), (126,130), (127,129) # face contours
70
+ )
71
+ self.joint_idx = \
72
+ (0,1,2,4,5,7,8,12,16,17,18,19,20,21,60,61,62,63,64,65,59,58,57,56,55, # body joints
73
+ 37,38,39,66,25,26,27,67,28,29,30,68,34,35,36,69,31,32,33,70, # left hand joints
74
+ 52,53,54,71,40,41,42,72,43,44,45,73,49,50,51,74,46,47,48,75, # right hand joints
75
+ 22,15, # jaw, head
76
+ 57,56, # eyeballs
77
+ 76,77,78,79,80,81,82,83,84,85, # eyebrow
78
+ 86,87,88,89, # nose
79
+ 90,91,92,93,94, # below nose
80
+ 95,96,97,98,99,100,101,102,103,104,105,106, # eyes
81
+ 107, # right mouth
82
+ 108,109,110,111,112, # upper mouth
83
+ 113, # left mouth
84
+ 114,115,116,117,118, # lower mouth
85
+ 119, # right lip
86
+ 120,121,122, # upper lip
87
+ 123, # left lip
88
+ 124,125,126, # lower lip
89
+ 127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143 # face contour
90
+ )
91
+ self.joint_part = \
92
+ {'body': range(self.joints_name.index('Pelvis'), self.joints_name.index('Nose')+1),
93
+ 'lhand': range(self.joints_name.index('L_Thumb_1'), self.joints_name.index('L_Pinky_4')+1),
94
+ 'rhand': range(self.joints_name.index('R_Thumb_1'), self.joints_name.index('R_Pinky_4')+1),
95
+ 'hand': range(self.joints_name.index('L_Thumb_1'), self.joints_name.index('R_Pinky_4')+1),
96
+ 'face': range(self.joints_name.index('Face_1'), self.joints_name.index('Face_72')+1)}
97
+
98
+ # changed SMPLX joint set for PositionNet prediction
99
+ self.pos_joint_num = 65 # 25 (body joints) + 40 (hand joints)
100
+ self.pos_joints_name = \
101
+ ('Pelvis', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Neck', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Big_toe', 'L_Small_toe', 'L_Heel', 'R_Big_toe', 'R_Small_toe', 'R_Heel', 'L_Ear', 'R_Ear', 'L_Eye', 'R_Eye', 'Nose', # body joints
102
+ 'L_Thumb_1', 'L_Thumb_2', 'L_Thumb_3', 'L_Thumb_4', 'L_Index_1', 'L_Index_2', 'L_Index_3', 'L_Index_4', 'L_Middle_1', 'L_Middle_2', 'L_Middle_3', 'L_Middle_4', 'L_Ring_1', 'L_Ring_2', 'L_Ring_3', 'L_Ring_4', 'L_Pinky_1', 'L_Pinky_2', 'L_Pinky_3', 'L_Pinky_4', # left hand joints
103
+ 'R_Thumb_1', 'R_Thumb_2', 'R_Thumb_3', 'R_Thumb_4', 'R_Index_1', 'R_Index_2', 'R_Index_3', 'R_Index_4', 'R_Middle_1', 'R_Middle_2', 'R_Middle_3', 'R_Middle_4', 'R_Ring_1', 'R_Ring_2', 'R_Ring_3', 'R_Ring_4', 'R_Pinky_1', 'R_Pinky_2', 'R_Pinky_3', 'R_Pinky_4', # right hand joints
104
+ )
105
+ self.pos_joint_part = \
106
+ {'body': range(self.pos_joints_name.index('Pelvis'), self.pos_joints_name.index('Nose')+1),
107
+ 'lhand': range(self.pos_joints_name.index('L_Thumb_1'), self.pos_joints_name.index('L_Pinky_4')+1),
108
+ 'rhand': range(self.pos_joints_name.index('R_Thumb_1'), self.pos_joints_name.index('R_Pinky_4')+1),
109
+ 'hand': range(self.pos_joints_name.index('L_Thumb_1'), self.pos_joints_name.index('R_Pinky_4')+1)}
110
+ self.pos_joint_part['L_MCP'] = [self.pos_joints_name.index('L_Index_1') - len(self.pos_joint_part['body']),
111
+ self.pos_joints_name.index('L_Middle_1') - len(self.pos_joint_part['body']),
112
+ self.pos_joints_name.index('L_Ring_1') - len(self.pos_joint_part['body']),
113
+ self.pos_joints_name.index('L_Pinky_1') - len(self.pos_joint_part['body'])]
114
+ self.pos_joint_part['R_MCP'] = [self.pos_joints_name.index('R_Index_1') - len(self.pos_joint_part['body']) - len(self.pos_joint_part['lhand']),
115
+ self.pos_joints_name.index('R_Middle_1') - len(self.pos_joint_part['body']) - len(self.pos_joint_part['lhand']),
116
+ self.pos_joints_name.index('R_Ring_1') - len(self.pos_joint_part['body']) - len(self.pos_joint_part['lhand']),
117
+ self.pos_joints_name.index('R_Pinky_1') - len(self.pos_joint_part['body']) - len(self.pos_joint_part['lhand'])]
118
+
119
+ def make_hand_regressor(self):
120
+ regressor = self.layer['neutral'].J_regressor.numpy()
121
+ lhand_regressor = np.concatenate((regressor[[20,37,38,39],:],
122
+ np.eye(self.vertex_num)[5361,None],
123
+ regressor[[25,26,27],:],
124
+ np.eye(self.vertex_num)[4933,None],
125
+ regressor[[28,29,30],:],
126
+ np.eye(self.vertex_num)[5058,None],
127
+ regressor[[34,35,36],:],
128
+ np.eye(self.vertex_num)[5169,None],
129
+ regressor[[31,32,33],:],
130
+ np.eye(self.vertex_num)[5286,None]))
131
+ rhand_regressor = np.concatenate((regressor[[21,52,53,54],:],
132
+ np.eye(self.vertex_num)[8079,None],
133
+ regressor[[40,41,42],:],
134
+ np.eye(self.vertex_num)[7669,None],
135
+ regressor[[43,44,45],:],
136
+ np.eye(self.vertex_num)[7794,None],
137
+ regressor[[49,50,51],:],
138
+ np.eye(self.vertex_num)[7905,None],
139
+ regressor[[46,47,48],:],
140
+ np.eye(self.vertex_num)[8022,None]))
141
+ hand_regressor = {'left': lhand_regressor, 'right': rhand_regressor}
142
+ return hand_regressor
143
+
144
+
145
+ def reduce_joint_set(self, joint):
146
+ new_joint = []
147
+ for name in self.pos_joints_name:
148
+ idx = self.joints_name.index(name)
149
+ new_joint.append(joint[:,idx,:])
150
+ new_joint = torch.stack(new_joint,1)
151
+ return new_joint
152
+
153
+ class SMPL(object):
154
+ def __init__(self):
155
+ self.layer_arg = {'create_body_pose': False, 'create_betas': False, 'create_global_orient': False, 'create_transl': False}
156
+ self.layer = {'neutral': smplx.create(cfg.human_model_path, 'smpl', gender='NEUTRAL', **self.layer_arg), 'male': smplx.create(cfg.human_model_path, 'smpl', gender='MALE', **self.layer_arg), 'female': smplx.create(cfg.human_model_path, 'smpl', gender='FEMALE', **self.layer_arg)}
157
+ self.vertex_num = 6890
158
+ self.face = self.layer['neutral'].faces
159
+ self.shape_param_dim = 10
160
+ self.vposer_code_dim = 32
161
+
162
+ # original SMPL joint set
163
+ self.orig_joint_num = 24
164
+ self.orig_joints_name = ('Pelvis', 'L_Hip', 'R_Hip', 'Spine_1', 'L_Knee', 'R_Knee', 'Spine_2', 'L_Ankle', 'R_Ankle', 'Spine_3', 'L_Foot', 'R_Foot', 'Neck', 'L_Collar', 'R_Collar', 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand')
165
+ self.orig_flip_pairs = ( (1,2), (4,5), (7,8), (10,11), (13,14), (16,17), (18,19), (20,21), (22,23) )
166
+ self.orig_root_joint_idx = self.orig_joints_name.index('Pelvis')
167
+ self.orig_joint_regressor = self.layer['neutral'].J_regressor.numpy().astype(np.float32)
168
+
169
+ self.joint_num = self.orig_joint_num
170
+ self.joints_name = self.orig_joints_name
171
+ self.flip_pairs = self.orig_flip_pairs
172
+ self.root_joint_idx = self.orig_root_joint_idx
173
+ self.joint_regressor = self.orig_joint_regressor
174
+
175
+ smpl_x = SMPLX()
176
+ smpl = SMPL()
common/utils/inference_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union
2
+
3
+ def process_mmdet_results(mmdet_results: list,
4
+ cat_id: int = 0,
5
+ multi_person: bool = True) -> list:
6
+ """Process mmdet results, sort bboxes by area in descending order.
7
+
8
+ Args:
9
+ mmdet_results (list):
10
+ Result of mmdet.apis.inference_detector
11
+ when the input is a batch.
12
+ Shape of the nested lists is
13
+ (n_frame, n_category, n_human, 5).
14
+ cat_id (int, optional):
15
+ Category ID. This function will only select
16
+ the selected category, and drop the others.
17
+ Defaults to 0, ID of human category.
18
+ multi_person (bool, optional):
19
+ Whether to allow multi-person detection, which is
20
+ slower than single-person. If false, the function
21
+ only assure that the first person of each frame
22
+ has the biggest bbox.
23
+ Defaults to True.
24
+
25
+ Returns:
26
+ list:
27
+ A list of detected bounding boxes.
28
+ Shape of the nested lists is
29
+ (n_frame, n_human, 5)
30
+ and each bbox is (x, y, x, y, score).
31
+ """
32
+ ret_list = []
33
+ only_max_arg = not multi_person
34
+ # for _, frame_results in enumerate(mmdet_results):
35
+ cat_bboxes = mmdet_results[cat_id]
36
+ # import pdb; pdb.set_trace()
37
+ sorted_bbox = qsort_bbox_list(cat_bboxes, only_max_arg)
38
+
39
+ if only_max_arg:
40
+ ret_list.append(sorted_bbox[0:1])
41
+ else:
42
+ ret_list.append(sorted_bbox)
43
+ return ret_list
44
+
45
+
46
+ def qsort_bbox_list(bbox_list: list,
47
+ only_max: bool = False,
48
+ bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy'):
49
+ """Sort a list of bboxes, by their area in pixel(W*H).
50
+
51
+ Args:
52
+ input_list (list):
53
+ A list of bboxes. Each item is a list of (x1, y1, x2, y2)
54
+ only_max (bool, optional):
55
+ If True, only assure the max element at first place,
56
+ others may not be well sorted.
57
+ If False, return a well sorted descending list.
58
+ Defaults to False.
59
+ bbox_convention (str, optional):
60
+ Bbox type, xyxy or xywh. Defaults to 'xyxy'.
61
+
62
+ Returns:
63
+ list:
64
+ A sorted(maybe not so well) descending list.
65
+ """
66
+ # import pdb; pdb.set_trace()
67
+ if len(bbox_list) <= 1:
68
+ return bbox_list
69
+ else:
70
+ bigger_list = []
71
+ less_list = []
72
+ anchor_index = int(len(bbox_list) / 2)
73
+ anchor_bbox = bbox_list[anchor_index]
74
+ anchor_area = get_area_of_bbox(anchor_bbox, bbox_convention)
75
+ for i in range(len(bbox_list)):
76
+ if i == anchor_index:
77
+ continue
78
+ tmp_bbox = bbox_list[i]
79
+ tmp_area = get_area_of_bbox(tmp_bbox, bbox_convention)
80
+ if tmp_area >= anchor_area:
81
+ bigger_list.append(tmp_bbox)
82
+ else:
83
+ less_list.append(tmp_bbox)
84
+ if only_max:
85
+ return qsort_bbox_list(bigger_list) + \
86
+ [anchor_bbox, ] + less_list
87
+ else:
88
+ return qsort_bbox_list(bigger_list) + \
89
+ [anchor_bbox, ] + qsort_bbox_list(less_list)
90
+
91
+ def get_area_of_bbox(
92
+ bbox: Union[list, tuple],
93
+ bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy') -> float:
94
+ """Get the area of a bbox_xyxy.
95
+
96
+ Args:
97
+ (Union[list, tuple]):
98
+ A list of [x1, y1, x2, y2].
99
+ bbox_convention (str, optional):
100
+ Bbox type, xyxy or xywh. Defaults to 'xyxy'.
101
+
102
+ Returns:
103
+ float:
104
+ Area of the bbox(|y2-y1|*|x2-x1|).
105
+ """
106
+ # import pdb;pdb.set_trace()
107
+ if bbox_convention == 'xyxy':
108
+ return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
109
+ elif bbox_convention == 'xywh':
110
+ return abs(bbox[2] * bbox[3])
111
+ else:
112
+ raise TypeError(f'Wrong bbox convention: {bbox_convention}')
113
+
114
+ def calculate_iou(bbox1, bbox2):
115
+ # Calculate the Intersection over Union (IoU) between two bounding boxes
116
+ x1 = max(bbox1[0], bbox2[0])
117
+ y1 = max(bbox1[1], bbox2[1])
118
+ x2 = min(bbox1[2], bbox2[2])
119
+ y2 = min(bbox1[3], bbox2[3])
120
+
121
+ intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
122
+
123
+ bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
124
+ bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
125
+
126
+ union_area = bbox1_area + bbox2_area - intersection_area
127
+
128
+ iou = intersection_area / union_area
129
+ return iou
130
+
131
+
132
+ def non_max_suppression(bboxes, iou_threshold):
133
+ # Sort the bounding boxes by their confidence scores (e.g., the probability of containing an object)
134
+ bboxes = sorted(bboxes, key=lambda x: x[4], reverse=True)
135
+
136
+ # Initialize a list to store the selected bounding boxes
137
+ selected_bboxes = []
138
+
139
+ # Perform non-maximum suppression
140
+ while len(bboxes) > 0:
141
+ current_bbox = bboxes[0]
142
+ selected_bboxes.append(current_bbox)
143
+ bboxes = bboxes[1:]
144
+
145
+ remaining_bboxes = []
146
+ for bbox in bboxes:
147
+ iou = calculate_iou(current_bbox, bbox)
148
+ if iou < iou_threshold:
149
+ remaining_bboxes.append(bbox)
150
+
151
+ bboxes = remaining_bboxes
152
+
153
+ return selected_bboxes
common/utils/preprocessing.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import random
4
+ from config import cfg
5
+ import math
6
+ from utils.human_models import smpl_x, smpl
7
+ from utils.transforms import cam2pixel, transform_joint_to_other_db
8
+ from plyfile import PlyData, PlyElement
9
+ import torch
10
+
11
+
12
+ def load_img(path, order='RGB'):
13
+ img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
14
+ if not isinstance(img, np.ndarray):
15
+ raise IOError("Fail to read %s" % path)
16
+
17
+ if order == 'RGB':
18
+ img = img[:, :, ::-1].copy()
19
+
20
+ img = img.astype(np.float32)
21
+ return img
22
+
23
+
24
+ def get_bbox(joint_img, joint_valid, extend_ratio=1.2):
25
+ x_img, y_img = joint_img[:, 0], joint_img[:, 1]
26
+ x_img = x_img[joint_valid == 1];
27
+ y_img = y_img[joint_valid == 1];
28
+ xmin = min(x_img);
29
+ ymin = min(y_img);
30
+ xmax = max(x_img);
31
+ ymax = max(y_img);
32
+
33
+ x_center = (xmin + xmax) / 2.;
34
+ width = xmax - xmin;
35
+ xmin = x_center - 0.5 * width * extend_ratio
36
+ xmax = x_center + 0.5 * width * extend_ratio
37
+
38
+ y_center = (ymin + ymax) / 2.;
39
+ height = ymax - ymin;
40
+ ymin = y_center - 0.5 * height * extend_ratio
41
+ ymax = y_center + 0.5 * height * extend_ratio
42
+
43
+ bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32)
44
+ return bbox
45
+
46
+
47
+ def sanitize_bbox(bbox, img_width, img_height):
48
+ x, y, w, h = bbox
49
+ x1 = np.max((0, x))
50
+ y1 = np.max((0, y))
51
+ x2 = np.min((img_width - 1, x1 + np.max((0, w - 1))))
52
+ y2 = np.min((img_height - 1, y1 + np.max((0, h - 1))))
53
+ if w * h > 0 and x2 > x1 and y2 > y1:
54
+ bbox = np.array([x1, y1, x2 - x1, y2 - y1])
55
+ else:
56
+ bbox = None
57
+
58
+ return bbox
59
+
60
+
61
+ def process_bbox(bbox, img_width, img_height, ratio=1.25):
62
+ bbox = sanitize_bbox(bbox, img_width, img_height)
63
+ if bbox is None:
64
+ return bbox
65
+
66
+ # aspect ratio preserving bbox
67
+ w = bbox[2]
68
+ h = bbox[3]
69
+ c_x = bbox[0] + w / 2.
70
+ c_y = bbox[1] + h / 2.
71
+ aspect_ratio = cfg.input_img_shape[1] / cfg.input_img_shape[0]
72
+ if w > aspect_ratio * h:
73
+ h = w / aspect_ratio
74
+ elif w < aspect_ratio * h:
75
+ w = h * aspect_ratio
76
+ bbox[2] = w * ratio
77
+ bbox[3] = h * ratio
78
+ bbox[0] = c_x - bbox[2] / 2.
79
+ bbox[1] = c_y - bbox[3] / 2.
80
+
81
+ bbox = bbox.astype(np.float32)
82
+ return bbox
83
+
84
+
85
+ def get_aug_config():
86
+ scale_factor = 0.25
87
+ rot_factor = 30
88
+ color_factor = 0.2
89
+
90
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0
91
+ rot = np.clip(np.random.randn(), -2.0,
92
+ 2.0) * rot_factor if random.random() <= 0.6 else 0
93
+ c_up = 1.0 + color_factor
94
+ c_low = 1.0 - color_factor
95
+ color_scale = np.array([random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)])
96
+ do_flip = random.random() <= 0.5
97
+
98
+ return scale, rot, color_scale, do_flip
99
+
100
+
101
+ def augmentation(img, bbox, data_split):
102
+ if getattr(cfg, 'no_aug', False):
103
+ scale, rot, color_scale, do_flip = 1.0, 0.0, np.array([1, 1, 1]), False
104
+ elif data_split == 'train':
105
+ scale, rot, color_scale, do_flip = get_aug_config()
106
+ else:
107
+ scale, rot, color_scale, do_flip = 1.0, 0.0, np.array([1, 1, 1]), False
108
+
109
+ img, trans, inv_trans = generate_patch_image(img, bbox, scale, rot, do_flip, cfg.input_img_shape)
110
+ img = np.clip(img * color_scale[None, None, :], 0, 255)
111
+ return img, trans, inv_trans, rot, do_flip
112
+
113
+
114
+ def generate_patch_image(cvimg, bbox, scale, rot, do_flip, out_shape):
115
+ img = cvimg.copy()
116
+ img_height, img_width, img_channels = img.shape
117
+
118
+ bb_c_x = float(bbox[0] + 0.5 * bbox[2])
119
+ bb_c_y = float(bbox[1] + 0.5 * bbox[3])
120
+ bb_width = float(bbox[2])
121
+ bb_height = float(bbox[3])
122
+
123
+ if do_flip:
124
+ img = img[:, ::-1, :]
125
+ bb_c_x = img_width - bb_c_x - 1
126
+
127
+ trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot)
128
+ img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR)
129
+ img_patch = img_patch.astype(np.float32)
130
+ inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot,
131
+ inv=True)
132
+
133
+ return img_patch, trans, inv_trans
134
+
135
+
136
+ def rotate_2d(pt_2d, rot_rad):
137
+ x = pt_2d[0]
138
+ y = pt_2d[1]
139
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
140
+ xx = x * cs - y * sn
141
+ yy = x * sn + y * cs
142
+ return np.array([xx, yy], dtype=np.float32)
143
+
144
+
145
+ def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
146
+ # augment size with scale
147
+ src_w = src_width * scale
148
+ src_h = src_height * scale
149
+ src_center = np.array([c_x, c_y], dtype=np.float32)
150
+
151
+ # augment rotation
152
+ rot_rad = np.pi * rot / 180
153
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
154
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
155
+
156
+ dst_w = dst_width
157
+ dst_h = dst_height
158
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
159
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
160
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
161
+
162
+ src = np.zeros((3, 2), dtype=np.float32)
163
+ src[0, :] = src_center
164
+ src[1, :] = src_center + src_downdir
165
+ src[2, :] = src_center + src_rightdir
166
+
167
+ dst = np.zeros((3, 2), dtype=np.float32)
168
+ dst[0, :] = dst_center
169
+ dst[1, :] = dst_center + dst_downdir
170
+ dst[2, :] = dst_center + dst_rightdir
171
+
172
+ if inv:
173
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
174
+ else:
175
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
176
+
177
+ trans = trans.astype(np.float32)
178
+ return trans
179
+
180
+
181
+ def process_db_coord(joint_img, joint_cam, joint_valid, do_flip, img_shape, flip_pairs, img2bb_trans, rot,
182
+ src_joints_name, target_joints_name):
183
+ joint_img_original = joint_img.copy()
184
+ joint_img, joint_cam, joint_valid = joint_img.copy(), joint_cam.copy(), joint_valid.copy()
185
+
186
+ # flip augmentation
187
+ if do_flip:
188
+ joint_cam[:, 0] = -joint_cam[:, 0]
189
+ joint_img[:, 0] = img_shape[1] - 1 - joint_img[:, 0]
190
+ for pair in flip_pairs:
191
+ joint_img[pair[0], :], joint_img[pair[1], :] = joint_img[pair[1], :].copy(), joint_img[pair[0], :].copy()
192
+ joint_cam[pair[0], :], joint_cam[pair[1], :] = joint_cam[pair[1], :].copy(), joint_cam[pair[0], :].copy()
193
+ joint_valid[pair[0], :], joint_valid[pair[1], :] = joint_valid[pair[1], :].copy(), joint_valid[pair[0],
194
+ :].copy()
195
+
196
+ # 3D data rotation augmentation
197
+ rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
198
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
199
+ [0, 0, 1]], dtype=np.float32)
200
+ joint_cam = np.dot(rot_aug_mat, joint_cam.transpose(1, 0)).transpose(1, 0)
201
+
202
+ # affine transformation
203
+ joint_img_xy1 = np.concatenate((joint_img[:, :2], np.ones_like(joint_img[:, :1])), 1)
204
+ joint_img[:, :2] = np.dot(img2bb_trans, joint_img_xy1.transpose(1, 0)).transpose(1, 0)
205
+ joint_img[:, 0] = joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2]
206
+ joint_img[:, 1] = joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1]
207
+
208
+ # check truncation
209
+ joint_trunc = joint_valid * ((joint_img_original[:, 0] > 0) * (joint_img[:, 0] >= 0) * (joint_img[:, 0] < cfg.output_hm_shape[2]) * \
210
+ (joint_img_original[:, 1] > 0) *(joint_img[:, 1] >= 0) * (joint_img[:, 1] < cfg.output_hm_shape[1]) * \
211
+ (joint_img_original[:, 2] > 0) *(joint_img[:, 2] >= 0) * (joint_img[:, 2] < cfg.output_hm_shape[0])).reshape(-1,
212
+ 1).astype(
213
+ np.float32)
214
+
215
+ # transform joints to target db joints
216
+ joint_img = transform_joint_to_other_db(joint_img, src_joints_name, target_joints_name)
217
+ joint_cam_wo_ra = transform_joint_to_other_db(joint_cam, src_joints_name, target_joints_name)
218
+ joint_valid = transform_joint_to_other_db(joint_valid, src_joints_name, target_joints_name)
219
+ joint_trunc = transform_joint_to_other_db(joint_trunc, src_joints_name, target_joints_name)
220
+
221
+ # root-alignment, for joint_cam input wo ra
222
+ joint_cam_ra = joint_cam_wo_ra.copy()
223
+ joint_cam_ra = joint_cam_ra - joint_cam_ra[smpl_x.root_joint_idx, None, :] # root-relative
224
+ joint_cam_ra[smpl_x.joint_part['lhand'], :] = joint_cam_ra[smpl_x.joint_part['lhand'], :] - joint_cam_ra[
225
+ smpl_x.lwrist_idx, None,
226
+ :] # left hand root-relative
227
+ joint_cam_ra[smpl_x.joint_part['rhand'], :] = joint_cam_ra[smpl_x.joint_part['rhand'], :] - joint_cam_ra[
228
+ smpl_x.rwrist_idx, None,
229
+ :] # right hand root-relative
230
+ joint_cam_ra[smpl_x.joint_part['face'], :] = joint_cam_ra[smpl_x.joint_part['face'], :] - joint_cam_ra[smpl_x.neck_idx,
231
+ None,
232
+ :] # face root-relative
233
+
234
+ return joint_img, joint_cam_wo_ra, joint_cam_ra, joint_valid, joint_trunc
235
+
236
+
237
+ def process_human_model_output(human_model_param, cam_param, do_flip, img_shape, img2bb_trans, rot, human_model_type, joint_img=None):
238
+ if human_model_type == 'smplx':
239
+ human_model = smpl_x
240
+ rotation_valid = np.ones((smpl_x.orig_joint_num), dtype=np.float32)
241
+ coord_valid = np.ones((smpl_x.joint_num), dtype=np.float32)
242
+
243
+ root_pose, body_pose, shape, trans = human_model_param['root_pose'], human_model_param['body_pose'], \
244
+ human_model_param['shape'], human_model_param['trans']
245
+ if 'lhand_pose' in human_model_param and human_model_param['lhand_valid']:
246
+ lhand_pose = human_model_param['lhand_pose']
247
+ else:
248
+ lhand_pose = np.zeros((3 * len(smpl_x.orig_joint_part['lhand'])), dtype=np.float32)
249
+ rotation_valid[smpl_x.orig_joint_part['lhand']] = 0
250
+ coord_valid[smpl_x.joint_part['lhand']] = 0
251
+ if 'rhand_pose' in human_model_param and human_model_param['rhand_valid']:
252
+ rhand_pose = human_model_param['rhand_pose']
253
+ else:
254
+ rhand_pose = np.zeros((3 * len(smpl_x.orig_joint_part['rhand'])), dtype=np.float32)
255
+ rotation_valid[smpl_x.orig_joint_part['rhand']] = 0
256
+ coord_valid[smpl_x.joint_part['rhand']] = 0
257
+ if 'jaw_pose' in human_model_param and 'expr' in human_model_param and human_model_param['face_valid']:
258
+ jaw_pose = human_model_param['jaw_pose']
259
+ expr = human_model_param['expr']
260
+ expr_valid = True
261
+ else:
262
+ jaw_pose = np.zeros((3), dtype=np.float32)
263
+ expr = np.zeros((smpl_x.expr_code_dim), dtype=np.float32)
264
+ rotation_valid[smpl_x.orig_joint_part['face']] = 0
265
+ coord_valid[smpl_x.joint_part['face']] = 0
266
+ expr_valid = False
267
+ if 'gender' in human_model_param:
268
+ gender = human_model_param['gender']
269
+ else:
270
+ gender = 'neutral'
271
+ root_pose = torch.FloatTensor(root_pose).view(1, 3) # (1,3)
272
+ body_pose = torch.FloatTensor(body_pose).view(-1, 3) # (21,3)
273
+ lhand_pose = torch.FloatTensor(lhand_pose).view(-1, 3) # (15,3)
274
+ rhand_pose = torch.FloatTensor(rhand_pose).view(-1, 3) # (15,3)
275
+ jaw_pose = torch.FloatTensor(jaw_pose).view(-1, 3) # (1,3)
276
+ shape = torch.FloatTensor(shape).view(1, -1) # SMPLX shape parameter
277
+ expr = torch.FloatTensor(expr).view(1, -1) # SMPLX expression parameter
278
+ trans = torch.FloatTensor(trans).view(1, -1) # translation vector
279
+
280
+ # apply camera extrinsic (rotation)
281
+ # merge root pose and camera rotation
282
+ if 'R' in cam_param:
283
+ R = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3)
284
+ root_pose = root_pose.numpy()
285
+ root_pose, _ = cv2.Rodrigues(root_pose)
286
+ root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose))
287
+ root_pose = torch.from_numpy(root_pose).view(1, 3)
288
+
289
+ # get mesh and joint coordinates
290
+ zero_pose = torch.zeros((1, 3)).float() # eye poses
291
+ with torch.no_grad():
292
+ output = smpl_x.layer[gender](betas=shape, body_pose=body_pose.view(1, -1), global_orient=root_pose,
293
+ transl=trans, left_hand_pose=lhand_pose.view(1, -1),
294
+ right_hand_pose=rhand_pose.view(1, -1), jaw_pose=jaw_pose.view(1, -1),
295
+ leye_pose=zero_pose, reye_pose=zero_pose, expression=expr)
296
+ mesh_cam = output.vertices[0].numpy()
297
+ joint_cam = output.joints[0].numpy()[smpl_x.joint_idx, :]
298
+
299
+ # apply camera exrinsic (translation)
300
+ # compenstate rotation (translation from origin to root joint was not cancled)
301
+ if 'R' in cam_param and 't' in cam_param:
302
+ R, t = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3), np.array(cam_param['t'],
303
+ dtype=np.float32).reshape(1, 3)
304
+ root_cam = joint_cam[smpl_x.root_joint_idx, None, :]
305
+ joint_cam = joint_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
306
+ mesh_cam = mesh_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
307
+
308
+ # concat root, body, two hands, and jaw pose
309
+ pose = torch.cat((root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose))
310
+
311
+ # joint coordinates
312
+ if 'focal' not in cam_param or 'princpt' not in cam_param:
313
+ assert joint_img is not None
314
+ else:
315
+ joint_img = cam2pixel(joint_cam, cam_param['focal'], cam_param['princpt'])
316
+
317
+ joint_img_original = joint_img.copy()
318
+
319
+ joint_cam = joint_cam - joint_cam[smpl_x.root_joint_idx, None, :] # root-relative
320
+ joint_cam[smpl_x.joint_part['lhand'], :] = joint_cam[smpl_x.joint_part['lhand'], :] - joint_cam[
321
+ smpl_x.lwrist_idx, None,
322
+ :] # left hand root-relative
323
+ joint_cam[smpl_x.joint_part['rhand'], :] = joint_cam[smpl_x.joint_part['rhand'], :] - joint_cam[
324
+ smpl_x.rwrist_idx, None,
325
+ :] # right hand root-relative
326
+ joint_cam[smpl_x.joint_part['face'], :] = joint_cam[smpl_x.joint_part['face'], :] - joint_cam[smpl_x.neck_idx,
327
+ None,
328
+ :] # face root-relative
329
+ joint_img[smpl_x.joint_part['body'], 2] = (joint_cam[smpl_x.joint_part['body'], 2].copy() / (
330
+ cfg.body_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] # body depth discretize
331
+ joint_img[smpl_x.joint_part['lhand'], 2] = (joint_cam[smpl_x.joint_part['lhand'], 2].copy() / (
332
+ cfg.hand_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] # left hand depth discretize
333
+ joint_img[smpl_x.joint_part['rhand'], 2] = (joint_cam[smpl_x.joint_part['rhand'], 2].copy() / (
334
+ cfg.hand_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] # right hand depth discretize
335
+ joint_img[smpl_x.joint_part['face'], 2] = (joint_cam[smpl_x.joint_part['face'], 2].copy() / (
336
+ cfg.face_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] # face depth discretize
337
+
338
+ elif human_model_type == 'smpl':
339
+ human_model = smpl
340
+ pose, shape, trans = human_model_param['pose'], human_model_param['shape'], human_model_param['trans']
341
+ if 'gender' in human_model_param:
342
+ gender = human_model_param['gender']
343
+ else:
344
+ gender = 'neutral'
345
+ pose = torch.FloatTensor(pose).view(-1, 3)
346
+ shape = torch.FloatTensor(shape).view(1, -1);
347
+ trans = torch.FloatTensor(trans).view(1, -1) # translation vector
348
+
349
+ # apply camera extrinsic (rotation)
350
+ # merge root pose and camera rotation
351
+ if 'R' in cam_param:
352
+ R = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3)
353
+ root_pose = pose[smpl.orig_root_joint_idx, :].numpy()
354
+ root_pose, _ = cv2.Rodrigues(root_pose)
355
+ root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose))
356
+ pose[smpl.orig_root_joint_idx] = torch.from_numpy(root_pose).view(3)
357
+
358
+ # get mesh and joint coordinates
359
+ root_pose = pose[smpl.orig_root_joint_idx].view(1, 3)
360
+ body_pose = torch.cat((pose[:smpl.orig_root_joint_idx, :], pose[smpl.orig_root_joint_idx + 1:, :])).view(1, -1)
361
+ with torch.no_grad():
362
+ output = smpl.layer[gender](betas=shape, body_pose=body_pose, global_orient=root_pose, transl=trans)
363
+ mesh_cam = output.vertices[0].numpy()
364
+ joint_cam = np.dot(smpl.joint_regressor, mesh_cam)
365
+
366
+ # apply camera exrinsic (translation)
367
+ # compenstate rotation (translation from origin to root joint was not cancled)
368
+ if 'R' in cam_param and 't' in cam_param:
369
+ R, t = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3), np.array(cam_param['t'],
370
+ dtype=np.float32).reshape(1, 3)
371
+ root_cam = joint_cam[smpl.root_joint_idx, None, :]
372
+ joint_cam = joint_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
373
+ mesh_cam = mesh_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
374
+
375
+ # joint coordinates
376
+ if 'focal' not in cam_param or 'princpt' not in cam_param:
377
+ assert joint_img is not None
378
+ else:
379
+ joint_img = cam2pixel(joint_cam, cam_param['focal'], cam_param['princpt'])
380
+
381
+ joint_img_original = joint_img.copy()
382
+ joint_cam = joint_cam - joint_cam[smpl.root_joint_idx, None, :] # body root-relative
383
+ joint_img[:, 2] = (joint_cam[:, 2].copy() / (cfg.body_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[
384
+ 0] # body depth discretize
385
+
386
+ elif human_model_type == 'mano':
387
+ human_model = mano
388
+ pose, shape, trans = human_model_param['pose'], human_model_param['shape'], human_model_param['trans']
389
+ hand_type = human_model_param['hand_type']
390
+ pose = torch.FloatTensor(pose).view(-1, 3)
391
+ shape = torch.FloatTensor(shape).view(1, -1);
392
+ trans = torch.FloatTensor(trans).view(1, -1) # translation vector
393
+
394
+ # apply camera extrinsic (rotation)
395
+ # merge root pose and camera rotation
396
+ if 'R' in cam_param:
397
+ R = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3)
398
+ root_pose = pose[mano.orig_root_joint_idx, :].numpy()
399
+ root_pose, _ = cv2.Rodrigues(root_pose)
400
+ root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose))
401
+ pose[mano.orig_root_joint_idx] = torch.from_numpy(root_pose).view(3)
402
+
403
+ # get mesh and joint coordinates
404
+ root_pose = pose[mano.orig_root_joint_idx].view(1, 3)
405
+ hand_pose = torch.cat((pose[:mano.orig_root_joint_idx, :], pose[mano.orig_root_joint_idx + 1:, :])).view(1, -1)
406
+ with torch.no_grad():
407
+ output = mano.layer[hand_type](betas=shape, hand_pose=hand_pose, global_orient=root_pose, transl=trans)
408
+ mesh_cam = output.vertices[0].numpy()
409
+ joint_cam = np.dot(mano.joint_regressor, mesh_cam)
410
+
411
+ # apply camera exrinsic (translation)
412
+ # compenstate rotation (translation from origin to root joint was not cancled)
413
+ if 'R' in cam_param and 't' in cam_param:
414
+ R, t = np.array(cam_param['R'], dtype=np.float32).reshape(3, 3), np.array(cam_param['t'],
415
+ dtype=np.float32).reshape(1, 3)
416
+ root_cam = joint_cam[mano.root_joint_idx, None, :]
417
+ joint_cam = joint_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
418
+ mesh_cam = mesh_cam - root_cam + np.dot(R, root_cam.transpose(1, 0)).transpose(1, 0) + t
419
+
420
+ # joint coordinates
421
+ if 'focal' not in cam_param or 'princpt' not in cam_param:
422
+ assert joint_img is not None
423
+ else:
424
+ joint_img = cam2pixel(joint_cam, cam_param['focal'], cam_param['princpt'])
425
+ joint_cam = joint_cam - joint_cam[mano.root_joint_idx, None, :] # hand root-relative
426
+ joint_img[:, 2] = (joint_cam[:, 2].copy() / (cfg.hand_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[
427
+ 0] # hand depth discretize
428
+
429
+ mesh_cam_orig = mesh_cam.copy() # back-up the original one
430
+
431
+ ## so far, data augmentations are not applied yet
432
+ ## now, apply data augmentations
433
+
434
+ # image projection
435
+ if do_flip:
436
+ joint_cam[:, 0] = -joint_cam[:, 0]
437
+ joint_img[:, 0] = img_shape[1] - 1 - joint_img[:, 0]
438
+ for pair in human_model.flip_pairs:
439
+ joint_cam[pair[0], :], joint_cam[pair[1], :] = joint_cam[pair[1], :].copy(), joint_cam[pair[0], :].copy()
440
+ joint_img[pair[0], :], joint_img[pair[1], :] = joint_img[pair[1], :].copy(), joint_img[pair[0], :].copy()
441
+ if human_model_type == 'smplx':
442
+ coord_valid[pair[0]], coord_valid[pair[1]] = coord_valid[pair[1]].copy(), coord_valid[pair[0]].copy()
443
+
444
+ # x,y affine transform, root-relative depth
445
+ joint_img_xy1 = np.concatenate((joint_img[:, :2], np.ones_like(joint_img[:, 0:1])), 1)
446
+ joint_img[:, :2] = np.dot(img2bb_trans, joint_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2]
447
+ joint_img[:, 0] = joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2]
448
+ joint_img[:, 1] = joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1]
449
+
450
+ # check truncation
451
+ # TODO
452
+ joint_trunc = ((joint_img_original[:, 0] > 0) * (joint_img[:, 0] >= 0) * (joint_img[:, 0] < cfg.output_hm_shape[2]) * \
453
+ (joint_img_original[:, 1] > 0) * (joint_img[:, 1] >= 0) * (joint_img[:, 1] < cfg.output_hm_shape[1]) * \
454
+ (joint_img_original[:, 2] > 0) * (joint_img[:, 2] >= 0) * (joint_img[:, 2] < cfg.output_hm_shape[0])).reshape(-1, 1).astype(
455
+ np.float32)
456
+
457
+ # 3D data rotation augmentation
458
+ rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
459
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
460
+ [0, 0, 1]], dtype=np.float32)
461
+ # coordinate
462
+ joint_cam = np.dot(rot_aug_mat, joint_cam.transpose(1, 0)).transpose(1, 0)
463
+ # parameters
464
+ # flip pose parameter (axis-angle)
465
+ if do_flip:
466
+ for pair in human_model.orig_flip_pairs:
467
+ pose[pair[0], :], pose[pair[1], :] = pose[pair[1], :].clone(), pose[pair[0], :].clone()
468
+ if human_model_type == 'smplx':
469
+ rotation_valid[pair[0]], rotation_valid[pair[1]] = rotation_valid[pair[1]].copy(), rotation_valid[
470
+ pair[0]].copy()
471
+ pose[:, 1:3] *= -1 # multiply -1 to y and z axis of axis-angle
472
+
473
+ # rotate root pose
474
+ pose = pose.numpy()
475
+ root_pose = pose[human_model.orig_root_joint_idx, :]
476
+ root_pose, _ = cv2.Rodrigues(root_pose)
477
+ root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat, root_pose))
478
+ pose[human_model.orig_root_joint_idx] = root_pose.reshape(3)
479
+
480
+ # change to mean shape if beta is too far from it
481
+ shape[(shape.abs() > 3).any(dim=1)] = 0.
482
+ shape = shape.numpy().reshape(-1)
483
+
484
+ # return results
485
+ if human_model_type == 'smplx':
486
+ pose = pose.reshape(-1)
487
+ expr = expr.numpy().reshape(-1)
488
+
489
+ return joint_img, joint_cam, joint_trunc, pose, shape, expr, rotation_valid, coord_valid, expr_valid, mesh_cam_orig
490
+ elif human_model_type == 'smpl':
491
+ pose = pose.reshape(-1)
492
+ return joint_img, joint_cam, joint_trunc, pose, shape, mesh_cam_orig
493
+ elif human_model_type == 'mano':
494
+ pose = pose.reshape(-1)
495
+ return joint_img, joint_cam, joint_trunc, pose, shape, mesh_cam_orig
496
+
497
+
498
+ def get_fitting_error_3D(db_joint, db_joint_from_fit, joint_valid):
499
+ # mask coordinate
500
+ db_joint = db_joint[np.tile(joint_valid, (1, 3)) == 1].reshape(-1, 3)
501
+ db_joint_from_fit = db_joint_from_fit[np.tile(joint_valid, (1, 3)) == 1].reshape(-1, 3)
502
+
503
+ db_joint_from_fit = db_joint_from_fit - np.mean(db_joint_from_fit, 0)[None, :] + np.mean(db_joint, 0)[None,
504
+ :] # translation alignment
505
+ error = np.sqrt(np.sum((db_joint - db_joint_from_fit) ** 2, 1)).mean()
506
+ return error
507
+
508
+
509
+ def load_obj(file_name):
510
+ v = []
511
+ obj_file = open(file_name)
512
+ for line in obj_file:
513
+ words = line.split(' ')
514
+ if words[0] == 'v':
515
+ x, y, z = float(words[1]), float(words[2]), float(words[3])
516
+ v.append(np.array([x, y, z]))
517
+ return np.stack(v)
518
+
519
+
520
+ def load_ply(file_name):
521
+ plydata = PlyData.read(file_name)
522
+ x = plydata['vertex']['x']
523
+ y = plydata['vertex']['y']
524
+ z = plydata['vertex']['z']
525
+ v = np.stack((x, y, z), 1)
526
+ return v
527
+
528
+ def resize_bbox(bbox, scale=1.2):
529
+ if isinstance(bbox, list):
530
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
531
+ else:
532
+ x1, y1, x2, y2 = bbox
533
+ x_center = (x1+x2)/2.0
534
+ y_center = (y1+y2)/2.0
535
+ x_size, y_size = x2-x1, y2-y1
536
+ x1_resize = x_center-x_size/2.0*scale
537
+ x2_resize = x_center+x_size/2.0*scale
538
+ y1_resize = y_center - y_size / 2.0 * scale
539
+ y2_resize = y_center + y_size / 2.0 * scale
540
+ bbox[0], bbox[1], bbox[2], bbox[3] = x1_resize, y1_resize, x2_resize, y2_resize
541
+ return bbox
common/utils/smplx/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+
3
+ Software Copyright License for non-commercial scientific research purposes
4
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
5
+
6
+ Ownership / Licensees
7
+ The Software and the associated materials has been developed at the
8
+
9
+ Max Planck Institute for Intelligent Systems (hereinafter "MPI").
10
+
11
+ Any copyright or patent right is owned by and proprietary material of the
12
+
13
+ Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
14
+
15
+ hereinafter the “Licensor”.
16
+
17
+ License Grant
18
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
19
+
20
+ To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
21
+ To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
22
+ Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Model & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
23
+
24
+ The Model & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model & Software to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Model & Software, you agree not to reverse engineer it.
25
+
26
+ No Distribution
27
+ The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
28
+
29
+ Disclaimer of Representations and Warranties
30
+ You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party.
31
+
32
+ Limitation of Liability
33
+ Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
34
+ Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
35
+ Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders.
36
+ The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause.
37
+
38
+ No Maintenance Services
39
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time.
40
+
41
+ Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
42
+
43
+ Publications using the Model & Software
44
+ You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software.
45
+
46
+ Citation:
47
+
48
+
49
+ @inproceedings{SMPL-X:2019,
50
+ title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image},
51
+ author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.},
52
+ booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
53
+ year = {2019}
54
+ }
55
+ Commercial licensing opportunities
56
+ For commercial uses of the Software, please send email to ps-license@tue.mpg.de
57
+
58
+ This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
common/utils/smplx/README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## SMPL-X: A new joint 3D model of the human body, face and hands together
2
+
3
+ [[Paper Page](https://smpl-x.is.tue.mpg.de)] [[Paper](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/497/SMPL-X.pdf)]
4
+ [[Supp. Mat.](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/498/SMPL-X-supp.pdf)]
5
+
6
+ ![SMPL-X Examples](./images/teaser_fig.png)
7
+
8
+ ## Table of Contents
9
+ * [License](#license)
10
+ * [Description](#description)
11
+ * [Installation](#installation)
12
+ * [Downloading the model](#downloading-the-model)
13
+ * [Loading SMPL-X, SMPL+H and SMPL](#loading-smpl-x-smplh-and-smpl)
14
+ * [SMPL and SMPL+H setup](#smpl-and-smplh-setup)
15
+ * [Model loading](https://github.com/vchoutas/smplx#model-loading)
16
+ * [MANO and FLAME correspondences](#mano-and-flame-correspondences)
17
+ * [Example](#example)
18
+ * [Citation](#citation)
19
+ * [Acknowledgments](#acknowledgments)
20
+ * [Contact](#contact)
21
+
22
+ ## License
23
+
24
+ Software Copyright License for **non-commercial scientific research purposes**.
25
+ Please read carefully the [terms and conditions](https://github.com/vchoutas/smplx/blob/master/LICENSE) and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this [License](./LICENSE).
26
+
27
+ ## Disclaimer
28
+
29
+ The original images used for the figures 1 and 2 of the paper can be found in this link.
30
+ The images in the paper are used under license from gettyimages.com.
31
+ We have acquired the right to use them in the publication, but redistribution is not allowed.
32
+ Please follow the instructions on the given link to acquire right of usage.
33
+ Our results are obtained on the 483 × 724 pixels resolution of the original images.
34
+
35
+ ## Description
36
+
37
+ *SMPL-X* (SMPL eXpressive) is a unified body model with shape parameters trained jointly for the
38
+ face, hands and body. *SMPL-X* uses standard vertex based linear blend skinning with learned corrective blend
39
+ shapes, has N = 10, 475 vertices and K = 54 joints,
40
+ which include joints for the neck, jaw, eyeballs and fingers.
41
+ SMPL-X is defined by a function M(θ, β, ψ), where θ is the pose parameters, β the shape parameters and
42
+ ψ the facial expression parameters.
43
+
44
+
45
+ ## Installation
46
+
47
+ To install the model please follow the next steps in the specified order:
48
+ 1. To install from PyPi simply run:
49
+ ```Shell
50
+ pip install smplx[all]
51
+ ```
52
+ 2. Clone this repository and install it using the *setup.py* script:
53
+ ```Shell
54
+ git clone https://github.com/vchoutas/smplx
55
+ python setup.py install
56
+ ```
57
+
58
+ ## Downloading the model
59
+
60
+ To download the *SMPL-X* model go to [this project website](https://smpl-x.is.tue.mpg.de) and register to get access to the downloads section.
61
+
62
+ To download the *SMPL+H* model go to [this project website](http://mano.is.tue.mpg.de) and register to get access to the downloads section.
63
+
64
+ To download the *SMPL* model go to [this](http://smpl.is.tue.mpg.de) (male and female models) and [this](http://smplify.is.tue.mpg.de) (gender neutral model) project website and register to get access to the downloads section.
65
+
66
+ ## Loading SMPL-X, SMPL+H and SMPL
67
+
68
+ ### SMPL and SMPL+H setup
69
+
70
+ The loader gives the option to use any of the SMPL-X, SMPL+H, SMPL, and MANO models. Depending on the model you want to use, please follow the respective download instructions. To switch between MANO, SMPL, SMPL+H and SMPL-X just change the *model_path* or *model_type* parameters. For more details please check the docs of the model classes.
71
+ Before using SMPL and SMPL+H you should follow the instructions in [tools/README.md](./tools/README.md) to remove the
72
+ Chumpy objects from both model pkls, as well as merge the MANO parameters with SMPL+H.
73
+
74
+ ### Model loading
75
+
76
+ You can either use the [create](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L54)
77
+ function from [body_models](./smplx/body_models.py) or directly call the constructor for the
78
+ [SMPL](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L106),
79
+ [SMPL+H](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L395) and
80
+ [SMPL-X](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L628) model. The path to the model can either be the path to the file with the parameters or a directory with the following structure:
81
+ ```bash
82
+ models
83
+ ├── smpl
84
+ │   ├── SMPL_FEMALE.pkl
85
+ │   ��── SMPL_MALE.pkl
86
+ │   └── SMPL_NEUTRAL.pkl
87
+ ├── smplh
88
+ │   ├── SMPLH_FEMALE.pkl
89
+ │   └── SMPLH_MALE.pkl
90
+ ├── mano
91
+ | ├── MANO_RIGHT.pkl
92
+ | └── MANO_LEFT.pkl
93
+ └── smplx
94
+ ├── SMPLX_FEMALE.npz
95
+ ├── SMPLX_FEMALE.pkl
96
+ ├── SMPLX_MALE.npz
97
+ ├── SMPLX_MALE.pkl
98
+ ├── SMPLX_NEUTRAL.npz
99
+ └── SMPLX_NEUTRAL.pkl
100
+ ```
101
+
102
+
103
+ ## MANO and FLAME correspondences
104
+
105
+ The vertex correspondences between SMPL-X and MANO, FLAME can be downloaded
106
+ from [the project website](https://smpl-x.is.tue.mpg.de). If you have extracted
107
+ the correspondence data in the folder *correspondences*, then use the following
108
+ scripts to visualize them:
109
+
110
+ 1. To view MANO correspondences run the following command:
111
+
112
+ ```
113
+ python examples/vis_mano_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/MANO_SMPLX_vertex_ids.pkl
114
+ ```
115
+
116
+ 2. To view FLAME correspondences run the following command:
117
+
118
+ ```
119
+ python examples/vis_flame_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/SMPL-X__FLAME_vertex_ids.npy
120
+ ```
121
+
122
+ ## Example
123
+
124
+ After installing the *smplx* package and downloading the model parameters you should be able to run the *demo.py*
125
+ script to visualize the results. For this step you have to install the [pyrender](https://pyrender.readthedocs.io/en/latest/index.html) and [trimesh](https://trimsh.org/) packages.
126
+
127
+ `python examples/demo.py --model-folder $SMPLX_FOLDER --plot-joints=True --gender="neutral"`
128
+
129
+ ![SMPL-X Examples](./images/example.png)
130
+
131
+ ## Citation
132
+
133
+ Depending on which model is loaded for your project, i.e. SMPL-X or SMPL+H or SMPL, please cite the most relevant work below, listed in the same order:
134
+
135
+ ```
136
+ @inproceedings{SMPL-X:2019,
137
+ title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image},
138
+ author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.},
139
+ booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
140
+ year = {2019}
141
+ }
142
+ ```
143
+
144
+ ```
145
+ @article{MANO:SIGGRAPHASIA:2017,
146
+ title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together},
147
+ author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.},
148
+ journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)},
149
+ volume = {36},
150
+ number = {6},
151
+ series = {245:1--245:17},
152
+ month = nov,
153
+ year = {2017},
154
+ month_numeric = {11}
155
+ }
156
+ ```
157
+
158
+ ```
159
+ @article{SMPL:2015,
160
+ author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.},
161
+ title = {{SMPL}: A Skinned Multi-Person Linear Model},
162
+ journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)},
163
+ month = oct,
164
+ number = {6},
165
+ pages = {248:1--248:16},
166
+ publisher = {ACM},
167
+ volume = {34},
168
+ year = {2015}
169
+ }
170
+ ```
171
+
172
+ This repository was originally developed for SMPL-X / SMPLify-X (CVPR 2019), you might be interested in having a look: [https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de).
173
+
174
+ ## Acknowledgments
175
+
176
+ ### Facial Contour
177
+
178
+ Special thanks to [Soubhik Sanyal](https://github.com/soubhiksanyal) for sharing the Tensorflow code used for the facial
179
+ landmarks.
180
+
181
+ ## Contact
182
+ The code of this repository was implemented by [Vassilis Choutas](vassilis.choutas@tuebingen.mpg.de).
183
+
184
+ For questions, please contact [smplx@tue.mpg.de](smplx@tue.mpg.de).
185
+
186
+ For commercial licensing (and all related questions for business applications), please contact [ps-licensing@tue.mpg.de](ps-licensing@tue.mpg.de).
common/utils/smplx/examples/demo.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ import smplx
24
+
25
+
26
+ def main(model_folder,
27
+ model_type='smplx',
28
+ ext='npz',
29
+ gender='neutral',
30
+ plot_joints=False,
31
+ num_betas=10,
32
+ sample_shape=True,
33
+ sample_expression=True,
34
+ num_expression_coeffs=10,
35
+ plotting_module='pyrender',
36
+ use_face_contour=False):
37
+
38
+ model = smplx.create(model_folder, model_type=model_type,
39
+ gender=gender, use_face_contour=use_face_contour,
40
+ num_betas=num_betas,
41
+ num_expression_coeffs=num_expression_coeffs,
42
+ ext=ext)
43
+ print(model)
44
+
45
+ betas, expression = None, None
46
+ if sample_shape:
47
+ betas = torch.randn([1, model.num_betas], dtype=torch.float32)
48
+ if sample_expression:
49
+ expression = torch.randn(
50
+ [1, model.num_expression_coeffs], dtype=torch.float32)
51
+
52
+ output = model(betas=betas, expression=expression,
53
+ return_verts=True)
54
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
55
+ joints = output.joints.detach().cpu().numpy().squeeze()
56
+
57
+ print('Vertices shape =', vertices.shape)
58
+ print('Joints shape =', joints.shape)
59
+
60
+ if plotting_module == 'pyrender':
61
+ import pyrender
62
+ import trimesh
63
+ vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
64
+ tri_mesh = trimesh.Trimesh(vertices, model.faces,
65
+ vertex_colors=vertex_colors)
66
+
67
+ mesh = pyrender.Mesh.from_trimesh(tri_mesh)
68
+
69
+ scene = pyrender.Scene()
70
+ scene.add(mesh)
71
+
72
+ if plot_joints:
73
+ sm = trimesh.creation.uv_sphere(radius=0.005)
74
+ sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
75
+ tfs = np.tile(np.eye(4), (len(joints), 1, 1))
76
+ tfs[:, :3, 3] = joints
77
+ joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
78
+ scene.add(joints_pcl)
79
+
80
+ pyrender.Viewer(scene, use_raymond_lighting=True)
81
+ elif plotting_module == 'matplotlib':
82
+ from matplotlib import pyplot as plt
83
+ from mpl_toolkits.mplot3d import Axes3D
84
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
85
+
86
+ fig = plt.figure()
87
+ ax = fig.add_subplot(111, projection='3d')
88
+
89
+ mesh = Poly3DCollection(vertices[model.faces], alpha=0.1)
90
+ face_color = (1.0, 1.0, 0.9)
91
+ edge_color = (0, 0, 0)
92
+ mesh.set_edgecolor(edge_color)
93
+ mesh.set_facecolor(face_color)
94
+ ax.add_collection3d(mesh)
95
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
96
+
97
+ if plot_joints:
98
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)
99
+ plt.show()
100
+ elif plotting_module == 'open3d':
101
+ import open3d as o3d
102
+
103
+ mesh = o3d.geometry.TriangleMesh()
104
+ mesh.vertices = o3d.utility.Vector3dVector(
105
+ vertices)
106
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
107
+ mesh.compute_vertex_normals()
108
+ mesh.paint_uniform_color([0.3, 0.3, 0.3])
109
+
110
+ geometry = [mesh]
111
+ if plot_joints:
112
+ joints_pcl = o3d.geometry.PointCloud()
113
+ joints_pcl.points = o3d.utility.Vector3dVector(joints)
114
+ joints_pcl.paint_uniform_color([0.7, 0.3, 0.3])
115
+ geometry.append(joints_pcl)
116
+
117
+ o3d.visualization.draw_geometries(geometry)
118
+ else:
119
+ raise ValueError('Unknown plotting_module: {}'.format(plotting_module))
120
+
121
+
122
+ if __name__ == '__main__':
123
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
124
+
125
+ parser.add_argument('--model-folder', required=True, type=str,
126
+ help='The path to the model folder')
127
+ parser.add_argument('--model-type', default='smplx', type=str,
128
+ choices=['smpl', 'smplh', 'smplx', 'mano', 'flame'],
129
+ help='The type of model to load')
130
+ parser.add_argument('--gender', type=str, default='neutral',
131
+ help='The gender of the model')
132
+ parser.add_argument('--num-betas', default=10, type=int,
133
+ dest='num_betas',
134
+ help='Number of shape coefficients.')
135
+ parser.add_argument('--num-expression-coeffs', default=10, type=int,
136
+ dest='num_expression_coeffs',
137
+ help='Number of expression coefficients.')
138
+ parser.add_argument('--plotting-module', type=str, default='pyrender',
139
+ dest='plotting_module',
140
+ choices=['pyrender', 'matplotlib', 'open3d'],
141
+ help='The module to use for plotting the result')
142
+ parser.add_argument('--ext', type=str, default='npz',
143
+ help='Which extension to use for loading')
144
+ parser.add_argument('--plot-joints', default=False,
145
+ type=lambda arg: arg.lower() in ['true', '1'],
146
+ help='The path to the model folder')
147
+ parser.add_argument('--sample-shape', default=True,
148
+ dest='sample_shape',
149
+ type=lambda arg: arg.lower() in ['true', '1'],
150
+ help='Sample a random shape')
151
+ parser.add_argument('--sample-expression', default=True,
152
+ dest='sample_expression',
153
+ type=lambda arg: arg.lower() in ['true', '1'],
154
+ help='Sample a random expression')
155
+ parser.add_argument('--use-face-contour', default=False,
156
+ type=lambda arg: arg.lower() in ['true', '1'],
157
+ help='Compute the contour of the face')
158
+
159
+ args = parser.parse_args()
160
+
161
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
162
+ model_type = args.model_type
163
+ plot_joints = args.plot_joints
164
+ use_face_contour = args.use_face_contour
165
+ gender = args.gender
166
+ ext = args.ext
167
+ plotting_module = args.plotting_module
168
+ num_betas = args.num_betas
169
+ num_expression_coeffs = args.num_expression_coeffs
170
+ sample_shape = args.sample_shape
171
+ sample_expression = args.sample_expression
172
+
173
+ main(model_folder, model_type, ext=ext,
174
+ gender=gender, plot_joints=plot_joints,
175
+ num_betas=num_betas,
176
+ num_expression_coeffs=num_expression_coeffs,
177
+ sample_shape=sample_shape,
178
+ sample_expression=sample_expression,
179
+ plotting_module=plotting_module,
180
+ use_face_contour=use_face_contour)
common/utils/smplx/examples/demo_layers.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ import smplx
24
+
25
+
26
+ def main(model_folder,
27
+ model_type='smplx',
28
+ ext='npz',
29
+ gender='neutral',
30
+ plot_joints=False,
31
+ num_betas=10,
32
+ sample_shape=True,
33
+ sample_expression=True,
34
+ num_expression_coeffs=10,
35
+ plotting_module='pyrender',
36
+ use_face_contour=False):
37
+
38
+ model = smplx.build_layer(
39
+ model_folder, model_type=model_type,
40
+ gender=gender, use_face_contour=use_face_contour,
41
+ num_betas=num_betas,
42
+ num_expression_coeffs=num_expression_coeffs,
43
+ ext=ext)
44
+ print(model)
45
+
46
+ betas, expression = None, None
47
+ if sample_shape:
48
+ betas = torch.randn([1, model.num_betas], dtype=torch.float32)
49
+ if sample_expression:
50
+ expression = torch.randn(
51
+ [1, model.num_expression_coeffs], dtype=torch.float32)
52
+
53
+ output = model(betas=betas, expression=expression,
54
+ return_verts=True)
55
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
56
+ joints = output.joints.detach().cpu().numpy().squeeze()
57
+
58
+ print('Vertices shape =', vertices.shape)
59
+ print('Joints shape =', joints.shape)
60
+
61
+ if plotting_module == 'pyrender':
62
+ import pyrender
63
+ import trimesh
64
+ vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
65
+ tri_mesh = trimesh.Trimesh(vertices, model.faces,
66
+ vertex_colors=vertex_colors)
67
+
68
+ mesh = pyrender.Mesh.from_trimesh(tri_mesh)
69
+
70
+ scene = pyrender.Scene()
71
+ scene.add(mesh)
72
+
73
+ if plot_joints:
74
+ sm = trimesh.creation.uv_sphere(radius=0.005)
75
+ sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
76
+ tfs = np.tile(np.eye(4), (len(joints), 1, 1))
77
+ tfs[:, :3, 3] = joints
78
+ joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
79
+ scene.add(joints_pcl)
80
+
81
+ pyrender.Viewer(scene, use_raymond_lighting=True)
82
+ elif plotting_module == 'matplotlib':
83
+ from matplotlib import pyplot as plt
84
+ from mpl_toolkits.mplot3d import Axes3D
85
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
86
+
87
+ fig = plt.figure()
88
+ ax = fig.add_subplot(111, projection='3d')
89
+
90
+ mesh = Poly3DCollection(vertices[model.faces], alpha=0.1)
91
+ face_color = (1.0, 1.0, 0.9)
92
+ edge_color = (0, 0, 0)
93
+ mesh.set_edgecolor(edge_color)
94
+ mesh.set_facecolor(face_color)
95
+ ax.add_collection3d(mesh)
96
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
97
+
98
+ if plot_joints:
99
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)
100
+ plt.show()
101
+ elif plotting_module == 'open3d':
102
+ import open3d as o3d
103
+
104
+ mesh = o3d.geometry.TriangleMesh()
105
+ mesh.vertices = o3d.utility.Vector3dVector(
106
+ vertices)
107
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
108
+ mesh.compute_vertex_normals()
109
+ mesh.paint_uniform_color([0.3, 0.3, 0.3])
110
+
111
+ geometry = [mesh]
112
+ if plot_joints:
113
+ joints_pcl = o3d.geometry.PointCloud()
114
+ joints_pcl.points = o3d.utility.Vector3dVector(joints)
115
+ joints_pcl.paint_uniform_color([0.7, 0.3, 0.3])
116
+ geometry.append(joints_pcl)
117
+
118
+ o3d.visualization.draw_geometries(geometry)
119
+ else:
120
+ raise ValueError('Unknown plotting_module: {}'.format(plotting_module))
121
+
122
+
123
+ if __name__ == '__main__':
124
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
125
+
126
+ parser.add_argument('--model-folder', required=True, type=str,
127
+ help='The path to the model folder')
128
+ parser.add_argument('--model-type', default='smplx', type=str,
129
+ choices=['smpl', 'smplh', 'smplx', 'mano', 'flame'],
130
+ help='The type of model to load')
131
+ parser.add_argument('--gender', type=str, default='neutral',
132
+ help='The gender of the model')
133
+ parser.add_argument('--num-betas', default=10, type=int,
134
+ dest='num_betas',
135
+ help='Number of shape coefficients.')
136
+ parser.add_argument('--num-expression-coeffs', default=10, type=int,
137
+ dest='num_expression_coeffs',
138
+ help='Number of expression coefficients.')
139
+ parser.add_argument('--plotting-module', type=str, default='pyrender',
140
+ dest='plotting_module',
141
+ choices=['pyrender', 'matplotlib', 'open3d'],
142
+ help='The module to use for plotting the result')
143
+ parser.add_argument('--ext', type=str, default='npz',
144
+ help='Which extension to use for loading')
145
+ parser.add_argument('--plot-joints', default=False,
146
+ type=lambda arg: arg.lower() in ['true', '1'],
147
+ help='The path to the model folder')
148
+ parser.add_argument('--sample-shape', default=True,
149
+ dest='sample_shape',
150
+ type=lambda arg: arg.lower() in ['true', '1'],
151
+ help='Sample a random shape')
152
+ parser.add_argument('--sample-expression', default=True,
153
+ dest='sample_expression',
154
+ type=lambda arg: arg.lower() in ['true', '1'],
155
+ help='Sample a random expression')
156
+ parser.add_argument('--use-face-contour', default=False,
157
+ type=lambda arg: arg.lower() in ['true', '1'],
158
+ help='Compute the contour of the face')
159
+
160
+ args = parser.parse_args()
161
+
162
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
163
+ model_type = args.model_type
164
+ plot_joints = args.plot_joints
165
+ use_face_contour = args.use_face_contour
166
+ gender = args.gender
167
+ ext = args.ext
168
+ plotting_module = args.plotting_module
169
+ num_betas = args.num_betas
170
+ num_expression_coeffs = args.num_expression_coeffs
171
+ sample_shape = args.sample_shape
172
+ sample_expression = args.sample_expression
173
+
174
+ main(model_folder, model_type, ext=ext,
175
+ gender=gender, plot_joints=plot_joints,
176
+ num_betas=num_betas,
177
+ num_expression_coeffs=num_expression_coeffs,
178
+ sample_shape=sample_shape,
179
+ sample_expression=sample_expression,
180
+ plotting_module=plotting_module,
181
+ use_face_contour=use_face_contour)
common/utils/smplx/examples/vis_flame_vertices.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+ import pickle
20
+
21
+ import numpy as np
22
+ import torch
23
+ import open3d as o3d
24
+
25
+ import smplx
26
+
27
+
28
+ def main(model_folder, corr_fname, ext='npz',
29
+ head_color=(0.3, 0.3, 0.6),
30
+ gender='neutral'):
31
+
32
+ head_idxs = np.load(corr_fname)
33
+
34
+ model = smplx.create(model_folder, model_type='smplx',
35
+ gender=gender,
36
+ ext=ext)
37
+ betas = torch.zeros([1, 10], dtype=torch.float32)
38
+ expression = torch.zeros([1, 10], dtype=torch.float32)
39
+
40
+ output = model(betas=betas, expression=expression,
41
+ return_verts=True)
42
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
43
+ joints = output.joints.detach().cpu().numpy().squeeze()
44
+
45
+ print('Vertices shape =', vertices.shape)
46
+ print('Joints shape =', joints.shape)
47
+
48
+ mesh = o3d.geometry.TriangleMesh()
49
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
50
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
51
+ mesh.compute_vertex_normals()
52
+
53
+ colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
54
+ colors[head_idxs] = head_color
55
+
56
+ mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
57
+
58
+ o3d.visualization.draw_geometries([mesh])
59
+
60
+
61
+ if __name__ == '__main__':
62
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
63
+
64
+ parser.add_argument('--model-folder', required=True, type=str,
65
+ help='The path to the model folder')
66
+ parser.add_argument('--corr-fname', required=True, type=str,
67
+ dest='corr_fname',
68
+ help='Filename with the head correspondences')
69
+ parser.add_argument('--gender', type=str, default='neutral',
70
+ help='The gender of the model')
71
+ parser.add_argument('--ext', type=str, default='npz',
72
+ help='Which extension to use for loading')
73
+ parser.add_argument('--head', default='right',
74
+ choices=['right', 'left'],
75
+ type=str, help='Which head to plot')
76
+ parser.add_argument('--head-color', type=float, nargs=3, dest='head_color',
77
+ default=(0.3, 0.3, 0.6),
78
+ help='Color for the head vertices')
79
+
80
+ args = parser.parse_args()
81
+
82
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
83
+ corr_fname = args.corr_fname
84
+ gender = args.gender
85
+ ext = args.ext
86
+ head = args.head
87
+ head_color = args.head_color
88
+
89
+ main(model_folder, corr_fname, ext=ext,
90
+ head_color=head_color,
91
+ gender=gender
92
+ )
common/utils/smplx/examples/vis_mano_vertices.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+ import pickle
20
+
21
+ import numpy as np
22
+ import torch
23
+ import open3d as o3d
24
+
25
+ import smplx
26
+
27
+
28
+ def main(model_folder, corr_fname, ext='npz',
29
+ hand_color=(0.3, 0.3, 0.6),
30
+ gender='neutral', hand='right'):
31
+
32
+ with open(corr_fname, 'rb') as f:
33
+ idxs_data = pickle.load(f)
34
+ if hand == 'both':
35
+ hand_idxs = np.concatenate(
36
+ [idxs_data['left_hand'], idxs_data['right_hand']]
37
+ )
38
+ else:
39
+ hand_idxs = idxs_data[f'{hand}_hand']
40
+
41
+ model = smplx.create(model_folder, model_type='smplx',
42
+ gender=gender,
43
+ ext=ext)
44
+ betas = torch.zeros([1, 10], dtype=torch.float32)
45
+ expression = torch.zeros([1, 10], dtype=torch.float32)
46
+
47
+ output = model(betas=betas, expression=expression,
48
+ return_verts=True)
49
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
50
+ joints = output.joints.detach().cpu().numpy().squeeze()
51
+
52
+ print('Vertices shape =', vertices.shape)
53
+ print('Joints shape =', joints.shape)
54
+
55
+ mesh = o3d.geometry.TriangleMesh()
56
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
57
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
58
+ mesh.compute_vertex_normals()
59
+
60
+ colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
61
+ colors[hand_idxs] = hand_color
62
+
63
+ mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
64
+
65
+ o3d.visualization.draw_geometries([mesh])
66
+
67
+
68
+ if __name__ == '__main__':
69
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
70
+
71
+ parser.add_argument('--model-folder', required=True, type=str,
72
+ help='The path to the model folder')
73
+ parser.add_argument('--corr-fname', required=True, type=str,
74
+ dest='corr_fname',
75
+ help='Filename with the hand correspondences')
76
+ parser.add_argument('--gender', type=str, default='neutral',
77
+ help='The gender of the model')
78
+ parser.add_argument('--ext', type=str, default='npz',
79
+ help='Which extension to use for loading')
80
+ parser.add_argument('--hand', default='right',
81
+ choices=['right', 'left', 'both'],
82
+ type=str, help='Which hand to plot')
83
+ parser.add_argument('--hand-color', type=float, nargs=3, dest='hand_color',
84
+ default=(0.3, 0.3, 0.6),
85
+ help='Color for the hand vertices')
86
+
87
+ args = parser.parse_args()
88
+
89
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
90
+ corr_fname = args.corr_fname
91
+ gender = args.gender
92
+ ext = args.ext
93
+ hand = args.hand
94
+ hand_color = args.hand_color
95
+
96
+ main(model_folder, corr_fname, ext=ext,
97
+ hand_color=hand_color,
98
+ gender=gender, hand=hand
99
+ )
common/utils/smplx/setup.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import io
19
+ import os
20
+
21
+ from setuptools import setup
22
+
23
+ # Package meta-data.
24
+ NAME = 'smplx'
25
+ DESCRIPTION = 'PyTorch module for loading the SMPLX body model'
26
+ URL = 'http://smpl-x.is.tuebingen.mpg.de'
27
+ EMAIL = 'vassilis.choutas@tuebingen.mpg.de'
28
+ AUTHOR = 'Vassilis Choutas'
29
+ REQUIRES_PYTHON = '>=3.6.0'
30
+ VERSION = '0.1.21'
31
+
32
+ here = os.path.abspath(os.path.dirname(__file__))
33
+
34
+ try:
35
+ FileNotFoundError
36
+ except NameError:
37
+ FileNotFoundError = IOError
38
+
39
+ # Import the README and use it as the long-description.
40
+ # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
41
+ try:
42
+ with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
43
+ long_description = '\n' + f.read()
44
+ except FileNotFoundError:
45
+ long_description = DESCRIPTION
46
+
47
+ # Load the package's __version__.py module as a dictionary.
48
+ about = {}
49
+ if not VERSION:
50
+ with open(os.path.join(here, NAME, '__version__.py')) as f:
51
+ exec(f.read(), about)
52
+ else:
53
+ about['__version__'] = VERSION
54
+
55
+ pyrender_reqs = ['pyrender>=0.1.23', 'trimesh>=2.37.6', 'shapely']
56
+ matplotlib_reqs = ['matplotlib']
57
+ open3d_reqs = ['open3d-python']
58
+
59
+ setup(name=NAME,
60
+ version=about['__version__'],
61
+ description=DESCRIPTION,
62
+ long_description=long_description,
63
+ long_description_content_type='text/markdown',
64
+ author=AUTHOR,
65
+ author_email=EMAIL,
66
+ python_requires=REQUIRES_PYTHON,
67
+ url=URL,
68
+ install_requires=[
69
+ 'numpy>=1.16.2',
70
+ 'torch>=1.0.1.post2',
71
+ 'torchgeometry>=0.1.2'
72
+ ],
73
+ extras_require={
74
+ 'pyrender': pyrender_reqs,
75
+ 'open3d': open3d_reqs,
76
+ 'matplotlib': matplotlib_reqs,
77
+ 'all': pyrender_reqs + matplotlib_reqs + open3d_reqs
78
+ },
79
+ packages=['smplx', 'tools'])
common/utils/smplx/smplx/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from .body_models import (
18
+ create,
19
+ SMPL,
20
+ SMPLH,
21
+ SMPLX,
22
+ MANO,
23
+ FLAME,
24
+ build_layer,
25
+ SMPLLayer,
26
+ SMPLHLayer,
27
+ SMPLXLayer,
28
+ MANOLayer,
29
+ FLAMELayer,
30
+ )
common/utils/smplx/smplx/body_models.py ADDED
@@ -0,0 +1,2331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from typing import Optional, Dict, Union
18
+ import os
19
+ import os.path as osp
20
+
21
+ import pickle
22
+
23
+ import numpy as np
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from .lbs import (
29
+ lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords)
30
+
31
+ from .vertex_ids import vertex_ids as VERTEX_IDS
32
+ from .utils import (
33
+ Struct, to_np, to_tensor, Tensor, Array,
34
+ SMPLOutput,
35
+ SMPLHOutput,
36
+ SMPLXOutput,
37
+ MANOOutput,
38
+ FLAMEOutput,
39
+ find_joint_kin_chain)
40
+ from .vertex_joint_selector import VertexJointSelector
41
+ from config import cfg
42
+
43
+ class SMPL(nn.Module):
44
+
45
+ NUM_JOINTS = 23
46
+ NUM_BODY_JOINTS = 23
47
+ SHAPE_SPACE_DIM = 300
48
+
49
+ def __init__(
50
+ self, model_path: str,
51
+ data_struct: Optional[Struct] = None,
52
+ create_betas: bool = True,
53
+ betas: Optional[Tensor] = None,
54
+ num_betas: int = 10,
55
+ create_global_orient: bool = True,
56
+ global_orient: Optional[Tensor] = None,
57
+ create_body_pose: bool = True,
58
+ body_pose: Optional[Tensor] = None,
59
+ create_transl: bool = True,
60
+ transl: Optional[Tensor] = None,
61
+ dtype=torch.float32,
62
+ batch_size: int = 1,
63
+ joint_mapper=None,
64
+ gender: str = 'neutral',
65
+ vertex_ids: Dict[str, int] = None,
66
+ v_template: Optional[Union[Tensor, Array]] = None,
67
+ **kwargs
68
+ ) -> None:
69
+ ''' SMPL model constructor
70
+
71
+ Parameters
72
+ ----------
73
+ model_path: str
74
+ The path to the folder or to the file where the model
75
+ parameters are stored
76
+ data_struct: Strct
77
+ A struct object. If given, then the parameters of the model are
78
+ read from the object. Otherwise, the model tries to read the
79
+ parameters from the given `model_path`. (default = None)
80
+ create_global_orient: bool, optional
81
+ Flag for creating a member variable for the global orientation
82
+ of the body. (default = True)
83
+ global_orient: torch.tensor, optional, Bx3
84
+ The default value for the global orientation variable.
85
+ (default = None)
86
+ create_body_pose: bool, optional
87
+ Flag for creating a member variable for the pose of the body.
88
+ (default = True)
89
+ body_pose: torch.tensor, optional, Bx(Body Joints * 3)
90
+ The default value for the body pose variable.
91
+ (default = None)
92
+ num_betas: int, optional
93
+ Number of shape components to use
94
+ (default = 10).
95
+ create_betas: bool, optional
96
+ Flag for creating a member variable for the shape space
97
+ (default = True).
98
+ betas: torch.tensor, optional, Bx10
99
+ The default value for the shape member variable.
100
+ (default = None)
101
+ create_transl: bool, optional
102
+ Flag for creating a member variable for the translation
103
+ of the body. (default = True)
104
+ transl: torch.tensor, optional, Bx3
105
+ The default value for the transl variable.
106
+ (default = None)
107
+ dtype: torch.dtype, optional
108
+ The data type for the created variables
109
+ batch_size: int, optional
110
+ The batch size used for creating the member variables
111
+ joint_mapper: object, optional
112
+ An object that re-maps the joints. Useful if one wants to
113
+ re-order the SMPL joints to some other convention (e.g. MSCOCO)
114
+ (default = None)
115
+ gender: str, optional
116
+ Which gender to load
117
+ vertex_ids: dict, optional
118
+ A dictionary containing the indices of the extra vertices that
119
+ will be selected
120
+ '''
121
+
122
+ self.gender = gender
123
+
124
+ if data_struct is None:
125
+ if osp.isdir(model_path):
126
+ model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
127
+ smpl_path = os.path.join(model_path, model_fn)
128
+ else:
129
+ smpl_path = model_path
130
+ assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
131
+ smpl_path)
132
+
133
+ with open(smpl_path, 'rb') as smpl_file:
134
+ data_struct = Struct(**pickle.load(smpl_file,
135
+ encoding='latin1'))
136
+
137
+ super(SMPL, self).__init__()
138
+ self.batch_size = batch_size
139
+ shapedirs = data_struct.shapedirs
140
+ if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM):
141
+ print(f'WARNING: You are using a {self.name()} model, with only'
142
+ ' 10 shape coefficients.')
143
+ num_betas = min(num_betas, 10)
144
+ else:
145
+ num_betas = min(num_betas, self.SHAPE_SPACE_DIM)
146
+
147
+ self._num_betas = num_betas
148
+ shapedirs = shapedirs[:, :, :num_betas]
149
+ # The shape components
150
+ self.register_buffer(
151
+ 'shapedirs',
152
+ to_tensor(to_np(shapedirs), dtype=dtype))
153
+
154
+ if vertex_ids is None:
155
+ # SMPL and SMPL-H share the same topology, so any extra joints can
156
+ # be drawn from the same place
157
+ vertex_ids = VERTEX_IDS['smplh']
158
+
159
+ self.dtype = dtype
160
+
161
+ self.joint_mapper = joint_mapper
162
+
163
+ self.vertex_joint_selector = VertexJointSelector(
164
+ vertex_ids=vertex_ids, **kwargs)
165
+
166
+ self.faces = data_struct.f
167
+ self.register_buffer('faces_tensor',
168
+ to_tensor(to_np(self.faces, dtype=np.int64),
169
+ dtype=torch.long))
170
+
171
+ if create_betas:
172
+ if betas is None:
173
+ default_betas = torch.zeros(
174
+ [batch_size, self.num_betas], dtype=dtype)
175
+ else:
176
+ if torch.is_tensor(betas):
177
+ default_betas = betas.clone().detach()
178
+ else:
179
+ default_betas = torch.tensor(betas, dtype=dtype)
180
+
181
+ self.register_parameter(
182
+ 'betas', nn.Parameter(default_betas, requires_grad=True))
183
+
184
+ # The tensor that contains the global rotation of the model
185
+ # It is separated from the pose of the joints in case we wish to
186
+ # optimize only over one of them
187
+ if create_global_orient:
188
+ if global_orient is None:
189
+ default_global_orient = torch.zeros(
190
+ [batch_size, 3], dtype=dtype)
191
+ else:
192
+ if torch.is_tensor(global_orient):
193
+ default_global_orient = global_orient.clone().detach()
194
+ else:
195
+ default_global_orient = torch.tensor(
196
+ global_orient, dtype=dtype)
197
+
198
+ global_orient = nn.Parameter(default_global_orient,
199
+ requires_grad=True)
200
+ self.register_parameter('global_orient', global_orient)
201
+
202
+ if create_body_pose:
203
+ if body_pose is None:
204
+ default_body_pose = torch.zeros(
205
+ [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
206
+ else:
207
+ if torch.is_tensor(body_pose):
208
+ default_body_pose = body_pose.clone().detach()
209
+ else:
210
+ default_body_pose = torch.tensor(body_pose,
211
+ dtype=dtype)
212
+ self.register_parameter(
213
+ 'body_pose',
214
+ nn.Parameter(default_body_pose, requires_grad=True))
215
+
216
+ if create_transl:
217
+ if transl is None:
218
+ default_transl = torch.zeros([batch_size, 3],
219
+ dtype=dtype,
220
+ requires_grad=True)
221
+ else:
222
+ default_transl = torch.tensor(transl, dtype=dtype)
223
+ self.register_parameter(
224
+ 'transl', nn.Parameter(default_transl, requires_grad=True))
225
+
226
+ if v_template is None:
227
+ v_template = data_struct.v_template
228
+ if not torch.is_tensor(v_template):
229
+ v_template = to_tensor(to_np(v_template), dtype=dtype)
230
+ # The vertices of the template model
231
+ self.register_buffer('v_template', v_template)
232
+
233
+ j_regressor = to_tensor(to_np(
234
+ data_struct.J_regressor), dtype=dtype)
235
+ self.register_buffer('J_regressor', j_regressor)
236
+
237
+ # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
238
+ num_pose_basis = data_struct.posedirs.shape[-1]
239
+ # 207 x 20670
240
+ posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
241
+ self.register_buffer('posedirs',
242
+ to_tensor(to_np(posedirs), dtype=dtype))
243
+
244
+ # indices of parents for each joints
245
+ parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
246
+ parents[0] = -1
247
+ self.register_buffer('parents', parents)
248
+
249
+ self.register_buffer(
250
+ 'lbs_weights', to_tensor(to_np(data_struct.weights), dtype=dtype))
251
+
252
+ @property
253
+ def num_betas(self):
254
+ return self._num_betas
255
+
256
+ @property
257
+ def num_expression_coeffs(self):
258
+ return 0
259
+
260
+ def create_mean_pose(self, data_struct) -> Tensor:
261
+ pass
262
+
263
+ def name(self) -> str:
264
+ return 'SMPL'
265
+
266
+ @torch.no_grad()
267
+ def reset_params(self, **params_dict) -> None:
268
+ for param_name, param in self.named_parameters():
269
+ if param_name in params_dict:
270
+ param[:] = torch.tensor(params_dict[param_name])
271
+ else:
272
+ param.fill_(0)
273
+
274
+ def get_num_verts(self) -> int:
275
+ return self.v_template.shape[0]
276
+
277
+ def get_num_faces(self) -> int:
278
+ return self.faces.shape[0]
279
+
280
+ def extra_repr(self) -> str:
281
+ msg = [
282
+ f'Gender: {self.gender.upper()}',
283
+ f'Number of joints: {self.J_regressor.shape[0]}',
284
+ f'Betas: {self.num_betas}',
285
+ ]
286
+ return '\n'.join(msg)
287
+
288
+ def forward(
289
+ self,
290
+ betas: Optional[Tensor] = None,
291
+ body_pose: Optional[Tensor] = None,
292
+ global_orient: Optional[Tensor] = None,
293
+ transl: Optional[Tensor] = None,
294
+ return_verts=True,
295
+ return_full_pose: bool = False,
296
+ pose2rot: bool = True,
297
+ **kwargs
298
+ ) -> SMPLOutput:
299
+ ''' Forward pass for the SMPL model
300
+
301
+ Parameters
302
+ ----------
303
+ global_orient: torch.tensor, optional, shape Bx3
304
+ If given, ignore the member variable and use it as the global
305
+ rotation of the body. Useful if someone wishes to predicts this
306
+ with an external model. (default=None)
307
+ betas: torch.tensor, optional, shape Bx10
308
+ If given, ignore the member variable `betas` and use it
309
+ instead. For example, it can used if shape parameters
310
+ `betas` are predicted from some external model.
311
+ (default=None)
312
+ body_pose: torch.tensor, optional, shape Bx(J*3)
313
+ If given, ignore the member variable `body_pose` and use it
314
+ instead. For example, it can used if someone predicts the
315
+ pose of the body joints are predicted from some external model.
316
+ It should be a tensor that contains joint rotations in
317
+ axis-angle format. (default=None)
318
+ transl: torch.tensor, optional, shape Bx3
319
+ If given, ignore the member variable `transl` and use it
320
+ instead. For example, it can used if the translation
321
+ `transl` is predicted from some external model.
322
+ (default=None)
323
+ return_verts: bool, optional
324
+ Return the vertices. (default=True)
325
+ return_full_pose: bool, optional
326
+ Returns the full axis-angle pose vector (default=False)
327
+
328
+ Returns
329
+ -------
330
+ '''
331
+ # If no shape and pose parameters are passed along, then use the
332
+ # ones from the module
333
+ global_orient = (global_orient if global_orient is not None else
334
+ self.global_orient)
335
+ body_pose = body_pose if body_pose is not None else self.body_pose
336
+ betas = betas if betas is not None else self.betas
337
+
338
+ apply_trans = transl is not None or hasattr(self, 'transl')
339
+ if transl is None and hasattr(self, 'transl'):
340
+ transl = self.transl
341
+
342
+ full_pose = torch.cat([global_orient, body_pose], dim=1)
343
+
344
+ batch_size = max(betas.shape[0], global_orient.shape[0],
345
+ body_pose.shape[0])
346
+
347
+ if betas.shape[0] != batch_size:
348
+ num_repeats = int(batch_size / betas.shape[0])
349
+ betas = betas.expand(num_repeats, -1)
350
+
351
+ vertices, joints = lbs(betas, full_pose, self.v_template,
352
+ self.shapedirs, self.posedirs,
353
+ self.J_regressor, self.parents,
354
+ self.lbs_weights, pose2rot=pose2rot)
355
+
356
+ joints = self.vertex_joint_selector(vertices, joints)
357
+ # Map the joints to the current dataset
358
+ if self.joint_mapper is not None:
359
+ joints = self.joint_mapper(joints)
360
+
361
+ if apply_trans:
362
+ joints += transl.unsqueeze(dim=1)
363
+ vertices += transl.unsqueeze(dim=1)
364
+
365
+ output = SMPLOutput(vertices=vertices if return_verts else None,
366
+ global_orient=global_orient,
367
+ body_pose=body_pose,
368
+ joints=joints,
369
+ betas=betas,
370
+ full_pose=full_pose if return_full_pose else None)
371
+
372
+ return output
373
+
374
+
375
+ class SMPLLayer(SMPL):
376
+ def __init__(
377
+ self,
378
+ *args,
379
+ **kwargs
380
+ ) -> None:
381
+ # Just create a SMPL module without any member variables
382
+ super(SMPLLayer, self).__init__(
383
+ create_body_pose=False,
384
+ create_betas=False,
385
+ create_global_orient=False,
386
+ create_transl=False,
387
+ *args,
388
+ **kwargs,
389
+ )
390
+
391
+ def forward(
392
+ self,
393
+ betas: Optional[Tensor] = None,
394
+ body_pose: Optional[Tensor] = None,
395
+ global_orient: Optional[Tensor] = None,
396
+ transl: Optional[Tensor] = None,
397
+ return_verts=True,
398
+ return_full_pose: bool = False,
399
+ pose2rot: bool = True,
400
+ **kwargs
401
+ ) -> SMPLOutput:
402
+ ''' Forward pass for the SMPL model
403
+
404
+ Parameters
405
+ ----------
406
+ global_orient: torch.tensor, optional, shape Bx3
407
+ If given, ignore the member variable and use it as the global
408
+ rotation of the body. Useful if someone wishes to predicts this
409
+ with an external model. (default=None)
410
+ betas: torch.tensor, optional, shape Bx10
411
+ If given, ignore the member variable `betas` and use it
412
+ instead. For example, it can used if shape parameters
413
+ `betas` are predicted from some external model.
414
+ (default=None)
415
+ body_pose: torch.tensor, optional, shape Bx(J*3)
416
+ If given, ignore the member variable `body_pose` and use it
417
+ instead. For example, it can used if someone predicts the
418
+ pose of the body joints are predicted from some external model.
419
+ It should be a tensor that contains joint rotations in
420
+ axis-angle format. (default=None)
421
+ transl: torch.tensor, optional, shape Bx3
422
+ If given, ignore the member variable `transl` and use it
423
+ instead. For example, it can used if the translation
424
+ `transl` is predicted from some external model.
425
+ (default=None)
426
+ return_verts: bool, optional
427
+ Return the vertices. (default=True)
428
+ return_full_pose: bool, optional
429
+ Returns the full axis-angle pose vector (default=False)
430
+
431
+ Returns
432
+ -------
433
+ '''
434
+ device, dtype = self.shapedirs.device, self.shapedirs.dtype
435
+ if global_orient is None:
436
+ batch_size = 1
437
+ global_orient = torch.zeros(3, device=device, dtype=dtype).view(
438
+ 1, 1, 3).expand(batch_size, 1, 1).contiguous()
439
+ else:
440
+ batch_size = global_orient.shape[0]
441
+ if body_pose is None:
442
+ body_pose = torch.zeros(3, device=device, dtype=dtype).view(
443
+ 1, 1, 3).expand(
444
+ batch_size, self.NUM_BODY_JOINTS, 1).contiguous()
445
+ if betas is None:
446
+ betas = torch.zeros([batch_size, self.num_betas],
447
+ dtype=dtype, device=device)
448
+ if transl is None:
449
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
450
+ full_pose = torch.cat(
451
+ [global_orient.reshape(-1, 1, 3),
452
+ body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3)],
453
+ dim=1)
454
+
455
+ vertices, joints = lbs(betas, full_pose, self.v_template,
456
+ self.shapedirs, self.posedirs,
457
+ self.J_regressor, self.parents,
458
+ self.lbs_weights,
459
+ pose2rot=True)
460
+
461
+ joints = self.vertex_joint_selector(vertices, joints)
462
+ # Map the joints to the current dataset
463
+ if self.joint_mapper is not None:
464
+ joints = self.joint_mapper(joints)
465
+
466
+ if transl is not None:
467
+ joints += transl.unsqueeze(dim=1)
468
+ vertices += transl.unsqueeze(dim=1)
469
+
470
+ output = SMPLOutput(vertices=vertices if return_verts else None,
471
+ global_orient=global_orient,
472
+ body_pose=body_pose,
473
+ joints=joints,
474
+ betas=betas,
475
+ full_pose=full_pose if return_full_pose else None)
476
+
477
+ return output
478
+
479
+
480
+ class SMPLH(SMPL):
481
+
482
+ # The hand joints are replaced by MANO
483
+ NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2
484
+ NUM_HAND_JOINTS = 15
485
+ NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS
486
+
487
+ def __init__(
488
+ self, model_path,
489
+ data_struct: Optional[Struct] = None,
490
+ create_left_hand_pose: bool = True,
491
+ left_hand_pose: Optional[Tensor] = None,
492
+ create_right_hand_pose: bool = True,
493
+ right_hand_pose: Optional[Tensor] = None,
494
+ use_pca: bool = True,
495
+ num_pca_comps: int = 6,
496
+ flat_hand_mean: bool = False,
497
+ batch_size: int = 1,
498
+ gender: str = 'neutral',
499
+ dtype=torch.float32,
500
+ vertex_ids=None,
501
+ use_compressed: bool = True,
502
+ ext: str = 'pkl',
503
+ **kwargs
504
+ ) -> None:
505
+ ''' SMPLH model constructor
506
+
507
+ Parameters
508
+ ----------
509
+ model_path: str
510
+ The path to the folder or to the file where the model
511
+ parameters are stored
512
+ data_struct: Strct
513
+ A struct object. If given, then the parameters of the model are
514
+ read from the object. Otherwise, the model tries to read the
515
+ parameters from the given `model_path`. (default = None)
516
+ create_left_hand_pose: bool, optional
517
+ Flag for creating a member variable for the pose of the left
518
+ hand. (default = True)
519
+ left_hand_pose: torch.tensor, optional, BxP
520
+ The default value for the left hand pose member variable.
521
+ (default = None)
522
+ create_right_hand_pose: bool, optional
523
+ Flag for creating a member variable for the pose of the right
524
+ hand. (default = True)
525
+ right_hand_pose: torch.tensor, optional, BxP
526
+ The default value for the right hand pose member variable.
527
+ (default = None)
528
+ num_pca_comps: int, optional
529
+ The number of PCA components to use for each hand.
530
+ (default = 6)
531
+ flat_hand_mean: bool, optional
532
+ If False, then the pose of the hand is initialized to False.
533
+ batch_size: int, optional
534
+ The batch size used for creating the member variables
535
+ gender: str, optional
536
+ Which gender to load
537
+ dtype: torch.dtype, optional
538
+ The data type for the created variables
539
+ vertex_ids: dict, optional
540
+ A dictionary containing the indices of the extra vertices that
541
+ will be selected
542
+ '''
543
+
544
+ self.num_pca_comps = num_pca_comps
545
+ # If no data structure is passed, then load the data from the given
546
+ # model folder
547
+ if data_struct is None:
548
+ # Load the model
549
+ if osp.isdir(model_path):
550
+ model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext)
551
+ smplh_path = os.path.join(model_path, model_fn)
552
+ else:
553
+ smplh_path = model_path
554
+ assert osp.exists(smplh_path), 'Path {} does not exist!'.format(
555
+ smplh_path)
556
+
557
+ if ext == 'pkl':
558
+ with open(smplh_path, 'rb') as smplh_file:
559
+ model_data = pickle.load(smplh_file, encoding='latin1')
560
+ elif ext == 'npz':
561
+ model_data = np.load(smplh_path, allow_pickle=True)
562
+ else:
563
+ raise ValueError('Unknown extension: {}'.format(ext))
564
+ data_struct = Struct(**model_data)
565
+
566
+ if vertex_ids is None:
567
+ vertex_ids = VERTEX_IDS['smplh']
568
+
569
+ super(SMPLH, self).__init__(
570
+ model_path=model_path,
571
+ data_struct=data_struct,
572
+ batch_size=batch_size, vertex_ids=vertex_ids, gender=gender,
573
+ use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs)
574
+
575
+ self.use_pca = use_pca
576
+ self.num_pca_comps = num_pca_comps
577
+ self.flat_hand_mean = flat_hand_mean
578
+
579
+ left_hand_components = data_struct.hands_componentsl[:num_pca_comps]
580
+ right_hand_components = data_struct.hands_componentsr[:num_pca_comps]
581
+
582
+ self.np_left_hand_components = left_hand_components
583
+ self.np_right_hand_components = right_hand_components
584
+ if self.use_pca:
585
+ self.register_buffer(
586
+ 'left_hand_components',
587
+ torch.tensor(left_hand_components, dtype=dtype))
588
+ self.register_buffer(
589
+ 'right_hand_components',
590
+ torch.tensor(right_hand_components, dtype=dtype))
591
+
592
+ if self.flat_hand_mean:
593
+ left_hand_mean = np.zeros_like(data_struct.hands_meanl)
594
+ else:
595
+ left_hand_mean = data_struct.hands_meanl
596
+
597
+ if self.flat_hand_mean:
598
+ right_hand_mean = np.zeros_like(data_struct.hands_meanr)
599
+ else:
600
+ right_hand_mean = data_struct.hands_meanr
601
+
602
+ self.register_buffer('left_hand_mean',
603
+ to_tensor(left_hand_mean, dtype=self.dtype))
604
+ self.register_buffer('right_hand_mean',
605
+ to_tensor(right_hand_mean, dtype=self.dtype))
606
+
607
+ # Create the buffers for the pose of the left hand
608
+ hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
609
+ if create_left_hand_pose:
610
+ if left_hand_pose is None:
611
+ default_lhand_pose = torch.zeros([batch_size, hand_pose_dim],
612
+ dtype=dtype)
613
+ else:
614
+ default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype)
615
+
616
+ left_hand_pose_param = nn.Parameter(default_lhand_pose,
617
+ requires_grad=True)
618
+ self.register_parameter('left_hand_pose',
619
+ left_hand_pose_param)
620
+
621
+ if create_right_hand_pose:
622
+ if right_hand_pose is None:
623
+ default_rhand_pose = torch.zeros([batch_size, hand_pose_dim],
624
+ dtype=dtype)
625
+ else:
626
+ default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype)
627
+
628
+ right_hand_pose_param = nn.Parameter(default_rhand_pose,
629
+ requires_grad=True)
630
+ self.register_parameter('right_hand_pose',
631
+ right_hand_pose_param)
632
+
633
+ # Create the buffer for the mean pose.
634
+ pose_mean_tensor = self.create_mean_pose(
635
+ data_struct, flat_hand_mean=flat_hand_mean)
636
+ if not torch.is_tensor(pose_mean_tensor):
637
+ pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype)
638
+ self.register_buffer('pose_mean', pose_mean_tensor)
639
+
640
+ def create_mean_pose(self, data_struct, flat_hand_mean=False):
641
+ # Create the array for the mean pose. If flat_hand is false, then use
642
+ # the mean that is given by the data, rather than the flat open hand
643
+ global_orient_mean = torch.zeros([3], dtype=self.dtype)
644
+ body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3],
645
+ dtype=self.dtype)
646
+
647
+ pose_mean = torch.cat([global_orient_mean, body_pose_mean,
648
+ self.left_hand_mean,
649
+ self.right_hand_mean], dim=0)
650
+ return pose_mean
651
+
652
+ def name(self) -> str:
653
+ return 'SMPL+H'
654
+
655
+ def extra_repr(self):
656
+ msg = super(SMPLH, self).extra_repr()
657
+ msg = [msg]
658
+ if self.use_pca:
659
+ msg.append(f'Number of PCA components: {self.num_pca_comps}')
660
+ msg.append(f'Flat hand mean: {self.flat_hand_mean}')
661
+ return '\n'.join(msg)
662
+
663
+ def forward(
664
+ self,
665
+ betas: Optional[Tensor] = None,
666
+ global_orient: Optional[Tensor] = None,
667
+ body_pose: Optional[Tensor] = None,
668
+ left_hand_pose: Optional[Tensor] = None,
669
+ right_hand_pose: Optional[Tensor] = None,
670
+ transl: Optional[Tensor] = None,
671
+ return_verts: bool = True,
672
+ return_full_pose: bool = False,
673
+ pose2rot: bool = True,
674
+ **kwargs
675
+ ) -> SMPLHOutput:
676
+ '''
677
+ '''
678
+ # If no shape and pose parameters are passed along, then use the
679
+ # ones from the module
680
+ global_orient = (global_orient if global_orient is not None else
681
+ self.global_orient)
682
+ body_pose = body_pose if body_pose is not None else self.body_pose
683
+ betas = betas if betas is not None else self.betas
684
+ left_hand_pose = (left_hand_pose if left_hand_pose is not None else
685
+ self.left_hand_pose)
686
+ right_hand_pose = (right_hand_pose if right_hand_pose is not None else
687
+ self.right_hand_pose)
688
+
689
+ apply_trans = transl is not None or hasattr(self, 'transl')
690
+ if transl is None:
691
+ if hasattr(self, 'transl'):
692
+ transl = self.transl
693
+
694
+ if self.use_pca:
695
+ left_hand_pose = torch.einsum(
696
+ 'bi,ij->bj', [left_hand_pose, self.left_hand_components])
697
+ right_hand_pose = torch.einsum(
698
+ 'bi,ij->bj', [right_hand_pose, self.right_hand_components])
699
+
700
+ full_pose = torch.cat([global_orient, body_pose,
701
+ left_hand_pose,
702
+ right_hand_pose], dim=1)
703
+ full_pose += self.pose_mean
704
+
705
+ vertices, joints = lbs(self.betas, full_pose, self.v_template,
706
+ self.shapedirs, self.posedirs,
707
+ self.J_regressor, self.parents,
708
+ self.lbs_weights, pose2rot=pose2rot)
709
+
710
+ # Add any extra joints that might be needed
711
+ joints = self.vertex_joint_selector(vertices, joints)
712
+ if self.joint_mapper is not None:
713
+ joints = self.joint_mapper(joints)
714
+
715
+ if apply_trans:
716
+ joints += transl.unsqueeze(dim=1)
717
+ vertices += transl.unsqueeze(dim=1)
718
+
719
+ output = SMPLHOutput(vertices=vertices if return_verts else None,
720
+ joints=joints,
721
+ betas=betas,
722
+ global_orient=global_orient,
723
+ body_pose=body_pose,
724
+ left_hand_pose=left_hand_pose,
725
+ right_hand_pose=right_hand_pose,
726
+ full_pose=full_pose if return_full_pose else None)
727
+
728
+ return output
729
+
730
+
731
+ class SMPLHLayer(SMPLH):
732
+
733
+ def __init__(
734
+ self, *args, **kwargs
735
+ ) -> None:
736
+ ''' SMPL+H as a layer model constructor
737
+ '''
738
+ super(SMPLHLayer, self).__init__(
739
+ create_global_orient=False,
740
+ create_body_pose=False,
741
+ create_left_hand_pose=False,
742
+ create_right_hand_pose=False,
743
+ create_betas=False,
744
+ create_transl=False,
745
+ *args,
746
+ **kwargs)
747
+
748
+ def forward(
749
+ self,
750
+ betas: Optional[Tensor] = None,
751
+ global_orient: Optional[Tensor] = None,
752
+ body_pose: Optional[Tensor] = None,
753
+ left_hand_pose: Optional[Tensor] = None,
754
+ right_hand_pose: Optional[Tensor] = None,
755
+ transl: Optional[Tensor] = None,
756
+ return_verts: bool = True,
757
+ return_full_pose: bool = False,
758
+ pose2rot: bool = True,
759
+ **kwargs
760
+ ) -> SMPLHOutput:
761
+ '''
762
+ '''
763
+ device, dtype = self.shapedirs.device, self.shapedirs.dtype
764
+ if global_orient is None:
765
+ batch_size = 1
766
+ global_orient = torch.zeros(3, device=device, dtype=dtype).view(
767
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
768
+ else:
769
+ batch_size = global_orient.shape[0]
770
+ if body_pose is None:
771
+ body_pose = torch.zeros(3, device=device, dtype=dtype).view(
772
+ 1, 1, 3).expand(batch_size, 21, -1).contiguous()
773
+ if left_hand_pose is None:
774
+ left_hand_pose = torch.zeros(3, device=device, dtype=dtype).view(
775
+ 1, 1, 3).expand(batch_size, 15, -1).contiguous()
776
+ if right_hand_pose is None:
777
+ right_hand_pose = torch.zeros(3, device=device, dtype=dtype).view(
778
+ 1, 1, 3).expand(batch_size, 15, -1).contiguous()
779
+ if betas is None:
780
+ betas = torch.zeros([batch_size, self.num_betas],
781
+ dtype=dtype, device=device)
782
+ if transl is None:
783
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
784
+
785
+ # Concatenate all pose vectors
786
+ full_pose = torch.cat(
787
+ [global_orient.reshape(-1, 1, 3),
788
+ body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3),
789
+ left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3),
790
+ right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3)],
791
+ dim=1)
792
+
793
+ vertices, joints = lbs(betas, full_pose, self.v_template,
794
+ self.shapedirs, self.posedirs,
795
+ self.J_regressor, self.parents,
796
+ self.lbs_weights, pose2rot=True)
797
+
798
+ # Add any extra joints that might be needed
799
+ joints = self.vertex_joint_selector(vertices, joints)
800
+ if self.joint_mapper is not None:
801
+ joints = self.joint_mapper(joints)
802
+
803
+ if transl is not None:
804
+ joints += transl.unsqueeze(dim=1)
805
+ vertices += transl.unsqueeze(dim=1)
806
+
807
+ output = SMPLHOutput(vertices=vertices if return_verts else None,
808
+ joints=joints,
809
+ betas=betas,
810
+ global_orient=global_orient,
811
+ body_pose=body_pose,
812
+ left_hand_pose=left_hand_pose,
813
+ right_hand_pose=right_hand_pose,
814
+ full_pose=full_pose if return_full_pose else None)
815
+
816
+ return output
817
+
818
+
819
+ class SMPLX(SMPLH):
820
+ '''
821
+ SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters
822
+ trained jointly for the face, hands and body.
823
+ SMPL-X uses standard vertex based linear blend skinning with learned
824
+ corrective blend shapes, has N=10475 vertices and K=54 joints,
825
+ which includes joints for the neck, jaw, eyeballs and fingers.
826
+ '''
827
+
828
+ NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS
829
+ NUM_HAND_JOINTS = 15
830
+ NUM_FACE_JOINTS = 3
831
+ NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
832
+ EXPRESSION_SPACE_DIM = 100
833
+ NECK_IDX = 12
834
+
835
+ def __init__(
836
+ self, model_path: str,
837
+ num_expression_coeffs: int = 10,
838
+ create_expression: bool = True,
839
+ expression: Optional[Tensor] = None,
840
+ create_jaw_pose: bool = True,
841
+ jaw_pose: Optional[Tensor] = None,
842
+ create_leye_pose: bool = True,
843
+ leye_pose: Optional[Tensor] = None,
844
+ create_reye_pose=True,
845
+ reye_pose: Optional[Tensor] = None,
846
+ use_face_contour: bool = False,
847
+ batch_size: int = 1,
848
+ gender: str = 'neutral',
849
+ dtype=torch.float32,
850
+ ext: str = 'npz',
851
+ **kwargs
852
+ ) -> None:
853
+ ''' SMPLX model constructor
854
+
855
+ Parameters
856
+ ----------
857
+ model_path: str
858
+ The path to the folder or to the file where the model
859
+ parameters are stored
860
+ num_expression_coeffs: int, optional
861
+ Number of expression components to use
862
+ (default = 10).
863
+ create_expression: bool, optional
864
+ Flag for creating a member variable for the expression space
865
+ (default = True).
866
+ expression: torch.tensor, optional, Bx10
867
+ The default value for the expression member variable.
868
+ (default = None)
869
+ create_jaw_pose: bool, optional
870
+ Flag for creating a member variable for the jaw pose.
871
+ (default = False)
872
+ jaw_pose: torch.tensor, optional, Bx3
873
+ The default value for the jaw pose variable.
874
+ (default = None)
875
+ create_leye_pose: bool, optional
876
+ Flag for creating a member variable for the left eye pose.
877
+ (default = False)
878
+ leye_pose: torch.tensor, optional, Bx10
879
+ The default value for the left eye pose variable.
880
+ (default = None)
881
+ create_reye_pose: bool, optional
882
+ Flag for creating a member variable for the right eye pose.
883
+ (default = False)
884
+ reye_pose: torch.tensor, optional, Bx10
885
+ The default value for the right eye pose variable.
886
+ (default = None)
887
+ use_face_contour: bool, optional
888
+ Whether to compute the keypoints that form the facial contour
889
+ batch_size: int, optional
890
+ The batch size used for creating the member variables
891
+ gender: str, optional
892
+ Which gender to load
893
+ dtype: torch.dtype
894
+ The data type for the created variables
895
+ '''
896
+
897
+ # Load the model
898
+ if osp.isdir(model_path):
899
+ model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext)
900
+ smplx_path = os.path.join(model_path, model_fn)
901
+ else:
902
+ smplx_path = model_path
903
+ assert osp.exists(smplx_path), 'Path {} does not exist!'.format(smplx_path)
904
+ if ext == 'pkl':
905
+ with open(smplx_path, 'rb') as smplx_file:
906
+ model_data = pickle.load(smplx_file, encoding='latin1')
907
+ elif ext == 'npz':
908
+ model_data = np.load(smplx_path, allow_pickle=True)
909
+ else:
910
+ raise ValueError('Unknown extension: {}'.format(ext))
911
+
912
+ data_struct = Struct(**model_data)
913
+
914
+ super(SMPLX, self).__init__(
915
+ model_path=model_path,
916
+ data_struct=data_struct,
917
+ dtype=dtype,
918
+ batch_size=batch_size,
919
+ vertex_ids=VERTEX_IDS['smplx'],
920
+ gender=gender, ext=ext,
921
+ **kwargs)
922
+
923
+ lmk_faces_idx = data_struct.lmk_faces_idx
924
+ self.register_buffer('lmk_faces_idx',
925
+ torch.tensor(lmk_faces_idx, dtype=torch.long))
926
+ lmk_bary_coords = data_struct.lmk_bary_coords
927
+ self.register_buffer('lmk_bary_coords',
928
+ torch.tensor(lmk_bary_coords, dtype=dtype))
929
+
930
+ self.use_face_contour = use_face_contour
931
+ if self.use_face_contour:
932
+ dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx
933
+ dynamic_lmk_faces_idx = torch.tensor(
934
+ dynamic_lmk_faces_idx,
935
+ dtype=torch.long)
936
+ self.register_buffer('dynamic_lmk_faces_idx',
937
+ dynamic_lmk_faces_idx)
938
+
939
+ dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords
940
+ dynamic_lmk_bary_coords = torch.tensor(
941
+ dynamic_lmk_bary_coords, dtype=dtype)
942
+ self.register_buffer('dynamic_lmk_bary_coords',
943
+ dynamic_lmk_bary_coords)
944
+
945
+ neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
946
+ self.register_buffer(
947
+ 'neck_kin_chain',
948
+ torch.tensor(neck_kin_chain, dtype=torch.long))
949
+
950
+ if create_jaw_pose:
951
+ if jaw_pose is None:
952
+ default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
953
+ else:
954
+ default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
955
+ jaw_pose_param = nn.Parameter(default_jaw_pose,
956
+ requires_grad=True)
957
+ self.register_parameter('jaw_pose', jaw_pose_param)
958
+
959
+ if create_leye_pose:
960
+ if leye_pose is None:
961
+ default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
962
+ else:
963
+ default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
964
+ leye_pose_param = nn.Parameter(default_leye_pose,
965
+ requires_grad=True)
966
+ self.register_parameter('leye_pose', leye_pose_param)
967
+
968
+ if create_reye_pose:
969
+ if reye_pose is None:
970
+ default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
971
+ else:
972
+ default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
973
+ reye_pose_param = nn.Parameter(default_reye_pose,
974
+ requires_grad=True)
975
+ self.register_parameter('reye_pose', reye_pose_param)
976
+
977
+ shapedirs = data_struct.shapedirs
978
+ if len(shapedirs.shape) < 3:
979
+ shapedirs = shapedirs[:, :, None]
980
+ if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM +
981
+ self.EXPRESSION_SPACE_DIM):
982
+ print(f'WARNING: You are using a {self.name()} model, with only'
983
+ ' 10 shape and 10 expression coefficients.')
984
+ expr_start_idx = 10
985
+ expr_end_idx = 20
986
+ num_expression_coeffs = min(num_expression_coeffs, 10)
987
+ else:
988
+ expr_start_idx = self.SHAPE_SPACE_DIM
989
+ expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
990
+ num_expression_coeffs = min(
991
+ num_expression_coeffs, self.EXPRESSION_SPACE_DIM)
992
+
993
+ self._num_expression_coeffs = num_expression_coeffs
994
+
995
+ expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
996
+ self.register_buffer(
997
+ 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype))
998
+
999
+ if create_expression:
1000
+ if expression is None:
1001
+ default_expression = torch.zeros(
1002
+ [batch_size, self.num_expression_coeffs], dtype=dtype)
1003
+ else:
1004
+ default_expression = torch.tensor(expression, dtype=dtype)
1005
+ expression_param = nn.Parameter(default_expression,
1006
+ requires_grad=True)
1007
+ self.register_parameter('expression', expression_param)
1008
+
1009
+ def name(self) -> str:
1010
+ return 'SMPL-X'
1011
+
1012
+ @property
1013
+ def num_expression_coeffs(self):
1014
+ return self._num_expression_coeffs
1015
+
1016
+ def create_mean_pose(self, data_struct, flat_hand_mean=False):
1017
+ # Create the array for the mean pose. If flat_hand is false, then use
1018
+ # the mean that is given by the data, rather than the flat open hand
1019
+ global_orient_mean = torch.zeros([3], dtype=self.dtype)
1020
+ body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3],
1021
+ dtype=self.dtype)
1022
+ jaw_pose_mean = torch.zeros([3], dtype=self.dtype)
1023
+ leye_pose_mean = torch.zeros([3], dtype=self.dtype)
1024
+ reye_pose_mean = torch.zeros([3], dtype=self.dtype)
1025
+
1026
+ pose_mean = np.concatenate([global_orient_mean, body_pose_mean,
1027
+ jaw_pose_mean,
1028
+ leye_pose_mean, reye_pose_mean,
1029
+ self.left_hand_mean, self.right_hand_mean],
1030
+ axis=0)
1031
+
1032
+ return pose_mean
1033
+
1034
+ def extra_repr(self):
1035
+ msg = super(SMPLX, self).extra_repr()
1036
+ msg = [
1037
+ msg,
1038
+ f'Number of Expression Coefficients: {self.num_expression_coeffs}'
1039
+ ]
1040
+ return '\n'.join(msg)
1041
+
1042
+ def forward(
1043
+ self,
1044
+ betas: Optional[Tensor] = None,
1045
+ global_orient: Optional[Tensor] = None,
1046
+ body_pose: Optional[Tensor] = None,
1047
+ left_hand_pose: Optional[Tensor] = None,
1048
+ right_hand_pose: Optional[Tensor] = None,
1049
+ transl: Optional[Tensor] = None,
1050
+ expression: Optional[Tensor] = None,
1051
+ jaw_pose: Optional[Tensor] = None,
1052
+ leye_pose: Optional[Tensor] = None,
1053
+ reye_pose: Optional[Tensor] = None,
1054
+ return_verts: bool = True,
1055
+ return_full_pose: bool = False,
1056
+ pose2rot: bool = True,
1057
+ **kwargs
1058
+ ) -> SMPLXOutput:
1059
+ '''
1060
+ Forward pass for the SMPLX model
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ global_orient: torch.tensor, optional, shape Bx3
1065
+ If given, ignore the member variable and use it as the global
1066
+ rotation of the body. Useful if someone wishes to predicts this
1067
+ with an external model. (default=None)
1068
+ betas: torch.tensor, optional, shape Bx10
1069
+ If given, ignore the member variable `betas` and use it
1070
+ instead. For example, it can used if shape parameters
1071
+ `betas` are predicted from some external model.
1072
+ (default=None)
1073
+ expression: torch.tensor, optional, shape Bx10
1074
+ If given, ignore the member variable `expression` and use it
1075
+ instead. For example, it can used if expression parameters
1076
+ `expression` are predicted from some external model.
1077
+ body_pose: torch.tensor, optional, shape Bx(J*3)
1078
+ If given, ignore the member variable `body_pose` and use it
1079
+ instead. For example, it can used if someone predicts the
1080
+ pose of the body joints are predicted from some external model.
1081
+ It should be a tensor that contains joint rotations in
1082
+ axis-angle format. (default=None)
1083
+ left_hand_pose: torch.tensor, optional, shape BxP
1084
+ If given, ignore the member variable `left_hand_pose` and
1085
+ use this instead. It should either contain PCA coefficients or
1086
+ joint rotations in axis-angle format.
1087
+ right_hand_pose: torch.tensor, optional, shape BxP
1088
+ If given, ignore the member variable `right_hand_pose` and
1089
+ use this instead. It should either contain PCA coefficients or
1090
+ joint rotations in axis-angle format.
1091
+ jaw_pose: torch.tensor, optional, shape Bx3
1092
+ If given, ignore the member variable `jaw_pose` and
1093
+ use this instead. It should either joint rotations in
1094
+ axis-angle format.
1095
+ transl: torch.tensor, optional, shape Bx3
1096
+ If given, ignore the member variable `transl` and use it
1097
+ instead. For example, it can used if the translation
1098
+ `transl` is predicted from some external model.
1099
+ (default=None)
1100
+ return_verts: bool, optional
1101
+ Return the vertices. (default=True)
1102
+ return_full_pose: bool, optional
1103
+ Returns the full axis-angle pose vector (default=False)
1104
+
1105
+ Returns
1106
+ -------
1107
+ output: ModelOutput
1108
+ A named tuple of type `ModelOutput`
1109
+ '''
1110
+
1111
+ # If no shape and pose parameters are passed along, then use the
1112
+ # ones from the module
1113
+ global_orient = (global_orient if global_orient is not None else
1114
+ self.global_orient)
1115
+ body_pose = body_pose if body_pose is not None else self.body_pose
1116
+ betas = betas if betas is not None else self.betas
1117
+
1118
+ left_hand_pose = (left_hand_pose if left_hand_pose is not None else
1119
+ self.left_hand_pose)
1120
+ right_hand_pose = (right_hand_pose if right_hand_pose is not None else
1121
+ self.right_hand_pose)
1122
+ jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
1123
+ leye_pose = leye_pose if leye_pose is not None else self.leye_pose
1124
+ reye_pose = reye_pose if reye_pose is not None else self.reye_pose
1125
+ expression = expression if expression is not None else self.expression
1126
+
1127
+ apply_trans = transl is not None or hasattr(self, 'transl')
1128
+ if transl is None:
1129
+ if hasattr(self, 'transl'):
1130
+ transl = self.transl
1131
+
1132
+ if self.use_pca:
1133
+ left_hand_pose = torch.einsum(
1134
+ 'bi,ij->bj', [left_hand_pose, self.left_hand_components])
1135
+ right_hand_pose = torch.einsum(
1136
+ 'bi,ij->bj', [right_hand_pose, self.right_hand_components])
1137
+
1138
+ full_pose = torch.cat([global_orient, body_pose,
1139
+ jaw_pose, leye_pose, reye_pose,
1140
+ left_hand_pose,
1141
+ right_hand_pose], dim=1)
1142
+
1143
+ # Add the mean pose of the model. Does not affect the body, only the
1144
+ # hands when flat_hand_mean == False
1145
+ full_pose += self.pose_mean
1146
+
1147
+ batch_size = max(betas.shape[0], global_orient.shape[0],
1148
+ body_pose.shape[0])
1149
+ # Concatenate the shape and expression coefficients
1150
+ scale = int(batch_size / betas.shape[0])
1151
+ if scale > 1:
1152
+ betas = betas.expand(scale, -1)
1153
+ shape_components = torch.cat([betas, expression], dim=-1)
1154
+
1155
+ shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
1156
+
1157
+ vertices, joints = lbs(shape_components, full_pose, self.v_template,
1158
+ shapedirs, self.posedirs,
1159
+ self.J_regressor, self.parents,
1160
+ self.lbs_weights, pose2rot=pose2rot,
1161
+ )
1162
+
1163
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
1164
+ dim=0).expand(batch_size, -1).contiguous()
1165
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
1166
+ self.batch_size, 1, 1)
1167
+ if self.use_face_contour:
1168
+ lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
1169
+ vertices, full_pose, self.dynamic_lmk_faces_idx,
1170
+ self.dynamic_lmk_bary_coords,
1171
+ self.neck_kin_chain,
1172
+ pose2rot=True,
1173
+ )
1174
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
1175
+
1176
+ lmk_faces_idx = torch.cat([lmk_faces_idx,
1177
+ dyn_lmk_faces_idx], 1)
1178
+ lmk_bary_coords = torch.cat(
1179
+ [lmk_bary_coords.expand(batch_size, -1, -1),
1180
+ dyn_lmk_bary_coords], 1)
1181
+
1182
+ landmarks = vertices2landmarks(vertices, self.faces_tensor,
1183
+ lmk_faces_idx,
1184
+ lmk_bary_coords)
1185
+
1186
+ # Add any extra joints that might be needed
1187
+ joints = self.vertex_joint_selector(vertices, joints)
1188
+ # Add the landmarks to the joints
1189
+ joints = torch.cat([joints, landmarks], dim=1)
1190
+ # Map the joints to the current dataset
1191
+
1192
+ if self.joint_mapper is not None:
1193
+ joints = self.joint_mapper(joints=joints, vertices=vertices)
1194
+
1195
+ if apply_trans:
1196
+ joints += transl.unsqueeze(dim=1)
1197
+ vertices += transl.unsqueeze(dim=1)
1198
+
1199
+ output = SMPLXOutput(vertices=vertices if return_verts else None,
1200
+ joints=joints,
1201
+ betas=betas,
1202
+ expression=expression,
1203
+ global_orient=global_orient,
1204
+ body_pose=body_pose,
1205
+ left_hand_pose=left_hand_pose,
1206
+ right_hand_pose=right_hand_pose,
1207
+ jaw_pose=jaw_pose,
1208
+ full_pose=full_pose if return_full_pose else None)
1209
+ return output
1210
+
1211
+
1212
+ class SMPLXLayer(SMPLX):
1213
+ def __init__(
1214
+ self,
1215
+ *args,
1216
+ **kwargs
1217
+ ) -> None:
1218
+ # Just create a SMPLX module without any member variables
1219
+ super(SMPLXLayer, self).__init__(
1220
+ create_global_orient=False,
1221
+ create_body_pose=False,
1222
+ create_left_hand_pose=False,
1223
+ create_right_hand_pose=False,
1224
+ create_jaw_pose=False,
1225
+ create_leye_pose=False,
1226
+ create_reye_pose=False,
1227
+ create_betas=False,
1228
+ create_expression=False,
1229
+ create_transl=False,
1230
+ *args, **kwargs,
1231
+ )
1232
+
1233
+ def forward(
1234
+ self,
1235
+ betas: Optional[Tensor] = None,
1236
+ global_orient: Optional[Tensor] = None,
1237
+ body_pose: Optional[Tensor] = None,
1238
+ left_hand_pose: Optional[Tensor] = None,
1239
+ right_hand_pose: Optional[Tensor] = None,
1240
+ transl: Optional[Tensor] = None,
1241
+ expression: Optional[Tensor] = None,
1242
+ jaw_pose: Optional[Tensor] = None,
1243
+ leye_pose: Optional[Tensor] = None,
1244
+ reye_pose: Optional[Tensor] = None,
1245
+ return_verts: bool = True,
1246
+ return_full_pose: bool = False,
1247
+ **kwargs
1248
+ ) -> SMPLXOutput:
1249
+ '''
1250
+ Forward pass for the SMPLX model
1251
+
1252
+ Parameters
1253
+ ----------
1254
+ global_orient: torch.tensor, optional, shape Bx3
1255
+ If given, ignore the member variable and use it as the global
1256
+ rotation of the body. Useful if someone wishes to predicts this
1257
+ with an external model. (default=None)
1258
+ betas: torch.tensor, optional, shape Bx10
1259
+ If given, ignore the member variable `betas` and use it
1260
+ instead. For example, it can used if shape parameters
1261
+ `betas` are predicted from some external model.
1262
+ (default=None)
1263
+ expression: torch.tensor, optional, shape Bx10
1264
+ If given, ignore the member variable `expression` and use it
1265
+ instead. For example, it can used if expression parameters
1266
+ `expression` are predicted from some external model.
1267
+ body_pose: torch.tensor, optional, shape Bx(J*3)
1268
+ If given, ignore the member variable `body_pose` and use it
1269
+ instead. For example, it can used if someone predicts the
1270
+ pose of the body joints are predicted from some external model.
1271
+ It should be a tensor that contains joint rotations in
1272
+ axis-angle format. (default=None)
1273
+ left_hand_pose: torch.tensor, optional, shape BxP
1274
+ If given, ignore the member variable `left_hand_pose` and
1275
+ use this instead. It should either contain PCA coefficients or
1276
+ joint rotations in axis-angle format.
1277
+ right_hand_pose: torch.tensor, optional, shape BxP
1278
+ If given, ignore the member variable `right_hand_pose` and
1279
+ use this instead. It should either contain PCA coefficients or
1280
+ joint rotations in axis-angle format.
1281
+ jaw_pose: torch.tensor, optional, shape Bx3x3
1282
+ If given, ignore the member variable `jaw_pose` and
1283
+ use this instead. It should either joint rotations in
1284
+ axis-angle format.
1285
+ transl: torch.tensor, optional, shape Bx3
1286
+ If given, ignore the member variable `transl` and use it
1287
+ instead. For example, it can used if the translation
1288
+ `transl` is predicted from some external model.
1289
+ (default=None)
1290
+ return_verts: bool, optional
1291
+ Return the vertices. (default=True)
1292
+ return_full_pose: bool, optional
1293
+ Returns the full pose vector (default=False)
1294
+ Returns
1295
+ -------
1296
+ output: ModelOutput
1297
+ A data class that contains the posed vertices and joints
1298
+ '''
1299
+ device, dtype = self.shapedirs.device, self.shapedirs.dtype
1300
+
1301
+ if global_orient is None:
1302
+ batch_size = 1
1303
+ global_orient = torch.zeros(3, device=device, dtype=dtype).view(
1304
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
1305
+ else:
1306
+ batch_size = global_orient.shape[0]
1307
+ if body_pose is None:
1308
+ body_pose = torch.zeros(3, device=device, dtype=dtype).view(
1309
+ 1, 1, 3).expand(
1310
+ batch_size, self.NUM_BODY_JOINTS, -1).contiguous()
1311
+ if left_hand_pose is None:
1312
+ left_hand_pose = torch.zeros(3, device=device, dtype=dtype).view(
1313
+ 1, 1, 3).expand(batch_size, 15, -1).contiguous()
1314
+ if right_hand_pose is None:
1315
+ right_hand_pose = torch.zeros(3, device=device, dtype=dtype).view(
1316
+ 1, 1, 3).expand(batch_size, 15, -1).contiguous()
1317
+ if jaw_pose is None:
1318
+ jaw_pose = torch.zeros(3, device=device, dtype=dtype).view(
1319
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
1320
+ if leye_pose is None:
1321
+ leye_pose = torch.zeros(3, device=device, dtype=dtype).view(
1322
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
1323
+ if reye_pose is None:
1324
+ reye_pose = torch.zeros(3, device=device, dtype=dtype).view(
1325
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
1326
+ if expression is None:
1327
+ expression = torch.zeros([batch_size, self.num_expression_coeffs],
1328
+ dtype=dtype, device=device)
1329
+ if betas is None:
1330
+ betas = torch.zeros([batch_size, self.num_betas],
1331
+ dtype=dtype, device=device)
1332
+ if transl is None:
1333
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
1334
+
1335
+ # Concatenate all pose vectors
1336
+ full_pose = torch.cat(
1337
+ [global_orient.reshape(-1, 1, 3),
1338
+ body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3),
1339
+ jaw_pose.reshape(-1, 1, 3),
1340
+ leye_pose.reshape(-1, 1, 3),
1341
+ reye_pose.reshape(-1, 1, 3),
1342
+ left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3),
1343
+ right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3)],
1344
+ dim=1)
1345
+ shape_components = torch.cat([betas, expression], dim=-1)
1346
+
1347
+ shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
1348
+
1349
+ vertices, joints = lbs(shape_components, full_pose, self.v_template,
1350
+ shapedirs, self.posedirs,
1351
+ self.J_regressor, self.parents,
1352
+ self.lbs_weights, pose2rot=True)
1353
+
1354
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
1355
+ dim=0).expand(batch_size, -1).contiguous()
1356
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
1357
+ self.batch_size, 1, 1)
1358
+ if self.use_face_contour:
1359
+ lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
1360
+ vertices, full_pose,
1361
+ self.dynamic_lmk_faces_idx,
1362
+ self.dynamic_lmk_bary_coords,
1363
+ self.neck_kin_chain,
1364
+ pose2rot=False,
1365
+ )
1366
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
1367
+
1368
+ lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
1369
+ lmk_bary_coords = torch.cat(
1370
+ [lmk_bary_coords.expand(batch_size, -1, -1),
1371
+ dyn_lmk_bary_coords], 1)
1372
+
1373
+ landmarks = vertices2landmarks(vertices, self.faces_tensor,
1374
+ lmk_faces_idx,
1375
+ lmk_bary_coords)
1376
+
1377
+ # Add any extra joints that might be needed
1378
+ joints = self.vertex_joint_selector(vertices, joints)
1379
+ # Add the landmarks to the joints
1380
+ joints = torch.cat([joints, landmarks], dim=1)
1381
+ # Map the joints to the current dataset
1382
+
1383
+ if self.joint_mapper is not None:
1384
+ joints = self.joint_mapper(joints=joints, vertices=vertices)
1385
+
1386
+ if transl is not None:
1387
+ joints += transl.unsqueeze(dim=1)
1388
+ vertices += transl.unsqueeze(dim=1)
1389
+
1390
+ output = SMPLXOutput(vertices=vertices if return_verts else None,
1391
+ joints=joints,
1392
+ betas=betas,
1393
+ expression=expression,
1394
+ global_orient=global_orient,
1395
+ body_pose=body_pose,
1396
+ left_hand_pose=left_hand_pose,
1397
+ right_hand_pose=right_hand_pose,
1398
+ jaw_pose=jaw_pose,
1399
+ transl=transl,
1400
+ full_pose=full_pose if return_full_pose else None)
1401
+ return output
1402
+
1403
+
1404
+ class MANO(SMPL):
1405
+ # The hand joints are replaced by MANO
1406
+ NUM_BODY_JOINTS = 1
1407
+ NUM_HAND_JOINTS = 15
1408
+ NUM_JOINTS = NUM_BODY_JOINTS + NUM_HAND_JOINTS
1409
+
1410
+ def __init__(
1411
+ self,
1412
+ model_path: str,
1413
+ is_rhand: bool = True,
1414
+ data_struct: Optional[Struct] = None,
1415
+ create_hand_pose: bool = True,
1416
+ hand_pose: Optional[Tensor] = None,
1417
+ use_pca: bool = True,
1418
+ num_pca_comps: int = 6,
1419
+ flat_hand_mean: bool = False,
1420
+ batch_size: int = 1,
1421
+ dtype=torch.float32,
1422
+ vertex_ids=None,
1423
+ use_compressed: bool = True,
1424
+ ext: str = 'pkl',
1425
+ **kwargs
1426
+ ) -> None:
1427
+ ''' MANO model constructor
1428
+
1429
+ Parameters
1430
+ ----------
1431
+ model_path: str
1432
+ The path to the folder or to the file where the model
1433
+ parameters are stored
1434
+ data_struct: Strct
1435
+ A struct object. If given, then the parameters of the model are
1436
+ read from the object. Otherwise, the model tries to read the
1437
+ parameters from the given `model_path`. (default = None)
1438
+ create_hand_pose: bool, optional
1439
+ Flag for creating a member variable for the pose of the right
1440
+ hand. (default = True)
1441
+ hand_pose: torch.tensor, optional, BxP
1442
+ The default value for the right hand pose member variable.
1443
+ (default = None)
1444
+ num_pca_comps: int, optional
1445
+ The number of PCA components to use for each hand.
1446
+ (default = 6)
1447
+ flat_hand_mean: bool, optional
1448
+ If False, then the pose of the hand is initialized to False.
1449
+ batch_size: int, optional
1450
+ The batch size used for creating the member variables
1451
+ dtype: torch.dtype, optional
1452
+ The data type for the created variables
1453
+ vertex_ids: dict, optional
1454
+ A dictionary containing the indices of the extra vertices that
1455
+ will be selected
1456
+ '''
1457
+
1458
+ self.num_pca_comps = num_pca_comps
1459
+ self.is_rhand = is_rhand
1460
+ # If no data structure is passed, then load the data from the given
1461
+ # model folder
1462
+ if data_struct is None:
1463
+ # Load the model
1464
+ if osp.isdir(model_path):
1465
+ model_fn = 'MANO_{}.{ext}'.format(
1466
+ 'RIGHT' if is_rhand else 'LEFT', ext=ext)
1467
+ mano_path = os.path.join(model_path, model_fn)
1468
+ else:
1469
+ mano_path = model_path
1470
+ self.is_rhand = True if 'RIGHT' in os.path.basename(
1471
+ model_path) else False
1472
+ assert osp.exists(mano_path), 'Path {} does not exist!'.format(
1473
+ mano_path)
1474
+
1475
+ if ext == 'pkl':
1476
+ with open(mano_path, 'rb') as mano_file:
1477
+ model_data = pickle.load(mano_file, encoding='latin1')
1478
+ elif ext == 'npz':
1479
+ model_data = np.load(mano_path, allow_pickle=True)
1480
+ else:
1481
+ raise ValueError('Unknown extension: {}'.format(ext))
1482
+ data_struct = Struct(**model_data)
1483
+
1484
+ if vertex_ids is None:
1485
+ vertex_ids = VERTEX_IDS['smplh']
1486
+
1487
+ super(MANO, self).__init__(
1488
+ model_path=model_path, data_struct=data_struct,
1489
+ batch_size=batch_size, vertex_ids=vertex_ids,
1490
+ use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs)
1491
+
1492
+ # add only MANO tips to the extra joints
1493
+ self.vertex_joint_selector.extra_joints_idxs = to_tensor(
1494
+ list(VERTEX_IDS['mano'].values()), dtype=torch.long)
1495
+
1496
+ self.use_pca = use_pca
1497
+ self.num_pca_comps = num_pca_comps
1498
+ if self.num_pca_comps == 45:
1499
+ self.use_pca = False
1500
+ self.flat_hand_mean = flat_hand_mean
1501
+
1502
+ hand_components = data_struct.hands_components[:num_pca_comps]
1503
+
1504
+ self.np_hand_components = hand_components
1505
+
1506
+ if self.use_pca:
1507
+ self.register_buffer(
1508
+ 'hand_components',
1509
+ torch.tensor(hand_components, dtype=dtype))
1510
+
1511
+ if self.flat_hand_mean:
1512
+ hand_mean = np.zeros_like(data_struct.hands_mean)
1513
+ else:
1514
+ hand_mean = data_struct.hands_mean
1515
+
1516
+ self.register_buffer('hand_mean',
1517
+ to_tensor(hand_mean, dtype=self.dtype))
1518
+
1519
+ # Create the buffers for the pose of the left hand
1520
+ hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS
1521
+ if create_hand_pose:
1522
+ if hand_pose is None:
1523
+ default_hand_pose = torch.zeros([batch_size, hand_pose_dim],
1524
+ dtype=dtype)
1525
+ else:
1526
+ default_hand_pose = torch.tensor(hand_pose, dtype=dtype)
1527
+
1528
+ hand_pose_param = nn.Parameter(default_hand_pose,
1529
+ requires_grad=True)
1530
+ self.register_parameter('hand_pose',
1531
+ hand_pose_param)
1532
+
1533
+ # Create the buffer for the mean pose.
1534
+ pose_mean = self.create_mean_pose(
1535
+ data_struct, flat_hand_mean=flat_hand_mean)
1536
+ pose_mean_tensor = pose_mean.clone().to(dtype)
1537
+ # pose_mean_tensor = torch.tensor(pose_mean, dtype=dtype)
1538
+ self.register_buffer('pose_mean', pose_mean_tensor)
1539
+
1540
+ def name(self) -> str:
1541
+ return 'MANO'
1542
+
1543
+ def create_mean_pose(self, data_struct, flat_hand_mean=False):
1544
+ # Create the array for the mean pose. If flat_hand is false, then use
1545
+ # the mean that is given by the data, rather than the flat open hand
1546
+ global_orient_mean = torch.zeros([3], dtype=self.dtype)
1547
+ pose_mean = torch.cat([global_orient_mean, self.hand_mean], dim=0)
1548
+ return pose_mean
1549
+
1550
+ def extra_repr(self):
1551
+ msg = [super(MANO, self).extra_repr()]
1552
+ if self.use_pca:
1553
+ msg.append(f'Number of PCA components: {self.num_pca_comps}')
1554
+ msg.append(f'Flat hand mean: {self.flat_hand_mean}')
1555
+ return '\n'.join(msg)
1556
+
1557
+ def forward(
1558
+ self,
1559
+ betas: Optional[Tensor] = None,
1560
+ global_orient: Optional[Tensor] = None,
1561
+ hand_pose: Optional[Tensor] = None,
1562
+ transl: Optional[Tensor] = None,
1563
+ return_verts: bool = True,
1564
+ return_full_pose: bool = False,
1565
+ **kwargs
1566
+ ) -> MANOOutput:
1567
+ ''' Forward pass for the MANO model
1568
+ '''
1569
+ # If no shape and pose parameters are passed along, then use the
1570
+ # ones from the module
1571
+ global_orient = (global_orient if global_orient is not None else
1572
+ self.global_orient)
1573
+ betas = betas if betas is not None else self.betas
1574
+ hand_pose = (hand_pose if hand_pose is not None else
1575
+ self.hand_pose)
1576
+
1577
+ apply_trans = transl is not None or hasattr(self, 'transl')
1578
+ if transl is None:
1579
+ if hasattr(self, 'transl'):
1580
+ transl = self.transl
1581
+
1582
+ if self.use_pca:
1583
+ hand_pose = torch.einsum(
1584
+ 'bi,ij->bj', [hand_pose, self.hand_components])
1585
+
1586
+ full_pose = torch.cat([global_orient, hand_pose], dim=1)
1587
+ full_pose += self.pose_mean
1588
+
1589
+ vertices, joints = lbs(betas, full_pose, self.v_template,
1590
+ self.shapedirs, self.posedirs,
1591
+ self.J_regressor, self.parents,
1592
+ self.lbs_weights, pose2rot=True,
1593
+ )
1594
+
1595
+ # # Add pre-selected extra joints that might be needed
1596
+ # joints = self.vertex_joint_selector(vertices, joints)
1597
+
1598
+ if self.joint_mapper is not None:
1599
+ joints = self.joint_mapper(joints)
1600
+
1601
+ if apply_trans:
1602
+ joints = joints + transl.unsqueeze(dim=1)
1603
+ vertices = vertices + transl.unsqueeze(dim=1)
1604
+
1605
+ output = MANOOutput(vertices=vertices if return_verts else None,
1606
+ joints=joints if return_verts else None,
1607
+ betas=betas,
1608
+ global_orient=global_orient,
1609
+ hand_pose=hand_pose,
1610
+ full_pose=full_pose if return_full_pose else None)
1611
+
1612
+ return output
1613
+
1614
+
1615
+ class MANOLayer(MANO):
1616
+ def __init__(self, *args, **kwargs) -> None:
1617
+ ''' MANO as a layer model constructor
1618
+ '''
1619
+ super(MANOLayer, self).__init__(
1620
+ create_global_orient=False,
1621
+ create_hand_pose=False,
1622
+ create_betas=False,
1623
+ create_transl=False,
1624
+ *args, **kwargs)
1625
+
1626
+ def name(self) -> str:
1627
+ return 'MANO'
1628
+
1629
+ def forward(
1630
+ self,
1631
+ betas: Optional[Tensor] = None,
1632
+ global_orient: Optional[Tensor] = None,
1633
+ hand_pose: Optional[Tensor] = None,
1634
+ transl: Optional[Tensor] = None,
1635
+ return_verts: bool = True,
1636
+ return_full_pose: bool = False,
1637
+ **kwargs
1638
+ ) -> MANOOutput:
1639
+ ''' Forward pass for the MANO model
1640
+ '''
1641
+ device, dtype = self.shapedirs.device, self.shapedirs.dtype
1642
+ if global_orient is None:
1643
+ batch_size = 1
1644
+ global_orient = torch.zeros(3, device=device, dtype=dtype).view(
1645
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
1646
+ else:
1647
+ batch_size = global_orient.shape[0]
1648
+ if hand_pose is None:
1649
+ hand_pose = torch.zeros(3, device=device, dtype=dtype).view(
1650
+ 1, 1, 3).expand(batch_size, 15, -1).contiguous()
1651
+ if betas is None:
1652
+ betas = torch.zeros(
1653
+ [batch_size, self.num_betas], dtype=dtype, device=device)
1654
+ if transl is None:
1655
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
1656
+
1657
+ full_pose = torch.cat([global_orient, hand_pose], dim=1)
1658
+ vertices, joints = lbs(betas, full_pose, self.v_template,
1659
+ self.shapedirs, self.posedirs,
1660
+ self.J_regressor, self.parents,
1661
+ self.lbs_weights, pose2rot=True)
1662
+
1663
+ if self.joint_mapper is not None:
1664
+ joints = self.joint_mapper(joints)
1665
+
1666
+ if transl is not None:
1667
+ joints = joints + transl.unsqueeze(dim=1)
1668
+ vertices = vertices + transl.unsqueeze(dim=1)
1669
+
1670
+ output = MANOOutput(
1671
+ vertices=vertices if return_verts else None,
1672
+ joints=joints if return_verts else None,
1673
+ betas=betas,
1674
+ global_orient=global_orient,
1675
+ hand_pose=hand_pose,
1676
+ full_pose=full_pose if return_full_pose else None)
1677
+
1678
+ return output
1679
+
1680
+
1681
+ class FLAME(SMPL):
1682
+ NUM_JOINTS = 5
1683
+ SHAPE_SPACE_DIM = 300
1684
+ EXPRESSION_SPACE_DIM = 100
1685
+ NECK_IDX = 0
1686
+
1687
+ def __init__(
1688
+ self,
1689
+ model_path: str,
1690
+ data_struct=None,
1691
+ num_expression_coeffs=10,
1692
+ create_expression: bool = True,
1693
+ expression: Optional[Tensor] = None,
1694
+ create_neck_pose: bool = True,
1695
+ neck_pose: Optional[Tensor] = None,
1696
+ create_jaw_pose: bool = True,
1697
+ jaw_pose: Optional[Tensor] = None,
1698
+ create_leye_pose: bool = True,
1699
+ leye_pose: Optional[Tensor] = None,
1700
+ create_reye_pose=True,
1701
+ reye_pose: Optional[Tensor] = None,
1702
+ use_face_contour=False,
1703
+ batch_size: int = 1,
1704
+ gender: str = 'neutral',
1705
+ dtype: torch.dtype = torch.float32,
1706
+ ext='pkl',
1707
+ **kwargs
1708
+ ) -> None:
1709
+ ''' FLAME model constructor
1710
+
1711
+ Parameters
1712
+ ----------
1713
+ model_path: str
1714
+ The path to the folder or to the file where the model
1715
+ parameters are stored
1716
+ num_expression_coeffs: int, optional
1717
+ Number of expression components to use
1718
+ (default = 10).
1719
+ create_expression: bool, optional
1720
+ Flag for creating a member variable for the expression space
1721
+ (default = True).
1722
+ expression: torch.tensor, optional, Bx10
1723
+ The default value for the expression member variable.
1724
+ (default = None)
1725
+ create_neck_pose: bool, optional
1726
+ Flag for creating a member variable for the neck pose.
1727
+ (default = False)
1728
+ neck_pose: torch.tensor, optional, Bx3
1729
+ The default value for the neck pose variable.
1730
+ (default = None)
1731
+ create_jaw_pose: bool, optional
1732
+ Flag for creating a member variable for the jaw pose.
1733
+ (default = False)
1734
+ jaw_pose: torch.tensor, optional, Bx3
1735
+ The default value for the jaw pose variable.
1736
+ (default = None)
1737
+ create_leye_pose: bool, optional
1738
+ Flag for creating a member variable for the left eye pose.
1739
+ (default = False)
1740
+ leye_pose: torch.tensor, optional, Bx10
1741
+ The default value for the left eye pose variable.
1742
+ (default = None)
1743
+ create_reye_pose: bool, optional
1744
+ Flag for creating a member variable for the right eye pose.
1745
+ (default = False)
1746
+ reye_pose: torch.tensor, optional, Bx10
1747
+ The default value for the right eye pose variable.
1748
+ (default = None)
1749
+ use_face_contour: bool, optional
1750
+ Whether to compute the keypoints that form the facial contour
1751
+ batch_size: int, optional
1752
+ The batch size used for creating the member variables
1753
+ gender: str, optional
1754
+ Which gender to load
1755
+ dtype: torch.dtype
1756
+ The data type for the created variables
1757
+ '''
1758
+ model_fn = f'FLAME_{gender.upper()}.{ext}'
1759
+ flame_path = os.path.join(model_path, model_fn)
1760
+ assert osp.exists(flame_path), 'Path {} does not exist!'.format(
1761
+ flame_path)
1762
+ if ext == 'npz':
1763
+ file_data = np.load(flame_path, allow_pickle=True)
1764
+ elif ext == 'pkl':
1765
+ with open(flame_path, 'rb') as smpl_file:
1766
+ file_data = pickle.load(smpl_file, encoding='latin1')
1767
+ else:
1768
+ raise ValueError('Unknown extension: {}'.format(ext))
1769
+ data_struct = Struct(**file_data)
1770
+
1771
+ super(FLAME, self).__init__(
1772
+ model_path=model_path,
1773
+ data_struct=data_struct,
1774
+ dtype=dtype,
1775
+ batch_size=batch_size,
1776
+ gender=gender,
1777
+ ext=ext,
1778
+ **kwargs)
1779
+
1780
+ self.use_face_contour = use_face_contour
1781
+
1782
+ self.vertex_joint_selector.extra_joints_idxs = to_tensor(
1783
+ [], dtype=torch.long)
1784
+
1785
+ if create_neck_pose:
1786
+ if neck_pose is None:
1787
+ default_neck_pose = torch.zeros([batch_size, 3], dtype=dtype)
1788
+ else:
1789
+ default_neck_pose = torch.tensor(neck_pose, dtype=dtype)
1790
+ neck_pose_param = nn.Parameter(
1791
+ default_neck_pose, requires_grad=True)
1792
+ self.register_parameter('neck_pose', neck_pose_param)
1793
+
1794
+ if create_jaw_pose:
1795
+ if jaw_pose is None:
1796
+ default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)
1797
+ else:
1798
+ default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)
1799
+ jaw_pose_param = nn.Parameter(default_jaw_pose,
1800
+ requires_grad=True)
1801
+ self.register_parameter('jaw_pose', jaw_pose_param)
1802
+
1803
+ if create_leye_pose:
1804
+ if leye_pose is None:
1805
+ default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)
1806
+ else:
1807
+ default_leye_pose = torch.tensor(leye_pose, dtype=dtype)
1808
+ leye_pose_param = nn.Parameter(default_leye_pose,
1809
+ requires_grad=True)
1810
+ self.register_parameter('leye_pose', leye_pose_param)
1811
+
1812
+ if create_reye_pose:
1813
+ if reye_pose is None:
1814
+ default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)
1815
+ else:
1816
+ default_reye_pose = torch.tensor(reye_pose, dtype=dtype)
1817
+ reye_pose_param = nn.Parameter(default_reye_pose,
1818
+ requires_grad=True)
1819
+ self.register_parameter('reye_pose', reye_pose_param)
1820
+
1821
+ shapedirs = data_struct.shapedirs
1822
+ if len(shapedirs.shape) < 3:
1823
+ shapedirs = shapedirs[:, :, None]
1824
+ if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM +
1825
+ self.EXPRESSION_SPACE_DIM):
1826
+ print(f'WARNING: You are using a {self.name()} model, with only'
1827
+ ' 10 shape and 10 expression coefficients.')
1828
+ expr_start_idx = 10
1829
+ expr_end_idx = 20
1830
+ num_expression_coeffs = min(num_expression_coeffs, 10)
1831
+ else:
1832
+ expr_start_idx = self.SHAPE_SPACE_DIM
1833
+ expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs
1834
+ num_expression_coeffs = min(
1835
+ num_expression_coeffs, self.EXPRESSION_SPACE_DIM)
1836
+
1837
+ self._num_expression_coeffs = num_expression_coeffs
1838
+
1839
+ expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]
1840
+ self.register_buffer(
1841
+ 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype))
1842
+
1843
+ if create_expression:
1844
+ if expression is None:
1845
+ default_expression = torch.zeros(
1846
+ [batch_size, self.num_expression_coeffs], dtype=dtype)
1847
+ else:
1848
+ default_expression = torch.tensor(expression, dtype=dtype)
1849
+ expression_param = nn.Parameter(default_expression,
1850
+ requires_grad=True)
1851
+ self.register_parameter('expression', expression_param)
1852
+
1853
+ # The pickle file that contains the barycentric coordinates for
1854
+ # regressing the landmarks
1855
+ landmark_bcoord_filename = osp.join(
1856
+ model_path, 'flame_static_embedding.pkl')
1857
+
1858
+ with open(landmark_bcoord_filename, 'rb') as fp:
1859
+ landmarks_data = pickle.load(fp, encoding='latin1')
1860
+
1861
+ lmk_faces_idx = landmarks_data['lmk_face_idx'].astype(np.int64)
1862
+ self.register_buffer('lmk_faces_idx',
1863
+ torch.tensor(lmk_faces_idx, dtype=torch.long))
1864
+ lmk_bary_coords = landmarks_data['lmk_b_coords']
1865
+ self.register_buffer('lmk_bary_coords',
1866
+ torch.tensor(lmk_bary_coords, dtype=dtype))
1867
+ if self.use_face_contour:
1868
+ face_contour_path = os.path.join(
1869
+ model_path, 'flame_dynamic_embedding.npy')
1870
+ contour_embeddings = np.load(face_contour_path,
1871
+ allow_pickle=True,
1872
+ encoding='latin1')[()]
1873
+
1874
+ dynamic_lmk_faces_idx = np.array(
1875
+ contour_embeddings['lmk_face_idx'], dtype=np.int64)
1876
+ dynamic_lmk_faces_idx = torch.tensor(
1877
+ dynamic_lmk_faces_idx,
1878
+ dtype=torch.long)
1879
+ self.register_buffer('dynamic_lmk_faces_idx',
1880
+ dynamic_lmk_faces_idx)
1881
+
1882
+ dynamic_lmk_b_coords = torch.tensor(
1883
+ contour_embeddings['lmk_b_coords'], dtype=dtype)
1884
+ self.register_buffer(
1885
+ 'dynamic_lmk_bary_coords', dynamic_lmk_b_coords)
1886
+
1887
+ neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)
1888
+ self.register_buffer(
1889
+ 'neck_kin_chain',
1890
+ torch.tensor(neck_kin_chain, dtype=torch.long))
1891
+
1892
+ @property
1893
+ def num_expression_coeffs(self):
1894
+ return self._num_expression_coeffs
1895
+
1896
+ def name(self) -> str:
1897
+ return 'FLAME'
1898
+
1899
+ def extra_repr(self):
1900
+ msg = [
1901
+ super(FLAME, self).extra_repr(),
1902
+ f'Number of Expression Coefficients: {self.num_expression_coeffs}',
1903
+ f'Use face contour: {self.use_face_contour}',
1904
+ ]
1905
+ return '\n'.join(msg)
1906
+
1907
+ def forward(
1908
+ self,
1909
+ betas: Optional[Tensor] = None,
1910
+ global_orient: Optional[Tensor] = None,
1911
+ neck_pose: Optional[Tensor] = None,
1912
+ transl: Optional[Tensor] = None,
1913
+ expression: Optional[Tensor] = None,
1914
+ jaw_pose: Optional[Tensor] = None,
1915
+ leye_pose: Optional[Tensor] = None,
1916
+ reye_pose: Optional[Tensor] = None,
1917
+ return_verts: bool = True,
1918
+ return_full_pose: bool = False,
1919
+ pose2rot: bool = True,
1920
+ **kwargs
1921
+ ) -> FLAMEOutput:
1922
+ '''
1923
+ Forward pass for the SMPLX model
1924
+
1925
+ Parameters
1926
+ ----------
1927
+ global_orient: torch.tensor, optional, shape Bx3
1928
+ If given, ignore the member variable and use it as the global
1929
+ rotation of the body. Useful if someone wishes to predicts this
1930
+ with an external model. (default=None)
1931
+ betas: torch.tensor, optional, shape Bx10
1932
+ If given, ignore the member variable `betas` and use it
1933
+ instead. For example, it can used if shape parameters
1934
+ `betas` are predicted from some external model.
1935
+ (default=None)
1936
+ expression: torch.tensor, optional, shape Bx10
1937
+ If given, ignore the member variable `expression` and use it
1938
+ instead. For example, it can used if expression parameters
1939
+ `expression` are predicted from some external model.
1940
+ jaw_pose: torch.tensor, optional, shape Bx3
1941
+ If given, ignore the member variable `jaw_pose` and
1942
+ use this instead. It should either joint rotations in
1943
+ axis-angle format.
1944
+ jaw_pose: torch.tensor, optional, shape Bx3
1945
+ If given, ignore the member variable `jaw_pose` and
1946
+ use this instead. It should either joint rotations in
1947
+ axis-angle format.
1948
+ transl: torch.tensor, optional, shape Bx3
1949
+ If given, ignore the member variable `transl` and use it
1950
+ instead. For example, it can used if the translation
1951
+ `transl` is predicted from some external model.
1952
+ (default=None)
1953
+ return_verts: bool, optional
1954
+ Return the vertices. (default=True)
1955
+ return_full_pose: bool, optional
1956
+ Returns the full axis-angle pose vector (default=False)
1957
+
1958
+ Returns
1959
+ -------
1960
+ output: ModelOutput
1961
+ A named tuple of type `ModelOutput`
1962
+ '''
1963
+
1964
+ # If no shape and pose parameters are passed along, then use the
1965
+ # ones from the module
1966
+ global_orient = (global_orient if global_orient is not None else
1967
+ self.global_orient)
1968
+ jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
1969
+ neck_pose = neck_pose if neck_pose is not None else self.neck_pose
1970
+
1971
+ leye_pose = leye_pose if leye_pose is not None else self.leye_pose
1972
+ reye_pose = reye_pose if reye_pose is not None else self.reye_pose
1973
+
1974
+ betas = betas if betas is not None else self.betas
1975
+ expression = expression if expression is not None else self.expression
1976
+
1977
+ apply_trans = transl is not None or hasattr(self, 'transl')
1978
+ if transl is None:
1979
+ if hasattr(self, 'transl'):
1980
+ transl = self.transl
1981
+
1982
+ full_pose = torch.cat(
1983
+ [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)
1984
+
1985
+ batch_size = max(betas.shape[0], global_orient.shape[0],
1986
+ jaw_pose.shape[0])
1987
+ # Concatenate the shape and expression coefficients
1988
+ scale = int(batch_size / betas.shape[0])
1989
+ if scale > 1:
1990
+ betas = betas.expand(scale, -1)
1991
+ shape_components = torch.cat([betas, expression], dim=-1)
1992
+ shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
1993
+
1994
+ vertices, joints = lbs(shape_components, full_pose, self.v_template,
1995
+ shapedirs, self.posedirs,
1996
+ self.J_regressor, self.parents,
1997
+ self.lbs_weights, pose2rot=pose2rot,
1998
+ )
1999
+
2000
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
2001
+ dim=0).expand(batch_size, -1).contiguous()
2002
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
2003
+ self.batch_size, 1, 1)
2004
+ if self.use_face_contour:
2005
+ lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
2006
+ vertices, full_pose, self.dynamic_lmk_faces_idx,
2007
+ self.dynamic_lmk_bary_coords,
2008
+ self.neck_kin_chain,
2009
+ pose2rot=True,
2010
+ )
2011
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
2012
+ lmk_faces_idx = torch.cat([lmk_faces_idx,
2013
+ dyn_lmk_faces_idx], 1)
2014
+ lmk_bary_coords = torch.cat(
2015
+ [lmk_bary_coords.expand(batch_size, -1, -1),
2016
+ dyn_lmk_bary_coords], 1)
2017
+
2018
+ landmarks = vertices2landmarks(vertices, self.faces_tensor,
2019
+ lmk_faces_idx,
2020
+ lmk_bary_coords)
2021
+
2022
+ # Add any extra joints that might be needed
2023
+ joints = self.vertex_joint_selector(vertices, joints)
2024
+ # Add the landmarks to the joints
2025
+ joints = torch.cat([joints, landmarks], dim=1)
2026
+
2027
+ # Map the joints to the current dataset
2028
+ if self.joint_mapper is not None:
2029
+ joints = self.joint_mapper(joints=joints, vertices=vertices)
2030
+
2031
+ if apply_trans:
2032
+ joints += transl.unsqueeze(dim=1)
2033
+ vertices += transl.unsqueeze(dim=1)
2034
+
2035
+ output = FLAMEOutput(vertices=vertices if return_verts else None,
2036
+ joints=joints,
2037
+ betas=betas,
2038
+ expression=expression,
2039
+ global_orient=global_orient,
2040
+ neck_pose=neck_pose,
2041
+ jaw_pose=jaw_pose,
2042
+ full_pose=full_pose if return_full_pose else None)
2043
+ return output
2044
+
2045
+
2046
+ class FLAMELayer(FLAME):
2047
+ def __init__(self, *args, **kwargs) -> None:
2048
+ ''' FLAME as a layer model constructor '''
2049
+ super(FLAMELayer, self).__init__(
2050
+ create_betas=False,
2051
+ create_expression=False,
2052
+ create_global_orient=False,
2053
+ create_neck_pose=False,
2054
+ create_jaw_pose=False,
2055
+ create_leye_pose=False,
2056
+ create_reye_pose=False,
2057
+ *args,
2058
+ **kwargs)
2059
+
2060
+ def forward(
2061
+ self,
2062
+ betas: Optional[Tensor] = None,
2063
+ global_orient: Optional[Tensor] = None,
2064
+ neck_pose: Optional[Tensor] = None,
2065
+ transl: Optional[Tensor] = None,
2066
+ expression: Optional[Tensor] = None,
2067
+ jaw_pose: Optional[Tensor] = None,
2068
+ leye_pose: Optional[Tensor] = None,
2069
+ reye_pose: Optional[Tensor] = None,
2070
+ return_verts: bool = True,
2071
+ return_full_pose: bool = False,
2072
+ pose2rot: bool = True,
2073
+ **kwargs
2074
+ ) -> FLAMEOutput:
2075
+ '''
2076
+ Forward pass for the SMPLX model
2077
+
2078
+ Parameters
2079
+ ----------
2080
+ global_orient: torch.tensor, optional, shape Bx3
2081
+ If given, ignore the member variable and use it as the global
2082
+ rotation of the body. Useful if someone wishes to predicts this
2083
+ with an external model. (default=None)
2084
+ betas: torch.tensor, optional, shape Bx10
2085
+ If given, ignore the member variable `betas` and use it
2086
+ instead. For example, it can used if shape parameters
2087
+ `betas` are predicted from some external model.
2088
+ (default=None)
2089
+ expression: torch.tensor, optional, shape Bx10
2090
+ If given, ignore the member variable `expression` and use it
2091
+ instead. For example, it can used if expression parameters
2092
+ `expression` are predicted from some external model.
2093
+ jaw_pose: torch.tensor, optional, shape Bx3
2094
+ If given, ignore the member variable `jaw_pose` and
2095
+ use this instead. It should either joint rotations in
2096
+ axis-angle format.
2097
+ jaw_pose: torch.tensor, optional, shape Bx3
2098
+ If given, ignore the member variable `jaw_pose` and
2099
+ use this instead. It should either joint rotations in
2100
+ axis-angle format.
2101
+ transl: torch.tensor, optional, shape Bx3
2102
+ If given, ignore the member variable `transl` and use it
2103
+ instead. For example, it can used if the translation
2104
+ `transl` is predicted from some external model.
2105
+ (default=None)
2106
+ return_verts: bool, optional
2107
+ Return the vertices. (default=True)
2108
+ return_full_pose: bool, optional
2109
+ Returns the full axis-angle pose vector (default=False)
2110
+
2111
+ Returns
2112
+ -------
2113
+ output: ModelOutput
2114
+ A named tuple of type `ModelOutput`
2115
+ '''
2116
+ device, dtype = self.shapedirs.device, self.shapedirs.dtype
2117
+ if global_orient is None:
2118
+ batch_size = 1
2119
+ global_orient = torch.zeros(3, device=device, dtype=dtype).view(
2120
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
2121
+ else:
2122
+ batch_size = global_orient.shape[0]
2123
+ if neck_pose is None:
2124
+ neck_pose = torch.zeros(3, device=device, dtype=dtype).view(
2125
+ 1, 1, 3).expand(batch_size, 1, -1).contiguous()
2126
+ if jaw_pose is None:
2127
+ jaw_pose = torch.zeros(3, device=device, dtype=dtype).view(
2128
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
2129
+ if leye_pose is None:
2130
+ leye_pose = torch.zeros(3, device=device, dtype=dtype).view(
2131
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
2132
+ if reye_pose is None:
2133
+ reye_pose = torch.zeros(3, device=device, dtype=dtype).view(
2134
+ 1, 1, 3).expand(batch_size, -1, -1).contiguous()
2135
+ if betas is None:
2136
+ betas = torch.zeros([batch_size, self.num_betas],
2137
+ dtype=dtype, device=device)
2138
+ if expression is None:
2139
+ expression = torch.zeros([batch_size, self.num_expression_coeffs],
2140
+ dtype=dtype, device=device)
2141
+ if transl is None:
2142
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
2143
+
2144
+ full_pose = torch.cat(
2145
+ [global_orient, neck_pose, jaw_pose, leye_pose, reye_pose], dim=1)
2146
+
2147
+ shape_components = torch.cat([betas, expression], dim=-1)
2148
+ shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
2149
+
2150
+ vertices, joints = lbs(shape_components, full_pose, self.v_template,
2151
+ shapedirs, self.posedirs,
2152
+ self.J_regressor, self.parents,
2153
+ self.lbs_weights, pose2rot=True,
2154
+ )
2155
+
2156
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
2157
+ dim=0).expand(batch_size, -1).contiguous()
2158
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
2159
+ self.batch_size, 1, 1)
2160
+ if self.use_face_contour:
2161
+ lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
2162
+ vertices, full_pose, self.dynamic_lmk_faces_idx,
2163
+ self.dynamic_lmk_bary_coords,
2164
+ self.neck_kin_chain,
2165
+ pose2rot=False,
2166
+ )
2167
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords
2168
+ lmk_faces_idx = torch.cat([lmk_faces_idx,
2169
+ dyn_lmk_faces_idx], 1)
2170
+ lmk_bary_coords = torch.cat(
2171
+ [lmk_bary_coords.expand(batch_size, -1, -1),
2172
+ dyn_lmk_bary_coords], 1)
2173
+
2174
+ landmarks = vertices2landmarks(vertices, self.faces_tensor,
2175
+ lmk_faces_idx,
2176
+ lmk_bary_coords)
2177
+
2178
+ # Add any extra joints that might be needed
2179
+ joints = self.vertex_joint_selector(vertices, joints)
2180
+ # Add the landmarks to the joints
2181
+ joints = torch.cat([joints, landmarks], dim=1)
2182
+
2183
+ # Map the joints to the current dataset
2184
+ if self.joint_mapper is not None:
2185
+ joints = self.joint_mapper(joints=joints, vertices=vertices)
2186
+
2187
+ joints += transl.unsqueeze(dim=1)
2188
+ vertices += transl.unsqueeze(dim=1)
2189
+
2190
+ output = FLAMEOutput(vertices=vertices if return_verts else None,
2191
+ joints=joints,
2192
+ betas=betas,
2193
+ expression=expression,
2194
+ global_orient=global_orient,
2195
+ neck_pose=neck_pose,
2196
+ jaw_pose=jaw_pose,
2197
+ full_pose=full_pose if return_full_pose else None)
2198
+ return output
2199
+
2200
+
2201
+ def build_layer(
2202
+ model_path: str,
2203
+ model_type: str = 'smpl',
2204
+ **kwargs
2205
+ ) -> Union[SMPLLayer, SMPLHLayer, SMPLXLayer, MANOLayer, FLAMELayer]:
2206
+ ''' Method for creating a model from a path and a model type
2207
+
2208
+ Parameters
2209
+ ----------
2210
+ model_path: str
2211
+ Either the path to the model you wish to load or a folder,
2212
+ where each subfolder contains the differents types, i.e.:
2213
+ model_path:
2214
+ |
2215
+ |-- smpl
2216
+ |-- SMPL_FEMALE
2217
+ |-- SMPL_NEUTRAL
2218
+ |-- SMPL_MALE
2219
+ |-- smplh
2220
+ |-- SMPLH_FEMALE
2221
+ |-- SMPLH_MALE
2222
+ |-- smplx
2223
+ |-- SMPLX_FEMALE
2224
+ |-- SMPLX_NEUTRAL
2225
+ |-- SMPLX_MALE
2226
+ |-- mano
2227
+ |-- MANO RIGHT
2228
+ |-- MANO LEFT
2229
+ |-- flame
2230
+ |-- FLAME_FEMALE
2231
+ |-- FLAME_MALE
2232
+ |-- FLAME_NEUTRAL
2233
+
2234
+ model_type: str, optional
2235
+ When model_path is a folder, then this parameter specifies the
2236
+ type of model to be loaded
2237
+ **kwargs: dict
2238
+ Keyword arguments
2239
+
2240
+ Returns
2241
+ -------
2242
+ body_model: nn.Module
2243
+ The PyTorch module that implements the corresponding body model
2244
+ Raises
2245
+ ------
2246
+ ValueError: In case the model type is not one of SMPL, SMPLH,
2247
+ SMPLX, MANO or FLAME
2248
+ '''
2249
+
2250
+ if osp.isdir(model_path):
2251
+ model_path = os.path.join(model_path, model_type)
2252
+ else:
2253
+ model_type = osp.basename(model_path).split('_')[0].lower()
2254
+
2255
+ if model_type.lower() == 'smpl':
2256
+ return SMPLLayer(model_path, **kwargs)
2257
+ elif model_type.lower() == 'smplh':
2258
+ return SMPLHLayer(model_path, **kwargs)
2259
+ elif model_type.lower() == 'smplx':
2260
+ return SMPLXLayer(model_path, **kwargs)
2261
+ elif 'mano' in model_type.lower():
2262
+ return MANOLayer(model_path, **kwargs)
2263
+ elif 'flame' in model_type.lower():
2264
+ return FLAMELayer(model_path, **kwargs)
2265
+ else:
2266
+ raise ValueError(f'Unknown model type {model_type}, exiting!')
2267
+
2268
+
2269
+ def create(
2270
+ model_path: str,
2271
+ model_type: str = 'smpl',
2272
+ **kwargs
2273
+ ) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]:
2274
+ ''' Method for creating a model from a path and a model type
2275
+
2276
+ Parameters
2277
+ ----------
2278
+ model_path: str
2279
+ Either the path to the model you wish to load or a folder,
2280
+ where each subfolder contains the differents types, i.e.:
2281
+ model_path:
2282
+ |
2283
+ |-- smpl
2284
+ |-- SMPL_FEMALE
2285
+ |-- SMPL_NEUTRAL
2286
+ |-- SMPL_MALE
2287
+ |-- smplh
2288
+ |-- SMPLH_FEMALE
2289
+ |-- SMPLH_MALE
2290
+ |-- smplx
2291
+ |-- SMPLX_FEMALE
2292
+ |-- SMPLX_NEUTRAL
2293
+ |-- SMPLX_MALE
2294
+ |-- mano
2295
+ |-- MANO RIGHT
2296
+ |-- MANO LEFT
2297
+
2298
+ model_type: str, optional
2299
+ When model_path is a folder, then this parameter specifies the
2300
+ type of model to be loaded
2301
+ **kwargs: dict
2302
+ Keyword arguments
2303
+
2304
+ Returns
2305
+ -------
2306
+ body_model: nn.Module
2307
+ The PyTorch module that implements the corresponding body model
2308
+ Raises
2309
+ ------
2310
+ ValueError: In case the model type is not one of SMPL, SMPLH,
2311
+ SMPLX, MANO or FLAME
2312
+ '''
2313
+
2314
+ # If it's a folder, assume
2315
+ if osp.isdir(model_path):
2316
+ model_path = os.path.join(model_path, model_type)
2317
+ else:
2318
+ model_type = osp.basename(model_path).split('_')[0].lower()
2319
+
2320
+ if model_type.lower() == 'smpl':
2321
+ return SMPL(model_path, **kwargs)
2322
+ elif model_type.lower() == 'smplh':
2323
+ return SMPLH(model_path, **kwargs)
2324
+ elif model_type.lower() == 'smplx':
2325
+ return SMPLX(model_path, **kwargs)
2326
+ elif 'mano' in model_type.lower():
2327
+ return MANO(model_path, **kwargs)
2328
+ elif 'flame' in model_type.lower():
2329
+ return FLAME(model_path, **kwargs)
2330
+ else:
2331
+ raise ValueError(f'Unknown model type {model_type}, exiting!')
common/utils/smplx/smplx/joint_names.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ JOINT_NAMES = [
18
+ 'pelvis',
19
+ 'left_hip',
20
+ 'right_hip',
21
+ 'spine1',
22
+ 'left_knee',
23
+ 'right_knee',
24
+ 'spine2',
25
+ 'left_ankle',
26
+ 'right_ankle',
27
+ 'spine3',
28
+ 'left_foot',
29
+ 'right_foot',
30
+ 'neck',
31
+ 'left_collar',
32
+ 'right_collar',
33
+ 'head',
34
+ 'left_shoulder',
35
+ 'right_shoulder',
36
+ 'left_elbow',
37
+ 'right_elbow',
38
+ 'left_wrist',
39
+ 'right_wrist',
40
+ 'jaw',
41
+ 'left_eye_smplhf',
42
+ 'right_eye_smplhf',
43
+ 'left_index1',
44
+ 'left_index2',
45
+ 'left_index3',
46
+ 'left_middle1',
47
+ 'left_middle2',
48
+ 'left_middle3',
49
+ 'left_pinky1',
50
+ 'left_pinky2',
51
+ 'left_pinky3',
52
+ 'left_ring1',
53
+ 'left_ring2',
54
+ 'left_ring3',
55
+ 'left_thumb1',
56
+ 'left_thumb2',
57
+ 'left_thumb3',
58
+ 'right_index1',
59
+ 'right_index2',
60
+ 'right_index3',
61
+ 'right_middle1',
62
+ 'right_middle2',
63
+ 'right_middle3',
64
+ 'right_pinky1',
65
+ 'right_pinky2',
66
+ 'right_pinky3',
67
+ 'right_ring1',
68
+ 'right_ring2',
69
+ 'right_ring3',
70
+ 'right_thumb1',
71
+ 'right_thumb2',
72
+ 'right_thumb3',
73
+ 'nose',
74
+ 'right_eye',
75
+ 'left_eye',
76
+ 'right_ear',
77
+ 'left_ear',
78
+ 'left_big_toe',
79
+ 'left_small_toe',
80
+ 'left_heel',
81
+ 'right_big_toe',
82
+ 'right_small_toe',
83
+ 'right_heel',
84
+ 'left_thumb',
85
+ 'left_index',
86
+ 'left_middle',
87
+ 'left_ring',
88
+ 'left_pinky',
89
+ 'right_thumb',
90
+ 'right_index',
91
+ 'right_middle',
92
+ 'right_ring',
93
+ 'right_pinky',
94
+ 'right_eye_brow1',
95
+ 'right_eye_brow2',
96
+ 'right_eye_brow3',
97
+ 'right_eye_brow4',
98
+ 'right_eye_brow5',
99
+ 'left_eye_brow5',
100
+ 'left_eye_brow4',
101
+ 'left_eye_brow3',
102
+ 'left_eye_brow2',
103
+ 'left_eye_brow1',
104
+ 'nose1',
105
+ 'nose2',
106
+ 'nose3',
107
+ 'nose4',
108
+ 'right_nose_2',
109
+ 'right_nose_1',
110
+ 'nose_middle',
111
+ 'left_nose_1',
112
+ 'left_nose_2',
113
+ 'right_eye1',
114
+ 'right_eye2',
115
+ 'right_eye3',
116
+ 'right_eye4',
117
+ 'right_eye5',
118
+ 'right_eye6',
119
+ 'left_eye4',
120
+ 'left_eye3',
121
+ 'left_eye2',
122
+ 'left_eye1',
123
+ 'left_eye6',
124
+ 'left_eye5',
125
+ 'right_mouth_1',
126
+ 'right_mouth_2',
127
+ 'right_mouth_3',
128
+ 'mouth_top',
129
+ 'left_mouth_3',
130
+ 'left_mouth_2',
131
+ 'left_mouth_1',
132
+ 'left_mouth_5', # 59 in OpenPose output
133
+ 'left_mouth_4', # 58 in OpenPose output
134
+ 'mouth_bottom',
135
+ 'right_mouth_4',
136
+ 'right_mouth_5',
137
+ 'right_lip_1',
138
+ 'right_lip_2',
139
+ 'lip_top',
140
+ 'left_lip_2',
141
+ 'left_lip_1',
142
+ 'left_lip_3',
143
+ 'lip_bottom',
144
+ 'right_lip_3',
145
+ # Face contour
146
+ 'right_contour_1',
147
+ 'right_contour_2',
148
+ 'right_contour_3',
149
+ 'right_contour_4',
150
+ 'right_contour_5',
151
+ 'right_contour_6',
152
+ 'right_contour_7',
153
+ 'right_contour_8',
154
+ 'contour_middle',
155
+ 'left_contour_8',
156
+ 'left_contour_7',
157
+ 'left_contour_6',
158
+ 'left_contour_5',
159
+ 'left_contour_4',
160
+ 'left_contour_3',
161
+ 'left_contour_2',
162
+ 'left_contour_1',
163
+ ]
common/utils/smplx/smplx/lbs.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import print_function
19
+ from __future__ import division
20
+
21
+ from typing import Tuple, List
22
+ import numpy as np
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+
27
+ from .utils import rot_mat_to_euler, Tensor
28
+
29
+
30
+ def find_dynamic_lmk_idx_and_bcoords(
31
+ vertices: Tensor,
32
+ pose: Tensor,
33
+ dynamic_lmk_faces_idx: Tensor,
34
+ dynamic_lmk_b_coords: Tensor,
35
+ neck_kin_chain: List[int],
36
+ pose2rot: bool = True,
37
+ ) -> Tuple[Tensor, Tensor]:
38
+ ''' Compute the faces, barycentric coordinates for the dynamic landmarks
39
+
40
+
41
+ To do so, we first compute the rotation of the neck around the y-axis
42
+ and then use a pre-computed look-up table to find the faces and the
43
+ barycentric coordinates that will be used.
44
+
45
+ Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
46
+ for providing the original TensorFlow implementation and for the LUT.
47
+
48
+ Parameters
49
+ ----------
50
+ vertices: torch.tensor BxVx3, dtype = torch.float32
51
+ The tensor of input vertices
52
+ pose: torch.tensor Bx(Jx3), dtype = torch.float32
53
+ The current pose of the body model
54
+ dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
55
+ The look-up table from neck rotation to faces
56
+ dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
57
+ The look-up table from neck rotation to barycentric coordinates
58
+ neck_kin_chain: list
59
+ A python list that contains the indices of the joints that form the
60
+ kinematic chain of the neck.
61
+ dtype: torch.dtype, optional
62
+
63
+ Returns
64
+ -------
65
+ dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
66
+ A tensor of size BxL that contains the indices of the faces that
67
+ will be used to compute the current dynamic landmarks.
68
+ dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
69
+ A tensor of size BxL that contains the indices of the faces that
70
+ will be used to compute the current dynamic landmarks.
71
+ '''
72
+
73
+ dtype = vertices.dtype
74
+ batch_size = vertices.shape[0]
75
+
76
+ if pose2rot:
77
+ aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
78
+ neck_kin_chain)
79
+ rot_mats = batch_rodrigues(
80
+ aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3)
81
+ else:
82
+ rot_mats = torch.index_select(
83
+ pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain)
84
+
85
+ rel_rot_mat = torch.eye(
86
+ 3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat(
87
+ batch_size, 1, 1)
88
+ for idx in range(len(neck_kin_chain)):
89
+ rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
90
+
91
+ y_rot_angle = torch.round(
92
+ torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
93
+ max=39)).to(dtype=torch.long)
94
+ neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
95
+ mask = y_rot_angle.lt(-39).to(dtype=torch.long)
96
+ neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
97
+ y_rot_angle = (neg_mask * neg_vals +
98
+ (1 - neg_mask) * y_rot_angle)
99
+
100
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
101
+ 0, y_rot_angle)
102
+ dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
103
+ 0, y_rot_angle)
104
+
105
+ return dyn_lmk_faces_idx, dyn_lmk_b_coords
106
+
107
+
108
+ def vertices2landmarks(
109
+ vertices: Tensor,
110
+ faces: Tensor,
111
+ lmk_faces_idx: Tensor,
112
+ lmk_bary_coords: Tensor
113
+ ) -> Tensor:
114
+ ''' Calculates landmarks by barycentric interpolation
115
+
116
+ Parameters
117
+ ----------
118
+ vertices: torch.tensor BxVx3, dtype = torch.float32
119
+ The tensor of input vertices
120
+ faces: torch.tensor Fx3, dtype = torch.long
121
+ The faces of the mesh
122
+ lmk_faces_idx: torch.tensor L, dtype = torch.long
123
+ The tensor with the indices of the faces used to calculate the
124
+ landmarks.
125
+ lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
126
+ The tensor of barycentric coordinates that are used to interpolate
127
+ the landmarks
128
+
129
+ Returns
130
+ -------
131
+ landmarks: torch.tensor BxLx3, dtype = torch.float32
132
+ The coordinates of the landmarks for each mesh in the batch
133
+ '''
134
+ # Extract the indices of the vertices for each face
135
+ # BxLx3
136
+ batch_size, num_verts = vertices.shape[:2]
137
+ device = vertices.device
138
+
139
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
140
+ batch_size, -1, 3)
141
+
142
+ lmk_faces += torch.arange(
143
+ batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
144
+
145
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
146
+ batch_size, -1, 3, 3)
147
+
148
+ landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
149
+ return landmarks
150
+
151
+
152
+ def lbs(
153
+ betas: Tensor,
154
+ pose: Tensor,
155
+ v_template: Tensor,
156
+ shapedirs: Tensor,
157
+ posedirs: Tensor,
158
+ J_regressor: Tensor,
159
+ parents: Tensor,
160
+ lbs_weights: Tensor,
161
+ pose2rot: bool = True,
162
+ ) -> Tuple[Tensor, Tensor]:
163
+ ''' Performs Linear Blend Skinning with the given shape and pose parameters
164
+
165
+ Parameters
166
+ ----------
167
+ betas : torch.tensor BxNB
168
+ The tensor of shape parameters
169
+ pose : torch.tensor Bx(J + 1) * 3
170
+ The pose parameters in axis-angle format
171
+ v_template torch.tensor BxVx3
172
+ The template mesh that will be deformed
173
+ shapedirs : torch.tensor 1xNB
174
+ The tensor of PCA shape displacements
175
+ posedirs : torch.tensor Px(V * 3)
176
+ The pose PCA coefficients
177
+ J_regressor : torch.tensor JxV
178
+ The regressor array that is used to calculate the joints from
179
+ the position of the vertices
180
+ parents: torch.tensor J
181
+ The array that describes the kinematic tree for the model
182
+ lbs_weights: torch.tensor N x V x (J + 1)
183
+ The linear blend skinning weights that represent how much the
184
+ rotation matrix of each part affects each vertex
185
+ pose2rot: bool, optional
186
+ Flag on whether to convert the input pose tensor to rotation
187
+ matrices. The default value is True. If False, then the pose tensor
188
+ should already contain rotation matrices and have a size of
189
+ Bx(J + 1)x9
190
+ dtype: torch.dtype, optional
191
+
192
+ Returns
193
+ -------
194
+ verts: torch.tensor BxVx3
195
+ The vertices of the mesh after applying the shape and pose
196
+ displacements.
197
+ joints: torch.tensor BxJx3
198
+ The joints of the model
199
+ '''
200
+
201
+ batch_size = max(betas.shape[0], pose.shape[0])
202
+ device, dtype = betas.device, betas.dtype
203
+
204
+ # Add shape contribution
205
+ v_shaped = v_template + blend_shapes(betas, shapedirs)
206
+
207
+ # Get the joints
208
+ # NxJx3 array
209
+ J = vertices2joints(J_regressor, v_shaped)
210
+
211
+ # 3. Add pose blend shapes
212
+ # N x J x 3 x 3
213
+ ident = torch.eye(3, dtype=dtype, device=device)
214
+ if pose2rot:
215
+ rot_mats = batch_rodrigues(pose.view(-1, 3)).view(
216
+ [batch_size, -1, 3, 3])
217
+
218
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
219
+ # (N x P) x (P, V * 3) -> N x V x 3
220
+ pose_offsets = torch.matmul(
221
+ pose_feature, posedirs).view(batch_size, -1, 3)
222
+ else:
223
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
224
+ rot_mats = pose.view(batch_size, -1, 3, 3)
225
+
226
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
227
+ posedirs).view(batch_size, -1, 3)
228
+
229
+ v_posed = pose_offsets + v_shaped
230
+ # 4. Get the global joint location
231
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
232
+
233
+ # 5. Do skinning:
234
+ # W is N x V x (J + 1)
235
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
236
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
237
+ num_joints = J_regressor.shape[0]
238
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
239
+ .view(batch_size, -1, 4, 4)
240
+
241
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
242
+ dtype=dtype, device=device)
243
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
244
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
245
+
246
+ verts = v_homo[:, :, :3, 0]
247
+
248
+ return verts, J_transformed
249
+
250
+
251
+ def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
252
+ ''' Calculates the 3D joint locations from the vertices
253
+
254
+ Parameters
255
+ ----------
256
+ J_regressor : torch.tensor JxV
257
+ The regressor array that is used to calculate the joints from the
258
+ position of the vertices
259
+ vertices : torch.tensor BxVx3
260
+ The tensor of mesh vertices
261
+
262
+ Returns
263
+ -------
264
+ torch.tensor BxJx3
265
+ The location of the joints
266
+ '''
267
+
268
+ return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
269
+
270
+
271
+ def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor:
272
+ ''' Calculates the per vertex displacement due to the blend shapes
273
+
274
+
275
+ Parameters
276
+ ----------
277
+ betas : torch.tensor Bx(num_betas)
278
+ Blend shape coefficients
279
+ shape_disps: torch.tensor Vx3x(num_betas)
280
+ Blend shapes
281
+
282
+ Returns
283
+ -------
284
+ torch.tensor BxVx3
285
+ The per-vertex displacement due to shape deformation
286
+ '''
287
+
288
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
289
+ # i.e. Multiply each shape displacement by its corresponding beta and
290
+ # then sum them.
291
+ blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
292
+ return blend_shape
293
+
294
+
295
+ def batch_rodrigues(
296
+ rot_vecs: Tensor,
297
+ epsilon: float = 1e-8,
298
+ ) -> Tensor:
299
+ ''' Calculates the rotation matrices for a batch of rotation vectors
300
+ Parameters
301
+ ----------
302
+ rot_vecs: torch.tensor Nx3
303
+ array of N axis-angle vectors
304
+ Returns
305
+ -------
306
+ R: torch.tensor Nx3x3
307
+ The rotation matrices for the given axis-angle parameters
308
+ '''
309
+
310
+ batch_size = rot_vecs.shape[0]
311
+ device, dtype = rot_vecs.device, rot_vecs.dtype
312
+
313
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
314
+ rot_dir = rot_vecs / angle
315
+
316
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
317
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
318
+
319
+ # Bx1 arrays
320
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
321
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
322
+
323
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
324
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
325
+ .view((batch_size, 3, 3))
326
+
327
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
328
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
329
+ return rot_mat
330
+
331
+
332
+ def transform_mat(R: Tensor, t: Tensor) -> Tensor:
333
+ ''' Creates a batch of transformation matrices
334
+ Args:
335
+ - R: Bx3x3 array of a batch of rotation matrices
336
+ - t: Bx3x1 array of a batch of translation vectors
337
+ Returns:
338
+ - T: Bx4x4 Transformation matrix
339
+ '''
340
+ # No padding left or right, only add an extra row
341
+ return torch.cat([F.pad(R, [0, 0, 0, 1]),
342
+ F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
343
+
344
+
345
+ def batch_rigid_transform(
346
+ rot_mats: Tensor,
347
+ joints: Tensor,
348
+ parents: Tensor,
349
+ dtype=torch.float32
350
+ ) -> Tensor:
351
+ """
352
+ Applies a batch of rigid transformations to the joints
353
+
354
+ Parameters
355
+ ----------
356
+ rot_mats : torch.tensor BxNx3x3
357
+ Tensor of rotation matrices
358
+ joints : torch.tensor BxNx3
359
+ Locations of joints
360
+ parents : torch.tensor BxN
361
+ The kinematic tree of each object
362
+ dtype : torch.dtype, optional:
363
+ The data type of the created tensors, the default is torch.float32
364
+
365
+ Returns
366
+ -------
367
+ posed_joints : torch.tensor BxNx3
368
+ The locations of the joints after applying the pose rotations
369
+ rel_transforms : torch.tensor BxNx4x4
370
+ The relative (with respect to the root joint) rigid transformations
371
+ for all the joints
372
+ """
373
+
374
+ joints = torch.unsqueeze(joints, dim=-1)
375
+
376
+ rel_joints = joints.clone()
377
+ rel_joints[:, 1:] -= joints[:, parents[1:]]
378
+
379
+ transforms_mat = transform_mat(
380
+ rot_mats.reshape(-1, 3, 3),
381
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
382
+
383
+ transform_chain = [transforms_mat[:, 0]]
384
+ for i in range(1, parents.shape[0]):
385
+ # Subtract the joint location at the rest pose
386
+ # No need for rotation, since it's identity when at rest
387
+ curr_res = torch.matmul(transform_chain[parents[i]],
388
+ transforms_mat[:, i])
389
+ transform_chain.append(curr_res)
390
+
391
+ transforms = torch.stack(transform_chain, dim=1)
392
+
393
+ # The last column of the transformations contains the posed joints
394
+ posed_joints = transforms[:, :, :3, 3]
395
+
396
+ # The last column of the transformations contains the posed joints
397
+ posed_joints = transforms[:, :, :3, 3]
398
+
399
+ joints_homogen = F.pad(joints, [0, 0, 0, 1])
400
+
401
+ rel_transforms = transforms - F.pad(
402
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
403
+
404
+ return posed_joints, rel_transforms
common/utils/smplx/smplx/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from typing import NewType, Union, Optional
18
+ from dataclasses import dataclass, asdict, fields
19
+ import numpy as np
20
+ import torch
21
+
22
+ Tensor = NewType('Tensor', torch.Tensor)
23
+ Array = NewType('Array', np.ndarray)
24
+
25
+
26
+ @dataclass
27
+ class ModelOutput:
28
+ vertices: Optional[Tensor] = None
29
+ joints: Optional[Tensor] = None
30
+ full_pose: Optional[Tensor] = None
31
+ global_orient: Optional[Tensor] = None
32
+ transl: Optional[Tensor] = None
33
+
34
+ def __getitem__(self, key):
35
+ return getattr(self, key)
36
+
37
+ def get(self, key, default=None):
38
+ return getattr(self, key, default)
39
+
40
+ def __iter__(self):
41
+ return self.keys()
42
+
43
+ def keys(self):
44
+ keys = [t.name for t in fields(self)]
45
+ return iter(keys)
46
+
47
+ def values(self):
48
+ values = [getattr(self, t.name) for t in fields(self)]
49
+ return iter(values)
50
+
51
+ def items(self):
52
+ data = [(t.name, getattr(self, t.name)) for t in fields(self)]
53
+ return iter(data)
54
+
55
+
56
+ @dataclass
57
+ class SMPLOutput(ModelOutput):
58
+ betas: Optional[Tensor] = None
59
+ body_pose: Optional[Tensor] = None
60
+
61
+
62
+ @dataclass
63
+ class SMPLHOutput(SMPLOutput):
64
+ left_hand_pose: Optional[Tensor] = None
65
+ right_hand_pose: Optional[Tensor] = None
66
+ transl: Optional[Tensor] = None
67
+
68
+
69
+ @dataclass
70
+ class SMPLXOutput(SMPLHOutput):
71
+ expression: Optional[Tensor] = None
72
+ jaw_pose: Optional[Tensor] = None
73
+
74
+
75
+ @dataclass
76
+ class MANOOutput(ModelOutput):
77
+ betas: Optional[Tensor] = None
78
+ hand_pose: Optional[Tensor] = None
79
+
80
+
81
+ @dataclass
82
+ class FLAMEOutput(ModelOutput):
83
+ betas: Optional[Tensor] = None
84
+ expression: Optional[Tensor] = None
85
+ jaw_pose: Optional[Tensor] = None
86
+ neck_pose: Optional[Tensor] = None
87
+
88
+
89
+ def find_joint_kin_chain(joint_id, kinematic_tree):
90
+ kin_chain = []
91
+ curr_idx = joint_id
92
+ while curr_idx != -1:
93
+ kin_chain.append(curr_idx)
94
+ curr_idx = kinematic_tree[curr_idx]
95
+ return kin_chain
96
+
97
+
98
+ def to_tensor(
99
+ array: Union[Array, Tensor], dtype=torch.float32
100
+ ) -> Tensor:
101
+ if torch.is_tensor(array):
102
+ return array
103
+ else:
104
+ return torch.tensor(array, dtype=dtype)
105
+
106
+
107
+ class Struct(object):
108
+ def __init__(self, **kwargs):
109
+ for key, val in kwargs.items():
110
+ setattr(self, key, val)
111
+
112
+
113
+ def to_np(array, dtype=np.float32):
114
+ if 'scipy.sparse' in str(type(array)):
115
+ array = array.todense()
116
+ return np.array(array, dtype=dtype)
117
+
118
+
119
+ def rot_mat_to_euler(rot_mats):
120
+ # Calculates rotation matrix to euler angles
121
+ # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
122
+
123
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
124
+ rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
125
+ return torch.atan2(-rot_mats[:, 2, 0], sy)
common/utils/smplx/smplx/vertex_ids.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from __future__ import print_function
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+
21
+ # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
22
+ # MSCOCO and OpenPose joints
23
+ vertex_ids = {
24
+ 'smplh': {
25
+ 'nose': 332,
26
+ 'reye': 6260,
27
+ 'leye': 2800,
28
+ 'rear': 4071,
29
+ 'lear': 583,
30
+ 'rthumb': 6191,
31
+ 'rindex': 5782,
32
+ 'rmiddle': 5905,
33
+ 'rring': 6016,
34
+ 'rpinky': 6133,
35
+ 'lthumb': 2746,
36
+ 'lindex': 2319,
37
+ 'lmiddle': 2445,
38
+ 'lring': 2556,
39
+ 'lpinky': 2673,
40
+ 'LBigToe': 3216,
41
+ 'LSmallToe': 3226,
42
+ 'LHeel': 3387,
43
+ 'RBigToe': 6617,
44
+ 'RSmallToe': 6624,
45
+ 'RHeel': 6787
46
+ },
47
+ 'smplx': {
48
+ 'nose': 9120,
49
+ 'reye': 9929,
50
+ 'leye': 9448,
51
+ 'rear': 616,
52
+ 'lear': 6,
53
+ 'rthumb': 8079,
54
+ 'rindex': 7669,
55
+ 'rmiddle': 7794,
56
+ 'rring': 7905,
57
+ 'rpinky': 8022,
58
+ 'lthumb': 5361,
59
+ 'lindex': 4933,
60
+ 'lmiddle': 5058,
61
+ 'lring': 5169,
62
+ 'lpinky': 5286,
63
+ 'LBigToe': 5770,
64
+ 'LSmallToe': 5780,
65
+ 'LHeel': 8846,
66
+ 'RBigToe': 8463,
67
+ 'RSmallToe': 8474,
68
+ 'RHeel': 8635
69
+ },
70
+ 'mano': {
71
+ 'thumb': 744,
72
+ 'index': 320,
73
+ 'middle': 443,
74
+ 'ring': 554,
75
+ 'pinky': 671,
76
+ }
77
+ }
common/utils/smplx/smplx/vertex_joint_selector.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import print_function
19
+ from __future__ import division
20
+
21
+ import numpy as np
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from .utils import to_tensor
27
+
28
+
29
+ class VertexJointSelector(nn.Module):
30
+
31
+ def __init__(self, vertex_ids=None,
32
+ use_hands=True,
33
+ use_feet_keypoints=True, **kwargs):
34
+ super(VertexJointSelector, self).__init__()
35
+
36
+ extra_joints_idxs = []
37
+
38
+ face_keyp_idxs = np.array([
39
+ vertex_ids['nose'],
40
+ vertex_ids['reye'],
41
+ vertex_ids['leye'],
42
+ vertex_ids['rear'],
43
+ vertex_ids['lear']], dtype=np.int64)
44
+
45
+ extra_joints_idxs = np.concatenate([extra_joints_idxs,
46
+ face_keyp_idxs])
47
+
48
+ if use_feet_keypoints:
49
+ feet_keyp_idxs = np.array([vertex_ids['LBigToe'],
50
+ vertex_ids['LSmallToe'],
51
+ vertex_ids['LHeel'],
52
+ vertex_ids['RBigToe'],
53
+ vertex_ids['RSmallToe'],
54
+ vertex_ids['RHeel']], dtype=np.int32)
55
+
56
+ extra_joints_idxs = np.concatenate(
57
+ [extra_joints_idxs, feet_keyp_idxs])
58
+
59
+ if use_hands:
60
+ self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']
61
+
62
+ tips_idxs = []
63
+ for hand_id in ['l', 'r']:
64
+ for tip_name in self.tip_names:
65
+ tips_idxs.append(vertex_ids[hand_id + tip_name])
66
+
67
+ extra_joints_idxs = np.concatenate(
68
+ [extra_joints_idxs, tips_idxs])
69
+
70
+ self.register_buffer('extra_joints_idxs',
71
+ to_tensor(extra_joints_idxs, dtype=torch.long))
72
+
73
+ def forward(self, vertices, joints):
74
+ extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
75
+ joints = torch.cat([joints, extra_joints], dim=1)
76
+
77
+ return joints
common/utils/smplx/tools/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Removing Chumpy objects
2
+
3
+ In a Python 2 virtual environment with [Chumpy](https://github.com/mattloper/chumpy) installed run the following to remove any Chumpy objects from the model data:
4
+
5
+ ```bash
6
+ python tools/clean_ch.py --input-models path-to-models/*.pkl --output-folder output-folder
7
+ ```
8
+
9
+ ## Merging SMPL-H and MANO parameters
10
+
11
+ In order to use the given PyTorch SMPL-H module we first need to merge the SMPL-H and MANO parameters in a single file. After agreeing to the license and downloading the models, run the following command:
12
+
13
+ ```bash
14
+ python tools/merge_smplh_mano.py --smplh-fn SMPLH_FOLDER/SMPLH_GENDER.pkl \
15
+ --mano-left-fn MANO_FOLDER/MANO_LEFT.pkl \
16
+ --mano-right-fn MANO_FOLDER/MANO_RIGHT.pkl \
17
+ --output-folder OUTPUT_FOLDER
18
+ ```
19
+
20
+ where SMPLH_FOLDER is the folder with the SMPL-H files and MANO_FOLDER the one for the MANO files.
common/utils/smplx/tools/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import clean_ch
19
+ import merge_smplh_mano
common/utils/smplx/tools/clean_ch.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import print_function
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+
22
+ import argparse
23
+ import os
24
+ import os.path as osp
25
+
26
+ import pickle
27
+
28
+ from tqdm import tqdm
29
+ import numpy as np
30
+
31
+
32
+ def clean_fn(fn, output_folder='output'):
33
+ with open(fn, 'rb') as body_file:
34
+ body_data = pickle.load(body_file)
35
+
36
+ output_dict = {}
37
+ for key, data in body_data.iteritems():
38
+ if 'chumpy' in str(type(data)):
39
+ output_dict[key] = np.array(data)
40
+ else:
41
+ output_dict[key] = data
42
+
43
+ out_fn = osp.split(fn)[1]
44
+
45
+ out_path = osp.join(output_folder, out_fn)
46
+ with open(out_path, 'wb') as out_file:
47
+ pickle.dump(output_dict, out_file)
48
+
49
+
50
+ if __name__ == '__main__':
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--input-models', dest='input_models', nargs='+',
53
+ required=True, type=str,
54
+ help='The path to the model that will be processed')
55
+ parser.add_argument('--output-folder', dest='output_folder',
56
+ required=True, type=str,
57
+ help='The path to the output folder')
58
+
59
+ args = parser.parse_args()
60
+
61
+ input_models = args.input_models
62
+ output_folder = args.output_folder
63
+ if not osp.exists(output_folder):
64
+ print('Creating directory: {}'.format(output_folder))
65
+ os.makedirs(output_folder)
66
+
67
+ for input_model in input_models:
68
+ clean_fn(input_model, output_folder=output_folder)
common/utils/smplx/tools/merge_smplh_mano.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import print_function
19
+
20
+ import os
21
+ import os.path as osp
22
+ import pickle
23
+
24
+ import argparse
25
+
26
+ import numpy as np
27
+
28
+
29
+ def merge_models(smplh_fn, mano_left_fn, mano_right_fn,
30
+ output_folder='output'):
31
+
32
+ with open(smplh_fn, 'rb') as body_file:
33
+ body_data = pickle.load(body_file)
34
+
35
+ with open(mano_left_fn, 'rb') as lhand_file:
36
+ lhand_data = pickle.load(lhand_file)
37
+
38
+ with open(mano_right_fn, 'rb') as rhand_file:
39
+ rhand_data = pickle.load(rhand_file)
40
+
41
+ out_fn = osp.split(smplh_fn)[1]
42
+
43
+ output_data = body_data.copy()
44
+ output_data['hands_componentsl'] = lhand_data['hands_components']
45
+ output_data['hands_componentsr'] = rhand_data['hands_components']
46
+
47
+ output_data['hands_coeffsl'] = lhand_data['hands_coeffs']
48
+ output_data['hands_coeffsr'] = rhand_data['hands_coeffs']
49
+
50
+ output_data['hands_meanl'] = lhand_data['hands_mean']
51
+ output_data['hands_meanr'] = rhand_data['hands_mean']
52
+
53
+ for key, data in output_data.iteritems():
54
+ if 'chumpy' in str(type(data)):
55
+ output_data[key] = np.array(data)
56
+ else:
57
+ output_data[key] = data
58
+
59
+ out_path = osp.join(output_folder, out_fn)
60
+ print(out_path)
61
+ print('Saving to {}'.format(out_path))
62
+ with open(out_path, 'wb') as output_file:
63
+ pickle.dump(output_data, output_file)
64
+
65
+
66
+ if __name__ == '__main__':
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument('--smplh-fn', dest='smplh_fn', required=True,
69
+ type=str, help='The path to the SMPLH model')
70
+ parser.add_argument('--mano-left-fn', dest='mano_left_fn', required=True,
71
+ type=str, help='The path to the left hand MANO model')
72
+ parser.add_argument('--mano-right-fn', dest='mano_right_fn', required=True,
73
+ type=str, help='The path to the right hand MANO model')
74
+ parser.add_argument('--output-folder', dest='output_folder',
75
+ required=True, type=str,
76
+ help='The path to the output folder')
77
+
78
+ args = parser.parse_args()
79
+
80
+ smplh_fn = args.smplh_fn
81
+ mano_left_fn = args.mano_left_fn
82
+ mano_right_fn = args.mano_right_fn
83
+ output_folder = args.output_folder
84
+
85
+ if not osp.exists(output_folder):
86
+ print('Creating directory: {}'.format(output_folder))
87
+ os.makedirs(output_folder)
88
+
89
+ merge_models(smplh_fn, mano_left_fn, mano_right_fn, output_folder)
common/utils/transforms.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import scipy
4
+ from config import cfg
5
+ from torch.nn import functional as F
6
+ import torchgeometry as tgm
7
+
8
+
9
+ def cam2pixel(cam_coord, f, c):
10
+ x = cam_coord[:, 0] / cam_coord[:, 2] * f[0] + c[0]
11
+ y = cam_coord[:, 1] / cam_coord[:, 2] * f[1] + c[1]
12
+ z = cam_coord[:, 2]
13
+ return np.stack((x, y, z), 1)
14
+
15
+
16
+ def pixel2cam(pixel_coord, f, c):
17
+ x = (pixel_coord[:, 0] - c[0]) / f[0] * pixel_coord[:, 2]
18
+ y = (pixel_coord[:, 1] - c[1]) / f[1] * pixel_coord[:, 2]
19
+ z = pixel_coord[:, 2]
20
+ return np.stack((x, y, z), 1)
21
+
22
+
23
+ def world2cam(world_coord, R, t):
24
+ cam_coord = np.dot(R, world_coord.transpose(1, 0)).transpose(1, 0) + t.reshape(1, 3)
25
+ return cam_coord
26
+
27
+
28
+ def cam2world(cam_coord, R, t):
29
+ world_coord = np.dot(np.linalg.inv(R), (cam_coord - t.reshape(1, 3)).transpose(1, 0)).transpose(1, 0)
30
+ return world_coord
31
+
32
+
33
+ def rigid_transform_3D(A, B):
34
+ n, dim = A.shape
35
+ centroid_A = np.mean(A, axis=0)
36
+ centroid_B = np.mean(B, axis=0)
37
+ H = np.dot(np.transpose(A - centroid_A), B - centroid_B) / n
38
+ U, s, V = np.linalg.svd(H)
39
+ R = np.dot(np.transpose(V), np.transpose(U))
40
+ if np.linalg.det(R) < 0:
41
+ s[-1] = -s[-1]
42
+ V[2] = -V[2]
43
+ R = np.dot(np.transpose(V), np.transpose(U))
44
+
45
+ varP = np.var(A, axis=0).sum()
46
+ c = 1 / varP * np.sum(s)
47
+
48
+ t = -np.dot(c * R, np.transpose(centroid_A)) + np.transpose(centroid_B)
49
+ return c, R, t
50
+
51
+
52
+ def rigid_align(A, B):
53
+ c, R, t = rigid_transform_3D(A, B)
54
+ A2 = np.transpose(np.dot(c * R, np.transpose(A))) + t
55
+ return A2
56
+
57
+
58
+ def transform_joint_to_other_db(src_joint, src_name, dst_name):
59
+ src_joint_num = len(src_name)
60
+ dst_joint_num = len(dst_name)
61
+
62
+ new_joint = np.zeros(((dst_joint_num,) + src_joint.shape[1:]), dtype=np.float32)
63
+ for src_idx in range(len(src_name)):
64
+ name = src_name[src_idx]
65
+ if name in dst_name:
66
+ dst_idx = dst_name.index(name)
67
+ new_joint[dst_idx] = src_joint[src_idx]
68
+
69
+ return new_joint
70
+
71
+
72
+ def rot6d_to_axis_angle(x):
73
+ batch_size = x.shape[0]
74
+
75
+ x = x.view(-1, 3, 2)
76
+ a1 = x[:, :, 0]
77
+ a2 = x[:, :, 1]
78
+ b1 = F.normalize(a1)
79
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
80
+ b3 = torch.cross(b1, b2)
81
+ rot_mat = torch.stack((b1, b2, b3), dim=-1) # 3x3 rotation matrix
82
+
83
+ rot_mat = torch.cat([rot_mat, torch.zeros((batch_size, 3, 1)).to(cfg.device).float()], 2) # 3x4 rotation matrix
84
+ axis_angle = tgm.rotation_matrix_to_angle_axis(rot_mat).reshape(-1, 3) # axis-angle
85
+ axis_angle[torch.isnan(axis_angle)] = 0.0
86
+ return axis_angle
87
+
88
+
89
+ def sample_joint_features(img_feat, joint_xy):
90
+ height, width = img_feat.shape[2:]
91
+ x = joint_xy[:, :, 0] / (width - 1) * 2 - 1
92
+ y = joint_xy[:, :, 1] / (height - 1) * 2 - 1
93
+ grid = torch.stack((x, y), 2)[:, :, None, :]
94
+ img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num
95
+ img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim
96
+ return img_feat
97
+
98
+
99
+ def soft_argmax_2d(heatmap2d):
100
+ batch_size = heatmap2d.shape[0]
101
+ height, width = heatmap2d.shape[2:]
102
+ heatmap2d = heatmap2d.reshape((batch_size, -1, height * width))
103
+ heatmap2d = F.softmax(heatmap2d, 2)
104
+ heatmap2d = heatmap2d.reshape((batch_size, -1, height, width))
105
+
106
+ accu_x = heatmap2d.sum(dim=(2))
107
+ accu_y = heatmap2d.sum(dim=(3))
108
+
109
+ accu_x = accu_x * torch.arange(width).float().to(cfg.device)[None, None, :]
110
+ accu_y = accu_y * torch.arange(height).float().to(cfg.device)[None, None, :]
111
+
112
+ accu_x = accu_x.sum(dim=2, keepdim=True)
113
+ accu_y = accu_y.sum(dim=2, keepdim=True)
114
+
115
+ coord_out = torch.cat((accu_x, accu_y), dim=2)
116
+ return coord_out
117
+
118
+
119
+ def soft_argmax_3d(heatmap3d):
120
+ batch_size = heatmap3d.shape[0]
121
+ depth, height, width = heatmap3d.shape[2:]
122
+ heatmap3d = heatmap3d.reshape((batch_size, -1, depth * height * width))
123
+ heatmap3d = F.softmax(heatmap3d, 2)
124
+ heatmap3d = heatmap3d.reshape((batch_size, -1, depth, height, width))
125
+
126
+ accu_x = heatmap3d.sum(dim=(2, 3))
127
+ accu_y = heatmap3d.sum(dim=(2, 4))
128
+ accu_z = heatmap3d.sum(dim=(3, 4))
129
+
130
+ accu_x = accu_x * torch.arange(width).float().to(cfg.device)[None, None, :]
131
+ accu_y = accu_y * torch.arange(height).float().to(cfg.device)[None, None, :]
132
+ accu_z = accu_z * torch.arange(depth).float().to(cfg.device)[None, None, :]
133
+
134
+ accu_x = accu_x.sum(dim=2, keepdim=True)
135
+ accu_y = accu_y.sum(dim=2, keepdim=True)
136
+ accu_z = accu_z.sum(dim=2, keepdim=True)
137
+
138
+ coord_out = torch.cat((accu_x, accu_y, accu_z), dim=2)
139
+ return coord_out
140
+
141
+
142
+ def restore_bbox(bbox_center, bbox_size, aspect_ratio, extension_ratio):
143
+ bbox = bbox_center.view(-1, 1, 2) + torch.cat((-bbox_size.view(-1, 1, 2) / 2., bbox_size.view(-1, 1, 2) / 2.),
144
+ 1) # xyxy in (cfg.output_hm_shape[2], cfg.output_hm_shape[1]) space
145
+ bbox[:, :, 0] = bbox[:, :, 0] / cfg.output_hm_shape[2] * cfg.input_body_shape[1]
146
+ bbox[:, :, 1] = bbox[:, :, 1] / cfg.output_hm_shape[1] * cfg.input_body_shape[0]
147
+ bbox = bbox.view(-1, 4)
148
+
149
+ # xyxy -> xywh
150
+ bbox[:, 2] = bbox[:, 2] - bbox[:, 0]
151
+ bbox[:, 3] = bbox[:, 3] - bbox[:, 1]
152
+
153
+ # aspect ratio preserving bbox
154
+ w = bbox[:, 2]
155
+ h = bbox[:, 3]
156
+ c_x = bbox[:, 0] + w / 2.
157
+ c_y = bbox[:, 1] + h / 2.
158
+
159
+ mask1 = w > (aspect_ratio * h)
160
+ mask2 = w < (aspect_ratio * h)
161
+ h[mask1] = w[mask1] / aspect_ratio
162
+ w[mask2] = h[mask2] * aspect_ratio
163
+
164
+ bbox[:, 2] = w * extension_ratio
165
+ bbox[:, 3] = h * extension_ratio
166
+ bbox[:, 0] = c_x - bbox[:, 2] / 2.
167
+ bbox[:, 1] = c_y - bbox[:, 3] / 2.
168
+
169
+ # xywh -> xyxy
170
+ bbox[:, 2] = bbox[:, 2] + bbox[:, 0]
171
+ bbox[:, 3] = bbox[:, 3] + bbox[:, 1]
172
+ return bbox
common/utils/vis.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from mpl_toolkits.mplot3d import Axes3D
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib as mpl
7
+ import os
8
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
9
+ import pyrender
10
+ import trimesh
11
+ from config import cfg
12
+
13
+ def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1):
14
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
15
+ cmap = plt.get_cmap('rainbow')
16
+ colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
17
+ colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
18
+
19
+ # Perform the drawing on a copy of the image, to allow for blending.
20
+ kp_mask = np.copy(img)
21
+
22
+ # Draw the keypoints.
23
+ for l in range(len(kps_lines)):
24
+ i1 = kps_lines[l][0]
25
+ i2 = kps_lines[l][1]
26
+ p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32)
27
+ p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32)
28
+ if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
29
+ cv2.line(
30
+ kp_mask, p1, p2,
31
+ color=colors[l], thickness=2, lineType=cv2.LINE_AA)
32
+ if kps[2, i1] > kp_thresh:
33
+ cv2.circle(
34
+ kp_mask, p1,
35
+ radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
36
+ if kps[2, i2] > kp_thresh:
37
+ cv2.circle(
38
+ kp_mask, p2,
39
+ radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
40
+
41
+ # Blend the keypoints.
42
+ return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
43
+
44
+ def vis_keypoints(img, kps, alpha=1, radius=3, color=None):
45
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
46
+ cmap = plt.get_cmap('rainbow')
47
+ if color is None:
48
+ colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)]
49
+ colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
50
+
51
+ # Perform the drawing on a copy of the image, to allow for blending.
52
+ kp_mask = np.copy(img)
53
+
54
+ # Draw the keypoints.
55
+ for i in range(len(kps)):
56
+ p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32)
57
+ if color is None:
58
+ cv2.circle(kp_mask, p, radius=radius, color=colors[i], thickness=-1, lineType=cv2.LINE_AA)
59
+ else:
60
+ cv2.circle(kp_mask, p, radius=radius, color=color, thickness=-1, lineType=cv2.LINE_AA)
61
+
62
+ # Blend the keypoints.
63
+ return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
64
+
65
+ def vis_mesh(img, mesh_vertex, alpha=0.5):
66
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
67
+ cmap = plt.get_cmap('rainbow')
68
+ colors = [cmap(i) for i in np.linspace(0, 1, len(mesh_vertex))]
69
+ colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
70
+
71
+ # Perform the drawing on a copy of the image, to allow for blending.
72
+ mask = np.copy(img)
73
+
74
+ # Draw the mesh
75
+ for i in range(len(mesh_vertex)):
76
+ p = mesh_vertex[i][0].astype(np.int32), mesh_vertex[i][1].astype(np.int32)
77
+ cv2.circle(mask, p, radius=1, color=colors[i], thickness=-1, lineType=cv2.LINE_AA)
78
+
79
+ # Blend the keypoints.
80
+ return cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
81
+
82
+ def vis_3d_skeleton(kpt_3d, kpt_3d_vis, kps_lines, filename=None):
83
+
84
+ fig = plt.figure()
85
+ ax = fig.add_subplot(111, projection='3d')
86
+
87
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
88
+ cmap = plt.get_cmap('rainbow')
89
+ colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
90
+ colors = [np.array((c[2], c[1], c[0])) for c in colors]
91
+
92
+ for l in range(len(kps_lines)):
93
+ i1 = kps_lines[l][0]
94
+ i2 = kps_lines[l][1]
95
+ x = np.array([kpt_3d[i1,0], kpt_3d[i2,0]])
96
+ y = np.array([kpt_3d[i1,1], kpt_3d[i2,1]])
97
+ z = np.array([kpt_3d[i1,2], kpt_3d[i2,2]])
98
+
99
+ if kpt_3d_vis[i1,0] > 0 and kpt_3d_vis[i2,0] > 0:
100
+ ax.plot(x, z, -y, c=colors[l], linewidth=2)
101
+ if kpt_3d_vis[i1,0] > 0:
102
+ ax.scatter(kpt_3d[i1,0], kpt_3d[i1,2], -kpt_3d[i1,1], c=colors[l], marker='o')
103
+ if kpt_3d_vis[i2,0] > 0:
104
+ ax.scatter(kpt_3d[i2,0], kpt_3d[i2,2], -kpt_3d[i2,1], c=colors[l], marker='o')
105
+
106
+ x_r = np.array([0, cfg.input_shape[1]], dtype=np.float32)
107
+ y_r = np.array([0, cfg.input_shape[0]], dtype=np.float32)
108
+ z_r = np.array([0, 1], dtype=np.float32)
109
+
110
+ if filename is None:
111
+ ax.set_title('3D vis')
112
+ else:
113
+ ax.set_title(filename)
114
+
115
+ ax.set_xlabel('X Label')
116
+ ax.set_ylabel('Z Label')
117
+ ax.set_zlabel('Y Label')
118
+ ax.legend()
119
+
120
+ plt.show()
121
+ cv2.waitKey(0)
122
+
123
+ def save_obj(v, f, file_name='output.obj'):
124
+ obj_file = open(file_name, 'w')
125
+ for i in range(len(v)):
126
+ obj_file.write('v ' + str(v[i][0]) + ' ' + str(v[i][1]) + ' ' + str(v[i][2]) + '\n')
127
+ for i in range(len(f)):
128
+ obj_file.write('f ' + str(f[i][0]+1) + '/' + str(f[i][0]+1) + ' ' + str(f[i][1]+1) + '/' + str(f[i][1]+1) + ' ' + str(f[i][2]+1) + '/' + str(f[i][2]+1) + '\n')
129
+ obj_file.close()
130
+
131
+
132
+ def perspective_projection(vertices, cam_param):
133
+ # vertices: [N, 3]
134
+ # cam_param: [3]
135
+ fx, fy= cam_param['focal']
136
+ cx, cy = cam_param['princpt']
137
+ vertices[:, 0] = vertices[:, 0] * fx / vertices[:, 2] + cx
138
+ vertices[:, 1] = vertices[:, 1] * fy / vertices[:, 2] + cy
139
+ return vertices
140
+
141
+
142
+ def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
143
+ if mesh_as_vertices:
144
+ # to run on cluster where headless pyrender is not supported for A100/V100
145
+ vertices_2d = perspective_projection(mesh, cam_param)
146
+ img = vis_keypoints(img, vertices_2d, alpha=0.8, radius=2, color=(0, 0, 255))
147
+ else:
148
+ # mesh
149
+ mesh = trimesh.Trimesh(mesh, face)
150
+ rot = trimesh.transformations.rotation_matrix(
151
+ np.radians(180), [1, 0, 0])
152
+ mesh.apply_transform(rot)
153
+ material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0))
154
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False)
155
+ scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
156
+ scene.add(mesh, 'mesh')
157
+
158
+ focal, princpt = cam_param['focal'], cam_param['princpt']
159
+ camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
160
+ scene.add(camera)
161
+
162
+ # renderer
163
+ renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
164
+
165
+ # light
166
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
167
+ light_pose = np.eye(4)
168
+ light_pose[:3, 3] = np.array([0, -1, 1])
169
+ scene.add(light, pose=light_pose)
170
+ light_pose[:3, 3] = np.array([0, 1, 1])
171
+ scene.add(light, pose=light_pose)
172
+ light_pose[:3, 3] = np.array([1, 1, 2])
173
+ scene.add(light, pose=light_pose)
174
+
175
+ # render
176
+ rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
177
+ rgb = rgb[:,:,:3].astype(np.float32)
178
+ valid_mask = (depth > 0)[:,:,None]
179
+
180
+ # save to image
181
+ img = rgb * valid_mask + img * (1-valid_mask)
182
+
183
+ return img
main/SMPLer_X.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from nets.smpler_x import PositionNet, HandRotationNet, FaceRegressor, BoxNet, HandRoI, BodyRotationNet
5
+ from nets.loss import CoordLoss, ParamLoss, CELoss
6
+ from utils.human_models import smpl_x
7
+ from utils.transforms import rot6d_to_axis_angle, restore_bbox
8
+ from config import cfg
9
+ import math
10
+ import copy
11
+ from mmpose.models import build_posenet
12
+ from mmcv import Config
13
+
14
+ class Model(nn.Module):
15
+ def __init__(self, encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net,
16
+ hand_rotation_net, face_regressor):
17
+ super(Model, self).__init__()
18
+
19
+ # body
20
+ self.encoder = encoder
21
+ self.body_position_net = body_position_net
22
+ self.body_regressor = body_rotation_net
23
+ self.box_net = box_net
24
+
25
+ # hand
26
+ self.hand_roi_net = hand_roi_net
27
+ self.hand_position_net = hand_position_net
28
+ self.hand_regressor = hand_rotation_net
29
+
30
+ # face
31
+ self.face_regressor = face_regressor
32
+
33
+ self.smplx_layer = copy.deepcopy(smpl_x.layer['neutral']).to(cfg.device)
34
+ self.coord_loss = CoordLoss()
35
+ self.param_loss = ParamLoss()
36
+ self.ce_loss = CELoss()
37
+
38
+ self.body_num_joints = len(smpl_x.pos_joint_part['body'])
39
+ self.hand_joint_num = len(smpl_x.pos_joint_part['rhand'])
40
+
41
+ self.neck = [self.box_net, self.hand_roi_net]
42
+
43
+ self.head = [self.body_position_net, self.body_regressor,
44
+ self.hand_position_net, self.hand_regressor,
45
+ self.face_regressor]
46
+
47
+ self.trainable_modules = [self.encoder, self.body_position_net, self.body_regressor,
48
+ self.box_net, self.hand_position_net,
49
+ self.hand_roi_net, self.hand_regressor, self.face_regressor]
50
+ self.special_trainable_modules = []
51
+
52
+ # backbone:
53
+ param_bb = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
54
+ # neck
55
+ param_neck = 0
56
+ for module in self.neck:
57
+ param_neck += sum(p.numel() for p in module.parameters() if p.requires_grad)
58
+ # head
59
+ param_head = 0
60
+ for module in self.head:
61
+ param_head += sum(p.numel() for p in module.parameters() if p.requires_grad)
62
+
63
+ param_net = param_bb + param_neck + param_head
64
+
65
+ # print('#parameters:')
66
+ # print(f'{param_bb}, {param_neck}, {param_head}, {param_net}')
67
+
68
+ def get_camera_trans(self, cam_param):
69
+ # camera translation
70
+ t_xy = cam_param[:, :2]
71
+ gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive
72
+ k_value = torch.FloatTensor([math.sqrt(cfg.focal[0] * cfg.focal[1] * cfg.camera_3d_size * cfg.camera_3d_size / (
73
+ cfg.input_body_shape[0] * cfg.input_body_shape[1]))]).to(cfg.device).view(-1)
74
+ t_z = k_value * gamma
75
+ cam_trans = torch.cat((t_xy, t_z[:, None]), 1)
76
+ return cam_trans
77
+
78
+ def get_coord(self, root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose, shape, expr, cam_trans, mode):
79
+ batch_size = root_pose.shape[0]
80
+ zero_pose = torch.zeros((1, 3)).float().to(cfg.device).repeat(batch_size, 1) # eye poses
81
+ output = self.smplx_layer(betas=shape, body_pose=body_pose, global_orient=root_pose, right_hand_pose=rhand_pose,
82
+ left_hand_pose=lhand_pose, jaw_pose=jaw_pose, leye_pose=zero_pose,
83
+ reye_pose=zero_pose, expression=expr)
84
+ # camera-centered 3D coordinate
85
+ mesh_cam = output.vertices
86
+ if mode == 'test' and cfg.testset == 'AGORA': # use 144 joints for AGORA evaluation
87
+ joint_cam = output.joints
88
+ else:
89
+ joint_cam = output.joints[:, smpl_x.joint_idx, :]
90
+
91
+ # project 3D coordinates to 2D space
92
+ if mode == 'train' and len(cfg.trainset_3d) == 1 and cfg.trainset_3d[0] == 'AGORA' and len(
93
+ cfg.trainset_2d) == 0: # prevent gradients from backpropagating to SMPLX paraemter regression module
94
+ x = (joint_cam[:, :, 0].detach() + cam_trans[:, None, 0]) / (
95
+ joint_cam[:, :, 2].detach() + cam_trans[:, None, 2] + 1e-4) * cfg.focal[0] + cfg.princpt[0]
96
+ y = (joint_cam[:, :, 1].detach() + cam_trans[:, None, 1]) / (
97
+ joint_cam[:, :, 2].detach() + cam_trans[:, None, 2] + 1e-4) * cfg.focal[1] + cfg.princpt[1]
98
+ else:
99
+ x = (joint_cam[:, :, 0] + cam_trans[:, None, 0]) / (joint_cam[:, :, 2] + cam_trans[:, None, 2] + 1e-4) * \
100
+ cfg.focal[0] + cfg.princpt[0]
101
+ y = (joint_cam[:, :, 1] + cam_trans[:, None, 1]) / (joint_cam[:, :, 2] + cam_trans[:, None, 2] + 1e-4) * \
102
+ cfg.focal[1] + cfg.princpt[1]
103
+ x = x / cfg.input_body_shape[1] * cfg.output_hm_shape[2]
104
+ y = y / cfg.input_body_shape[0] * cfg.output_hm_shape[1]
105
+ joint_proj = torch.stack((x, y), 2)
106
+
107
+ # root-relative 3D coordinates
108
+ root_cam = joint_cam[:, smpl_x.root_joint_idx, None, :]
109
+ joint_cam = joint_cam - root_cam
110
+ mesh_cam = mesh_cam + cam_trans[:, None, :] # for rendering
111
+ joint_cam_wo_ra = joint_cam.clone()
112
+
113
+ # left hand root (left wrist)-relative 3D coordinatese
114
+ lhand_idx = smpl_x.joint_part['lhand']
115
+ lhand_cam = joint_cam[:, lhand_idx, :]
116
+ lwrist_cam = joint_cam[:, smpl_x.lwrist_idx, None, :]
117
+ lhand_cam = lhand_cam - lwrist_cam
118
+ joint_cam = torch.cat((joint_cam[:, :lhand_idx[0], :], lhand_cam, joint_cam[:, lhand_idx[-1] + 1:, :]), 1)
119
+
120
+ # right hand root (right wrist)-relative 3D coordinatese
121
+ rhand_idx = smpl_x.joint_part['rhand']
122
+ rhand_cam = joint_cam[:, rhand_idx, :]
123
+ rwrist_cam = joint_cam[:, smpl_x.rwrist_idx, None, :]
124
+ rhand_cam = rhand_cam - rwrist_cam
125
+ joint_cam = torch.cat((joint_cam[:, :rhand_idx[0], :], rhand_cam, joint_cam[:, rhand_idx[-1] + 1:, :]), 1)
126
+
127
+ # face root (neck)-relative 3D coordinates
128
+ face_idx = smpl_x.joint_part['face']
129
+ face_cam = joint_cam[:, face_idx, :]
130
+ neck_cam = joint_cam[:, smpl_x.neck_idx, None, :]
131
+ face_cam = face_cam - neck_cam
132
+ joint_cam = torch.cat((joint_cam[:, :face_idx[0], :], face_cam, joint_cam[:, face_idx[-1] + 1:, :]), 1)
133
+
134
+ return joint_proj, joint_cam, joint_cam_wo_ra, mesh_cam
135
+
136
+ def generate_mesh_gt(self, targets, mode):
137
+ if 'smplx_mesh_cam' in targets:
138
+ return targets['smplx_mesh_cam']
139
+ nums = [3, 63, 45, 45, 3]
140
+ accu = []
141
+ temp = 0
142
+ for num in nums:
143
+ temp += num
144
+ accu.append(temp)
145
+ pose = targets['smplx_pose']
146
+ root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose = \
147
+ pose[:, :accu[0]], pose[:, accu[0]:accu[1]], pose[:, accu[1]:accu[2]], pose[:, accu[2]:accu[3]], pose[:,
148
+ accu[3]:
149
+ accu[4]]
150
+ # print(lhand_pose)
151
+ shape = targets['smplx_shape']
152
+ expr = targets['smplx_expr']
153
+ cam_trans = targets['smplx_cam_trans']
154
+
155
+ # final output
156
+ joint_proj, joint_cam, joint_cam_wo_ra, mesh_cam = self.get_coord(root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose, shape,
157
+ expr, cam_trans, mode)
158
+
159
+ return mesh_cam
160
+
161
+ def bbox_split(self, bbox):
162
+ # bbox:[bs, 3, 3]
163
+ lhand_bbox_center, rhand_bbox_center, face_bbox_center = \
164
+ bbox[:, 0, :2], bbox[:, 1, :2], bbox[:, 2, :2]
165
+ return lhand_bbox_center, rhand_bbox_center, face_bbox_center
166
+
167
+ def forward(self, inputs, targets, meta_info, mode):
168
+
169
+ body_img = F.interpolate(inputs['img'], cfg.input_body_shape)
170
+
171
+ # 1. Encoder
172
+ img_feat, task_tokens = self.encoder(body_img) # task_token:[bs, N, c]
173
+ shape_token, cam_token, expr_token, jaw_pose_token, hand_token, body_pose_token = \
174
+ task_tokens[:, 0], task_tokens[:, 1], task_tokens[:, 2], task_tokens[:, 3], task_tokens[:, 4:6], task_tokens[:, 6:]
175
+
176
+ # 2. Body Regressor
177
+ body_joint_hm, body_joint_img = self.body_position_net(img_feat)
178
+ root_pose, body_pose, shape, cam_param, = self.body_regressor(body_pose_token, shape_token, cam_token, body_joint_img.detach())
179
+ root_pose = rot6d_to_axis_angle(root_pose)
180
+ body_pose = rot6d_to_axis_angle(body_pose.reshape(-1, 6)).reshape(body_pose.shape[0], -1) # (N, J_R*3)
181
+ cam_trans = self.get_camera_trans(cam_param)
182
+
183
+ # 3. Hand and Face BBox Estimation
184
+ lhand_bbox_center, lhand_bbox_size, rhand_bbox_center, rhand_bbox_size, face_bbox_center, face_bbox_size = self.box_net(img_feat, body_joint_hm.detach())
185
+ lhand_bbox = restore_bbox(lhand_bbox_center, lhand_bbox_size, cfg.input_hand_shape[1] / cfg.input_hand_shape[0], 2.0).detach() # xyxy in (cfg.input_body_shape[1], cfg.input_body_shape[0]) space
186
+ rhand_bbox = restore_bbox(rhand_bbox_center, rhand_bbox_size, cfg.input_hand_shape[1] / cfg.input_hand_shape[0], 2.0).detach() # xyxy in (cfg.input_body_shape[1], cfg.input_body_shape[0]) space
187
+ face_bbox = restore_bbox(face_bbox_center, face_bbox_size, cfg.input_face_shape[1] / cfg.input_face_shape[0], 1.5).detach() # xyxy in (cfg.input_body_shape[1], cfg.input_body_shape[0]) space
188
+
189
+ # 4. Differentiable Feature-level Hand Crop-Upsample
190
+ # hand_feat: list, [bsx2, c, cfg.output_hm_shape[1]*scale, cfg.output_hm_shape[2]*scale]
191
+ hand_feat = self.hand_roi_net(img_feat, lhand_bbox, rhand_bbox) # hand_feat: flipped left hand + right hand
192
+
193
+ # 5. Hand/Face Regressor
194
+ # hand regressor
195
+ _, hand_joint_img = self.hand_position_net(hand_feat) # (2N, J_P, 3)
196
+ hand_pose = self.hand_regressor(hand_feat, hand_joint_img.detach())
197
+ hand_pose = rot6d_to_axis_angle(hand_pose.reshape(-1, 6)).reshape(hand_feat.shape[0], -1) # (2N, J_R*3)
198
+ # restore flipped left hand joint coordinates
199
+ batch_size = hand_joint_img.shape[0] // 2
200
+ lhand_joint_img = hand_joint_img[:batch_size, :, :]
201
+ lhand_joint_img = torch.cat((cfg.output_hand_hm_shape[2] - 1 - lhand_joint_img[:, :, 0:1], lhand_joint_img[:, :, 1:]), 2)
202
+ rhand_joint_img = hand_joint_img[batch_size:, :, :]
203
+ # restore flipped left hand joint rotations
204
+ batch_size = hand_pose.shape[0] // 2
205
+ lhand_pose = hand_pose[:batch_size, :].reshape(-1, len(smpl_x.orig_joint_part['lhand']), 3)
206
+ lhand_pose = torch.cat((lhand_pose[:, :, 0:1], -lhand_pose[:, :, 1:3]), 2).view(batch_size, -1)
207
+ rhand_pose = hand_pose[batch_size:, :]
208
+
209
+ # hand regressor
210
+ expr, jaw_pose = self.face_regressor(expr_token, jaw_pose_token)
211
+ jaw_pose = rot6d_to_axis_angle(jaw_pose)
212
+
213
+ # final output
214
+ joint_proj, joint_cam, joint_cam_wo_ra, mesh_cam = self.get_coord(root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose, shape, expr, cam_trans, mode)
215
+ pose = torch.cat((root_pose, body_pose, lhand_pose, rhand_pose, jaw_pose), 1)
216
+ joint_img = torch.cat((body_joint_img, lhand_joint_img, rhand_joint_img), 1)
217
+
218
+ if mode == 'test' and 'smplx_pose' in targets:
219
+ mesh_pseudo_gt = self.generate_mesh_gt(targets, mode)
220
+
221
+ if mode == 'train':
222
+ # loss functions
223
+ loss = {}
224
+
225
+ smplx_kps_3d_weight = getattr(cfg, 'smplx_kps_3d_weight', 1.0)
226
+ smplx_kps_3d_weight = getattr(cfg, 'smplx_kps_weight', smplx_kps_3d_weight) # old config
227
+
228
+ smplx_kps_2d_weight = getattr(cfg, 'smplx_kps_2d_weight', 1.0)
229
+ net_kps_2d_weight = getattr(cfg, 'net_kps_2d_weight', 1.0)
230
+
231
+ smplx_pose_weight = getattr(cfg, 'smplx_pose_weight', 1.0)
232
+ smplx_shape_weight = getattr(cfg, 'smplx_loss_weight', 1.0)
233
+ # smplx_orient_weight = getattr(cfg, 'smplx_orient_weight', smplx_pose_weight) # if not specified, use the same weight as pose
234
+
235
+
236
+ # do not supervise root pose if original agora json is used
237
+ if getattr(cfg, 'agora_fix_global_orient_transl', False):
238
+ # loss['smplx_pose'] = self.param_loss(pose, targets['smplx_pose'], meta_info['smplx_pose_valid'])[:, 3:] * smplx_pose_weight
239
+ if hasattr(cfg, 'smplx_orient_weight'):
240
+ smplx_orient_weight = getattr(cfg, 'smplx_orient_weight')
241
+ loss['smplx_orient'] = self.param_loss(pose, targets['smplx_pose'], meta_info['smplx_pose_valid'])[:, :3] * smplx_orient_weight
242
+
243
+ loss['smplx_pose'] = self.param_loss(pose, targets['smplx_pose'], meta_info['smplx_pose_valid']) * smplx_pose_weight
244
+
245
+ else:
246
+ loss['smplx_pose'] = self.param_loss(pose, targets['smplx_pose'], meta_info['smplx_pose_valid'])[:, 3:] * smplx_pose_weight
247
+
248
+ loss['smplx_shape'] = self.param_loss(shape, targets['smplx_shape'],
249
+ meta_info['smplx_shape_valid'][:, None]) * smplx_shape_weight
250
+ loss['smplx_expr'] = self.param_loss(expr, targets['smplx_expr'], meta_info['smplx_expr_valid'][:, None])
251
+
252
+ # supervision for keypoints3d wo/ ra
253
+ loss['joint_cam'] = self.coord_loss(joint_cam_wo_ra, targets['joint_cam'], meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) * smplx_kps_3d_weight
254
+ # supervision for keypoints3d w/ ra
255
+ loss['smplx_joint_cam'] = self.coord_loss(joint_cam, targets['smplx_joint_cam'], meta_info['smplx_joint_valid']) * smplx_kps_3d_weight
256
+
257
+ if not (meta_info['lhand_bbox_valid'] == 0).all():
258
+ loss['lhand_bbox'] = (self.coord_loss(lhand_bbox_center, targets['lhand_bbox_center'], meta_info['lhand_bbox_valid'][:, None]) +
259
+ self.coord_loss(lhand_bbox_size, targets['lhand_bbox_size'], meta_info['lhand_bbox_valid'][:, None]))
260
+ if not (meta_info['rhand_bbox_valid'] == 0).all():
261
+ loss['rhand_bbox'] = (self.coord_loss(rhand_bbox_center, targets['rhand_bbox_center'], meta_info['rhand_bbox_valid'][:, None]) +
262
+ self.coord_loss(rhand_bbox_size, targets['rhand_bbox_size'], meta_info['rhand_bbox_valid'][:, None]))
263
+ if not (meta_info['face_bbox_valid'] == 0).all():
264
+ loss['face_bbox'] = (self.coord_loss(face_bbox_center, targets['face_bbox_center'], meta_info['face_bbox_valid'][:, None]) +
265
+ self.coord_loss(face_bbox_size, targets['face_bbox_size'], meta_info['face_bbox_valid'][:, None]))
266
+
267
+ # if (meta_info['face_bbox_valid'] == 0).all():
268
+ # out = {}
269
+ targets['original_joint_img'] = targets['joint_img'].clone()
270
+ targets['original_smplx_joint_img'] = targets['smplx_joint_img'].clone()
271
+ # out['original_joint_proj'] = joint_proj.clone()
272
+ if not (meta_info['lhand_bbox_valid'] + meta_info['rhand_bbox_valid'] == 0).all():
273
+
274
+ # change hand target joint_img and joint_trunc according to hand bbox (cfg.output_hm_shape -> downsampled hand bbox space)
275
+ for part_name, bbox in (('lhand', lhand_bbox), ('rhand', rhand_bbox)):
276
+ for coord_name, trunc_name in (('joint_img', 'joint_trunc'), ('smplx_joint_img', 'smplx_joint_trunc')):
277
+ x = targets[coord_name][:, smpl_x.joint_part[part_name], 0]
278
+ y = targets[coord_name][:, smpl_x.joint_part[part_name], 1]
279
+ z = targets[coord_name][:, smpl_x.joint_part[part_name], 2]
280
+ trunc = meta_info[trunc_name][:, smpl_x.joint_part[part_name], 0]
281
+
282
+ x -= (bbox[:, None, 0] / cfg.input_body_shape[1] * cfg.output_hm_shape[2])
283
+ x *= (cfg.output_hand_hm_shape[2] / (
284
+ (bbox[:, None, 2] - bbox[:, None, 0]) / cfg.input_body_shape[1] * cfg.output_hm_shape[
285
+ 2]))
286
+ y -= (bbox[:, None, 1] / cfg.input_body_shape[0] * cfg.output_hm_shape[1])
287
+ y *= (cfg.output_hand_hm_shape[1] / (
288
+ (bbox[:, None, 3] - bbox[:, None, 1]) / cfg.input_body_shape[0] * cfg.output_hm_shape[
289
+ 1]))
290
+ z *= cfg.output_hand_hm_shape[0] / cfg.output_hm_shape[0]
291
+ trunc *= ((x >= 0) * (x < cfg.output_hand_hm_shape[2]) * (y >= 0) * (
292
+ y < cfg.output_hand_hm_shape[1]))
293
+
294
+ coord = torch.stack((x, y, z), 2)
295
+ trunc = trunc[:, :, None]
296
+ targets[coord_name] = torch.cat((targets[coord_name][:, :smpl_x.joint_part[part_name][0], :], coord,
297
+ targets[coord_name][:, smpl_x.joint_part[part_name][-1] + 1:, :]),
298
+ 1)
299
+ meta_info[trunc_name] = torch.cat((meta_info[trunc_name][:, :smpl_x.joint_part[part_name][0], :],
300
+ trunc,
301
+ meta_info[trunc_name][:, smpl_x.joint_part[part_name][-1] + 1:,
302
+ :]), 1)
303
+
304
+ # change hand projected joint coordinates according to hand bbox (cfg.output_hm_shape -> hand bbox space)
305
+ for part_name, bbox in (('lhand', lhand_bbox), ('rhand', rhand_bbox)):
306
+ x = joint_proj[:, smpl_x.joint_part[part_name], 0]
307
+ y = joint_proj[:, smpl_x.joint_part[part_name], 1]
308
+
309
+ x -= (bbox[:, None, 0] / cfg.input_body_shape[1] * cfg.output_hm_shape[2])
310
+ x *= (cfg.output_hand_hm_shape[2] / (
311
+ (bbox[:, None, 2] - bbox[:, None, 0]) / cfg.input_body_shape[1] * cfg.output_hm_shape[2]))
312
+ y -= (bbox[:, None, 1] / cfg.input_body_shape[0] * cfg.output_hm_shape[1])
313
+ y *= (cfg.output_hand_hm_shape[1] / (
314
+ (bbox[:, None, 3] - bbox[:, None, 1]) / cfg.input_body_shape[0] * cfg.output_hm_shape[1]))
315
+
316
+ coord = torch.stack((x, y), 2)
317
+ trans = []
318
+ for bid in range(coord.shape[0]):
319
+ mask = meta_info['joint_trunc'][bid, smpl_x.joint_part[part_name], 0] == 1
320
+ if torch.sum(mask) == 0:
321
+ trans.append(torch.zeros((2)).float().to(cfg.device))
322
+ else:
323
+ trans.append((-coord[bid, mask, :2] + targets['joint_img'][:, smpl_x.joint_part[part_name], :][
324
+ bid, mask, :2]).mean(0))
325
+ trans = torch.stack(trans)[:, None, :]
326
+ coord = coord + trans # global translation alignment
327
+ joint_proj = torch.cat((joint_proj[:, :smpl_x.joint_part[part_name][0], :], coord,
328
+ joint_proj[:, smpl_x.joint_part[part_name][-1] + 1:, :]), 1)
329
+
330
+ if not (meta_info['face_bbox_valid'] == 0).all():
331
+ # change face projected joint coordinates according to face bbox (cfg.output_hm_shape -> face bbox space)
332
+ coord = joint_proj[:, smpl_x.joint_part['face'], :]
333
+ trans = []
334
+ for bid in range(coord.shape[0]):
335
+ mask = meta_info['joint_trunc'][bid, smpl_x.joint_part['face'], 0] == 1
336
+ if torch.sum(mask) == 0:
337
+ trans.append(torch.zeros((2)).float().to(cfg.device))
338
+ else:
339
+ trans.append((-coord[bid, mask, :2] + targets['joint_img'][:, smpl_x.joint_part['face'], :][bid,
340
+ mask, :2]).mean(0))
341
+ trans = torch.stack(trans)[:, None, :]
342
+ coord = coord + trans # global translation alignment
343
+ joint_proj = torch.cat((joint_proj[:, :smpl_x.joint_part['face'][0], :], coord,
344
+ joint_proj[:, smpl_x.joint_part['face'][-1] + 1:, :]), 1)
345
+
346
+ loss['joint_proj'] = self.coord_loss(joint_proj, targets['joint_img'][:, :, :2], meta_info['joint_trunc']) * smplx_kps_2d_weight
347
+ loss['joint_img'] = self.coord_loss(joint_img, smpl_x.reduce_joint_set(targets['joint_img']),
348
+ smpl_x.reduce_joint_set(meta_info['joint_trunc']), meta_info['is_3D']) * net_kps_2d_weight
349
+
350
+ loss['smplx_joint_img'] = self.coord_loss(joint_img, smpl_x.reduce_joint_set(targets['smplx_joint_img']),
351
+ smpl_x.reduce_joint_set(meta_info['smplx_joint_trunc'])) * net_kps_2d_weight
352
+
353
+ return loss
354
+ else:
355
+ # change hand output joint_img according to hand bbox
356
+ for part_name, bbox in (('lhand', lhand_bbox), ('rhand', rhand_bbox)):
357
+ joint_img[:, smpl_x.pos_joint_part[part_name], 0] *= (
358
+ ((bbox[:, None, 2] - bbox[:, None, 0]) / cfg.input_body_shape[1] * cfg.output_hm_shape[2]) /
359
+ cfg.output_hand_hm_shape[2])
360
+ joint_img[:, smpl_x.pos_joint_part[part_name], 0] += (
361
+ bbox[:, None, 0] / cfg.input_body_shape[1] * cfg.output_hm_shape[2])
362
+ joint_img[:, smpl_x.pos_joint_part[part_name], 1] *= (
363
+ ((bbox[:, None, 3] - bbox[:, None, 1]) / cfg.input_body_shape[0] * cfg.output_hm_shape[1]) /
364
+ cfg.output_hand_hm_shape[1])
365
+ joint_img[:, smpl_x.pos_joint_part[part_name], 1] += (
366
+ bbox[:, None, 1] / cfg.input_body_shape[0] * cfg.output_hm_shape[1])
367
+
368
+ # change input_body_shape to input_img_shape
369
+ for bbox in (lhand_bbox, rhand_bbox, face_bbox):
370
+ bbox[:, 0] *= cfg.input_img_shape[1] / cfg.input_body_shape[1]
371
+ bbox[:, 1] *= cfg.input_img_shape[0] / cfg.input_body_shape[0]
372
+ bbox[:, 2] *= cfg.input_img_shape[1] / cfg.input_body_shape[1]
373
+ bbox[:, 3] *= cfg.input_img_shape[0] / cfg.input_body_shape[0]
374
+
375
+ # test output
376
+ out = {}
377
+ out['img'] = inputs['img']
378
+ out['joint_img'] = joint_img
379
+ out['smplx_joint_proj'] = joint_proj
380
+ out['smplx_mesh_cam'] = mesh_cam
381
+ out['smplx_root_pose'] = root_pose
382
+ out['smplx_body_pose'] = body_pose
383
+ out['smplx_lhand_pose'] = lhand_pose
384
+ out['smplx_rhand_pose'] = rhand_pose
385
+ out['smplx_jaw_pose'] = jaw_pose
386
+ out['smplx_shape'] = shape
387
+ out['smplx_expr'] = expr
388
+ out['cam_trans'] = cam_trans
389
+ out['lhand_bbox'] = lhand_bbox
390
+ out['rhand_bbox'] = rhand_bbox
391
+ out['face_bbox'] = face_bbox
392
+ if 'smplx_shape' in targets:
393
+ out['smplx_shape_target'] = targets['smplx_shape']
394
+ if 'img_path' in meta_info:
395
+ out['img_path'] = meta_info['img_path']
396
+ if 'smplx_pose' in targets:
397
+ out['smplx_mesh_cam_pseudo_gt'] = mesh_pseudo_gt
398
+ if 'smplx_mesh_cam' in targets:
399
+ out['smplx_mesh_cam_target'] = targets['smplx_mesh_cam']
400
+ if 'smpl_mesh_cam' in targets:
401
+ out['smpl_mesh_cam_target'] = targets['smpl_mesh_cam']
402
+ if 'bb2img_trans' in meta_info:
403
+ out['bb2img_trans'] = meta_info['bb2img_trans']
404
+ if 'gt_smplx_transl' in meta_info:
405
+ out['gt_smplx_transl'] = meta_info['gt_smplx_transl']
406
+
407
+ return out
408
+
409
+ def init_weights(m):
410
+ try:
411
+ if type(m) == nn.ConvTranspose2d:
412
+ nn.init.normal_(m.weight, std=0.001)
413
+ elif type(m) == nn.Conv2d:
414
+ nn.init.normal_(m.weight, std=0.001)
415
+ nn.init.constant_(m.bias, 0)
416
+ elif type(m) == nn.BatchNorm2d:
417
+ nn.init.constant_(m.weight, 1)
418
+ nn.init.constant_(m.bias, 0)
419
+ elif type(m) == nn.Linear:
420
+ nn.init.normal_(m.weight, std=0.01)
421
+ nn.init.constant_(m.bias, 0)
422
+ except AttributeError:
423
+ pass
424
+
425
+
426
+ def get_model(mode):
427
+
428
+ # body
429
+ vit_cfg = Config.fromfile(cfg.encoder_config_file)
430
+ vit = build_posenet(vit_cfg.model)
431
+ body_position_net = PositionNet('body', feat_dim=cfg.feat_dim)
432
+ body_rotation_net = BodyRotationNet(feat_dim=cfg.feat_dim)
433
+ box_net = BoxNet(feat_dim=cfg.feat_dim)
434
+
435
+ # hand
436
+ hand_position_net = PositionNet('hand', feat_dim=cfg.feat_dim)
437
+ hand_roi_net = HandRoI(feat_dim=cfg.feat_dim, upscale=cfg.upscale)
438
+ hand_rotation_net = HandRotationNet('hand', feat_dim=cfg.feat_dim)
439
+
440
+ # face
441
+ face_regressor = FaceRegressor(feat_dim=cfg.feat_dim)
442
+
443
+ if mode == 'train':
444
+ # body
445
+ if not getattr(cfg, 'random_init', False):
446
+ encoder_pretrained_model = torch.load(cfg.encoder_pretrained_model_path)['state_dict']
447
+ vit.load_state_dict(encoder_pretrained_model, strict=False)
448
+ print(f"Initialize encoder from {cfg.encoder_pretrained_model_path}")
449
+ else:
450
+ print('Random init!!!!!!!')
451
+
452
+ body_position_net.apply(init_weights)
453
+ body_rotation_net.apply(init_weights)
454
+ box_net.apply(init_weights)
455
+
456
+ # hand
457
+ hand_position_net.apply(init_weights)
458
+ hand_roi_net.apply(init_weights)
459
+ hand_rotation_net.apply(init_weights)
460
+
461
+ # face
462
+ face_regressor.apply(init_weights)
463
+
464
+ encoder = vit.backbone
465
+
466
+ model = Model(encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net, hand_rotation_net,
467
+ face_regressor)
468
+ return model
main/_base_/datasets/300w.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_info = dict(
2
+ dataset_name='300w',
3
+ paper_info=dict(
4
+ author='Sagonas, Christos and Antonakos, Epameinondas '
5
+ 'and Tzimiropoulos, Georgios and Zafeiriou, Stefanos '
6
+ 'and Pantic, Maja',
7
+ title='300 faces in-the-wild challenge: '
8
+ 'Database and results',
9
+ container='Image and vision computing',
10
+ year='2016',
11
+ homepage='https://ibug.doc.ic.ac.uk/resources/300-W/',
12
+ ),
13
+ keypoint_info={
14
+ 0:
15
+ dict(
16
+ name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-16'),
17
+ 1:
18
+ dict(
19
+ name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-15'),
20
+ 2:
21
+ dict(
22
+ name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-14'),
23
+ 3:
24
+ dict(
25
+ name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-13'),
26
+ 4:
27
+ dict(
28
+ name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-12'),
29
+ 5:
30
+ dict(
31
+ name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-11'),
32
+ 6:
33
+ dict(
34
+ name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-10'),
35
+ 7:
36
+ dict(name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-9'),
37
+ 8:
38
+ dict(name='kpt-8', id=8, color=[255, 255, 255], type='', swap=''),
39
+ 9:
40
+ dict(name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-7'),
41
+ 10:
42
+ dict(
43
+ name='kpt-10', id=10, color=[255, 255, 255], type='',
44
+ swap='kpt-6'),
45
+ 11:
46
+ dict(
47
+ name='kpt-11', id=11, color=[255, 255, 255], type='',
48
+ swap='kpt-5'),
49
+ 12:
50
+ dict(
51
+ name='kpt-12', id=12, color=[255, 255, 255], type='',
52
+ swap='kpt-4'),
53
+ 13:
54
+ dict(
55
+ name='kpt-13', id=13, color=[255, 255, 255], type='',
56
+ swap='kpt-3'),
57
+ 14:
58
+ dict(
59
+ name='kpt-14', id=14, color=[255, 255, 255], type='',
60
+ swap='kpt-2'),
61
+ 15:
62
+ dict(
63
+ name='kpt-15', id=15, color=[255, 255, 255], type='',
64
+ swap='kpt-1'),
65
+ 16:
66
+ dict(
67
+ name='kpt-16', id=16, color=[255, 255, 255], type='',
68
+ swap='kpt-0'),
69
+ 17:
70
+ dict(
71
+ name='kpt-17',
72
+ id=17,
73
+ color=[255, 255, 255],
74
+ type='',
75
+ swap='kpt-26'),
76
+ 18:
77
+ dict(
78
+ name='kpt-18',
79
+ id=18,
80
+ color=[255, 255, 255],
81
+ type='',
82
+ swap='kpt-25'),
83
+ 19:
84
+ dict(
85
+ name='kpt-19',
86
+ id=19,
87
+ color=[255, 255, 255],
88
+ type='',
89
+ swap='kpt-24'),
90
+ 20:
91
+ dict(
92
+ name='kpt-20',
93
+ id=20,
94
+ color=[255, 255, 255],
95
+ type='',
96
+ swap='kpt-23'),
97
+ 21:
98
+ dict(
99
+ name='kpt-21',
100
+ id=21,
101
+ color=[255, 255, 255],
102
+ type='',
103
+ swap='kpt-22'),
104
+ 22:
105
+ dict(
106
+ name='kpt-22',
107
+ id=22,
108
+ color=[255, 255, 255],
109
+ type='',
110
+ swap='kpt-21'),
111
+ 23:
112
+ dict(
113
+ name='kpt-23',
114
+ id=23,
115
+ color=[255, 255, 255],
116
+ type='',
117
+ swap='kpt-20'),
118
+ 24:
119
+ dict(
120
+ name='kpt-24',
121
+ id=24,
122
+ color=[255, 255, 255],
123
+ type='',
124
+ swap='kpt-19'),
125
+ 25:
126
+ dict(
127
+ name='kpt-25',
128
+ id=25,
129
+ color=[255, 255, 255],
130
+ type='',
131
+ swap='kpt-18'),
132
+ 26:
133
+ dict(
134
+ name='kpt-26',
135
+ id=26,
136
+ color=[255, 255, 255],
137
+ type='',
138
+ swap='kpt-17'),
139
+ 27:
140
+ dict(name='kpt-27', id=27, color=[255, 255, 255], type='', swap=''),
141
+ 28:
142
+ dict(name='kpt-28', id=28, color=[255, 255, 255], type='', swap=''),
143
+ 29:
144
+ dict(name='kpt-29', id=29, color=[255, 255, 255], type='', swap=''),
145
+ 30:
146
+ dict(name='kpt-30', id=30, color=[255, 255, 255], type='', swap=''),
147
+ 31:
148
+ dict(
149
+ name='kpt-31',
150
+ id=31,
151
+ color=[255, 255, 255],
152
+ type='',
153
+ swap='kpt-35'),
154
+ 32:
155
+ dict(
156
+ name='kpt-32',
157
+ id=32,
158
+ color=[255, 255, 255],
159
+ type='',
160
+ swap='kpt-34'),
161
+ 33:
162
+ dict(name='kpt-33', id=33, color=[255, 255, 255], type='', swap=''),
163
+ 34:
164
+ dict(
165
+ name='kpt-34',
166
+ id=34,
167
+ color=[255, 255, 255],
168
+ type='',
169
+ swap='kpt-32'),
170
+ 35:
171
+ dict(
172
+ name='kpt-35',
173
+ id=35,
174
+ color=[255, 255, 255],
175
+ type='',
176
+ swap='kpt-31'),
177
+ 36:
178
+ dict(
179
+ name='kpt-36',
180
+ id=36,
181
+ color=[255, 255, 255],
182
+ type='',
183
+ swap='kpt-45'),
184
+ 37:
185
+ dict(
186
+ name='kpt-37',
187
+ id=37,
188
+ color=[255, 255, 255],
189
+ type='',
190
+ swap='kpt-44'),
191
+ 38:
192
+ dict(
193
+ name='kpt-38',
194
+ id=38,
195
+ color=[255, 255, 255],
196
+ type='',
197
+ swap='kpt-43'),
198
+ 39:
199
+ dict(
200
+ name='kpt-39',
201
+ id=39,
202
+ color=[255, 255, 255],
203
+ type='',
204
+ swap='kpt-42'),
205
+ 40:
206
+ dict(
207
+ name='kpt-40',
208
+ id=40,
209
+ color=[255, 255, 255],
210
+ type='',
211
+ swap='kpt-47'),
212
+ 41:
213
+ dict(
214
+ name='kpt-41',
215
+ id=41,
216
+ color=[255, 255, 255],
217
+ type='',
218
+ swap='kpt-46'),
219
+ 42:
220
+ dict(
221
+ name='kpt-42',
222
+ id=42,
223
+ color=[255, 255, 255],
224
+ type='',
225
+ swap='kpt-39'),
226
+ 43:
227
+ dict(
228
+ name='kpt-43',
229
+ id=43,
230
+ color=[255, 255, 255],
231
+ type='',
232
+ swap='kpt-38'),
233
+ 44:
234
+ dict(
235
+ name='kpt-44',
236
+ id=44,
237
+ color=[255, 255, 255],
238
+ type='',
239
+ swap='kpt-37'),
240
+ 45:
241
+ dict(
242
+ name='kpt-45',
243
+ id=45,
244
+ color=[255, 255, 255],
245
+ type='',
246
+ swap='kpt-36'),
247
+ 46:
248
+ dict(
249
+ name='kpt-46',
250
+ id=46,
251
+ color=[255, 255, 255],
252
+ type='',
253
+ swap='kpt-41'),
254
+ 47:
255
+ dict(
256
+ name='kpt-47',
257
+ id=47,
258
+ color=[255, 255, 255],
259
+ type='',
260
+ swap='kpt-40'),
261
+ 48:
262
+ dict(
263
+ name='kpt-48',
264
+ id=48,
265
+ color=[255, 255, 255],
266
+ type='',
267
+ swap='kpt-54'),
268
+ 49:
269
+ dict(
270
+ name='kpt-49',
271
+ id=49,
272
+ color=[255, 255, 255],
273
+ type='',
274
+ swap='kpt-53'),
275
+ 50:
276
+ dict(
277
+ name='kpt-50',
278
+ id=50,
279
+ color=[255, 255, 255],
280
+ type='',
281
+ swap='kpt-52'),
282
+ 51:
283
+ dict(name='kpt-51', id=51, color=[255, 255, 255], type='', swap=''),
284
+ 52:
285
+ dict(
286
+ name='kpt-52',
287
+ id=52,
288
+ color=[255, 255, 255],
289
+ type='',
290
+ swap='kpt-50'),
291
+ 53:
292
+ dict(
293
+ name='kpt-53',
294
+ id=53,
295
+ color=[255, 255, 255],
296
+ type='',
297
+ swap='kpt-49'),
298
+ 54:
299
+ dict(
300
+ name='kpt-54',
301
+ id=54,
302
+ color=[255, 255, 255],
303
+ type='',
304
+ swap='kpt-48'),
305
+ 55:
306
+ dict(
307
+ name='kpt-55',
308
+ id=55,
309
+ color=[255, 255, 255],
310
+ type='',
311
+ swap='kpt-59'),
312
+ 56:
313
+ dict(
314
+ name='kpt-56',
315
+ id=56,
316
+ color=[255, 255, 255],
317
+ type='',
318
+ swap='kpt-58'),
319
+ 57:
320
+ dict(name='kpt-57', id=57, color=[255, 255, 255], type='', swap=''),
321
+ 58:
322
+ dict(
323
+ name='kpt-58',
324
+ id=58,
325
+ color=[255, 255, 255],
326
+ type='',
327
+ swap='kpt-56'),
328
+ 59:
329
+ dict(
330
+ name='kpt-59',
331
+ id=59,
332
+ color=[255, 255, 255],
333
+ type='',
334
+ swap='kpt-55'),
335
+ 60:
336
+ dict(
337
+ name='kpt-60',
338
+ id=60,
339
+ color=[255, 255, 255],
340
+ type='',
341
+ swap='kpt-64'),
342
+ 61:
343
+ dict(
344
+ name='kpt-61',
345
+ id=61,
346
+ color=[255, 255, 255],
347
+ type='',
348
+ swap='kpt-63'),
349
+ 62:
350
+ dict(name='kpt-62', id=62, color=[255, 255, 255], type='', swap=''),
351
+ 63:
352
+ dict(
353
+ name='kpt-63',
354
+ id=63,
355
+ color=[255, 255, 255],
356
+ type='',
357
+ swap='kpt-61'),
358
+ 64:
359
+ dict(
360
+ name='kpt-64',
361
+ id=64,
362
+ color=[255, 255, 255],
363
+ type='',
364
+ swap='kpt-60'),
365
+ 65:
366
+ dict(
367
+ name='kpt-65',
368
+ id=65,
369
+ color=[255, 255, 255],
370
+ type='',
371
+ swap='kpt-67'),
372
+ 66:
373
+ dict(name='kpt-66', id=66, color=[255, 255, 255], type='', swap=''),
374
+ 67:
375
+ dict(
376
+ name='kpt-67',
377
+ id=67,
378
+ color=[255, 255, 255],
379
+ type='',
380
+ swap='kpt-65'),
381
+ },
382
+ skeleton_info={},
383
+ joint_weights=[1.] * 68,
384
+ sigmas=[])
main/_base_/datasets/aflw.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_info = dict(
2
+ dataset_name='aflw',
3
+ paper_info=dict(
4
+ author='Koestinger, Martin and Wohlhart, Paul and '
5
+ 'Roth, Peter M and Bischof, Horst',
6
+ title='Annotated facial landmarks in the wild: '
7
+ 'A large-scale, real-world database for facial '
8
+ 'landmark localization',
9
+ container='2011 IEEE international conference on computer '
10
+ 'vision workshops (ICCV workshops)',
11
+ year='2011',
12
+ homepage='https://www.tugraz.at/institute/icg/research/'
13
+ 'team-bischof/lrs/downloads/aflw/',
14
+ ),
15
+ keypoint_info={
16
+ 0:
17
+ dict(name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-5'),
18
+ 1:
19
+ dict(name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-4'),
20
+ 2:
21
+ dict(name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-3'),
22
+ 3:
23
+ dict(name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-2'),
24
+ 4:
25
+ dict(name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-1'),
26
+ 5:
27
+ dict(name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-0'),
28
+ 6:
29
+ dict(
30
+ name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-11'),
31
+ 7:
32
+ dict(
33
+ name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-10'),
34
+ 8:
35
+ dict(name='kpt-8', id=8, color=[255, 255, 255], type='', swap='kpt-9'),
36
+ 9:
37
+ dict(name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-8'),
38
+ 10:
39
+ dict(
40
+ name='kpt-10', id=10, color=[255, 255, 255], type='',
41
+ swap='kpt-7'),
42
+ 11:
43
+ dict(
44
+ name='kpt-11', id=11, color=[255, 255, 255], type='',
45
+ swap='kpt-6'),
46
+ 12:
47
+ dict(
48
+ name='kpt-12',
49
+ id=12,
50
+ color=[255, 255, 255],
51
+ type='',
52
+ swap='kpt-14'),
53
+ 13:
54
+ dict(name='kpt-13', id=13, color=[255, 255, 255], type='', swap=''),
55
+ 14:
56
+ dict(
57
+ name='kpt-14',
58
+ id=14,
59
+ color=[255, 255, 255],
60
+ type='',
61
+ swap='kpt-12'),
62
+ 15:
63
+ dict(
64
+ name='kpt-15',
65
+ id=15,
66
+ color=[255, 255, 255],
67
+ type='',
68
+ swap='kpt-17'),
69
+ 16:
70
+ dict(name='kpt-16', id=16, color=[255, 255, 255], type='', swap=''),
71
+ 17:
72
+ dict(
73
+ name='kpt-17',
74
+ id=17,
75
+ color=[255, 255, 255],
76
+ type='',
77
+ swap='kpt-15'),
78
+ 18:
79
+ dict(name='kpt-18', id=18, color=[255, 255, 255], type='', swap='')
80
+ },
81
+ skeleton_info={},
82
+ joint_weights=[1.] * 19,
83
+ sigmas=[])
main/_base_/datasets/aic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_info = dict(
2
+ dataset_name='aic',
3
+ paper_info=dict(
4
+ author='Wu, Jiahong and Zheng, He and Zhao, Bo and '
5
+ 'Li, Yixin and Yan, Baoming and Liang, Rui and '
6
+ 'Wang, Wenjia and Zhou, Shipei and Lin, Guosen and '
7
+ 'Fu, Yanwei and others',
8
+ title='Ai challenger: A large-scale dataset for going '
9
+ 'deeper in image understanding',
10
+ container='arXiv',
11
+ year='2017',
12
+ homepage='https://github.com/AIChallenger/AI_Challenger_2017',
13
+ ),
14
+ keypoint_info={
15
+ 0:
16
+ dict(
17
+ name='right_shoulder',
18
+ id=0,
19
+ color=[255, 128, 0],
20
+ type='upper',
21
+ swap='left_shoulder'),
22
+ 1:
23
+ dict(
24
+ name='right_elbow',
25
+ id=1,
26
+ color=[255, 128, 0],
27
+ type='upper',
28
+ swap='left_elbow'),
29
+ 2:
30
+ dict(
31
+ name='right_wrist',
32
+ id=2,
33
+ color=[255, 128, 0],
34
+ type='upper',
35
+ swap='left_wrist'),
36
+ 3:
37
+ dict(
38
+ name='left_shoulder',
39
+ id=3,
40
+ color=[0, 255, 0],
41
+ type='upper',
42
+ swap='right_shoulder'),
43
+ 4:
44
+ dict(
45
+ name='left_elbow',
46
+ id=4,
47
+ color=[0, 255, 0],
48
+ type='upper',
49
+ swap='right_elbow'),
50
+ 5:
51
+ dict(
52
+ name='left_wrist',
53
+ id=5,
54
+ color=[0, 255, 0],
55
+ type='upper',
56
+ swap='right_wrist'),
57
+ 6:
58
+ dict(
59
+ name='right_hip',
60
+ id=6,
61
+ color=[255, 128, 0],
62
+ type='lower',
63
+ swap='left_hip'),
64
+ 7:
65
+ dict(
66
+ name='right_knee',
67
+ id=7,
68
+ color=[255, 128, 0],
69
+ type='lower',
70
+ swap='left_knee'),
71
+ 8:
72
+ dict(
73
+ name='right_ankle',
74
+ id=8,
75
+ color=[255, 128, 0],
76
+ type='lower',
77
+ swap='left_ankle'),
78
+ 9:
79
+ dict(
80
+ name='left_hip',
81
+ id=9,
82
+ color=[0, 255, 0],
83
+ type='lower',
84
+ swap='right_hip'),
85
+ 10:
86
+ dict(
87
+ name='left_knee',
88
+ id=10,
89
+ color=[0, 255, 0],
90
+ type='lower',
91
+ swap='right_knee'),
92
+ 11:
93
+ dict(
94
+ name='left_ankle',
95
+ id=11,
96
+ color=[0, 255, 0],
97
+ type='lower',
98
+ swap='right_ankle'),
99
+ 12:
100
+ dict(
101
+ name='head_top',
102
+ id=12,
103
+ color=[51, 153, 255],
104
+ type='upper',
105
+ swap=''),
106
+ 13:
107
+ dict(name='neck', id=13, color=[51, 153, 255], type='upper', swap='')
108
+ },
109
+ skeleton_info={
110
+ 0:
111
+ dict(link=('right_wrist', 'right_elbow'), id=0, color=[255, 128, 0]),
112
+ 1: dict(
113
+ link=('right_elbow', 'right_shoulder'), id=1, color=[255, 128, 0]),
114
+ 2: dict(link=('right_shoulder', 'neck'), id=2, color=[51, 153, 255]),
115
+ 3: dict(link=('neck', 'left_shoulder'), id=3, color=[51, 153, 255]),
116
+ 4: dict(link=('left_shoulder', 'left_elbow'), id=4, color=[0, 255, 0]),
117
+ 5: dict(link=('left_elbow', 'left_wrist'), id=5, color=[0, 255, 0]),
118
+ 6: dict(link=('right_ankle', 'right_knee'), id=6, color=[255, 128, 0]),
119
+ 7: dict(link=('right_knee', 'right_hip'), id=7, color=[255, 128, 0]),
120
+ 8: dict(link=('right_hip', 'left_hip'), id=8, color=[51, 153, 255]),
121
+ 9: dict(link=('left_hip', 'left_knee'), id=9, color=[0, 255, 0]),
122
+ 10: dict(link=('left_knee', 'left_ankle'), id=10, color=[0, 255, 0]),
123
+ 11: dict(link=('head_top', 'neck'), id=11, color=[51, 153, 255]),
124
+ 12: dict(
125
+ link=('right_shoulder', 'right_hip'), id=12, color=[51, 153, 255]),
126
+ 13:
127
+ dict(link=('left_shoulder', 'left_hip'), id=13, color=[51, 153, 255])
128
+ },
129
+ joint_weights=[
130
+ 1., 1.2, 1.5, 1., 1.2, 1.5, 1., 1.2, 1.5, 1., 1.2, 1.5, 1., 1.
131
+ ],
132
+
133
+ # 'https://github.com/AIChallenger/AI_Challenger_2017/blob/master/'
134
+ # 'Evaluation/keypoint_eval/keypoint_eval.py#L50'
135
+ # delta = 2 x sigma
136
+ sigmas=[
137
+ 0.01388152, 0.01515228, 0.01057665, 0.01417709, 0.01497891, 0.01402144,
138
+ 0.03909642, 0.03686941, 0.01981803, 0.03843971, 0.03412318, 0.02415081,
139
+ 0.01291456, 0.01236173
140
+ ])