Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2020 The Google AI Perception Team Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Estimate AIST++ SMPL-format Motion.""" | |
| import os | |
| import pickle | |
| from absl import app | |
| from absl import flags | |
| from absl import logging | |
| from aist_plusplus.loader import AISTDataset | |
| import numpy as np | |
| from smplx import SMPL | |
| import torch | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_string( | |
| 'anno_dir', | |
| '/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/', | |
| 'input local dictionary for AIST++ annotations.') | |
| flags.DEFINE_string( | |
| 'smpl_dir', | |
| '/usr/local/google/home/ruilongli/data/SMPL/', | |
| 'input local dictionary that stores SMPL data.') | |
| flags.DEFINE_string( | |
| 'save_dir', | |
| '/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/motions/', | |
| 'output local dictionary that stores AIST++ SMPL-format motion data.') | |
| flags.DEFINE_list( | |
| 'sequence_names', | |
| None, | |
| 'list of sequence names to be processed. None means to process all.') | |
| flags.DEFINE_string( | |
| 'save_dir_gcs', | |
| None, | |
| 'output GCS directory.') | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| def unify_joint_mappings(dataset='openpose25'): | |
| """Unify different joint definations. | |
| Output unified defination: | |
| ['Nose', | |
| 'RShoulder', 'RElbow', 'RWrist', | |
| 'LShoulder', 'LElbow', 'LWrist', | |
| 'RHip', 'RKnee', 'RAnkle', | |
| 'LHip', 'LKnee', 'LAnkle', | |
| 'REye', 'LEye', | |
| 'REar', 'LEar', | |
| 'LBigToe', 'LHeel', | |
| 'RBigToe', 'RHeel',] | |
| Args: | |
| dataset: `openpose25`, `coco`(17) and `smpl`. | |
| Returns: | |
| a list of indexs that maps the joints to a unified defination. | |
| """ | |
| if dataset == 'openpose25': | |
| return np.array([ | |
| 0, | |
| 2, 3, 4, | |
| 5, 6, 7, | |
| 9, 10, 11, | |
| 12, 13, 14, | |
| 15, 16, | |
| 17, 18, | |
| 19, 21, | |
| 22, 24, | |
| ], dtype=np.int32) | |
| elif dataset == 'smpl': | |
| return np.array([ | |
| 24, | |
| 17, 19, 21, | |
| 16, 18, 20, | |
| 2, 5, 8, | |
| 1, 4, 7, | |
| 25, 26, | |
| 27, 28, | |
| 29, 31, | |
| 32, 34, | |
| ], dtype=np.int32) | |
| elif dataset == 'coco': | |
| return np.array([ | |
| 0, | |
| 5, 7, 9, | |
| 6, 8, 10, | |
| 11, 13, 15, | |
| 12, 14, 16, | |
| 1, 2, | |
| 3, 4, | |
| ], dtype=np.int32) | |
| else: | |
| raise ValueError(f'{dataset} is not supported') | |
| class SMPLRegressor: | |
| """SMPL fitting based on 3D keypoints.""" | |
| def __init__(self, smpl_model_path, smpl_model_gener='MALE'): | |
| # Fitting hyper-parameters | |
| self.base_lr = 100.0 | |
| self.niter = 10000 | |
| self.metric = torch.nn.MSELoss() | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.smpl_model_path = smpl_model_path | |
| self.smpl_model_gender = smpl_model_gener | |
| # Mapping to unify joint definations | |
| self.joints_mapping_smpl = unify_joint_mappings(dataset='smpl') | |
| def get_optimizer(self, smpl, step, base_lr): | |
| """Setup opimizer with a warm up learning rate.""" | |
| if step < 100: | |
| optimizer = torch.optim.SGD([ | |
| {'params': [smpl.transl], 'lr': base_lr}, | |
| {'params': [smpl.scaling], 'lr': base_lr * 0.01}, | |
| {'params': [smpl.global_orient], 'lr': 0.0}, | |
| {'params': [smpl.body_pose], 'lr': 0.0}, | |
| {'params': [smpl.betas], 'lr': 0.0}, | |
| ]) | |
| elif step < 400: | |
| optimizer = torch.optim.SGD([ | |
| {'params': [smpl.transl], 'lr': base_lr}, | |
| {'params': [smpl.scaling], 'lr': base_lr * 0.01}, | |
| {'params': [smpl.global_orient], 'lr': base_lr * 0.001}, | |
| {'params': [smpl.body_pose], 'lr': 0.0}, | |
| {'params': [smpl.betas], 'lr': 0.0}, | |
| ]) | |
| else: | |
| optimizer = torch.optim.SGD([ | |
| {'params': [smpl.transl], 'lr': base_lr}, | |
| {'params': [smpl.scaling], 'lr': base_lr * 0.01}, | |
| {'params': [smpl.global_orient], 'lr': base_lr * 0.001}, | |
| {'params': [smpl.body_pose], 'lr': base_lr * 0.001}, | |
| {'params': [smpl.betas], 'lr': 0.0}, | |
| ]) | |
| return optimizer | |
| def fit(self, keypoints3d, dtype='coco', verbose=True): | |
| """Run fitting to optimize the SMPL parameters.""" | |
| assert dtype == 'coco', 'only support coco format for now.' | |
| assert len(keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]' | |
| mapping_target = unify_joint_mappings(dataset=dtype) | |
| keypoints3d = keypoints3d[:, mapping_target, :] | |
| keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device) | |
| batch_size, njoints = keypoints3d.shape[0:2] | |
| # Init learnable smpl model | |
| smpl = SMPL( | |
| model_path=self.smpl_model_path, | |
| gender=self.smpl_model_gender, | |
| batch_size=batch_size).to(self.device) | |
| # Start fitting | |
| for step in range(self.niter): | |
| optimizer = self.get_optimizer(smpl, step, self.base_lr) | |
| output = smpl.forward() | |
| joints = output.joints[:, self.joints_mapping_smpl[:njoints], :] | |
| loss = self.metric(joints, keypoints3d) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if verbose and step % 10 == 0: | |
| logging.info(f'step {step:03d}; loss {loss.item():.3f};') | |
| # Return results | |
| return smpl, loss.item() | |
| def main(_): | |
| aist_dataset = AISTDataset(FLAGS.anno_dir) | |
| smpl_regressor = SMPLRegressor(FLAGS.smpl_dir, 'MALE') | |
| if FLAGS.sequence_names: | |
| seq_names = FLAGS.sequence_names | |
| else: | |
| seq_names = aist_dataset.mapping_seq2env.keys() | |
| for seq_name in seq_names: | |
| logging.info('processing %s', seq_name) | |
| # load 3D keypoints | |
| keypoints3d = AISTDataset.load_keypoint3d( | |
| aist_dataset.keypoint3d_dir, seq_name, use_optim=True) | |
| # SMPL fitting | |
| smpl, loss = smpl_regressor.fit(keypoints3d, dtype='coco', verbose=True) | |
| # One last time forward | |
| with torch.no_grad(): | |
| _ = smpl.forward() | |
| body_pose = smpl.body_pose.detach().cpu().numpy() | |
| global_orient = smpl.global_orient.detach().cpu().numpy() | |
| smpl_poses = np.concatenate([global_orient, body_pose], axis=1) | |
| smpl_scaling = smpl.scaling.detach().cpu().numpy() | |
| smpl_trans = smpl.transl.detach().cpu().numpy() | |
| os.makedirs(FLAGS.save_dir, exist_ok=True) | |
| motion_file = os.path.join(FLAGS.save_dir, f'{seq_name}.pkl') | |
| with open(motion_file, 'wb') as f: | |
| pickle.dump({ | |
| 'smpl_poses': smpl_poses, | |
| 'smpl_scaling': smpl_scaling, | |
| 'smpl_trans': smpl_trans, | |
| 'smpl_loss': loss, | |
| }, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| # upload results to GCS | |
| if FLAGS.save_dir_gcs: | |
| import gcs_utils | |
| gcs_utils.upload_files_to_gcs( | |
| local_folder=FLAGS.save_dir, | |
| gcs_path=FLAGS.save_dir_gcs) | |
| if __name__ == '__main__': | |
| app.run(main) | |