Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2020 The Google AI Perception Team Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Test code for running visualizer.""" | |
import os | |
from absl import app | |
from absl import flags | |
from aist_plusplus.loader import AISTDataset | |
from aist_plusplus.visualizer import plot_on_video | |
from smplx import SMPL | |
import torch | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string( | |
'anno_dir', | |
'/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/', | |
'input local dictionary for AIST++ annotations.') | |
flags.DEFINE_string( | |
'video_dir', | |
'/usr/local/google/home/ruilongli/data/AIST_plusplus/refined_10M_all_video/', | |
'input local dictionary for AIST Dance Videos.') | |
flags.DEFINE_string( | |
'smpl_dir', | |
'/usr/local/google/home/ruilongli/data/SMPL/', | |
'input local dictionary that stores SMPL data.') | |
flags.DEFINE_string( | |
'video_name', | |
'gWA_sFM_c01_d27_mWA2_ch21', | |
'input video name to be visualized.') | |
flags.DEFINE_string( | |
'save_dir', | |
'/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/tmp/', | |
'output local dictionary that stores AIST++ visualization.') | |
flags.DEFINE_enum( | |
'mode', '2D', ['2D', '3D', 'SMPL'], | |
'visualize 3D or 2D keypoints, or SMPL joints on image plane.') | |
def main(_): | |
# Parsing data info. | |
aist_dataset = AISTDataset(FLAGS.anno_dir) | |
video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4') | |
seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name) | |
view_idx = AISTDataset.VIEWS.index(view) | |
# Parsing keypoints. | |
if FLAGS.mode == '2D': # raw keypoints detection results. | |
keypoints2d, _, _ = AISTDataset.load_keypoint2d( | |
aist_dataset.keypoint2d_dir, seq_name) | |
keypoints2d = keypoints2d[view_idx, :, :, 0:2] | |
elif FLAGS.mode == '3D': # 3D keypoints with temporal optimization. | |
keypoints3d = AISTDataset.load_keypoint3d( | |
aist_dataset.keypoint3d_dir, seq_name, use_optim=True) | |
nframes, njoints, _ = keypoints3d.shape | |
env_name = aist_dataset.mapping_seq2env[seq_name] | |
cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name) | |
keypoints2d = cgroup.project(keypoints3d) | |
keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx] | |
elif FLAGS.mode == 'SMPL': # SMPL joints | |
smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion( | |
aist_dataset.motion_dir, seq_name) | |
smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1) | |
keypoints3d = smpl.forward( | |
global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(), | |
body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), | |
transl=torch.from_numpy(smpl_trans).float(), | |
scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(), | |
).joints.detach().numpy() | |
nframes, njoints, _ = keypoints3d.shape | |
env_name = aist_dataset.mapping_seq2env[seq_name] | |
cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name) | |
keypoints2d = cgroup.project(keypoints3d) | |
keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx] | |
# Visualize. | |
os.makedirs(FLAGS.save_dir, exist_ok=True) | |
save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4') | |
plot_on_video(keypoints2d, video_path, save_path, fps=60) | |
if __name__ == '__main__': | |
app.run(main) | |