dylanplummer commited on
Commit
43f7645
·
1 Parent(s): 3e7c114

Update to use new outputs

Browse files
Files changed (1) hide show
  1. app.py +65 -101
app.py CHANGED
@@ -10,7 +10,6 @@ import subprocess
10
  import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
- from matplotlib.animation import FuncAnimation
14
  from scipy.signal import medfilt
15
  from functools import partial
16
  from passlib.hash import pbkdf2_sha256
@@ -39,87 +38,22 @@ compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=c
39
  a = os.path.join(os.path.dirname(__file__), "files", "dylan.mp4")
40
  b = os.path.join(os.path.dirname(__file__), "files", "train14.mp4")
41
 
42
- def sigmoid(x):
43
- return 1 / (1 + np.exp(-x))
44
-
45
-
46
- def confidence_analysis(periodicity, counts, frames, out_dir='confidence_animations', top_n=9):
47
- os.makedirs(out_dir, exist_ok=True)
48
- jump_arrs = []
49
- confidence_arrs = []
50
- current_jump = []
51
- current_confidence = []
52
- current_period = 1
53
- for i in range(len(periodicity)):
54
- if counts[i] < current_period:
55
- current_jump.append(np.array(frames[i]))
56
- current_confidence.append(periodicity[i])
57
- else:
58
- jump_arrs.append(current_jump)
59
- confidence_arrs.append(current_confidence)
60
- current_jump = [np.array(frames[i])]
61
- current_confidence = [periodicity[i]]
62
- current_period += 1
63
- avg_confidences = [np.median(x) for x in confidence_arrs]
64
- conf_order = np.argsort(avg_confidences)
65
- tiled_img = []
66
- tiled_confs = []
67
- for out_i, conf_idx in enumerate(conf_order):
68
- frames = np.array(jump_arrs[conf_idx])
69
- confidence = np.array(confidence_arrs[conf_idx])
70
- mean_confidence = np.median(confidence)
71
- tiled_img.append(frames)
72
- tiled_confs.append(mean_confidence)
73
- # fig, axs = plt.subplots(1, 1, figsize = (3, 3))
74
 
75
- # img_ax = axs
76
- # img_ax.imshow(np.zeros((128, 128, 3)))
 
 
 
 
 
 
 
77
 
