Spaces:
Runtime error
Runtime error
File size: 3,814 Bytes
2d5fdd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# coding=utf-8
# Copyright 2020 The Google AI Perception Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test code for running visualizer."""
import os
from absl import app
from absl import flags
from aist_plusplus.loader import AISTDataset
from aist_plusplus.visualizer import plot_on_video
from smplx import SMPL
import torch
FLAGS = flags.FLAGS
flags.DEFINE_string(
'anno_dir',
'/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/',
'input local dictionary for AIST++ annotations.')
flags.DEFINE_string(
'video_dir',
'/usr/local/google/home/ruilongli/data/AIST_plusplus/refined_10M_all_video/',
'input local dictionary for AIST Dance Videos.')
flags.DEFINE_string(
'smpl_dir',
'/usr/local/google/home/ruilongli/data/SMPL/',
'input local dictionary that stores SMPL data.')
flags.DEFINE_string(
'video_name',
'gWA_sFM_c01_d27_mWA2_ch21',
'input video name to be visualized.')
flags.DEFINE_string(
'save_dir',
'/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/tmp/',
'output local dictionary that stores AIST++ visualization.')
flags.DEFINE_enum(
'mode', '2D', ['2D', '3D', 'SMPL'],
'visualize 3D or 2D keypoints, or SMPL joints on image plane.')
def main(_):
# Parsing data info.
aist_dataset = AISTDataset(FLAGS.anno_dir)
video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
view_idx = AISTDataset.VIEWS.index(view)
# Parsing keypoints.
if FLAGS.mode == '2D': # raw keypoints detection results.
keypoints2d, _, _ = AISTDataset.load_keypoint2d(
aist_dataset.keypoint2d_dir, seq_name)
keypoints2d = keypoints2d[view_idx, :, :, 0:2]
elif FLAGS.mode == '3D': # 3D keypoints with temporal optimization.
keypoints3d = AISTDataset.load_keypoint3d(
aist_dataset.keypoint3d_dir, seq_name, use_optim=True)
nframes, njoints, _ = keypoints3d.shape
env_name = aist_dataset.mapping_seq2env[seq_name]
cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
keypoints2d = cgroup.project(keypoints3d)
keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]
elif FLAGS.mode == 'SMPL': # SMPL joints
smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
aist_dataset.motion_dir, seq_name)
smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
keypoints3d = smpl.forward(
global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
transl=torch.from_numpy(smpl_trans).float(),
scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
).joints.detach().numpy()
nframes, njoints, _ = keypoints3d.shape
env_name = aist_dataset.mapping_seq2env[seq_name]
cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
keypoints2d = cgroup.project(keypoints3d)
keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]
# Visualize.
os.makedirs(FLAGS.save_dir, exist_ok=True)
save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4')
plot_on_video(keypoints2d, video_path, save_path, fps=60)
if __name__ == '__main__':
app.run(main)
|