Adi commited on
Commit
d63dc03
·
1 Parent(s): 0c222df

feature: adding files for rendering

Browse files
Files changed (50) hide show
  1. .gitignore +2 -1
  2. body_models/smpl/J_regressor_extra.npy +3 -0
  3. body_models/smpl/SMPL_NEUTRAL.pkl +3 -0
  4. body_models/smpl/kintree_table.pkl +3 -0
  5. body_models/smpl/smplfaces.npy +3 -0
  6. dataset/prepare/download_smpl.sh +13 -0
  7. exit/__pycache__/utils.cpython-312.pyc +0 -0
  8. exit/__pycache__/utils.cpython-38.pyc +0 -0
  9. generate.py +9 -0
  10. models/__pycache__/encdec.cpython-312.pyc +0 -0
  11. models/__pycache__/encdec.cpython-38.pyc +0 -0
  12. models/__pycache__/pos_encoding.cpython-312.pyc +0 -0
  13. models/__pycache__/pos_encoding.cpython-38.pyc +0 -0
  14. models/__pycache__/quantize_cnn.cpython-312.pyc +0 -0
  15. models/__pycache__/quantize_cnn.cpython-38.pyc +0 -0
  16. models/__pycache__/resnet.cpython-312.pyc +0 -0
  17. models/__pycache__/resnet.cpython-38.pyc +0 -0
  18. models/__pycache__/t2m_trans.cpython-312.pyc +0 -0
  19. models/__pycache__/t2m_trans.cpython-38.pyc +0 -0
  20. models/__pycache__/t2m_trans_uplow.cpython-312.pyc +0 -0
  21. models/__pycache__/t2m_trans_uplow.cpython-38.pyc +0 -0
  22. models/__pycache__/vqvae.cpython-312.pyc +0 -0
  23. models/__pycache__/vqvae.cpython-38.pyc +0 -0
  24. models/__pycache__/vqvae_sep.cpython-312.pyc +0 -0
  25. models/__pycache__/vqvae_sep.cpython-38.pyc +0 -0
  26. models/rotation2xyz.py +92 -0
  27. models/smpl.py +97 -0
  28. options/__pycache__/option_transformer.cpython-312.pyc +0 -0
  29. options/__pycache__/option_transformer.cpython-38.pyc +0 -0
  30. render_final.py +189 -0
  31. utils/__pycache__/humanml_utils.cpython-312.pyc +0 -0
  32. utils/__pycache__/humanml_utils.cpython-38.pyc +0 -0
  33. utils/__pycache__/motion_process.cpython-312.pyc +0 -0
  34. utils/__pycache__/motion_process.cpython-38.pyc +0 -0
  35. utils/__pycache__/quaternion.cpython-312.pyc +0 -0
  36. utils/__pycache__/quaternion.cpython-38.pyc +0 -0
  37. utils/config.py +17 -0
  38. utils/rotation_conversions.py +532 -0
  39. visualization/plot_3d_global.py +129 -0
  40. visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl +3 -0
  41. visualize/joints2smpl/smpl_models/gmm_08.pkl +3 -0
  42. visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 +3 -0
  43. visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl +3 -0
  44. visualize/joints2smpl/src/config.py +40 -0
  45. visualize/joints2smpl/src/customloss.py +222 -0
  46. visualize/joints2smpl/src/prior.py +230 -0
  47. visualize/joints2smpl/src/smplify.py +279 -0
  48. visualize/render_mesh.py +33 -0
  49. visualize/simplify_loc2rot.py +131 -0
  50. visualize/vis_utils.py +66 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  /.DS_STORE
2
- /.env
 
 
1
  /.DS_STORE
2
+ /.env
3
+ output/
body_models/smpl/J_regressor_extra.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc968ea4f9855571e82f90203280836b01f13ee42a8e1b89d8d580b801242a89
3
+ size 496160
body_models/smpl/SMPL_NEUTRAL.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98e65c74ad9b998783132f00880d1025a8d64b158e040e6ef13a557e5098bc42
3
+ size 39001280
body_models/smpl/kintree_table.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62116ec76c6192ae912557122ea935267ba7188144efb9306ea1366f0e50d4d2
3
+ size 349
body_models/smpl/smplfaces.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ee8e99db736acf178a6078ab5710ca942edc3738d34c72f41a35c40b370e045
3
+ size 165440
dataset/prepare/download_smpl.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ mkdir -p body_models
3
+ cd body_models/
4
+
5
+ echo -e "The smpl files will be stored in the 'body_models/smpl/' folder\n"
6
+ gdown 1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2
7
+ rm -rf smpl
8
+
9
+ unzip smpl.zip
10
+ echo -e "Cleaning\n"
11
+ rm smpl.zip
12
+
13
+ echo -e "Downloading done!"
exit/__pycache__/utils.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
exit/__pycache__/utils.cpython-38.pyc ADDED
Binary file (8.27 kB). View file
 
generate.py CHANGED
@@ -6,6 +6,7 @@ import models.t2m_trans as trans
6
  import models.t2m_trans_uplow as trans_uplow
7
  import numpy as np
8
  from exit.utils import visualize_2motions
 
9
  import options.option_transformer as option_trans
10
 
11
 
@@ -288,6 +289,14 @@ if __name__ == '__main__':
288
 
289
  mmm = MMM(args)
290
  pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
 
 
 
 
 
 
 
 
291
 
292
  std = np.load('./exit/t2m-std.npy')
293
  mean = np.load('./exit/t2m-mean.npy')
 
6
  import models.t2m_trans_uplow as trans_uplow
7
  import numpy as np
8
  from exit.utils import visualize_2motions
9
+ from exit.utils import recover_from_ric
10
  import options.option_transformer as option_trans
11
 
12
 
 
289
 
290
  mmm = MMM(args)
291
  pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
292
+ num_joints = 22
293
+
294
+ pred_pose = pred_pose[:args.length, :].detach().cpu()
295
+
296
+ converted_pose = recover_from_ric(pred_pose[0].detach().cpu(), num_joints).unsqueeze(0).numpy()
297
+
298
+ np.save('./output/mmm-pred.npy', converted_pose)
299
+ print('File saved successfully')
300
 
301
  std = np.load('./exit/t2m-std.npy')
302
  mean = np.load('./exit/t2m-mean.npy')
models/__pycache__/encdec.cpython-312.pyc ADDED
Binary file (4.18 kB). View file
 
models/__pycache__/encdec.cpython-38.pyc ADDED
Binary file (2.56 kB). View file
 
models/__pycache__/pos_encoding.cpython-312.pyc ADDED
Binary file (2.92 kB). View file
 
models/__pycache__/pos_encoding.cpython-38.pyc ADDED
Binary file (1.78 kB). View file
 
models/__pycache__/quantize_cnn.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
models/__pycache__/quantize_cnn.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
models/__pycache__/resnet.cpython-312.pyc ADDED
Binary file (4.69 kB). View file
 
models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
 
models/__pycache__/t2m_trans.cpython-312.pyc ADDED
Binary file (35.1 kB). View file
 
models/__pycache__/t2m_trans.cpython-38.pyc ADDED
Binary file (18.2 kB). View file
 
models/__pycache__/t2m_trans_uplow.cpython-312.pyc ADDED
Binary file (32.8 kB). View file
 
models/__pycache__/t2m_trans_uplow.cpython-38.pyc ADDED
Binary file (16.9 kB). View file
 
models/__pycache__/vqvae.cpython-312.pyc ADDED
Binary file (5.6 kB). View file
 
models/__pycache__/vqvae.cpython-38.pyc ADDED
Binary file (3.56 kB). View file
 
models/__pycache__/vqvae_sep.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
models/__pycache__/vqvae_sep.cpython-38.pyc ADDED
Binary file (7.32 kB). View file
 