78
- # def animate(i):
79
- # img = frames[i]
80
- # #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
81
- # print(np.min(img), np.mean(img), np.max(img))
82
- # img_ax.imshow(np.clip(img, 0, 255))
83
- # print(i, end=' ')
84
- # img_ax.set_xticks([])
85
- # img_ax.set_yticks([])
86
- # img_ax.set_title(f'Confidence: {mean_confidence:.2f}')
87
-
88
-
89
- # anim = FuncAnimation(fig, animate, frames=len(frames), interval=200)
90
- # anim.save(f'{out_dir}/conf_{out_i}.gif', writer=None)
91
- # plt.close(fig)
92
- if top_n > 10:
93
- break
94
- longest_len = max([len(x) for x in tiled_img])
95
- looped_tiled_img = []
96
- for frames in tiled_img:
97
- looped_tiled_img.append(np.concatenate([frames, frames[::-1]] * (longest_len // len(frames) + 1), axis=0)[:longest_len])
98
- # animate each tile
99
- n_rows = int(np.ceil(np.sqrt(top_n)))
100
- n_cols = int(np.ceil(top_n / n_rows))
101
- fig, axs = plt.subplots(n_rows, n_cols, figsize = (n_cols * 2, n_rows * 2))
102
- for i in range(n_rows):
103
- for j in range(n_cols):
104
- if i * n_cols + j < len(looped_tiled_img):
105
- img_ax = axs[i][j]
106
- img_ax.imshow(np.zeros((128, 128, 3)))
107
- def animate(i):
108
- print(i, end=' ')
109
- for row in range(n_rows):
110
- for col in range(n_cols):
111
- img = looped_tiled_img[row * n_cols + col][i]
112
- img_ax = axs[row][col]
113
- img_ax.imshow(np.clip(img, 0, 255))
114
- img_ax.set_xticks([])
115
- img_ax.set_yticks([])
116
- img_ax.set_title(f'Conf: {tiled_confs[row * n_cols + col]:.2f}')
117
- anim = FuncAnimation(fig, animate, frames=longest_len, interval=200)
118
- anim.save(f'{out_dir}/tiled_conf.gif', writer=None)
119
- plt.close(fig)
120
 
121
 
122
- def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.85, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
123
  print(x)
124
  #api = HfApi(token=os.environ['DATASET_SECRET'])
125
  #out_file = str(uuid.uuid1())
@@ -145,44 +79,46 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
145
  break
146
  frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
147
  img = Image.fromarray(frame)
148
- width, height = img.size
149
- if width > height:
150
- img = img.resize((int(width / (height / img_size)), img_size))
151
- else:
152
- img = img.resize((img_size, int(height / (width / img_size))))
153
- all_frames.append(np.uint8(img))
154
  frame_i += 1
155
  cap.release()
156
 
157
  # Get output layer
158
  output_layer_period_length = compiled_model_ir.output(0)
159
  output_layer_periodicity = compiled_model_ir.output(1)
 
 
160
  length = len(all_frames)
161
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
162
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
 
163
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
 
164
  for _ in range(seq_len + stride_length): # pad full sequence
165
  all_frames.append(all_frames[-1])
166
  batch_list = []
167
  idx_list = []
168
- print(length, stride_length, stride_pad)
169
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
170
  batch = all_frames[i:i + seq_len]
171
  Xlist = []
172
  for img in batch:
173
  transforms_list = []
174
- if center_crop:
175
- #transforms_list.append(SquarePad())
176
- transforms_list.append(transforms.CenterCrop((img_size, img_size)))
177
- else:
178
- transforms_list.append(transforms.Resize((img_size, img_size)))
 
 
 
 
179
 
180
 
181
  transforms_list += [
182
  transforms.ToTensor()]
183
  #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
184
  preprocess = transforms.Compose(transforms_list)
185
- frameTensor = preprocess(Image.fromarray(img)).unsqueeze(0)
186
  Xlist.append(frameTensor)
187
 
188
  if len(Xlist) < seq_len:
@@ -198,12 +134,16 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
198
  result = compiled_model_ir(batch_X)
199
  y1pred = result[output_layer_period_length]
200
  y2pred = result[output_layer_periodicity]
201
- for y1, y2, idx in zip(y1pred, y2pred, idx_list):
 
202
  periodLength = y1.squeeze()
203
  periodicity = y2.squeeze()
 
204
  period_lengths[idx:idx+seq_len] += periodLength
205
  periodicities[idx:idx+seq_len] += periodicity
 
206
  period_length_overlaps[idx:idx+seq_len] += 1
 
207
  batch_list = []
208
  idx_list = []
209
  if len(batch_list) != 0: # still some leftover frames
@@ -214,15 +154,23 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
214
  result = compiled_model_ir(batch_X)
215
  y1pred = result[output_layer_period_length]
216
  y2pred = result[output_layer_periodicity]
217
- for y1, y2, idx in zip(y1pred, y2pred, idx_list):
 
218
  periodLength = y1.squeeze()
219
  periodicity = y2.squeeze()
 
220
  period_lengths[idx:idx+seq_len] += periodLength
221
  periodicities[idx:idx+seq_len] += periodicity
 
222
  period_length_overlaps[idx:idx+seq_len] += 1
 
223
 
224
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
225
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
 
 
 
 
226
 
227
  if median_pred_filter:
228
  periodicity = medfilt(periodicity, 5)
@@ -252,7 +200,9 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
252
  return f"{count_pred:.2f}"
253
  else:
254
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
255
- np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', '')
 
 
256
 
257
 
258
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.05), 0, 8)
@@ -305,8 +255,17 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
305
  histnorm='percent',
306
  title="Distribution of jumping speed (jumps-per-second)",
307
  range_x=[np.min(jumps_per_second[jumps_per_second > 0]) - 0.5, np.max(jumps_per_second) + 0.5])
 
 
 
 
 
 
 
 
 
