Spaces:
Runtime error
Runtime error
File size: 5,060 Bytes
2d5fdd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# coding=utf-8
# Copyright 2020 The Google AI Perception Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process frame-by-frame keypoints detection results to pkl."""
import glob
import json
import multiprocessing
import os
import pickle
from absl import app
from absl import flags
from absl import logging
from aist_plusplus.loader import AISTDataset
import numpy as np
FLAGS = flags.FLAGS
flags.DEFINE_string(
'keypoints_dir',
'/usr/local/google/home/ruilongli/data/AIST_plusplus_v4/posenet_2stage_pose_10M_60fps_all/',
'input local dictionary that stores 2D keypoints detection results in json.'
)
flags.DEFINE_string(
'save_dir',
'/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/keypoints2d/',
'output local dictionary that stores 2D keypoints detection results in pkl.'
)
np.random.seed(0)
def array_nan(shape, dtype=np.float32):
array = np.empty(shape, dtype=dtype)
array[:] = np.nan
return array
def load_keypoints2d_file(file_path, njoints=17):
"""load 2D keypoints from keypoint detection results.
Only one person is extracted from the results. If there are multiple
persons in the prediction results, we select the one with the highest
detection score.
Args:
file_path: the json file path.
njoints: number of joints in the keypoint defination.
Returns:
A `np.array` with the shape of [njoints, 3].
"""
keypoint = array_nan((njoints, 3), dtype=np.float32)
det_score = 0.0
try:
with open(file_path, 'r') as f:
data = json.load(f)
except Exception as e: # pylint: disable=broad-except
logging.warning(e)
return keypoint, det_score
det_scores = np.array(data['detection_scores'])
keypoints = np.array(data['keypoints']).reshape((-1, njoints, 3))
# The detection results may contain zero person or multiple people.
if det_scores.shape[0] == 0:
# There is no person in this image. We set NaN to this frame.
return keypoint, det_score
else:
# There are multiple people (>=1) in this image. We select the one with
# the highest detection score.
idx = np.argmax(det_scores)
keypoint = keypoints[idx]
det_score = det_scores[idx]
return keypoint, det_score
def load_keypoints2d(data_dir, seq_name, njoints=17):
"""Load 2D keypoints predictions for a set of multi-view videos."""
# Parsing sequence name to multi-view video names
video_names = [AISTDataset.get_video_name(seq_name, view)
for view in AISTDataset.VIEWS]
# In case frames are missing, we first scan all views to get a union
# of timestamps.
paths_cache = {}
timestamps = []
for video_name in video_names:
paths = sorted(glob.glob(os.path.join(data_dir, video_name, '*.json')))
paths_cache[video_name] = paths
timestamps += [int(p.split('.')[0].split('_')[-1]) for p in paths]
timestamps = np.array(sorted(list(set(timestamps)))) # (N,)
# Then we load all frames according to timestamps.
keypoints2d = []
det_scores = []
for video_name in video_names:
paths = [
os.path.join(data_dir, video_name, f'{video_name}_{ts}.json')
for ts in timestamps
]
keypoints2d_per_view = []
det_scores_per_view = []
for path in paths:
keypoint, det_score = load_keypoints2d_file(path, njoints=njoints)
keypoints2d_per_view.append(keypoint)
det_scores_per_view.append(det_score)
keypoints2d.append(keypoints2d_per_view)
det_scores.append(det_scores_per_view)
keypoints2d = np.array(
keypoints2d, dtype=np.float32) # (nviews, N, njoints, 3)
det_scores = np.array(
det_scores, dtype=np.float32) # (nviews, N)
return keypoints2d, det_scores, timestamps
def process_and_save(seq_name):
keypoints2d, det_scores, timestamps = load_keypoints2d(
FLAGS.keypoints_dir, seq_name=seq_name, njoints=17)
os.makedirs(FLAGS.save_dir, exist_ok=True)
save_path = os.path.join(FLAGS.save_dir, f'{seq_name}.pkl')
with open(save_path, 'wb') as f:
pickle.dump({
'keypoints2d': keypoints2d,
'det_scores': det_scores,
'timestamps': timestamps,
}, f, protocol=pickle.HIGHEST_PROTOCOL)
def main(_):
video_names = os.listdir(FLAGS.keypoints_dir)
video_names = [
video_name for video_name in video_names
if len(video_name.split('_')) == 6
]
seq_names = list(set([
AISTDataset.get_seq_name(video_name)[0] for video_name in video_names]))
pool = multiprocessing.Pool(16)
pool.map(process_and_save, seq_names)
if __name__ == '__main__':
app.run(main)
|