Yuliang commited on
Commit
eaf88bc
1 Parent(s): ed43914

optimize rot_6d instead of rot_mat

Browse files
apps/infer.py CHANGED
@@ -26,6 +26,7 @@ from lib.dataset.mesh_util import (
26
  unwrap,
27
  remesh,
28
  tensor2variable,
 
29
  )
30
 
31
  from lib.dataset.TestDataset import TestDataset
@@ -165,12 +166,16 @@ def generate_model(in_path, model_type):
165
  for _ in loop_smpl:
166
 
167
  optimizer_smpl.zero_grad()
 
 
 
 
168
 
169
  if dataset_param["hps_type"] != "pixie":
170
  smpl_out = dataset.smpl_model(
171
  betas=optimed_betas,
172
- body_pose=optimed_pose,
173
- global_orient=optimed_orient,
174
  pose2rot=False,
175
  )
176
 
@@ -180,8 +185,8 @@ def generate_model(in_path, model_type):
180
  smpl_verts, _, _ = dataset.smpl_model(
181
  shape_params=optimed_betas,
182
  expression_params=tensor2variable(data["exp"], device),
183
- body_pose=optimed_pose,
184
- global_pose=optimed_orient,
185
  jaw_pose=tensor2variable(data["jaw_pose"], device),
186
  left_hand_pose=tensor2variable(
187
  data["left_hand_pose"], device),
@@ -316,8 +321,8 @@ def generate_model(in_path, model_type):
316
  f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
317
 
318
  smpl_info = {'betas': optimed_betas,
319
- 'pose': optimed_pose,
320
- 'orient': optimed_orient,
321
  'trans': optimed_trans}
322
 
323
  np.save(
 
26
  unwrap,
27
  remesh,
28
  tensor2variable,
29
+ rot6d_to_rotmat
30
  )
31
 
32
  from lib.dataset.TestDataset import TestDataset
 
166
  for _ in loop_smpl:
167
 
168
  optimizer_smpl.zero_grad()
169
+
170
+ # 6d_rot to rot_mat
171
+ optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,6)).unsqueeze(0)
172
+ optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,6)).unsqueeze(0)
173
 
174
  if dataset_param["hps_type"] != "pixie":
175
  smpl_out = dataset.smpl_model(
176
  betas=optimed_betas,
177
+ body_pose=optimed_pose_mat,
178
+ global_orient=optimed_orient_mat,
179
  pose2rot=False,
180
  )
181
 
 
185
  smpl_verts, _, _ = dataset.smpl_model(
186
  shape_params=optimed_betas,
187
  expression_params=tensor2variable(data["exp"], device),
188
+ body_pose=optimed_pose_mat,
189
+ global_pose=optimed_orient_mat,
190
  jaw_pose=tensor2variable(data["jaw_pose"], device),
191
  left_hand_pose=tensor2variable(
192
  data["left_hand_pose"], device),
 
321
  f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")
322
 
323
  smpl_info = {'betas': optimed_betas,
324
+ 'pose': optimed_pose_mat,
325
+ 'orient': optimed_orient_mat,
326
  'trans': optimed_trans}
327
 
328
  np.save(
lib/common/train_util.py CHANGED
@@ -32,6 +32,8 @@ import os
32
  from termcolor import colored
33
 
34
 
 
 
35
  def reshape_sample_tensor(sample_tensor, num_views):
36
  if num_views == 1:
37
  return sample_tensor
 
32
  from termcolor import colored
33
 
34
 
35
+
36
+
37
  def reshape_sample_tensor(sample_tensor, num_views):
38
  if num_views == 1:
39
  return sample_tensor
lib/dataset/TestDataset.py CHANGED
@@ -240,6 +240,11 @@ class TestDataset():
240
  # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
241
  # global_orient - [1, 1, 3, 3]
242
  # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
 
 
 
 
 
243
 
244
  return data_dict
245
 
 
240
  # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
241
  # global_orient - [1, 1, 3, 3]
242
  # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
243
+
244
+ # from rot_mat to rot_6d for better optimization
245
+ N_body = data_dict["body_pose"].shape[1]
246
+ data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body,-1)
247
+ data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1,-1)
248
 
249
  return data_dict
250
 
lib/dataset/mesh_util.py CHANGED
@@ -44,6 +44,22 @@ from pytorch3d.loss import (
44
 
45
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def tensor2variable(tensor, device):
49
  # [1,23,3,3]
 
44
 
45
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
46
 
47
+ def rot6d_to_rotmat(x):
48
+ """Convert 6D rotation representation to 3x3 rotation matrix.
49
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
50
+ Input:
51
+ (B,6) Batch of 6-D rotation representations
52
+ Output:
53
+ (B,3,3) Batch of corresponding rotation matrices
54
+ """
55
+ x = x.view(-1, 3, 2)
56
+ a1 = x[:, :, 0]
57
+ a2 = x[:, :, 1]
58
+ b1 = F.normalize(a1)
59
+ b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
60
+ b3 = torch.cross(b1, b2)
61
+ return torch.stack((b1, b2, b3), dim=-1)
62
+
63
 
64
  def tensor2variable(tensor, device):
65
  # [1,23,3,3]