votuongquan2004@gmail.com commited on
Commit
547004e
1 Parent(s): 1baa500

update data.py

Browse files
Files changed (1) hide show
  1. utils/data.py +9 -9
utils/data.py CHANGED
@@ -161,21 +161,21 @@ def preprocess(
161
  '''
162
  inputs = extract_joints(source=source, keypoints_detector=keypoints_detector)
163
 
164
- # T = inputs.shape[1]
165
- # ori_data = inputs
166
- # for t in range(T - 1):
167
- # inputs[:, t, :, :] = ori_data[:, t + 1, :, :] - ori_data[:, t, :, :]
168
- # inputs[:, T - 1, :, :] = 0
169
 
170
  if random_choose:
171
  inputs = random_sample_np(inputs, window_size)
172
  else:
173
  inputs = uniform_sample_np(inputs, window_size)
174
 
175
- # if normalization:
176
- # assert inputs.shape[0] == 3
177
- # inputs[0, :, :, :] = inputs[0, :, :, :] - inputs[0, :, 0, 0].mean(axis=0)
178
- # inputs[1, :, :, :] = inputs[1, :, :, :] - inputs[1, :, 0, 0].mean(axis=0)
179
 
180
  return inputs[np.newaxis, :].astype(np.float32)
181
 
 
161
  '''
162
  inputs = extract_joints(source=source, keypoints_detector=keypoints_detector)
163
 
164
+ T = inputs.shape[1]
165
+ ori_data = inputs
166
+ for t in range(T - 1):
167
+ inputs[:, t, :, :] = ori_data[:, t + 1, :, :] - ori_data[:, t, :, :]
168
+ inputs[:, T - 1, :, :] = 0
169
 
170
  if random_choose:
171
  inputs = random_sample_np(inputs, window_size)
172
  else:
173
  inputs = uniform_sample_np(inputs, window_size)
174
 
175
+ if normalization:
176
+ assert inputs.shape[0] == 3
177
+ inputs[0, :, :, :] = inputs[0, :, :, :] - inputs[0, :, 0, 0].mean(axis=0)
178
+ inputs[1, :, :, :] = inputs[1, :, :, :] - inputs[1, :, 0, 0].mean(axis=0)
179
 
180
  return inputs[np.newaxis, :].astype(np.float32)
181