308
 
309
- return count_msg, fig, hist, periodLength
310
 
311
 
312
  DESCRIPTION = '# NextJump'
@@ -318,10 +277,10 @@ with gr.Blocks() as demo:
318
  gr.Markdown(DESCRIPTION)
319
  with gr.Column():
320
  with gr.Row():
321
- in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, scale=2)
322
 
323
  with gr.Row():
324
- run_button = gr.Button(label="Run", elem_id='run-button', style=dict(full_width=False), scale=1)
325
  api_dummy_button = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
326
  count_only = gr.Checkbox(label="Count Only", visible=False)
327
  api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
@@ -334,8 +293,13 @@ with gr.Blocks() as demo:
334
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
335
  #with gr.Column(min_width=480):
336
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
337
- out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
338
- out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
 
 
 
 
 
339
 
340
  with gr.Accordion(label="Instructions and more information", open=False):
341
  instructions = "## Instructions:"
@@ -362,10 +326,10 @@ with gr.Blocks() as demo:
362
  [b, False, True, -1, True, 1.0, 0.95],
363
  ],
364
  inputs=[in_video],
365
- outputs=[out_text, out_plot, out_hist],
366
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
367
 
368
- run_button.click(demo_inference, [in_video], outputs=[out_text, out_plot, out_hist])
369
  api_inference = partial(inference, api_call=True)
370
  api_dummy_button.click(api_inference, [in_video, count_only, api_token], outputs=[period_length], api_name='inference')
371
 
 
10
  import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
 
13
  from scipy.signal import medfilt
14
  from functools import partial
15
  from passlib.hash import pbkdf2_sha256
 
38
  a = os.path.join(os.path.dirname(__file__), "files", "dylan.mp4")
39
  b = os.path.join(os.path.dirname(__file__), "files", "train14.mp4")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ class SquarePad:
43
+ # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
44
+ def __call__(self, image):
45
+ w, h = image.size
46
+ max_wh = max(w, h)
47
+ hp = int((max_wh - w) / 2)
48
+ vp = int((max_wh - h) / 2)
49
+ padding = (hp, vp, hp, vp)
50
+ return F.pad(image, padding, 0, 'constant')
51
 
52
+ def sigmoid(x):
53
+ return 1 / (1 + np.exp(-x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
+ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
57
  print(x)
58
  #api = HfApi(token=os.environ['DATASET_SECRET'])
59
  #out_file = str(uuid.uuid1())
 
79
  break
80
  frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
81
  img = Image.fromarray(frame)
82
+ all_frames.append(img)
 
 
 
 
 
83
  frame_i += 1
84
  cap.release()
85
 
86
  # Get output layer
87
  output_layer_period_length = compiled_model_ir.output(0)
88
  output_layer_periodicity = compiled_model_ir.output(1)
89
+ output_layer_marks = compiled_model_ir.output(2)
90
+ output_layer_event_type = compiled_model_ir.output(3)
91
  length = len(all_frames)
92
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
93
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
94
+ event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 4))
95
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
96
+ event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 4))
97
  for _ in range(seq_len + stride_length): # pad full sequence
98
  all_frames.append(all_frames[-1])
99
  batch_list = []
100
  idx_list = []
 
101
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
102
  batch = all_frames[i:i + seq_len]
103
  Xlist = []
104
  for img in batch:
105
  transforms_list = []
106
+ # if center_crop:
107
+ # if width > height:
108
+ # transforms_list.append(transforms.Resize((int(width / (height / img_size)), img_size)))
109
+ # else:
110
+ # transforms_list.append(transforms.Resize((img_size, int(height / (width / img_size)))))
111
+ # transforms_list.append(transforms.CenterCrop((img_size, img_size)))
112
+ # else:
113
+ transforms_list.append(SquarePad())
114
+ transforms_list.append(transforms.Resize((img_size, img_size)))
115
 
116
 
117
  transforms_list += [
118
  transforms.ToTensor()]
119
  #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
120
  preprocess = transforms.Compose(transforms_list)
121
+ frameTensor = preprocess(img).unsqueeze(0)
122
  Xlist.append(frameTensor)
