votuongquan2004@gmail.com commited on
Commit
aaa2ffd
1 Parent(s): 640b470

update data.py

Browse files
Files changed (1) hide show
  1. utils/data.py +29 -15
utils/data.py CHANGED
@@ -135,34 +135,48 @@ def extract_joints(
135
 
136
  def preprocess(
137
  source: str,
138
- data_args: dict,
139
  keypoints_detector,
140
- device: str = 'cpu',
141
- ) -> torch.Tensor:
 
 
142
  '''
143
  Preprocess the video.
144
-
145
  Parameters
146
  ----------
147
  source : str
148
  The path to the video.
149
-
 
 
 
 
 
 
 
150
  Returns
151
  -------
152
- dict
153
- The model inputs.
154
  '''
155
- print('Extracting joints from pose...')
156
- inputs = extract_joints_new(source=source, keypoints_detector=keypoints_detector)
157
  T = inputs.shape[1]
158
- print('Sampling video...')
159
- if data_args['random_choose']:
160
- inputs = random_sample_np(inputs, data_args['window_size'])
 
 
 
 
161
  else:
162
- inputs = uniform_sample_np(inputs, data_args['window_size'])
 
 
 
 
 
163
 
164
- print('Normalizing video...')
165
- print(inputs.shape, inputs)
166
  return np.squeeze(inputs).transpose(1, 2, 0).astype(np.float32)
167
 
168
 
 
135
 
136
  def preprocess(
137
  source: str,
 
138
  keypoints_detector,
139
+ normalization: bool = True,
140
+ random_choose: bool = True,
141
+ window_size: int = 120,
142
+ ) -> np.ndarray:
143
  '''
144
  Preprocess the video.
 
145
  Parameters
146
  ----------
147
  source : str
148
  The path to the video.
149
+ keypoints_detector : mediapipe.solutions.holistic.Holistic
150
+ The keypoints detector.
151
+ normalization : bool, default=True
152
+ Whether to normalize the data.
153
+ random_choose : bool, default=True
154
+ Whether to randomly sample the data.
155
+ window_size : int, default=120
156
+ The window size.
157
  Returns
158
  -------
159
+ np.ndarray
160
+ The processed inputs for model.
161
  '''
162
+ inputs = extract_joints(source=source, keypoints_detector=keypoints_detector)
163
+
164
  T = inputs.shape[1]
165
+ # ori_data = inputs
166
+ # for t in range(T - 1):
167
+ # inputs[:, t, :, :] = ori_data[:, t + 1, :, :] - ori_data[:, t, :, :]
168
+ # inputs[:, T - 1, :, :] = 0
169
+
170
+ if random_choose:
171
+ inputs = random_sample_np(inputs, window_size)
172
  else:
173
+ inputs = uniform_sample_np(inputs, window_size)
174
+
175
+ # if normalization:
176
+ # assert inputs.shape[0] == 3
177
+ # inputs[0, :, :, :] = inputs[0, :, :, :] - inputs[0, :, 0, 0].mean(axis=0)
178
+ # inputs[1, :, :, :] = inputs[1, :, :, :] - inputs[1, :, 0, 0].mean(axis=0)
179
 
 
 
180
  return np.squeeze(inputs).transpose(1, 2, 0).astype(np.float32)
181
 
182