meng2003's picture
Upload 357 files
2d5fdd1
# coding=utf-8
# Copyright 2020 The Google AI Perception Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AIST++ Dataset Loader."""
import json
import os
import pickle
import aniposelib
import numpy as np
class AISTDataset:
"""A dataset class for loading, processing and plotting AIST++."""
VIEWS = ['c01', 'c02', 'c03', 'c04', 'c05', 'c06', 'c07', 'c08', 'c09']
def __init__(self, anno_dir):
assert os.path.exists(anno_dir), f'Data does not exist at {anno_dir}!'
# Init paths
self.camera_dir = os.path.join(anno_dir, 'cameras/')
self.motion_dir = os.path.join(anno_dir, 'motions/')
self.keypoint3d_dir = os.path.join(anno_dir, 'keypoints3d/')
self.keypoint2d_dir = os.path.join(anno_dir, 'keypoints2d/')
self.filter_file = os.path.join(anno_dir, 'ignore_list.txt')
# Load environment setting mapping
self.mapping_seq2env = {} # sequence name -> env name
self.mapping_env2seq = {} # env name -> a list of sequence names
env_mapping_file = os.path.join(self.camera_dir, 'mapping.txt')
env_mapping = np.loadtxt(env_mapping_file, dtype=str)
for seq_name, env_name in env_mapping:
self.mapping_seq2env[seq_name] = env_name
if env_name not in self.mapping_env2seq:
self.mapping_env2seq[env_name] = []
self.mapping_env2seq[env_name].append(seq_name)
@classmethod
def get_video_name(cls, seq_name, view):
"""Get AIST video name from AIST++ sequence name."""
return seq_name.replace('cAll', view)
@classmethod
def get_seq_name(cls, video_name):
"""Get AIST++ sequence name from AIST video name."""
tags = video_name.split('_')
if len(tags) == 3:
view = tags[1]
tags[1] = 'cAll'
else:
view = tags[2]
tags[2] = 'cAll'
return '_'.join(tags), view
@classmethod
def load_camera_group(cls, camera_dir, env_name):
"""Load a set of cameras in the environment."""
file_path = os.path.join(camera_dir, f'{env_name}.json')
assert os.path.exists(file_path), f'File {file_path} does not exist!'
with open(file_path, 'r') as f:
params = json.load(f)
cameras = []
for param_dict in params:
camera = aniposelib.cameras.Camera(name=param_dict['name'],
size=param_dict['size'],
matrix=param_dict['matrix'],
rvec=param_dict['rotation'],
tvec=param_dict['translation'],
dist=param_dict['distortions'])
cameras.append(camera)
camera_group = aniposelib.cameras.CameraGroup(cameras)
return camera_group
@classmethod
def load_motion(cls, motion_dir, seq_name):
"""Load a motion sequence represented using SMPL format."""
file_path = os.path.join(motion_dir, f'{seq_name}.pkl')
assert os.path.exists(file_path), f'File {file_path} does not exist!'
with open(file_path, 'rb') as f:
data = pickle.load(f)
smpl_poses = data['smpl_poses'] # (N, 24, 3)
smpl_scaling = data['smpl_scaling'] # (1,)
smpl_trans = data['smpl_trans'] # (N, 3)
return smpl_poses, smpl_scaling, smpl_trans
@classmethod
def load_keypoint3d(cls, keypoint_dir, seq_name, use_optim=False):
"""Load a 3D keypoint sequence represented using COCO format."""
file_path = os.path.join(keypoint_dir, f'{seq_name}.pkl')
assert os.path.exists(file_path), f'File {file_path} does not exist!'
with open(file_path, 'rb') as f:
data = pickle.load(f)
if use_optim:
return data['keypoints3d_optim'] # (N, 17, 3)
else:
return data['keypoints3d'] # (N, 17, 3)
@classmethod
def load_keypoint2d(cls, keypoint_dir, seq_name):
"""Load a 2D keypoint sequence represented using COCO format."""
file_path = os.path.join(keypoint_dir, f'{seq_name}.pkl')
assert os.path.exists(file_path), f'File {file_path} does not exist!'
with open(file_path, 'rb') as f:
data = pickle.load(f)
keypoints2d = data['keypoints2d'] # (nviews, N, 17, 3)
det_scores = data['det_scores'] # (nviews, N)
timestamps = data['timestamps'] # (N,)
return keypoints2d, det_scores, timestamps