123
 
124
  if len(Xlist) < seq_len:
 
134
  result = compiled_model_ir(batch_X)
135
  y1pred = result[output_layer_period_length]
136
  y2pred = result[output_layer_periodicity]
137
+ y4pred = result[output_layer_event_type]
138
+ for y1, y2, y4, idx in zip(y1pred, y2pred, y4pred, idx_list):
139
  periodLength = y1.squeeze()
140
  periodicity = y2.squeeze()
141
+ event_type = y4.squeeze()
142
  period_lengths[idx:idx+seq_len] += periodLength
143
  periodicities[idx:idx+seq_len] += periodicity
144
+ event_type_logits[idx:idx+seq_len] += event_type
145
  period_length_overlaps[idx:idx+seq_len] += 1
146
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
147
  batch_list = []
148
  idx_list = []
149
  if len(batch_list) != 0: # still some leftover frames
 
154
  result = compiled_model_ir(batch_X)
155
  y1pred = result[output_layer_period_length]
156
  y2pred = result[output_layer_periodicity]
157
+ y4pred = result[output_layer_event_type]
158
+ for y1, y2, y4, idx in zip(y1pred, y2pred, y4pred, idx_list):
159
  periodLength = y1.squeeze()
160
  periodicity = y2.squeeze()
161
+ event_type = y4.squeeze()
162
  period_lengths[idx:idx+seq_len] += periodLength
163
  periodicities[idx:idx+seq_len] += periodicity
164
+ event_type_logits[idx:idx+seq_len] += event_type
165
  period_length_overlaps[idx:idx+seq_len] += 1
166
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
167
 
168
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
169
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
170
+ event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
171
+ event_type_logits = np.mean(event_type_logits, axis=0)
172
+ # softmax of event type logits
173
+ event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
174
 
175
  if median_pred_filter:
176
  periodicity = medfilt(periodicity, 5)
 
200
  return f"{count_pred:.2f}"
201
  else:
202
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
203
+ np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
204
+ f"{count_pred:.2f}", \
205
+ f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
206
 
207
 
208
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.05), 0, 8)
 
255
  histnorm='percent',
256
  title="Distribution of jumping speed (jumps-per-second)",
257
  range_x=[np.min(jumps_per_second[jumps_per_second > 0]) - 0.5, np.max(jumps_per_second) + 0.5])
258
+
259
+ # make a bar plot of the event type distribution
260
+
261
+ bar = px.bar(x=['single rope speed', 'double dutch', 'double unders', 'single bounce'],
262
+ y=event_type_probs,
263
+ template="plotly_dark",
264
+ title="Event Type Distribution",
265
+ labels={'x': 'event type', 'y': 'probability'},
266
+ range_y=[0, 1])
267
 
268
+ return count_msg, fig, hist, bar
269
 
270
 
271
  DESCRIPTION = '# NextJump'
 
277
  gr.Markdown(DESCRIPTION)
278
  with gr.Column():
279
  with gr.Row():
280
+ in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400)
281
 
282
  with gr.Row():
283
+ run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
284
  api_dummy_button = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
285
  count_only = gr.Checkbox(label="Count Only", visible=False)
286
  api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
 
293
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
294
  #with gr.Column(min_width=480):
295
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
296
+ with gr.Row():
297
+ out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
298
+ with gr.Row():
299
+ with gr.Column():
300
+ out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
301
+ with gr.Column():
302
+ out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
303
 
304
  with gr.Accordion(label="Instructions and more information", open=False):
305
  instructions = "## Instructions:"
 
326
  [b, False, True, -1, True, 1.0, 0.95],
327
  ],
328
  inputs=[in_video],
329
+ outputs=[out_text, out_plot, out_hist, out_event_type_dist],
330
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
331
 
332
+ run_button.click(demo_inference, [in_video], outputs=[out_text, out_plot, out_hist, out_event_type_dist])
333
  api_inference = partial(inference, api_call=True)
334
  api_dummy_button.click(api_inference, [in_video, count_only, api_token], outputs=[period_length], api_name='inference')
335