models/rotation2xyz.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import torch
3
+ import utils.rotation_conversions as geometry
4
+
5
+
6
+ from models.smpl import SMPL, JOINTSTYPE_ROOT
7
+ # from .get_model import JOINTSTYPES
8
+ JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
9
+
10
+
11
+ class Rotation2xyz:
12
+ def __init__(self, device, dataset='amass'):
13
+ self.device = device
14
+ self.dataset = dataset
15
+ self.smpl_model = SMPL().eval().to(device)
16
+
17
+ def __call__(self, x, mask, pose_rep, translation, glob,
18
+ jointstype, vertstrans, betas=None, beta=0,
19
+ glob_rot=None, get_rotations_back=False, **kwargs):
20
+ if pose_rep == "xyz":
21
+ return x
22
+
23
+ if mask is None:
24
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
25
+
26
+ if not glob and glob_rot is None:
27
+ raise TypeError("You must specify global rotation if glob is False")
28
+
29
+ if jointstype not in JOINTSTYPES:
30
+ raise NotImplementedError("This jointstype is not implemented.")
31
+
32
+ if translation:
33
+ x_translations = x[:, -1, :3]
34
+ x_rotations = x[:, :-1]
35
+ else:
36
+ x_rotations = x
37
+
38
+ x_rotations = x_rotations.permute(0, 3, 1, 2)
39
+ nsamples, time, njoints, feats = x_rotations.shape
40
+
41
+ # Compute rotations (convert only masked sequences output)
42
+ if pose_rep == "rotvec":
43
+ rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
44
+ elif pose_rep == "rotmat":
45
+ rotations = x_rotations[mask].view(-1, njoints, 3, 3)
46
+ elif pose_rep == "rotquat":
47
+ rotations = geometry.quaternion_to_matrix(x_rotations[mask])
48
+ elif pose_rep == "rot6d":
49
+ rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
50
+ else:
51
+ raise NotImplementedError("No geometry for this one.")
52
+
53
+ if not glob:
54
+ global_orient = torch.tensor(glob_rot, device=x.device)
55
+ global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
56
+ global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
57
+ else:
58
+ global_orient = rotations[:, 0]
59
+ rotations = rotations[:, 1:]
60
+
61
+ if betas is None:
62
+ betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
63
+ dtype=rotations.dtype, device=rotations.device)
64
+ betas[:, 1] = beta
65
+ # import ipdb; ipdb.set_trace()
66
+ out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
67
+
68
+ # get the desirable joints
69
+ joints = out[jointstype]
70
+
71
+ x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
72
+ x_xyz[~mask] = 0
73
+ x_xyz[mask] = joints
74
+
75
+ x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
76
+
77
+ # the first translation root at the origin on the prediction
78
+ if jointstype != "vertices":
79
+ rootindex = JOINTSTYPE_ROOT[jointstype]
80
+ x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
81
+
82
+ if translation and vertstrans:
83
+ # the first translation root at the origin
84
+ x_translations = x_translations - x_translations[:, :, [0]]
85
+
86
+ # add the translation to all the joints
87
+ x_xyz = x_xyz + x_translations[:, None, :, :]
88
+
89
+ if get_rotations_back:
90
+ return x_xyz, rotations, global_orient
91
+ else:
92
+ return x_xyz
models/smpl.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import numpy as np
3
+ import torch
4
+
5
+ import contextlib
6
+
7
+ from smplx import SMPLLayer as _SMPLLayer
8
+ from smplx.lbs import vertices2joints
9
+
10
+
11
+ # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38]
12
+ # change 0 and 8
13
+ action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
14
+
15
+ from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA
16
+
17
+ JOINTSTYPE_ROOT = {"a2m": 0, # action2motion
18
+ "smpl": 0,
19
+ "a2mpl": 0, # set(smpl, a2m)
20
+ "vibe": 8} # 0 is the 8 position: OP MidHip below
21
+
22
+ JOINT_MAP = {
23
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
24
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
25
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
26
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
27
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
28
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
29
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
30
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
31
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
32
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
33
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
34
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
35
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
36
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
37
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
38
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
39
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
40
+ }
41
+
42
+ JOINT_NAMES = [
43
+ 'OP Nose', 'OP Neck', 'OP RShoulder',
44
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder',
45
+ 'OP LElbow', 'OP LWrist', 'OP MidHip',
46
+ 'OP RHip', 'OP RKnee', 'OP RAnkle',
47
+ 'OP LHip', 'OP LKnee', 'OP LAnkle',
48
+ 'OP REye', 'OP LEye', 'OP REar',
49
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe',
50
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
51
+ 'Right Ankle', 'Right Knee', 'Right Hip',
52
+ 'Left Hip', 'Left Knee', 'Left Ankle',
53
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder',
54
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist',
55
+ 'Neck (LSP)', 'Top of Head (LSP)',
56
+ 'Pelvis (MPII)', 'Thorax (MPII)',
57
+ 'Spine (H36M)', 'Jaw (H36M)',
58
+ 'Head (H36M)', 'Nose', 'Left Eye',
59
+ 'Right Eye', 'Left Ear', 'Right Ear'
60
+ ]
61
+
62
+
63
+ # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints
64
+ class SMPL(_SMPLLayer):
65
+ """ Extension of the official SMPL implementation to support more joints """
66
+
67
+ def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
68
+ kwargs["model_path"] = model_path
69
+
70
+ # remove the verbosity for the 10-shapes beta parameters
71
+ with contextlib.redirect_stdout(None):
72
+ super(SMPL, self).__init__(**kwargs)
73
+
74
+ J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
75
+ self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
76
+ vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
77
+ a2m_indexes = vibe_indexes[action2motion_joints]
78
+ smpl_indexes = np.arange(24)
79
+ a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])
80
+
81
+ self.maps = {"vibe": vibe_indexes,
82
+ "a2m": a2m_indexes,
83
+ "smpl": smpl_indexes,
84
+ "a2mpl": a2mpl_indexes}
85
+
86
+ def forward(self, *args, **kwargs):
87
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
88
+
89
+ extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
90
+ all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
91
+
92
+ output = {"vertices": smpl_output.vertices}
93
+
94
+ for joinstype, indexes in self.maps.items():
95
+ output[joinstype] = all_joints[:, indexes]
96
+
97
+ return output
options/__pycache__/option_transformer.cpython-312.pyc ADDED
Binary file (5.99 kB). View file
 
options/__pycache__/option_transformer.cpython-38.pyc ADDED
Binary file (3.41 kB). View file
 
render_final.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.rotation2xyz import Rotation2xyz
2
+ import numpy as np
3
+ from trimesh import Trimesh
4
+ import os
5
+ os.environ['PYOPENGL_PLATFORM'] = "osmesa"
6
+
7
+ import torch
8
+ from visualize.simplify_loc2rot import joints2smpl
9
+ import pyrender
10
+ import matplotlib.pyplot as plt
11
+
12
+ import io
13
+ import imageio
14
+ from shapely import geometry
15
+ import trimesh
16
+ from pyrender.constants import RenderFlags
17
+ import math
18
+ # import ffmpeg
19
+ from PIL import Image
20
+ import argparse
21
+
22
+ class WeakPerspectiveCamera(pyrender.Camera):
23
+ def __init__(self,
24
+ scale,
25
+ translation,
26
+ znear=pyrender.camera.DEFAULT_Z_NEAR,
27
+ zfar=None,
28
+ name=None):
29
+ super(WeakPerspectiveCamera, self).__init__(
30
+ znear=znear,
31
+ zfar=zfar,
32
+ name=name,
33
+ )
34
+ self.scale = scale
35
+ self.translation = translation
36
+
37
+ def get_projection_matrix(self, width=None, height=None):
38
+ P = np.eye(4)
39
+ P[0, 0] = self.scale[0]
40
+ P[1, 1] = self.scale[1]
41
+ P[0, 3] = self.translation[0] * self.scale[0]
42
+ P[1, 3] = -self.translation[1] * self.scale[1]
43
+ P[2, 2] = -1
44
+ return P
45
+
46
+ def render(motions, outdir='test_vis', device_id=0, name=None, pred=True):
47
+ frames, njoints, nfeats = motions.shape
48
+ MINS = motions.min(axis=0).min(axis=0)
49
+ MAXS = motions.max(axis=0).max(axis=0)
50
+
51
+ height_offset = MINS[1]
52
+ motions[:, :, 1] -= height_offset
53
+
54
+ j2s = joints2smpl(num_frames=frames, device_id=device_id, cuda=False)
55
+ rot2xyz = Rotation2xyz(device=torch.device('cpu'))
56
+ faces = rot2xyz.smpl_model.faces
57
+
58
+ filepath = os.path.join(outdir, f'{name}.pt')
59
+ if not os.path.exists(filepath):
60
+ print(f'Running SMPLify, it may take a few minutes.')
61
+ motion_tensor, opt_dict = j2s.joint2smpl(motions)
62
+ vertices = rot2xyz(torch.tensor(motion_tensor).clone(), mask=None,
63
+ pose_rep='rot6d', translation=True, glob=True,
64
+ jointstype='vertices', vertstrans=True)
65
+ torch.save(vertices, filepath)
66
+ else:
67
+ vertices = torch.load(filepath)
68
+
69
+ if not os.path.exists(outdir):
70
+ os.makedirs(outdir)
71
+
72
+ frames = vertices.shape[3] # shape: 1, nb_frames, 3, nb_joints
73
+ print (vertices.shape)
74
+ MINS = torch.min(torch.min(vertices[0], axis=0)[0], axis=1)[0]
75
+ MAXS = torch.max(torch.max(vertices[0], axis=0)[0], axis=1)[0]
76
+ # vertices[:,:,1,:] -= MINS[1] + 1e-5
77
+
78
+
79
+ out_list = []
80
+
81
+ minx = MINS[0] - 0.5
82
+ maxx = MAXS[0] + 0.5
83
+ minz = MINS[2] - 0.5
84
+ maxz = MAXS[2] + 0.5
85
+ polygon = geometry.Polygon([[minx, minz], [minx, maxz], [maxx, maxz], [maxx, minz]])
86
+ polygon_mesh = trimesh.creation.extrude_polygon(polygon, 1e-5)
87
+
88
+ vid = []
89
+ for i in range(frames):
90
+ if i % 10 == 0:
91
+ print(i)
92
+
93
+ mesh = Trimesh(vertices=vertices[0, :, :, i].squeeze().tolist(), faces=faces)
94
+
95
+ base_color = (0.11, 0.53, 0.8, 0.5)
96
+ ## OPAQUE rendering without alpha
97
+ ## BLEND rendering consider alpha
98
+ material = pyrender.MetallicRoughnessMaterial(
99
+ metallicFactor=0.7,
100
+ alphaMode='OPAQUE',
101
+ baseColorFactor=base_color
102
+ )
103
+
104
+
105
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
106
+
107
+ polygon_mesh.visual.face_colors = [0, 0, 0, 0.21]
108
+ polygon_render = pyrender.Mesh.from_trimesh(polygon_mesh, smooth=False)
109
+
110
+ bg_color = [1, 1, 1, 0.8]
111
+ scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4))
112
+
113
+ sx, sy, tx, ty = [0.75, 0.75, 0, 0.10]
114
+
115
+ camera = pyrender.PerspectiveCamera(yfov=(np.pi / 3.0))
116
+
117
+ light = pyrender.DirectionalLight(color=[1,1,1], intensity=300)
118
+
119
+ scene.add(mesh)
120
+
121
+ c = np.pi / 2
122
+
123
+ scene.add(polygon_render, pose=np.array([[ 1, 0, 0, 0],
124
+
125
+ [ 0, np.cos(c), -np.sin(c), MINS[1].cpu().numpy()],
126
+
127
+ [ 0, np.sin(c), np.cos(c), 0],
128
+
129
+ [ 0, 0, 0, 1]]))
130
+
131
+ light_pose = np.eye(4)
132
+ light_pose[:3, 3] = [0, -1, 1]
133
+ scene.add(light, pose=light_pose.copy())
134
+
135
+ light_pose[:3, 3] = [0, 1, 1]
136
+ scene.add(light, pose=light_pose.copy())
137
+
138
+ light_pose[:3, 3] = [1, 1, 2]
139
+ scene.add(light, pose=light_pose.copy())
140
+
141
+
142
+ c = -np.pi / 6
143
+
144
+ scene.add(camera, pose=[[ 1, 0, 0, (minx+maxx).cpu().numpy()/2],
145
+
146
+ [ 0, np.cos(c), -np.sin(c), 1.5],
147
+
148
+ [ 0, np.sin(c), np.cos(c), max(4, minz.cpu().numpy()+(1.5-MINS[1].cpu().numpy())*2, (maxx-minx).cpu().numpy())],
149
+
150
+ [ 0, 0, 0, 1]
151
+ ])
152
+
153
+ # render scene
154
+ r = pyrender.OffscreenRenderer(960, 960)
155
+
156
+ color, _ = r.render(scene, flags=RenderFlags.RGBA)
157
+ # Image.fromarray(color).save(outdir+'/'+name+'_'+str(i)+'.png')
158
+
159
+ vid.append(color)
160
+
161
+ r.delete()
162
+
163
+ out = np.stack(vid, axis=0)
164
+ gif_path = os.path.join(outdir, f'{name}.gif')
165
+ imageio.mimsave(gif_path, out, fps=20)
166
+
167
+
168
+
169
+
170
+
171
+ if __name__ == "__main__":
172
+ parser = argparse.ArgumentParser(description="Render 3D motion from a numpy file.")
173
+ parser.add_argument("filepath", type=str, help="Path to the numpy file containing the motion data.")
174
+ parser.add_argument("--outdir", type=str, default="./output", help="Output directory for rendered files.")
175
+ args = parser.parse_args()
176
+
177
+ # Ensure output directory exists
178
+ if not os.path.exists(args.outdir):
179
+ os.makedirs(args.outdir)
180
+
181
+ # Load the motion data from the provided file path
182
+ motions = np.load(args.filepath)
183
+ print('Loaded motion data from:', args.filepath, 'Shape:', motions.shape)
184
+
185
+ # Extract filename for naming output files
186
+ filename = os.path.basename(args.filepath).replace('.npy', '')
187
+
188
+ # Render the motion
189
+ render(motions[0], outdir=args.outdir, device_id=0, name=filename)
utils/__pycache__/humanml_utils.cpython-312.pyc ADDED
Binary file (2.64 kB). View file
 
