Yzy00518 commited on
Commit
94512e7
·
1 Parent(s): 24e6431

Upload src/inference/joint2smplx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/inference/joint2smplx.py +207 -0
src/inference/joint2smplx.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ import pickle
7
+ from scipy.interpolate import interp1d
8
+
9
+ #############Import fast smplx(modified from original ver)
10
+ local_smplx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', 'deps/smplx'))
11
+ sys.path.insert(0, local_smplx_path)
12
+ import smplx_fast
13
+
14
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
15
+ from utils.transforms import matrix_to_axis_angle, rotation_6d_to_matrix
16
+ from utils.constants import pelvis_shift, relaxed_hand_pose, SELECTED_JOINTS24
17
+
18
+
19
+ ###########This model is used to predict the initial pose for the optimization###########
20
+ class JointsToSMPLX(nn.Module):
21
+ def __init__(self, input_dim, output_dim, hidden_dim, **kwargs):
22
+ super().__init__()
23
+ self.layers = nn.Sequential(
24
+ nn.Linear(input_dim, hidden_dim),
25
+ nn.BatchNorm1d(hidden_dim),
26
+ nn.ReLU(),
27
+ nn.Linear(hidden_dim, hidden_dim),
28
+ nn.BatchNorm1d(hidden_dim),
29
+ nn.ReLU(),
30
+ nn.Linear(hidden_dim, output_dim),
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.layers(x)
35
+
36
+ def get_j2s_model(ckpt_path,
37
+ input_dim=72,
38
+ output_dim=132,
39
+ hidden_dim=64,
40
+ device='cpu'):
41
+ model_joints_to_smplx = JointsToSMPLX(input_dim=input_dim,
42
+ output_dim=output_dim,
43
+ hidden_dim=hidden_dim
44
+ )
45
+ if device == 'cpu':
46
+ map_location = torch.device('cpu')
47
+ else:
48
+ map_location = device
49
+
50
+ model_joints_to_smplx.load_state_dict(torch.load(ckpt_path, map_location=map_location))
51
+ model_joints_to_smplx.eval()
52
+ return model_joints_to_smplx
53
+
54
+ ###########This model is used to predict the initial pose for the optimization###########
55
+
56
+
57
+ def optimize_smpl(pose_pred, joints, joints_ind, smplx_path, print_loss=True):
58
+ device = joints.device
59
+ len = joints.shape[0]
60
+
61
+ smpl_model = smplx_fast.create(smplx_path,
62
+ model_type='smplx_joint_only',
63
+ gender='male', ext='npz',
64
+ num_betas=10,
65
+ use_pca=False,
66
+ create_global_orient=True,
67
+ create_body_pose=True,
68
+ create_betas=True,
69
+ create_left_hand_pose=True,
70
+ create_right_hand_pose=True,
71
+ create_expression=True,
72
+ create_jaw_pose=True,
73
+ create_leye_pose=True,
74
+ create_reye_pose=True,
75
+ create_transl=True,
76
+ batch_size=len,
77
+ ).to(device)
78
+ smpl_model.eval()
79
+
80
+ joints = joints.reshape(len, -1, 3) + torch.tensor(pelvis_shift).to(device)
81
+ pose_input = torch.nn.Parameter(pose_pred.detach(), requires_grad=True)
82
+ transl = torch.nn.Parameter(torch.zeros(pose_pred.shape[0], 3).to(device), requires_grad=True)
83
+ left_hand = torch.from_numpy(relaxed_hand_pose[:45].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device)
84
+ right_hand = torch.from_numpy(relaxed_hand_pose[45:].reshape(1, -1).repeat(pose_pred.shape[0], axis=0)).to(device)
85
+ optimizer = torch.optim.Adam(params=[pose_input, transl], lr=0.05)
86
+ loss_fn = nn.MSELoss()
87
+ vertices_output = None
88
+
89
+ for step in range(120):
90
+ smpl_output = smpl_model(transl=transl,
91
+ body_pose=pose_input[:, 3:],
92
+ global_orient=pose_input[:, :3],
93
+ return_verts=True,
94
+ left_hand_pose=left_hand,# @ left_hand_components[:hand_pca],
95
+ right_hand_pose=right_hand,# @ right_hand_components[:hand_pca],
96
+ )
97
+ joints_output = smpl_output[:, joints_ind].reshape(len, -1, 3)
98
+ loss = loss_fn(joints[:, :], joints_output[:, :])
99
+ optimizer.zero_grad()
100
+ loss.backward()
101
+ optimizer.step()
102
+
103
+ if print_loss:
104
+ print(loss.item(), flush=True)
105
+
106
+ return pose_input.detach().cpu().numpy(), \
107
+ transl.detach().cpu().numpy(), \
108
+ left_hand.detach().cpu().numpy(), \
109
+ right_hand.detach().cpu().numpy(), \
110
+ vertices_output
111
+
112
+
113
+ def joints_to_smpl(model, joints, joints_ind, interp_s, smplx_path, print_loss=True):
114
+ joints = interpolate_joints(joints, scale=interp_s)
115
+ input_len = joints.shape[0]
116
+ joints = joints.reshape(input_len, -1, 3)
117
+ joints = joints.permute(1, 0, 2)
118
+ trans_np = joints[0].detach().cpu().numpy()
119
+ joints = joints - joints[0]
120
+ joints = joints.permute(1, 0, 2)
121
+ joints = joints.reshape(input_len, -1)
122
+ pose_pred = model(joints)
123
+
124
+ pose_pred = pose_pred.reshape(-1, 6)
125
+ pose_pred = matrix_to_axis_angle(rotation_6d_to_matrix(pose_pred)).reshape(input_len, -1)
126
+ pose_output, transl, left_hand, right_hand, vertices = optimize_smpl(pose_pred,
127
+ joints,
128
+ joints_ind,
129
+ smplx_path,
130
+ print_loss=print_loss)
131
+ transl = trans_np - np.array(pelvis_shift) + transl
132
+ return pose_output, transl, left_hand, right_hand, vertices
133
+
134
+ def interpolate_joints(joints, scale):
135
+ if scale == 1:
136
+ return joints
137
+ device = joints.device
138
+ joints = joints.detach().cpu().numpy()
139
+ in_len = joints.shape[0]
140
+ out_len = int(in_len * scale)
141
+ joints = joints.reshape(in_len, -1)
142
+ x = np.array(range(in_len))
143
+ xnew = np.linspace(0, in_len - 1, out_len)
144
+ f = interp1d(x, joints, axis=0)
145
+ joints_new = f(xnew)
146
+ joints_new = torch.from_numpy(joints_new).to(device).float()
147
+
148
+ return joints_new
149
+
150
+
151
+
152
+
153
+ def process_file(file_path, # input dir
154
+ file_name, # input file
155
+ save_path, # output dir
156
+ JointsToSMPLX_model_path, # JointsToSMPLX weight
157
+ smplx_path, # smplx weight
158
+ key_list = ['generated_samples', 'original_samples'],
159
+ joints_ind = SELECTED_JOINTS24,
160
+ interp_s=2, # 2*10=20 fps
161
+ ):
162
+
163
+
164
+ data = np.load(os.path.join(file_path, file_name), allow_pickle=True)
165
+ model = get_j2s_model(ckpt_path=JointsToSMPLX_model_path, device='cpu')
166
+
167
+ for key in key_list: # original_samples, generated_samples, GT
168
+ if key in data:
169
+ joints = torch.tensor(data[key], dtype=torch.float32).reshape(-1, 72)
170
+
171
+ print_loss=False
172
+ if key == 'generated_samples':
173
+ print_loss=True
174
+
175
+ pose, transl, left_hand, right_hand, vertices = joints_to_smpl(model,
176
+ joints,
177
+ joints_ind,
178
+ interp_s,
179
+ smplx_path,
180
+ print_loss=print_loss)
181
+ try:
182
+ data_text = data['text']
183
+ except:
184
+ data_text = None
185
+
186
+ output_data = {
187
+ 'body_pose': pose[:, 3:],
188
+ 'global_orient': pose[:, :3],
189
+ 'transl': transl,
190
+ 'left_hand': left_hand,
191
+ 'right_hand': right_hand,
192
+ 'vertices': vertices,
193
+ 'text': data_text,
194
+ }
195
+
196
+ if key == 'generated_samples':
197
+ try:
198
+ output_data['mask'] = data['mask']
199
+ except:
200
+ output_data['mask'] = None
201
+
202
+ if not os.path.exists(os.path.join(save_path, key)):
203
+ os.makedirs(os.path.join(save_path, key))
204
+
205
+ output_file = os.path.join(os.path.join(save_path, key), file_name)
206
+ with open(output_file, 'wb') as file:
207
+ pickle.dump(output_data, file)