# coding=utf-8 # Copyright 2020 The Google AI Perception Team Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Estimate AIST++ SMPL-format Motion.""" import os import pickle from absl import app from absl import flags from absl import logging from aist_plusplus.loader import AISTDataset import numpy as np from smplx import SMPL import torch FLAGS = flags.FLAGS flags.DEFINE_string( 'anno_dir', '/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/', 'input local dictionary for AIST++ annotations.') flags.DEFINE_string( 'smpl_dir', '/usr/local/google/home/ruilongli/data/SMPL/', 'input local dictionary that stores SMPL data.') flags.DEFINE_string( 'save_dir', '/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/motions/', 'output local dictionary that stores AIST++ SMPL-format motion data.') flags.DEFINE_list( 'sequence_names', None, 'list of sequence names to be processed. None means to process all.') flags.DEFINE_string( 'save_dir_gcs', None, 'output GCS directory.') np.random.seed(0) torch.manual_seed(0) def unify_joint_mappings(dataset='openpose25'): """Unify different joint definations. Output unified defination: ['Nose', 'RShoulder', 'RElbow', 'RWrist', 'LShoulder', 'LElbow', 'LWrist', 'RHip', 'RKnee', 'RAnkle', 'LHip', 'LKnee', 'LAnkle', 'REye', 'LEye', 'REar', 'LEar', 'LBigToe', 'LHeel', 'RBigToe', 'RHeel',] Args: dataset: `openpose25`, `coco`(17) and `smpl`. Returns: a list of indexs that maps the joints to a unified defination. """ if dataset == 'openpose25': return np.array([ 0, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 24, ], dtype=np.int32) elif dataset == 'smpl': return np.array([ 24, 17, 19, 21, 16, 18, 20, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 31, 32, 34, ], dtype=np.int32) elif dataset == 'coco': return np.array([ 0, 5, 7, 9, 6, 8, 10, 11, 13, 15, 12, 14, 16, 1, 2, 3, 4, ], dtype=np.int32) else: raise ValueError(f'{dataset} is not supported') class SMPLRegressor: """SMPL fitting based on 3D keypoints.""" def __init__(self, smpl_model_path, smpl_model_gener='MALE'): # Fitting hyper-parameters self.base_lr = 100.0 self.niter = 10000 self.metric = torch.nn.MSELoss() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.smpl_model_path = smpl_model_path self.smpl_model_gender = smpl_model_gener # Mapping to unify joint definations self.joints_mapping_smpl = unify_joint_mappings(dataset='smpl') def get_optimizer(self, smpl, step, base_lr): """Setup opimizer with a warm up learning rate.""" if step < 100: optimizer = torch.optim.SGD([ {'params': [smpl.transl], 'lr': base_lr}, {'params': [smpl.scaling], 'lr': base_lr * 0.01}, {'params': [smpl.global_orient], 'lr': 0.0}, {'params': [smpl.body_pose], 'lr': 0.0}, {'params': [smpl.betas], 'lr': 0.0}, ]) elif step < 400: optimizer = torch.optim.SGD([ {'params': [smpl.transl], 'lr': base_lr}, {'params': [smpl.scaling], 'lr': base_lr * 0.01}, {'params': [smpl.global_orient], 'lr': base_lr * 0.001}, {'params': [smpl.body_pose], 'lr': 0.0}, {'params': [smpl.betas], 'lr': 0.0}, ]) else: optimizer = torch.optim.SGD([ {'params': [smpl.transl], 'lr': base_lr}, {'params': [smpl.scaling], 'lr': base_lr * 0.01}, {'params': [smpl.global_orient], 'lr': base_lr * 0.001}, {'params': [smpl.body_pose], 'lr': base_lr * 0.001}, {'params': [smpl.betas], 'lr': 0.0}, ]) return optimizer def fit(self, keypoints3d, dtype='coco', verbose=True): """Run fitting to optimize the SMPL parameters.""" assert dtype == 'coco', 'only support coco format for now.' assert len(keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]' mapping_target = unify_joint_mappings(dataset=dtype) keypoints3d = keypoints3d[:, mapping_target, :] keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device) batch_size, njoints = keypoints3d.shape[0:2] # Init learnable smpl model smpl = SMPL( model_path=self.smpl_model_path, gender=self.smpl_model_gender, batch_size=batch_size).to(self.device) # Start fitting for step in range(self.niter): optimizer = self.get_optimizer(smpl, step, self.base_lr) output = smpl.forward() joints = output.joints[:, self.joints_mapping_smpl[:njoints], :] loss = self.metric(joints, keypoints3d) optimizer.zero_grad() loss.backward() optimizer.step() if verbose and step % 10 == 0: logging.info(f'step {step:03d}; loss {loss.item():.3f};') # Return results return smpl, loss.item() def main(_): aist_dataset = AISTDataset(FLAGS.anno_dir) smpl_regressor = SMPLRegressor(FLAGS.smpl_dir, 'MALE') if FLAGS.sequence_names: seq_names = FLAGS.sequence_names else: seq_names = aist_dataset.mapping_seq2env.keys() for seq_name in seq_names: logging.info('processing %s', seq_name) # load 3D keypoints keypoints3d = AISTDataset.load_keypoint3d( aist_dataset.keypoint3d_dir, seq_name, use_optim=True) # SMPL fitting smpl, loss = smpl_regressor.fit(keypoints3d, dtype='coco', verbose=True) # One last time forward with torch.no_grad(): _ = smpl.forward() body_pose = smpl.body_pose.detach().cpu().numpy() global_orient = smpl.global_orient.detach().cpu().numpy() smpl_poses = np.concatenate([global_orient, body_pose], axis=1) smpl_scaling = smpl.scaling.detach().cpu().numpy() smpl_trans = smpl.transl.detach().cpu().numpy() os.makedirs(FLAGS.save_dir, exist_ok=True) motion_file = os.path.join(FLAGS.save_dir, f'{seq_name}.pkl') with open(motion_file, 'wb') as f: pickle.dump({ 'smpl_poses': smpl_poses, 'smpl_scaling': smpl_scaling, 'smpl_trans': smpl_trans, 'smpl_loss': loss, }, f, protocol=pickle.HIGHEST_PROTOCOL) # upload results to GCS if FLAGS.save_dir_gcs: import gcs_utils gcs_utils.upload_files_to_gcs( local_folder=FLAGS.save_dir, gcs_path=FLAGS.save_dir_gcs) if __name__ == '__main__': app.run(main)