utils/__pycache__/humanml_utils.cpython-38.pyc ADDED
Binary file (1.84 kB). View file
 
utils/__pycache__/motion_process.cpython-312.pyc ADDED
Binary file (3.2 kB). View file
 
utils/__pycache__/motion_process.cpython-38.pyc ADDED
Binary file (1.74 kB). View file
 
utils/__pycache__/quaternion.cpython-312.pyc ADDED
Binary file (23.5 kB). View file
 
utils/__pycache__/quaternion.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
utils/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ SMPL_DATA_PATH = "./body_models/smpl"
4
+
5
+ SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
6
+ SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
7
+ JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
8
+
9
+ ROT_CONVENTION_TO_ROT_NUMBER = {
10
+ 'legacy': 23,
11
+ 'no_hands': 21,
12
+ 'full_hands': 51,
13
+ 'mitten_hands': 33,
14
+ }
15
+
16
+ GENDERS = ['neutral', 'male', 'female']
17
+ NUM_BETAS = 10
utils/rotation_conversions.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ # Check PYTORCH3D_LICENCE before use
3
+
4
+ import functools
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ """
12
+ The transformation matrices returned from the functions in this file assume
13
+ the points on which the transformation will be applied are column vectors.
14
+ i.e. the R matrix is structured as
15
+ R = [
16
+ [Rxx, Rxy, Rxz],
17
+ [Ryx, Ryy, Ryz],
18
+ [Rzx, Rzy, Rzz],
19
+ ] # (3, 3)
20
+ This matrix can be applied to column vectors by post multiplication
21
+ by the points e.g.
22
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
23
+ transformed_points = R * points
24
+ To apply the same matrix to points which are row vectors, the R matrix
25
+ can be transposed and pre multiplied by the points:
26
+ e.g.
27
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
28
+ transformed_points = points * R.transpose(1, 0)
29
+ """
30
+
31
+
32
+ def quaternion_to_matrix(quaternions):
33
+ """
34
+ Convert rotations given as quaternions to rotation matrices.
35
+ Args:
36
+ quaternions: quaternions with real part first,
37
+ as tensor of shape (..., 4).
38
+ Returns:
39
+ Rotation matrices as tensor of shape (..., 3, 3).
40
+ """
41
+ r, i, j, k = torch.unbind(quaternions, -1)
42
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
43
+
44
+ o = torch.stack(
45
+ (
46
+ 1 - two_s * (j * j + k * k),
47
+ two_s * (i * j - k * r),
48
+ two_s * (i * k + j * r),
49
+ two_s * (i * j + k * r),
50
+ 1 - two_s * (i * i + k * k),
51
+ two_s * (j * k - i * r),
52
+ two_s * (i * k - j * r),
53
+ two_s * (j * k + i * r),
54
+ 1 - two_s * (i * i + j * j),
55
+ ),
56
+ -1,
57
+ )
58
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
59
+
60
+
61
+ def _copysign(a, b):
62
+ """
63
+ Return a tensor where each element has the absolute value taken from the,
64
+ corresponding element of a, with sign taken from the corresponding
65
+ element of b. This is like the standard copysign floating-point operation,
66
+ but is not careful about negative 0 and NaN.
67
+ Args:
68
+ a: source tensor.
69
+ b: tensor whose signs will be used, of the same shape as a.
70
+ Returns:
71
+ Tensor of the same shape as a with the signs of b.
72
+ """
73
+ signs_differ = (a < 0) != (b < 0)
74
+ return torch.where(signs_differ, -a, a)
75
+
76
+
77
+ def _sqrt_positive_part(x):
78
+ """
79
+ Returns torch.sqrt(torch.max(0, x))
80
+ but with a zero subgradient where x is 0.
81
+ """
82
+ ret = torch.zeros_like(x)
83
+ positive_mask = x > 0
84
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
85
+ return ret
86
+
87
+
88
+ def matrix_to_quaternion(matrix):
89
+ """
90
+ Convert rotations given as rotation matrices to quaternions.
91
+ Args:
92
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
93
+ Returns:
94
+ quaternions with real part first, as tensor of shape (..., 4).
95
+ """
96
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
97
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
98
+ m00 = matrix[..., 0, 0]
99
+ m11 = matrix[..., 1, 1]
100
+ m22 = matrix[..., 2, 2]
101
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
102
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
103
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
104
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
105
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
106
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
107
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
108
+ return torch.stack((o0, o1, o2, o3), -1)
109
+
110
+
111
+ def _axis_angle_rotation(axis: str, angle):
112
+ """
113
+ Return the rotation matrices for one of the rotations about an axis
114
+ of which Euler angles describe, for each value of the angle given.
115
+ Args:
116
+ axis: Axis label "X" or "Y or "Z".
117
+ angle: any shape tensor of Euler angles in radians
118
+ Returns:
119
+ Rotation matrices as tensor of shape (..., 3, 3).
120
+ """
121
+
122
+ cos = torch.cos(angle)
123
+ sin = torch.sin(angle)
124
+ one = torch.ones_like(angle)
125
+ zero = torch.zeros_like(angle)
126
+
127
+ if axis == "X":
128
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
129
+ if axis == "Y":
130
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
131
+ if axis == "Z":
132
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
133
+
134
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
135
+
136
+
137
+ def euler_angles_to_matrix(euler_angles, convention: str):
138
+ """
139
+ Convert rotations given as Euler angles in radians to rotation matrices.
140
+ Args:
141
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
142
+ convention: Convention string of three uppercase letters from
143
+ {"X", "Y", and "Z"}.
144
+ Returns:
145
+ Rotation matrices as tensor of shape (..., 3, 3).
146
+ """
147
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
148
+ raise ValueError("Invalid input euler angles.")
149
+ if len(convention) != 3:
150
+ raise ValueError("Convention must have 3 letters.")
151
+ if convention[1] in (convention[0], convention[2]):
152
+ raise ValueError(f"Invalid convention {convention}.")
153
+ for letter in convention:
154
+ if letter not in ("X", "Y", "Z"):
155
+ raise ValueError(f"Invalid letter {letter} in convention string.")
156
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
157
+ return functools.reduce(torch.matmul, matrices)
158
+
159
+
160
+ def _angle_from_tan(
161
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
162
+ ):
163
+ """
164
+ Extract the first or third Euler angle from the two members of
165
+ the matrix which are positive constant times its sine and cosine.
166
+ Args:
167
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
168
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
169
+ convention.
170
+ data: Rotation matrices as tensor of shape (..., 3, 3).
171
+ horizontal: Whether we are looking for the angle for the third axis,
172
+ which means the relevant entries are in the same row of the
173
+ rotation matrix. If not, they are in the same column.
174
+ tait_bryan: Whether the first and third axes in the convention differ.
175
+ Returns:
176
+ Euler Angles in radians for each matrix in data as a tensor
177
+ of shape (...).
178
+ """
179
+
180
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
181
+ if horizontal:
182
+ i2, i1 = i1, i2
183
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
184
+ if horizontal == even:
185
+ return torch.atan2(data[..., i1], data[..., i2])
186
+ if tait_bryan:
187
+ return torch.atan2(-data[..., i2], data[..., i1])
188
+ return torch.atan2(data[..., i2], -data[..., i1])
189
+
190
+
191
+ def _index_from_letter(letter: str):
192
+ if letter == "X":
193
+ return 0
194
+ if letter == "Y":
195
+ return 1
196
+ if letter == "Z":
197
+ return 2
198
+
199
+
200
+ def matrix_to_euler_angles(matrix, convention: str):
201
+ """
202
+ Convert rotations given as rotation matrices to Euler angles in radians.
203
+ Args:
204
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
205
+ convention: Convention string of three uppercase letters.
206
+ Returns:
207
+ Euler angles in radians as tensor of shape (..., 3).
208
+ """
209
+ if len(convention) != 3:
210
+ raise ValueError("Convention must have 3 letters.")
211
+ if convention[1] in (convention[0], convention[2]):
212
+ raise ValueError(f"Invalid convention {convention}.")
213
+ for letter in convention:
214
+ if letter not in ("X", "Y", "Z"):
215
+ raise ValueError(f"Invalid letter {letter} in convention string.")
216
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
217
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
218
+ i0 = _index_from_letter(convention[0])
219
+ i2 = _index_from_letter(convention[2])
220
+ tait_bryan = i0 != i2
221
+ if tait_bryan:
222
+ central_angle = torch.asin(
223
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
224
+ )
225
+ else:
226
+ central_angle = torch.acos(matrix[..., i0, i0])
227
+
228
+ o = (
229
+ _angle_from_tan(
230
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
231
+ ),
232
+ central_angle,
233
+ _angle_from_tan(
234
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
235
+ ),
236
+ )
237
+ return torch.stack(o, -1)
238
+
239
+
240
+ def random_quaternions(
241
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
242
+ ):
243
+ """
244
+ Generate random quaternions representing rotations,
245
+ i.e. versors with nonnegative real part.
246
+ Args:
247
+ n: Number of quaternions in a batch to return.
248
+ dtype: Type to return.
249
+ device: Desired device of returned tensor. Default:
250
+ uses the current device for the default tensor type.
251
+ requires_grad: Whether the resulting tensor should have the gradient
252
+ flag set.
253
+ Returns:
254
+ Quaternions as tensor of shape (N, 4).
255
+ """
256
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
257
+ s = (o * o).sum(1)
258
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
259
+ return o
260
+
261
+
262
+ def random_rotations(
263
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
264
+ ):
265
+ """
266
+ Generate random rotations as 3x3 rotation matrices.
267
+ Args:
268
+ n: Number of rotation matrices in a batch to return.
269
+ dtype: Type to return.
270
+ device: Device of returned tensor. Default: if None,
271
+ uses the current device for the default tensor type.
272
+ requires_grad: Whether the resulting tensor should have the gradient
273
+ flag set.
274
+ Returns:
275
+ Rotation matrices as tensor of shape (n, 3, 3).
276
+ """
277
+ quaternions = random_quaternions(
278
+ n, dtype=dtype, device=device, requires_grad=requires_grad
279
+ )
280
+ return quaternion_to_matrix(quaternions)
281
+
282
+
283
+ def random_rotation(
284
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
285
+ ):
286
+ """
287
+ Generate a single random 3x3 rotation matrix.
288
+ Args:
289
+ dtype: Type to return
290
+ device: Device of returned tensor. Default: if None,
291
+ uses the current device for the default tensor type
292
+ requires_grad: Whether the resulting tensor should have the gradient
293
+ flag set
294
+ Returns:
295
+ Rotation matrix as tensor of shape (3, 3).
296
+ """
297
+ return random_rotations(1, dtype, device, requires_grad)[0]
298
+
299
+
300
+ def standardize_quaternion(quaternions):
301
+ """
302
+ Convert a unit quaternion to a standard form: one in which the real
303
+ part is non negative.
304
+ Args:
305
+ quaternions: Quaternions with real part first,
306
+ as tensor of shape (..., 4).
307
+ Returns:
308
+ Standardized quaternions as tensor of shape (..., 4).
309
+ """
310
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
311
+
312
+
313
+ def quaternion_raw_multiply(a, b):
314
+ """
315
+ Multiply two quaternions.
316
+ Usual torch rules for broadcasting apply.
317
+ Args:
318
+ a: Quaternions as tensor of shape (..., 4), real part first.
319
+ b: Quaternions as tensor of shape (..., 4), real part first.
320
+ Returns:
321
+ The product of a and b, a tensor of quaternions shape (..., 4).
322
+ """
323
+ aw, ax, ay, az = torch.unbind(a, -1)
324
+ bw, bx, by, bz = torch.unbind(b, -1)
325
+ ow = aw * bw - ax * bx - ay * by - az * bz
326
+ ox = aw * bx + ax * bw + ay * bz - az * by
327
+ oy = aw * by - ax * bz + ay * bw + az * bx
328
+ oz = aw * bz + ax * by - ay * bx + az * bw
329
+ return torch.stack((ow, ox, oy, oz), -1)
330
+
331
+
332
+ def quaternion_multiply(a, b):
333
+ """
334
+ Multiply two quaternions representing rotations, returning the quaternion
335
+ representing their composition, i.e. the versor with nonnegative real part.
336
+ Usual torch rules for broadcasting apply.
337
+ Args:
338
+ a: Quaternions as tensor of shape (..., 4), real part first.
339
+ b: Quaternions as tensor of shape (..., 4), real part first.
340
+ Returns:
341
+ The product of a and b, a tensor of quaternions of shape (..., 4).
342
+ """
343
+ ab = quaternion_raw_multiply(a, b)
344
+ return standardize_quaternion(ab)
345
+
346
+
347
+ def quaternion_invert(quaternion):
348
+ """
349
+ Given a quaternion representing rotation, get the quaternion representing
350
+ its inverse.
351
+ Args:
352
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
353
+ first, which must be versors (unit quaternions).
354
+ Returns:
355
+ The inverse, a tensor of quaternions of shape (..., 4).
356
+ """
357
+
358
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
359
+
360
+
361
+ def quaternion_apply(quaternion, point):
362
+ """
363
+ Apply the rotation given by a quaternion to a 3D point.
364
+ Usual torch rules for broadcasting apply.
365
+ Args:
366
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
367
+ point: Tensor of 3D points of shape (..., 3).
368
+ Returns:
369
+ Tensor of rotated points of shape (..., 3).
370
+ """
371
+ if point.size(-1) != 3:
372
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
373
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
374
+ point_as_quaternion = torch.cat((real_parts, point), -1)
375
+ out = quaternion_raw_multiply(
376
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
377
+ quaternion_invert(quaternion),
378
+ )
379
+ return out[..., 1:]
380
+
381
+
382
+ def axis_angle_to_matrix(axis_angle):
383
+ """
384
+ Convert rotations given as axis/angle to rotation matrices.
385
+ Args:
386
+ axis_angle: Rotations given as a vector in axis angle form,
387
+ as a tensor of shape (..., 3), where the magnitude is
388
+ the angle turned anticlockwise in radians around the
389
+ vector's direction.
390
+ Returns:
391
+ Rotation matrices as tensor of shape (..., 3, 3).
392
+ """
393
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
394
+
395
+
396
+ def matrix_to_axis_angle(matrix):
397
+ """
398
+ Convert rotations given as rotation matrices to axis/angle.
399
+ Args:
400
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
401
+ Returns:
402
+ Rotations given as a vector in axis angle form, as a tensor
403
+ of shape (..., 3), where the magnitude is the angle
404
+ turned anticlockwise in radians around the vector's
405
+ direction.
406
+ """
407
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
408
+
409
+
410
+ def axis_angle_to_quaternion(axis_angle):
411
+ """
412
+ Convert rotations given as axis/angle to quaternions.
413
+ Args:
414
+ axis_angle: Rotations given as a vector in axis angle form,
415
+ as a tensor of shape (..., 3), where the magnitude is
416
+ the angle turned anticlockwise in radians around the
417
+ vector's direction.
418
+ Returns:
419
+ quaternions with real part first, as tensor of shape (..., 4).
420
+ """
421
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
422
+ half_angles = 0.5 * angles
423
+ eps = 1e-6
424
+ small_angles = angles.abs() < eps
425
+ sin_half_angles_over_angles = torch.empty_like(angles)
426
+ sin_half_angles_over_angles[~small_angles] = (
427
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
428
+ )
429
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
430
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
431
+ sin_half_angles_over_angles[small_angles] = (
432
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
433
+ )
434
+ quaternions = torch.cat(
435
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
436
+ )
437
+ return quaternions
438
+
439
+
440
+ def quaternion_to_axis_angle(quaternions):
441
+ """
442
+ Convert rotations given as quaternions to axis/angle.
443
+ Args:
444
+ quaternions: quaternions with real part first,
445
+ as tensor of shape (..., 4).
446
+ Returns:
447
+ Rotations given as a vector in axis angle form, as a tensor
448
+ of shape (..., 3), where the magnitude is the angle
449
+ turned anticlockwise in radians around the vector's
450
+ direction.
451
+ """
452
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
453
+ half_angles = torch.atan2(norms, quaternions[..., :1])
454
+ angles = 2 * half_angles
455
+ eps = 1e-6
456
+ small_angles = angles.abs() < eps
457
+ sin_half_angles_over_angles = torch.empty_like(angles)
458
+ sin_half_angles_over_angles[~small_angles] = (
459
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
460
+ )
461
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
462
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
463
+ sin_half_angles_over_angles[small_angles] = (
464
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
465
+ )
466
+ return quaternions[..., 1:] / sin_half_angles_over_angles
467
+
468
+
469
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
470
+ """
471
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
472
+ using Gram--Schmidt orthogonalisation per Section B of [1].
473
+ Args:
474
+ d6: 6D rotation representation, of size (*, 6)
475
+ Returns:
476
+ batch of rotation matrices of size (*, 3, 3)
477
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
478
+ On the Continuity of Rotation Representations in Neural Networks.
479
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
480
+ Retrieved from http://arxiv.org/abs/1812.07035
481
+ """
482
+
483
+ a1, a2 = d6[..., :3], d6[..., 3:]
484
+ b1 = F.normalize(a1, dim=-1)
485
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
486
+ b2 = F.normalize(b2, dim=-1)
487
+ b3 = torch.cross(b1, b2, dim=-1)
488
+ return torch.stack((b1, b2, b3), dim=-2)
489
+
490
+
491
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
494
+ by dropping the last row. Note that 6D representation is not unique.
495
+ Args:
496
+ matrix: batch of rotation matrices of size (*, 3, 3)
497
+ Returns:
498
+ 6D rotation representation, of size (*, 6)
499
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
500
+ On the Continuity of Rotation Representations in Neural Networks.
501
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
502
+ Retrieved from http://arxiv.org/abs/1812.07035
503
+ """
504
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
505
+
506
+ def canonicalize_smplh(poses, trans = None):
507
+ bs, nframes, njoints = poses.shape[:3]
508
+
509
+ global_orient = poses[:, :, 0]
510
+
511
+ # first global rotations
512
+ rot2d = matrix_to_axis_angle(global_orient[:, 0])
513
+ #rot2d[:, :2] = 0 # Remove the rotation along the vertical axis
514
+ rot2d = axis_angle_to_matrix(rot2d)
515
+
516
+ # Rotate the global rotation to eliminate Z rotations
517
+ global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient)
518
+
519
+ # Construct canonicalized version of x
520
+ xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2)
521
+
522
+ if trans is not None:
523
+ vel = trans[:, 1:] - trans[:, :-1]
524
+ # Turn the translation as well
525
+ vel = torch.einsum("ikj,ilk->ilj", rot2d, vel)
526
+ trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device),
527
+ torch.cumsum(vel, 1)), 1)
528
+ return xc, trans
529
+ else:
530
+ return xc
531
+
532
+
visualization/plot_3d_global.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import io
5
+ import matplotlib
6
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
7
+ import mpl_toolkits.mplot3d.axes3d as p3
8
+ from textwrap import wrap
9
+ import imageio
10
+
11
+ def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4):
12
+ matplotlib.use('Agg')
13
+
14
+
15
+ joints, out_name, title = args
16
+
17
+ data = joints.copy().reshape(len(joints), -1, 3)
18
+
19
+ nb_joints = joints.shape[1]
20
+ smpl_kinetic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] if nb_joints == 21 else [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
21
+ limits = 1000 if nb_joints == 21 else 2
22
+ MINS = data.min(axis=0).min(axis=0)
23
+ MAXS = data.max(axis=0).max(axis=0)
24
+ colors = ['red', 'blue', 'black', 'red', 'blue',
25
+ 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
26
+ 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
27
+ frame_number = data.shape[0]
28
+ # print(data.shape)
29
+
30
+ height_offset = MINS[1]
31
+ data[:, :, 1] -= height_offset
32
+ trajec = data[:, 0, [0, 2]]
33
+
34
+ data[..., 0] -= data[:, 0:1, 0]
35
+ data[..., 2] -= data[:, 0:1, 2]
36
+
37
+ def update(index):
38
+
39
+ def init():
40
+ ax.set_xlim(-limits, limits)
41
+ ax.set_ylim(-limits, limits)
42
+ ax.set_zlim(0, limits)
43
+ ax.grid(b=False)
44
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
45
+ ## Plot a plane XZ
46
+ verts = [
47
+ [minx, miny, minz],
48
+ [minx, miny, maxz],
49
+ [maxx, miny, maxz],
50
+ [maxx, miny, minz]
51
+ ]
52
+ xz_plane = Poly3DCollection([verts])
53
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
54
+ ax.add_collection3d(xz_plane)
55
+ fig = plt.figure(figsize=(480/96., 320/96.), dpi=96) if nb_joints == 21 else plt.figure(figsize=(10, 10), dpi=96)
56
+ if title is not None :
57
+ wraped_title = '\n'.join(wrap(title, 40))
58
+ fig.suptitle(wraped_title, fontsize=16)
59
+ ax = p3.Axes3D(fig)
60
+
61
+ init()
62
+
63
+ ax.lines = []
64
+ ax.collections = []
65
+ ax.view_init(elev=110, azim=-90)
66
+ ax.dist = 7.5
67
+ # ax =
68
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
69
+ MAXS[2] - trajec[index, 1])
70
+ # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
71
+
72
+ if index > 1:
73
+ ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
74
+ trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
75
+ color='blue')
76
+ # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
77
+
78
+ for i, (chain, color) in enumerate(zip(smpl_kinetic_chain, colors)):
79
+ # print(color)
80
+ if i < 5:
81
+ linewidth = 4.0
82
+ else:
83
+ linewidth = 2.0
84
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
85
+ color=color)
86
+ # print(trajec[:index, 0].shape)
87
+
88
+ plt.axis('off')
89
+ ax.set_xticklabels([])
90
+ ax.set_yticklabels([])
91
+ ax.set_zticklabels([])
92
+
93
+ if out_name is not None :
94
+ plt.savefig(out_name, dpi=96)
95
+ plt.close()
96
+
97
+ else :
98
+ io_buf = io.BytesIO()
99
+ fig.savefig(io_buf, format='raw', dpi=96)
100
+ io_buf.seek(0)
101
+ # print(fig.bbox.bounds)
102
+ arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
103
+ newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
104
+ io_buf.close()
105
+ plt.close()
106
+ return arr
107
+
108
+ out = []
109
+ for i in range(frame_number) :
110
+ out.append(update(i))
111
+ out = np.stack(out, axis=0)
112
+ return torch.from_numpy(out)
113
+
114
+
115
+ def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None) :
116
+
117
+ batch_size = len(smpl_joints_batch)
118
+ out = []
119
+ for i in range(batch_size) :
120
+ out.append(plot_3d_motion([smpl_joints_batch[i], None, title_batch[i] if title_batch is not None else None]))
121
+ if outname is not None:
122
+ imageio.mimsave(outname[i], np.array(out[-1]), fps=20)
123
+ out = torch.stack(out, axis=0)
124
+ return out
125
+
126
+
127
+
128
+
129
+
visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec
3
+ size 3556
visualize/joints2smpl/smpl_models/gmm_08.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491
3
+ size 839127
visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776
3
+ size 4848
visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7
3
+ size 1323168
visualize/joints2smpl/src/config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Map joints Name to SMPL joints idx
4
+ JOINT_MAP = {
5
+ 'MidHip': 0,
6
+ 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10,
7
+ 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11,
8
+ 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22,
9
+ 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23,
10
+ 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15,
11
+ 'LCollar':13, 'Rcollar' :14,
12
+ 'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28,
13
+ 'LHeel': 31, 'RHeel': 34,
14
+ 'OP RShoulder': 17, 'OP LShoulder': 16,
15
+ 'OP RHip': 2, 'OP LHip': 1,
16
+ 'OP Neck': 12,
17
+ }
18
+
19
+ full_smpl_idx = range(24)
20
+ key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20]
21
+
22
+
23
+ AMASS_JOINT_MAP = {
24
+ 'MidHip': 0,
25
+ 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10,
26
+ 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11,
27
+ 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20,
28
+ 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21,
29
+ 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15,
30
+ 'LCollar':13, 'Rcollar' :14,
31
+ }
32
+ amass_idx = range(22)
33
+ amass_smpl_idx = range(22)
34
+
35
+
36
+ SMPL_MODEL_DIR = "./body_models/"
37
+ GMM_MODEL_DIR = "./visualize/joints2smpl/smpl_models/"
38
+ SMPL_MEAN_FILE = "./visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5"
39
+ # for collsion
40
+ Part_Seg_DIR = "./visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl"
visualize/joints2smpl/src/customloss.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from visualize.joints2smpl.src import config
4
+
5
+ # Guassian
6
+ def gmof(x, sigma):
7
+ """
8
+ Geman-McClure error function
9
+ """
10
+ x_squared = x ** 2
11
+ sigma_squared = sigma ** 2
12
+ return (sigma_squared * x_squared) / (sigma_squared + x_squared)
13
+
14
+ # angle prior
15
+ def angle_prior(pose):
16
+ """
17
+ Angle prior that penalizes unnatural bending of the knees and elbows
18
+ """
19
+ # We subtract 3 because pose does not include the global rotation of the model
20
+ return torch.exp(
21
+ pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2
22
+
23
+
24
+ def perspective_projection(points, rotation, translation,
25
+ focal_length, camera_center):
26
+ """
27
+ This function computes the perspective projection of a set of points.
28
+ Input:
29
+ points (bs, N, 3): 3D points
30
+ rotation (bs, 3, 3): Camera rotation
31
+ translation (bs, 3): Camera translation
32
+ focal_length (bs,) or scalar: Focal length
33
+ camera_center (bs, 2): Camera center
34
+ """
35
+ batch_size = points.shape[0]
36
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
37
+ K[:, 0, 0] = focal_length
38
+ K[:, 1, 1] = focal_length
39
+ K[:, 2, 2] = 1.
40
+ K[:, :-1, -1] = camera_center
41
+
42
+ # Transform points
43
+ points = torch.einsum('bij,bkj->bki', rotation, points)
44
+ points = points + translation.unsqueeze(1)
45
+
46
+ # Apply perspective distortion
47
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
48
+
49
+ # Apply camera intrinsics
50
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
51
+
52
+ return projected_points[:, :, :-1]
53
+
54
+
55
+ def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center,
56
+ joints_2d, joints_conf, pose_prior,
57
+ focal_length=5000, sigma=100, pose_prior_weight=4.78,
58
+ shape_prior_weight=5, angle_prior_weight=15.2,
59
+ output='sum'):
60
+ """
61
+ Loss function for body fitting
62
+ """
63
+ batch_size = body_pose.shape[0]
64
+ rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1)
65
+
66
+ projected_joints = perspective_projection(model_joints, rotation, camera_t,
67
+ focal_length, camera_center)
68
+
69
+ # Weighted robust reprojection error
70
+ reprojection_error = gmof(projected_joints - joints_2d, sigma)
71
+ reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1)
72
+
73
+ # Pose prior loss
74
+ pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas)
75
+
76
+ # Angle prior for knees and elbows
77
+ angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1)
78
+
79
+ # Regularizer to prevent betas from taking large values
80
+ shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1)
81
+
82
+ total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss
83
+
84
+ if output == 'sum':
85
+ return total_loss.sum()
86
+ elif output == 'reprojection':
87
+ return reprojection_loss
88
+
89
+
90
+ # --- get camera fitting loss -----
91
+ def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center,
92
+ joints_2d, joints_conf,
93
+ focal_length=5000, depth_loss_weight=100):
94
+ """
95
+ Loss function for camera optimization.
96
+ """
97
+ # Project model joints
98
+ batch_size = model_joints.shape[0]
99
+ rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1)
100
+ projected_joints = perspective_projection(model_joints, rotation, camera_t,
101
+ focal_length, camera_center)
102
+
103
+ # get the indexed four
104
+ op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder']
105
+ op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints]
106
+ gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder']
107
+ gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
108
+
109
+ reprojection_error_op = (joints_2d[:, op_joints_ind] -
110
+ projected_joints[:, op_joints_ind]) ** 2
111
+ reprojection_error_gt = (joints_2d[:, gt_joints_ind] -
112
+ projected_joints[:, gt_joints_ind]) ** 2
113
+
114
+ # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections
115
+ # OpenPose joints are more reliable for this task, so we prefer to use them if possible
116
+ is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float()
117
+ reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2))
118
+
119
+ # Loss that penalizes deviation from depth estimate
120
+ depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2
121
+
122
+ total_loss = reprojection_loss + depth_loss
123
+ return total_loss.sum()
124
+
125
+
126
+
127
+ # #####--- body fitiing loss -----
128
+ def body_fitting_loss_3d(body_pose, preserve_pose,
129
+ betas, model_joints, camera_translation,
130
+ j3d, pose_prior,
131
+ joints3d_conf,
132
+ sigma=100, pose_prior_weight=4.78*1.5,
133
+ shape_prior_weight=5.0, angle_prior_weight=15.2,
134
+ joint_loss_weight=500.0,
135
+ pose_preserve_weight=0.0,
136
+ use_collision=False,
137
+ model_vertices=None, model_faces=None,
138
+ search_tree=None, pen_distance=None, filter_faces=None,
139
+ collision_loss_weight=1000
140
+ ):
141
+ """
142
+ Loss function for body fitting
143
+ """
144
+ batch_size = body_pose.shape[0]
145
+
146
+ #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1)
147
+
148
+ joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma)
149
+
150
+ joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1)
151
+ joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1)
152
+
153
+ # Pose prior loss
154
+ pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas)
155
+ # Angle prior for knees and elbows
156
+ angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1)
157
+ # Regularizer to prevent betas from taking large values
158
+ shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1)
159
+
160
+ collision_loss = 0.0
161
+ # Calculate the loss due to interpenetration
162
+ if use_collision:
163
+ triangles = torch.index_select(
164
+ model_vertices, 1,
165
+ model_faces).view(batch_size, -1, 3, 3)
166
+
167
+ with torch.no_grad():
168
+ collision_idxs = search_tree(triangles)
169
+
170
+ # Remove unwanted collisions
171
+ if filter_faces is not None:
172
+ collision_idxs = filter_faces(collision_idxs)
173
+
174
+ if collision_idxs.ge(0).sum().item() > 0:
175
+ collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs))
176
+
177
+ pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1)
178
+
179
+ # print('joint3d_loss', joint3d_loss.shape)
180
+ # print('pose_prior_loss', pose_prior_loss.shape)
181
+ # print('angle_prior_loss', angle_prior_loss.shape)
182
+ # print('shape_prior_loss', shape_prior_loss.shape)
183
+ # print('collision_loss', collision_loss)
184
+ # print('pose_preserve_loss', pose_preserve_loss.shape)
185
+
186
+ total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss
187
+
188
+ return total_loss.sum()
189
+
190
+
191
+ # #####--- get camera fitting loss -----
192
+ def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est,
193
+ j3d, joints_category="orig", depth_loss_weight=100.0):
194
+ """
195
+ Loss function for camera optimization.
196
+ """
197
+ model_joints = model_joints + camera_t
198
+ # # get the indexed four
199
+ # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder']
200
+ # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints]
201
+ #
202
+ # j3d_error_loss = (j3d[:, op_joints_ind] -
203
+ # model_joints[:, op_joints_ind]) ** 2
204
+
205
+ gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder']
206
+ gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
207
+
208
+ if joints_category=="orig":
209
+ select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
210
+ elif joints_category=="AMASS":
211
+ select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints]
212
+ else:
213
+ print("NO SUCH JOINTS CATEGORY!")
214
+
215
+ j3d_error_loss = (j3d[:, select_joints_ind] -
216
+ model_joints[:, gt_joints_ind]) ** 2
217
+
218
+ # Loss that penalizes deviation from depth estimate
219
+ depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2
220
+
221
+ total_loss = j3d_error_loss + depth_loss
222
+ return total_loss.sum()
visualize/joints2smpl/src/prior.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import print_function
19
+ from __future__ import division
20
+
21
+ import sys
22
+ import os
23
+
24
+ import time
25
+ import pickle
26
+
27
+ import numpy as np
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+
32
+ DEFAULT_DTYPE = torch.float32
33
+
34
+
35
+ def create_prior(prior_type, **kwargs):
36
+ if prior_type == 'gmm':
37
+ prior = MaxMixturePrior(**kwargs)
38
+ elif prior_type == 'l2':
39
+ return L2Prior(**kwargs)
40
+ elif prior_type == 'angle':
41
+ return SMPLifyAnglePrior(**kwargs)
42
+ elif prior_type == 'none' or prior_type is None:
43
+ # Don't use any pose prior
44
+ def no_prior(*args, **kwargs):
45
+ return 0.0
46
+ prior = no_prior
47
+ else:
48
+ raise ValueError('Prior {}'.format(prior_type) + ' is not implemented')
49
+ return prior
50
+
51
+
52
+ class SMPLifyAnglePrior(nn.Module):
53
+ def __init__(self, dtype=torch.float32, **kwargs):
54
+ super(SMPLifyAnglePrior, self).__init__()
55
+
56
+ # Indices for the roration angle of
57
+ # 55: left elbow, 90deg bend at -np.pi/2
58
+ # 58: right elbow, 90deg bend at np.pi/2
59
+ # 12: left knee, 90deg bend at np.pi/2
60
+ # 15: right knee, 90deg bend at np.pi/2
61
+ angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64)
62
+ angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long)
63
+ self.register_buffer('angle_prior_idxs', angle_prior_idxs)
64
+
65
+ angle_prior_signs = np.array([1, -1, -1, -1],
66
+ dtype=np.float32 if dtype == torch.float32
67
+ else np.float64)
68
+ angle_prior_signs = torch.tensor(angle_prior_signs,
69
+ dtype=dtype)
70
+ self.register_buffer('angle_prior_signs', angle_prior_signs)
71
+
72
+ def forward(self, pose, with_global_pose=False):
73
+ ''' Returns the angle prior loss for the given pose
74
+
75
+ Args:
76
+ pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle
77
+ representation of the rotations of the joints of the SMPL model.
78
+ Kwargs:
79
+ with_global_pose: Whether the pose vector also contains the global
80
+ orientation of the SMPL model. If not then the indices must be
81
+ corrected.
82
+ Returns:
83
+ A sze (B) tensor containing the angle prior loss for each element
84
+ in the batch.
85
+ '''
86
+ angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3
87
+ return torch.exp(pose[:, angle_prior_idxs] *
88
+ self.angle_prior_signs).pow(2)
89
+
90
+
91
+ class L2Prior(nn.Module):
92
+ def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs):
93
+ super(L2Prior, self).__init__()
94
+
95
+ def forward(self, module_input, *args):
96
+ return torch.sum(module_input.pow(2))
97
+
98
+
99
+ class MaxMixturePrior(nn.Module):
100
+
101
+ def __init__(self, prior_folder='prior',
102
+ num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16,
103
+ use_merged=True,
104
+ **kwargs):
105
+ super(MaxMixturePrior, self).__init__()
106
+
107
+ if dtype == DEFAULT_DTYPE:
108
+ np_dtype = np.float32
109
+ elif dtype == torch.float64:
110
+ np_dtype = np.float64
111
+ else:
112
+ print('Unknown float type {}, exiting!'.format(dtype))
113
+ sys.exit(-1)
114
+
115
+ self.num_gaussians = num_gaussians
116
+ self.epsilon = epsilon
117
+ self.use_merged = use_merged
118
+ gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians)
119
+
120
+ full_gmm_fn = os.path.join(prior_folder, gmm_fn)
121
+ if not os.path.exists(full_gmm_fn):
122
+ print('The path to the mixture prior "{}"'.format(full_gmm_fn) +
123
+ ' does not exist, exiting!')
124
+ sys.exit(-1)
125
+
126
+ with open(full_gmm_fn, 'rb') as f:
127
+ gmm = pickle.load(f, encoding='latin1')
128
+
129
+ if type(gmm) == dict:
130
+ means = gmm['means'].astype(np_dtype)
131
+ covs = gmm['covars'].astype(np_dtype)
132
+ weights = gmm['weights'].astype(np_dtype)
133
+ elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)):
134
+ means = gmm.means_.astype(np_dtype)
135
+ covs = gmm.covars_.astype(np_dtype)
136
+ weights = gmm.weights_.astype(np_dtype)
137
+ else:
138
+ print('Unknown type for the prior: {}, exiting!'.format(type(gmm)))
139
+ sys.exit(-1)
140
+
141
+ self.register_buffer('means', torch.tensor(means, dtype=dtype))
142
+
143
+ self.register_buffer('covs', torch.tensor(covs, dtype=dtype))
144
+
145
+ precisions = [np.linalg.inv(cov) for cov in covs]
146
+ precisions = np.stack(precisions).astype(np_dtype)
147
+
148
+ self.register_buffer('precisions',
149
+ torch.tensor(precisions, dtype=dtype))
150
+
151
+ # The constant term:
152
+ sqrdets = np.array([(np.sqrt(np.linalg.det(c)))
153
+ for c in gmm['covars']])
154
+ const = (2 * np.pi)**(69 / 2.)
155
+
156
+ nll_weights = np.asarray(gmm['weights'] / (const *
157
+ (sqrdets / sqrdets.min())))
158
+ nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0)
159
+ self.register_buffer('nll_weights', nll_weights)
160
+
161
+ weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0)
162
+ self.register_buffer('weights', weights)
163
+
164
+ self.register_buffer('pi_term',
165
+ torch.log(torch.tensor(2 * np.pi, dtype=dtype)))
166
+
167
+ cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon)
168
+ for cov in covs]
169
+ self.register_buffer('cov_dets',
170
+ torch.tensor(cov_dets, dtype=dtype))
171
+
172
+ # The dimensionality of the random variable
173
+ self.random_var_dim = self.means.shape[1]
174
+
175
+ def get_mean(self):
176
+ ''' Returns the mean of the mixture '''
177
+ mean_pose = torch.matmul(self.weights, self.means)
178
+ return mean_pose
179
+
180
+ def merged_log_likelihood(self, pose, betas):
181
+ diff_from_mean = pose.unsqueeze(dim=1) - self.means
182
+
183
+ prec_diff_prod = torch.einsum('mij,bmj->bmi',
184
+ [self.precisions, diff_from_mean])
185
+ diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1)
186
+
187
+ curr_loglikelihood = 0.5 * diff_prec_quadratic - \
188
+ torch.log(self.nll_weights)
189
+ # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) +
190
+ # self.random_var_dim * self.pi_term +
191
+ # diff_prec_quadratic
192
+ # ) - torch.log(self.weights)
193
+
194
+ min_likelihood, _ = torch.min(curr_loglikelihood, dim=1)
195
+ return min_likelihood
196
+
197
+ def log_likelihood(self, pose, betas, *args, **kwargs):
198
+ ''' Create graph operation for negative log-likelihood calculation
199
+ '''
200
+ likelihoods = []
201
+
202
+ for idx in range(self.num_gaussians):
203
+ mean = self.means[idx]
204
+ prec = self.precisions[idx]
205
+ cov = self.covs[idx]
206
+ diff_from_mean = pose - mean
207
+
208
+ curr_loglikelihood = torch.einsum('bj,ji->bi',
209
+ [diff_from_mean, prec])
210
+ curr_loglikelihood = torch.einsum('bi,bi->b',
211
+ [curr_loglikelihood,
212
+ diff_from_mean])
213
+ cov_term = torch.log(torch.det(cov) + self.epsilon)
214
+ curr_loglikelihood += 0.5 * (cov_term +
215
+ self.random_var_dim *
216
+ self.pi_term)
217
+ likelihoods.append(curr_loglikelihood)
218
+
219
+ log_likelihoods = torch.stack(likelihoods, dim=1)
220
+ min_idx = torch.argmin(log_likelihoods, dim=1)
221
+ weight_component = self.nll_weights[:, min_idx]
222
+ weight_component = -torch.log(weight_component)
223
+
224
+ return weight_component + log_likelihoods[:, min_idx]
225
+
226
+ def forward(self, pose, betas):
227
+ if self.use_merged:
228
+ return self.merged_log_likelihood(pose, betas)
229
+ else:
230
+ return self.log_likelihood(pose, betas)
visualize/joints2smpl/src/smplify.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os, sys
3
+ import pickle
4
+ import smplx
5
+ import numpy as np
6
+
7
+ sys.path.append(os.path.dirname(__file__))
8
+ from customloss import (camera_fitting_loss,
9
+ body_fitting_loss,
10
+ camera_fitting_loss_3d,
11
+ body_fitting_loss_3d,
12
+ )
13
+ from prior import MaxMixturePrior
14
+ from visualize.joints2smpl.src import config
15
+
16
+
17
+
18
+ @torch.no_grad()
19
+ def guess_init_3d(model_joints,
20
+ j3d,
21
+ joints_category="orig"):
22
+ """Initialize the camera translation via triangle similarity, by using the torso joints .
23
+ :param model_joints: SMPL model with pre joints
24
+ :param j3d: 25x3 array of Kinect Joints
25
+ :returns: 3D vector corresponding to the estimated camera translation
26
+ """
27
+ # get the indexed four
28
+ gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder']
29
+ gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
30
+
31
+ if joints_category=="orig":
32
+ joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints]
33
+ elif joints_category=="AMASS":
34
+ joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints]
35
+ else:
36
+ print("NO SUCH JOINTS CATEGORY!")
37
+
38
+ sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1)
39
+ init_t = sum_init_t / 4.0
40
+ return init_t
41
+
42
+
43
+ # SMPLIfy 3D
44
+ class SMPLify3D():
45
+ """Implementation of SMPLify, use 3D joints."""
46
+
47
+ def __init__(self,
48
+ smplxmodel,
49
+ step_size=1e-2,
50
+ batch_size=1,
51
+ num_iters=100,
52
+ use_collision=False,
53
+ use_lbfgs=True,
54
+ joints_category="orig",
55
+ device=torch.device('cuda:0'),
56
+ ):
57
+
58
+ # Store options
59
+ self.batch_size = batch_size
60
+ self.device = device
61
+ self.step_size = step_size
62
+
63
+ self.num_iters = num_iters
64
+ # --- choose optimizer
65
+ self.use_lbfgs = use_lbfgs
66
+ # GMM pose prior
67
+ self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR,
68
+ num_gaussians=8,
69
+ dtype=torch.float32).to(device)
70
+ # collision part
71
+ self.use_collision = use_collision
72
+ if self.use_collision:
73
+ self.part_segm_fn = config.Part_Seg_DIR
74
+
75
+ # reLoad SMPL-X model
76
+ self.smpl = smplxmodel
77
+
78
+ self.model_faces = smplxmodel.faces_tensor.view(-1)
79
+
80
+ # select joint joint_category
81
+ self.joints_category = joints_category
82
+
83
+ if joints_category=="orig":
84
+ self.smpl_index = config.full_smpl_idx
85
+ self.corr_index = config.full_smpl_idx
86
+ elif joints_category=="AMASS":
87
+ self.smpl_index = config.amass_smpl_idx
88
+ self.corr_index = config.amass_idx
89
+ else:
90
+ self.smpl_index = None
91
+ self.corr_index = None
92
+ print("NO SUCH JOINTS CATEGORY!")
93
+
94
+ # ---- get the man function here ------
95
+ def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0):
96
+ """Perform body fitting.
97
+ Input:
98
+ init_pose: SMPL pose estimate
99
+ init_betas: SMPL betas estimate
100
+ init_cam_t: Camera translation estimate
101
+ j3d: joints 3d aka keypoints
102
+ conf_3d: confidence for 3d joints
103
+ seq_ind: index of the sequence
104
+ Returns:
105
+ vertices: Vertices of optimized shape
106
+ joints: 3D joints of optimized shape
107
+ pose: SMPL pose parameters of optimized shape
108
+ betas: SMPL beta parameters of optimized shape
109
+ camera_translation: Camera translation
110
+ """
111
+
112
+ # # # add the mesh inter-section to avoid
113
+ search_tree = None
114
+ pen_distance = None
115
+ filter_faces = None
116
+
117
+ if self.use_collision:
118
+ from mesh_intersection.bvh_search_tree import BVH
119
+ import mesh_intersection.loss as collisions_loss
120
+ from mesh_intersection.filter_faces import FilterFaces
121
+
122
+ search_tree = BVH(max_collisions=8)
123
+
124
+ pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
125
+ sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True)
126
+
127
+ if self.part_segm_fn:
128
+ # Read the part segmentation
129
+ part_segm_fn = os.path.expandvars(self.part_segm_fn)
130
+ with open(part_segm_fn, 'rb') as faces_parents_file:
131
+ face_segm_data = pickle.load(faces_parents_file, encoding='latin1')
132
+ faces_segm = face_segm_data['segm']
133
+ faces_parents = face_segm_data['parents']
134
+ # Create the module used to filter invalid collision pairs
135
+ filter_faces = FilterFaces(
136
+ faces_segm=faces_segm, faces_parents=faces_parents,
137
+ ign_part_pairs=None).to(device=self.device)
138
+
139
+
140
+ # Split SMPL pose to body pose and global orientation
141
+ body_pose = init_pose[:, 3:].detach().clone()
142
+ global_orient = init_pose[:, :3].detach().clone()
143
+ betas = init_betas.detach().clone()
144
+
145
+ # use guess 3d to get the initial
146
+ smpl_output = self.smpl(global_orient=global_orient,
147
+ body_pose=body_pose,
148
+ betas=betas)
149
+ model_joints = smpl_output.joints
150
+
151
+ init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach()
152
+ camera_translation = init_cam_t.clone()
153
+
154
+ preserve_pose = init_pose[:, 3:].detach().clone()
155
+ # -------------Step 1: Optimize camera translation and body orientation--------
156
+ # Optimize only camera translation and body orientation
157
+ body_pose.requires_grad = False
158
+ betas.requires_grad = False
159
+ global_orient.requires_grad = True
160
+ camera_translation.requires_grad = True
161
+
162
+ camera_opt_params = [global_orient, camera_translation]
163
+
164
+ if self.use_lbfgs:
165
+ camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters,
166
+ lr=self.step_size, line_search_fn='strong_wolfe')
167
+ for i in range(10):
168
+ def closure():
169
+ camera_optimizer.zero_grad()
170
+ smpl_output = self.smpl(global_orient=global_orient,
171
+ body_pose=body_pose,
172
+ betas=betas)
173
+ model_joints = smpl_output.joints
174
+ # print('model_joints', model_joints.shape)
175
+ # print('camera_translation', camera_translation.shape)
176
+ # print('init_cam_t', init_cam_t.shape)
177
+ # print('j3d', j3d.shape)
178
+ loss = camera_fitting_loss_3d(model_joints, camera_translation,
179
+ init_cam_t, j3d, self.joints_category)
180
+ loss.backward()
181
+ return loss
182
+
183
+ camera_optimizer.step(closure)
184
+ else:
185
+ camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999))
186
+
187
+ for i in range(20):
188
+ smpl_output = self.smpl(global_orient=global_orient,
189
+ body_pose=body_pose,
190
+ betas=betas)
191
+ model_joints = smpl_output.joints
192
+
193
+ loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation,
194
+ init_cam_t, j3d[:, self.corr_index], self.joints_category)
195
+ camera_optimizer.zero_grad()
196
+ loss.backward()
197
+ camera_optimizer.step()
198
+
199
+ # Fix camera translation after optimizing camera
200
+ # --------Step 2: Optimize body joints --------------------------
201
+ # Optimize only the body pose and global orientation of the body
202
+ body_pose.requires_grad = True
203
+ global_orient.requires_grad = True
204
+ camera_translation.requires_grad = True
205
+
206
+ # --- if we use the sequence, fix the shape
207
+ if seq_ind == 0:
208
+ betas.requires_grad = True
209
+ body_opt_params = [body_pose, betas, global_orient, camera_translation]
210
+ else:
211
+ betas.requires_grad = False
212
+ body_opt_params = [body_pose, global_orient, camera_translation]
213
+
214
+ if self.use_lbfgs:
215
+ body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters,
216
+ lr=self.step_size, line_search_fn='strong_wolfe')
217
+ for i in range(self.num_iters):
218
+ def closure():
219
+ body_optimizer.zero_grad()
220
+ smpl_output = self.smpl(global_orient=global_orient,
221
+ body_pose=body_pose,
222
+ betas=betas)
223
+ model_joints = smpl_output.joints
224
+ model_vertices = smpl_output.vertices
225
+
226
+ loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
227
+ j3d[:, self.corr_index], self.pose_prior,
228
+ joints3d_conf=conf_3d,
229
+ joint_loss_weight=600.0,
230
+ pose_preserve_weight=5.0,
231
+ use_collision=self.use_collision,
232
+ model_vertices=model_vertices, model_faces=self.model_faces,
233
+ search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
234
+ loss.backward()
235
+ return loss
236
+
237
+ body_optimizer.step(closure)
238
+ else:
239
+ body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999))
240
+
241
+ for i in range(self.num_iters):
242
+ smpl_output = self.smpl(global_orient=global_orient,
243
+ body_pose=body_pose,
244
+ betas=betas)
245
+ model_joints = smpl_output.joints
246
+ model_vertices = smpl_output.vertices
247
+
248
+ loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
249
+ j3d[:, self.corr_index], self.pose_prior,
250
+ joints3d_conf=conf_3d,
251
+ joint_loss_weight=600.0,
252
+ use_collision=self.use_collision,
253
+ model_vertices=model_vertices, model_faces=self.model_faces,
254
+ search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
255
+ body_optimizer.zero_grad()
256
+ loss.backward()
257
+ body_optimizer.step()
258
+
259
+ # Get final loss value
260
+ with torch.no_grad():
261
+ smpl_output = self.smpl(global_orient=global_orient,
262
+ body_pose=body_pose,
263
+ betas=betas, return_full_pose=True)
264
+ model_joints = smpl_output.joints
265
+ model_vertices = smpl_output.vertices
266
+
267
+ final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
268
+ j3d[:, self.corr_index], self.pose_prior,
269
+ joints3d_conf=conf_3d,
270
+ joint_loss_weight=600.0,
271
+ use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces,
272
+ search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
273
+
274
+ vertices = smpl_output.vertices.detach()
275
+ joints = smpl_output.joints.detach()
276
+ pose = torch.cat([global_orient, body_pose], dim=-1).detach()
277
+ betas = betas.detach()
278
+
279
+ return vertices, joints, pose, betas, camera_translation, final_loss
visualize/render_mesh.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from visualize import vis_utils
4
+ import shutil
5
+ from tqdm import tqdm
6
+
7
+ if __name__ == '__main__':
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--input_path", type=str, required=True, help='stick figure mp4 file to be rendered.')
10
+ parser.add_argument("--cuda", type=bool, default=True, help='')
11
+ parser.add_argument("--device", type=int, default=0, help='')
12
+ params = parser.parse_args()
13
+
14
+ assert params.input_path.endswith('.mp4')
15
+ parsed_name = os.path.basename(params.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '')
16
+ sample_i, rep_i = [int(e) for e in parsed_name.split('_')]
17
+ npy_path = os.path.join(os.path.dirname(params.input_path), 'results.npy')
18
+ out_npy_path = params.input_path.replace('.mp4', '_smpl_params.npy')
19
+ assert os.path.exists(npy_path)
20
+ results_dir = params.input_path.replace('.mp4', '_obj')
21
+ if os.path.exists(results_dir):
22
+ shutil.rmtree(results_dir)
23
+ os.makedirs(results_dir)
24
+
25
+ npy2obj = vis_utils.npy2obj(npy_path, sample_i, rep_i,
26
+ device=params.device, cuda=params.cuda)
27
+
28
+ print('Saving obj files to [{}]'.format(os.path.abspath(results_dir)))
29
+ for frame_i in tqdm(range(npy2obj.real_num_frames)):
30
+ npy2obj.save_obj(os.path.join(results_dir, 'frame{:03d}.obj'.format(frame_i)), frame_i)
31
+
32
+ print('Saving SMPL params to [{}]'.format(os.path.abspath(out_npy_path)))
33
+ npy2obj.save_npy(out_npy_path)
visualize/simplify_loc2rot.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ from visualize.joints2smpl.src import config
5
+ import smplx
6
+ import h5py
7
+ from visualize.joints2smpl.src.smplify import SMPLify3D
8
+ from tqdm import tqdm
9
+ import utils.rotation_conversions as geometry
10
+ import argparse
11
+
12
+
13
+ class joints2smpl:
14
+
15
+ def __init__(self, num_frames, device_id, cuda=True):
16
+ self.device = torch.device("cuda:" + str(device_id) if cuda else "cpu")
17
+ # self.device = torch.device("cpu")
18
+ self.batch_size = num_frames
19
+ self.num_joints = 22 # for HumanML3D
20
+ self.joint_category = "AMASS"
21
+ self.num_smplify_iters = 120
22
+ self.fix_foot = False
23
+ print(config.SMPL_MODEL_DIR)
24
+ smplmodel = smplx.create(config.SMPL_MODEL_DIR,
25
+ model_type="smpl", gender="neutral", ext="pkl",
26
+ batch_size=self.batch_size).to(self.device)
27
+
28
+ # ## --- load the mean pose as original ----
29
+ smpl_mean_file = config.SMPL_MEAN_FILE
30
+
31
+ file = h5py.File(smpl_mean_file, 'r')
32
+ self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device)
33
+ self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device)
34
+ self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device)
35
+ #
36
+
37
+ # # #-------------initialize SMPLify
38
+ self.smplify = SMPLify3D(smplxmodel=smplmodel,
39
+ batch_size=self.batch_size,
40
+ joints_category=self.joint_category,
41
+ num_iters=self.num_smplify_iters,
42
+ device=self.device)
43
+
44
+
45
+ def npy2smpl(self, npy_path):
46
+ out_path = npy_path.replace('.npy', '_rot.npy')
47
+ motions = np.load(npy_path, allow_pickle=True)[None][0]
48
+ # print_batch('', motions)
49
+ n_samples = motions['motion'].shape[0]
50
+ all_thetas = []
51
+ for sample_i in tqdm(range(n_samples)):
52
+ thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3]
53
+ all_thetas.append(thetas.cpu().numpy())
54
+ motions['motion'] = np.concatenate(all_thetas, axis=0)
55
+ print('motions', motions['motion'].shape)
56
+
57
+ print(f'Saving [{out_path}]')
58
+ np.save(out_path, motions)
59
+ exit()
60
+
61
+
62
+
63
+ def joint2smpl(self, input_joints, init_params=None):
64
+ _smplify = self.smplify # if init_params is None else self.smplify_fast
65
+ pred_pose = torch.zeros(self.batch_size, 72).to(self.device)
66
+ pred_betas = torch.zeros(self.batch_size, 10).to(self.device)
67
+ pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device)
68
+ keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device)
69
+
70
+ # run the whole seqs
71
+ num_seqs = input_joints.shape[0]
72
+
73
+
74
+ # joints3d = input_joints[idx] # *1.2 #scale problem [check first]
75
+ keypoints_3d = torch.Tensor(input_joints).to(self.device).float()
76
+
77
+ # if idx == 0:
78
+ if init_params is None:
79
+ pred_betas = self.init_mean_shape
80
+ pred_pose = self.init_mean_pose
81
+ pred_cam_t = self.cam_trans_zero
82
+ else:
83
+ pred_betas = init_params['betas']
84
+ pred_pose = init_params['pose']
85
+ pred_cam_t = init_params['cam']
86
+
87
+ if self.joint_category == "AMASS":
88
+ confidence_input = torch.ones(self.num_joints)
89
+ # make sure the foot and ankle
90
+ if self.fix_foot == True:
91
+ confidence_input[7] = 1.5
92
+ confidence_input[8] = 1.5
93
+ confidence_input[10] = 1.5
94
+ confidence_input[11] = 1.5
95
+ else:
96
+ print("Such category not settle down!")
97
+
98
+ new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \
99
+ new_opt_cam_t, new_opt_joint_loss = _smplify(
100
+ pred_pose.detach(),
101
+ pred_betas.detach(),
102
+ pred_cam_t.detach(),
103
+ keypoints_3d,
104
+ conf_3d=confidence_input.to(self.device),
105
+ # seq_ind=idx
106
+ )
107
+
108
+ thetas = new_opt_pose.reshape(self.batch_size, 24, 3)
109
+ thetas = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(thetas)) # [bs, 24, 6]
110
+ root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3]
111
+ root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze(1) # [bs, 1, 6]
112
+ thetas = torch.cat([thetas, root_loc], dim=1).unsqueeze(0).permute(0, 2, 3, 1) # [1, 25, 6, 196]
113
+
114
+ return thetas.clone().detach(), {'pose': new_opt_joints[0, :24].flatten().clone().detach(), 'betas': new_opt_betas.clone().detach(), 'cam': new_opt_cam_t.clone().detach()}
115
+
116
+
117
+ if __name__ == '__main__':
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files')
120
+ parser.add_argument("--cuda", type=bool, default=True, help='')
121
+ parser.add_argument("--device", type=int, default=0, help='')
122
+ params = parser.parse_args()
123
+
124
+ simplify = joints2smpl(device_id=params.device, cuda=params.cuda)
125
+
126
+ if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'):
127
+ simplify.npy2smpl(params.input_path)
128
+ elif os.path.isdir(params.input_path):
129
+ files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')]
130
+ for f in files:
131
+ simplify.npy2smpl(f)
visualize/vis_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.rotation2xyz import Rotation2xyz
2
+ import numpy as np
3
+ from trimesh import Trimesh
4
+ import os
5
+ import torch
6
+ from visualize.simplify_loc2rot import joints2smpl
7
+
8
+ class npy2obj:
9
+ def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True):
10
+ self.npy_path = npy_path
11
+ self.motions = np.load(self.npy_path, allow_pickle=True)
12
+ if self.npy_path.endswith('.npz'):
13
+ self.motions = self.motions['arr_0']
14
+ self.motions = self.motions[None][0]
15
+ self.rot2xyz = Rotation2xyz(device='cpu')
16
+ self.faces = self.rot2xyz.smpl_model.faces
17
+ self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape
18
+ self.opt_cache = {}
19
+ self.sample_idx = sample_idx
20
+ self.total_num_samples = self.motions['num_samples']
21
+ self.rep_idx = rep_idx
22
+ self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx
23
+ self.num_frames = self.motions['motion'][self.absl_idx].shape[-1]
24
+ self.j2s = joints2smpl(num_frames=self.num_frames, device_id=device, cuda=cuda)
25
+
26
+ if self.nfeats == 3:
27
+ print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.')
28
+ motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3]
29
+ self.motions['motion'] = motion_tensor.cpu().numpy()
30
+ elif self.nfeats == 6:
31
+ self.motions['motion'] = self.motions['motion'][[self.absl_idx]]
32
+ self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape
33
+ self.real_num_frames = self.motions['lengths'][self.absl_idx]
34
+
35
+ self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None,
36
+ pose_rep='rot6d', translation=True, glob=True,
37
+ jointstype='vertices',
38
+ # jointstype='smpl', # for joint locations
39
+ vertstrans=True)
40
+ self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1)
41
+ self.vertices += self.root_loc
42
+
43
+ def get_vertices(self, sample_i, frame_i):
44
+ return self.vertices[sample_i, :, :, frame_i].squeeze().tolist()
45
+
46
+ def get_trimesh(self, sample_i, frame_i):
47
+ return Trimesh(vertices=self.get_vertices(sample_i, frame_i),
48
+ faces=self.faces)
49
+
50
+ def save_obj(self, save_path, frame_i):
51
+ mesh = self.get_trimesh(0, frame_i)
52
+ with open(save_path, 'w') as fw:
53
+ mesh.export(fw, 'obj')
54
+ return save_path
55
+
56
+ def save_npy(self, save_path):
57
+ data_dict = {
58
+ 'motion': self.motions['motion'][0, :, :, :self.real_num_frames],
59
+ 'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames],
60
+ 'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames],
61
+ 'faces': self.faces,
62
+ 'vertices': self.vertices[0, :, :, :self.real_num_frames],
63
+ 'text': self.motions['text'][0],
64
+ 'length': self.real_num_frames,
65
+ }
66
+ np.save(save_path, data_dict)