File size: 5,060 Bytes
2d5fdd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# coding=utf-8
# Copyright 2020 The Google AI Perception Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process frame-by-frame keypoints detection results to pkl."""
import glob
import json
import multiprocessing
import os
import pickle

from absl import app
from absl import flags
from absl import logging
from aist_plusplus.loader import AISTDataset
import numpy as np

FLAGS = flags.FLAGS
flags.DEFINE_string(
    'keypoints_dir',
    '/usr/local/google/home/ruilongli/data/AIST_plusplus_v4/posenet_2stage_pose_10M_60fps_all/',
    'input local dictionary that stores 2D keypoints detection results in json.'
)
flags.DEFINE_string(
    'save_dir',
    '/usr/local/google/home/ruilongli/data/public/aist_plusplus_final/keypoints2d/',
    'output local dictionary that stores 2D keypoints detection results in pkl.'
)
np.random.seed(0)


def array_nan(shape, dtype=np.float32):
  array = np.empty(shape, dtype=dtype)
  array[:] = np.nan
  return array


def load_keypoints2d_file(file_path, njoints=17):
  """load 2D keypoints from keypoint detection results.

  Only one person is extracted from the results. If there are multiple
  persons in the prediction results, we select the one with the highest
  detection score.

  Args:
    file_path: the json file path.
    njoints: number of joints in the keypoint defination.

  Returns:
    A `np.array` with the shape of [njoints, 3].
  """
  keypoint = array_nan((njoints, 3), dtype=np.float32)
  det_score = 0.0

  try:
    with open(file_path, 'r') as f:
      data = json.load(f)
  except Exception as e:  # pylint: disable=broad-except
    logging.warning(e)
    return keypoint, det_score

  det_scores = np.array(data['detection_scores'])
  keypoints = np.array(data['keypoints']).reshape((-1, njoints, 3))

  # The detection results may contain zero person or multiple people.
  if det_scores.shape[0] == 0:
    # There is no person in this image. We set NaN to this frame.
    return keypoint, det_score
  else:
    # There are multiple people (>=1) in this image. We select the one with
    # the highest detection score.
    idx = np.argmax(det_scores)
    keypoint = keypoints[idx]
    det_score = det_scores[idx]
    return keypoint, det_score


def load_keypoints2d(data_dir, seq_name, njoints=17):
  """Load 2D keypoints predictions for a set of multi-view videos."""
  # Parsing sequence name to multi-view video names
  video_names = [AISTDataset.get_video_name(seq_name, view)
                 for view in AISTDataset.VIEWS]

  # In case frames are missing, we first scan all views to get a union
  # of timestamps.
  paths_cache = {}
  timestamps = []
  for video_name in video_names:
    paths = sorted(glob.glob(os.path.join(data_dir, video_name, '*.json')))
    paths_cache[video_name] = paths
    timestamps += [int(p.split('.')[0].split('_')[-1]) for p in paths]
  timestamps = np.array(sorted(list(set(timestamps))))  # (N,)

  # Then we load all frames according to timestamps.
  keypoints2d = []
  det_scores = []
  for video_name in video_names:
    paths = [
        os.path.join(data_dir, video_name, f'{video_name}_{ts}.json')
        for ts in timestamps
    ]
    keypoints2d_per_view = []
    det_scores_per_view = []
    for path in paths:
      keypoint, det_score = load_keypoints2d_file(path, njoints=njoints)
      keypoints2d_per_view.append(keypoint)
      det_scores_per_view.append(det_score)
    keypoints2d.append(keypoints2d_per_view)
    det_scores.append(det_scores_per_view)

  keypoints2d = np.array(
      keypoints2d, dtype=np.float32)  # (nviews, N, njoints, 3)
  det_scores = np.array(
      det_scores, dtype=np.float32)  # (nviews, N)
  return keypoints2d, det_scores, timestamps


def process_and_save(seq_name):
  keypoints2d, det_scores, timestamps = load_keypoints2d(
      FLAGS.keypoints_dir, seq_name=seq_name, njoints=17)
  os.makedirs(FLAGS.save_dir, exist_ok=True)
  save_path = os.path.join(FLAGS.save_dir, f'{seq_name}.pkl')
  with open(save_path, 'wb') as f:
    pickle.dump({
        'keypoints2d': keypoints2d,
        'det_scores': det_scores,
        'timestamps': timestamps,
    }, f, protocol=pickle.HIGHEST_PROTOCOL)


def main(_):
  video_names = os.listdir(FLAGS.keypoints_dir)
  video_names = [
      video_name for video_name in video_names
      if len(video_name.split('_')) == 6
  ]
  seq_names = list(set([
      AISTDataset.get_seq_name(video_name)[0] for video_name in video_names]))

  pool = multiprocessing.Pool(16)
  pool.map(process_and_save, seq_names)


if __name__ == '__main__':
  app.run(main)