dylanplummer commited on
Commit
60119fd
·
verified ·
1 Parent(s): 0e66fd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -34
app.py CHANGED
@@ -30,11 +30,14 @@ from hls_download import download_clips
30
  #plt.style.use('dark_background')
31
 
32
  LOCAL = False
33
- IMG_SIZE = 256
34
  CACHE_API_CALLS = True
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
36
 
37
- onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
 
 
 
38
  if torch.cuda.is_available():
39
  print("Using CUDA")
40
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
@@ -234,8 +237,7 @@ def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, be
234
 
235
 
236
  def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
237
- count_only_api, api_key,
238
- img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
239
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
240
  api_call=False,
241
  progress=gr.Progress()):
@@ -269,7 +271,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
269
  seconds = length / fps
270
  all_frames = []
271
  frame_i = 0
272
- resize_amount = max((img_size + 64) / frame_width, (img_size + 64) / frame_height)
273
  while cap.isOpened():
274
  frame_i += 1
275
 
@@ -286,18 +288,21 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
286
  frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
287
  frame_center_x = frame.shape[1] // 2
288
  frame_center_y = frame.shape[0] // 2
289
- crop_x = frame_center_x - img_size // 2
290
- crop_y = frame_center_y - img_size // 2
291
- frame = frame[crop_y:crop_y+img_size, crop_x:crop_x+img_size]
292
  all_frames.append(frame)
293
 
294
  cap.release()
295
 
296
  length = len(all_frames)
297
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
 
298
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
299
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
300
  event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
 
 
301
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
302
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
303
  for _ in range(seq_len + stride_length): # pad full sequence
@@ -309,7 +314,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
309
  for i in range(0, length + stride_length - stride_pad, stride_length):
310
  batch = all_frames[i:i + seq_len]
311
  Xlist = []
312
- preprocess_tasks = [(idx, executor.submit(preprocess_image, img, img_size)) for idx, img in enumerate(batch)]
313
  for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
314
  Xlist.append(future.result())
315
 
@@ -342,23 +347,35 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
342
  y2_out = outputs[1]
343
  y3_out = outputs[2]
344
  y4_out = outputs[3]
345
- for y1, y2, y3, y4, idx in zip(y1_out, y2_out, y3_out, y4_out, idx_list):
 
 
346
  periodLength = y1.squeeze()
347
  periodicity = y2.squeeze()
348
  marks = y3.squeeze()
349
  event_type = y4.squeeze()
350
- period_lengths[idx:idx+seq_len] += periodLength
 
 
 
351
  periodicities[idx:idx+seq_len] += periodicity
352
  full_marks[idx:idx+seq_len] += marks
353
  event_type_logits[idx:idx+seq_len] += event_type
 
 
354
  period_length_overlaps[idx:idx+seq_len] += 1
355
  event_type_logit_overlaps[idx:idx+seq_len] += 1
356
  del y1_out, y2_out, y3_out, y4_out # free up memory
357
 
358
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
 
359
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
360
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
361
  per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
 
 
 
 
362
  event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
363
  # softmax of event type logits
364
  event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
@@ -469,33 +486,44 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
469
 
470
 
471
 
472
- fig, axs = plt.subplots(4, 1, figsize=(12, 10)) # Added a plot for count
473
 
474
- # Ensure data exists before plotting
475
- axs[0].plot(periodLength)
476
- axs[0].set_title(f"Stream 0 - Period Length")
 
 
477
 
478
- axs[1].plot(periodicity)
479
- axs[1].set_title("Stream 0 - Periodicity")
480
- axs[1].set_ylim(0, 1)
481
- axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
482
 
483
 
484
- axs[2].plot(full_marks, label='Raw Marks')
485
- marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
486
- axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
487
- axs[2].set_title("Stream 0 - Marks")
488
- axs[2].set_ylim(0, 1)
489
- axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
490
 
 
 
 
 
 
 
 
491
 
492
- axs[3].plot(count)
493
- axs[3].set_title("Stream 0 - Calculated Count")
494
 
495
- plt.tight_layout()
 
496
 
497
- plt.savefig('plot.png')
498
- plt.close()
 
 
 
499
 
500
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
501
  jumping_speed = np.copy(jumps_per_second)
@@ -508,6 +536,8 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
508
  'jumping speed': jumping_speed,
509
  'jumps per second': jumps_per_second,
510
  'periodicity': periodicity,
 
 
511
  'miss': misses,
512
  'frame_type': frame_type,
513
  'event_type': per_frame_event_types,
@@ -569,6 +599,74 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
569
  title='event type'
570
  ))
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  hist = px.histogram(df,
573
  x="jumps per second",
574
  template="plotly_dark",
@@ -589,9 +687,9 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
589
  except FileNotFoundError:
590
  pass
591
 
592
- return in_video, count_msg, fig, hist, bar
593
 
594
-
595
  with gr.Blocks() as demo:
596
  with gr.Row():
597
  in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
@@ -628,6 +726,11 @@ with gr.Blocks() as demo:
628
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
629
  with gr.Row():
630
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
 
 
 
 
 
631
  with gr.Row():
632
  with gr.Column():
633
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
@@ -638,7 +741,7 @@ with gr.Blocks() as demo:
638
  demo_inference = partial(inference, count_only_api=False, api_key=None)
639
 
640
  run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
641
- outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
642
  api_inference = partial(inference, api_call=True)
643
  api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
644
  outputs=[period_length], api_name='inference')
@@ -650,7 +753,7 @@ with gr.Blocks() as demo:
650
  ]
651
  gr.Examples(examples,
652
  inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
653
- outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
654
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
655
 
656
 
 
30
  #plt.style.use('dark_background')
31
 
32
  LOCAL = False
33
+ IMG_SIZE = 192
34
  CACHE_API_CALLS = True
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
36
 
37
+ onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename="nextjump_192.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
38
+
39
+ #onnx_file = 'nextjump.onnx'
40
+
41
  if torch.cuda.is_available():
42
  print("Using CUDA")
43
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
 
237
 
238
 
239
  def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
240
+ count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
 
241
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
242
  api_call=False,
243
  progress=gr.Progress()):
 
271
  seconds = length / fps
272
  all_frames = []
273
  frame_i = 0
274
+ resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height)
275
  while cap.isOpened():
276
  frame_i += 1
277
 
 
288
  frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
289
  frame_center_x = frame.shape[1] // 2
290
  frame_center_y = frame.shape[0] // 2
291
+ crop_x = frame_center_x - IMG_SIZE // 2
292
+ crop_y = frame_center_y - IMG_SIZE // 2
293
+ frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
294
  all_frames.append(frame)
295
 
296
  cap.release()
297
 
298
  length = len(all_frames)
299
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
300
+ period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length)
301
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
302
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
303
  event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
304
+ phase_sin = np.zeros(len(all_frames) + seq_len + stride_length)
305
+ phase_cos = np.zeros(len(all_frames) + seq_len + stride_length)
306
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
307
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
308
  for _ in range(seq_len + stride_length): # pad full sequence
 
314
  for i in range(0, length + stride_length - stride_pad, stride_length):
315
  batch = all_frames[i:i + seq_len]
316
  Xlist = []
317
+ preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)]
318
  for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
319
  Xlist.append(future.result())
320
 
 
347
  y2_out = outputs[1]
348
  y3_out = outputs[2]
349
  y4_out = outputs[3]
350
+ y5_out = outputs[4]
351
+ y6_out = outputs[5]
352
+ for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
353
  periodLength = y1.squeeze()
354
  periodicity = y2.squeeze()
355
  marks = y3.squeeze()
356
  event_type = y4.squeeze()
357
+ foot_type = y5.squeeze()
358
+ phase = y6.squeeze()
359
+ period_lengths[idx:idx+seq_len] += periodLength[:, 0]
360
+ period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
361
  periodicities[idx:idx+seq_len] += periodicity
362
  full_marks[idx:idx+seq_len] += marks
363
  event_type_logits[idx:idx+seq_len] += event_type
364
+ phase_sin[idx:idx+seq_len] += phase[:, 1]
365
+ phase_cos[idx:idx+seq_len] += phase[:, 0]
366
  period_length_overlaps[idx:idx+seq_len] += 1
367
  event_type_logit_overlaps[idx:idx+seq_len] += 1
368
  del y1_out, y2_out, y3_out, y4_out # free up memory
369
 
370
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
371
+ periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
372
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
373
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
374
  per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
375
+ phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length]
376
+ # negate sin to make the bottom of the plot the start of the jump
377
+ phase_sin = -phase_sin
378
+ phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length]
379
  event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
380
  # softmax of event type logits
381
  event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
 
486
 
487
 
488
 
489
+ # fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count
490
 
491
+ # # Ensure data exists before plotting
492
+ # axs[0].plot(periodLength, label='Period Length')
493
+ # axs[0].plot(periodLength_rope, label='Period Length (Rope)')
494
+ # axs[0].set_title(f"Stream 0 - Period Length")
495
+ # axs[0].legend()
496
 
497
+ # axs[1].plot(periodicity)
498
+ # axs[1].set_title("Stream 0 - Periodicity")
499
+ # axs[1].set_ylim(0, 1)
500
+ # axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
501
 
502
 
503
+ # axs[2].plot(full_marks, label='Raw Marks')
504
+ # marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
505
+ # axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
506
+ # axs[2].set_title("Stream 0 - Marks")
507
+ # axs[2].set_ylim(0, 1)
508
+ # axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
509
 
510
+ # # plot phase
511
+ # axs[3].plot(phase_sin, label='Phase Sin')
512
+ # axs[3].plot(phase_cos, label='Phase Cos')
513
+ # axs[3].set_title("Stream 0 - Phase")
514
+ # axs[3].set_ylim(-1, 1)
515
+ # axs[3].axhline(0, color='r', linestyle=':', label='Zero Line')
516
+ # axs[3].legend()
517
 
 
 
518
 
519
+ # axs[4].plot(count)
520
+ # axs[4].set_title("Stream 0 - Calculated Count")
521
 
522
+ # plt.tight_layout()
523
+
524
+ # plt.savefig('plot.png')
525
+ # plt.close()
526
+
527
 
528
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
529
  jumping_speed = np.copy(jumps_per_second)
 
536
  'jumping speed': jumping_speed,
537
  'jumps per second': jumps_per_second,
538
  'periodicity': periodicity,
539
+ 'phase sin': phase_sin,
540
+ 'phase cos': phase_cos,
541
  'miss': misses,
542
  'frame_type': frame_type,
543
  'event_type': per_frame_event_types,
 
599
  title='event type'
600
  ))
601
 
602
+
603
+ # -pi/2 phase offset to make the bottom of the plot the start of the jump
604
+ # phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
605
+ # phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
606
+
607
+ # plot phase spiral using plotly
608
+ fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin,
609
+ color=jumps_per_second,
610
+ color_continuous_scale='plasma',
611
+ title="Phase Spiral (speed)",
612
+ template="plotly_dark")
613
+ fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5))
614
+ fig_phase_spiral.update_layout(
615
+ xaxis_title="Phase Cos",
616
+ yaxis_title="Phase Sin",
617
+ xaxis=dict(range=[-1, 1]),
618
+ yaxis=dict(range=[-1, 1]),
619
+ showlegend=False,
620
+ )
621
+ # label colorbar as time
622
+ fig_phase_spiral.update_coloraxes(colorbar=dict(
623
+ title="Jumps per second"))
624
+ # make axes equal
625
+ fig_phase_spiral.update_layout(
626
+ xaxis=dict(scaleanchor="y"),
627
+ yaxis=dict(constrain="domain"),
628
+ )
629
+ # overlay line plot of phase sin and cos
630
+ fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data)
631
+ fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
632
+
633
+ # plot phase consistency (sin^2 + cos^2 = 1) as a line plot
634
+ # phase_consistency = phase_sin**2 + phase_cos**2
635
+ # #phase_consistency = medfilt(phase_consistency, 5)
636
+ # fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency,
637
+ # title="Phase Consistency (sin^2 + cos^2)",
638
+ # labels={'x': 'Frame', 'y': 'Phase Consistency'},
639
+ # template="plotly_dark")
640
+
641
+ # plot phase spiral colored by mark_preds
642
+ fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin,
643
+ color=full_marks,
644
+ color_continuous_scale='Jet',
645
+ title="Phase Spiral (marks)",
646
+ template="plotly_dark")
647
+ fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5))
648
+ fig_phase_spiral_marks.update_layout(
649
+ xaxis_title="Phase Cos",
650
+ yaxis_title="Phase Sin",
651
+ xaxis=dict(range=[-1, 1]),
652
+ yaxis=dict(range=[-1, 1]),
653
+ showlegend=False,
654
+ )
655
+ # label colorbar as time
656
+ fig_phase_spiral_marks.update_coloraxes(colorbar=dict(
657
+ title="Marks"))
658
+ # make axes equal
659
+ fig_phase_spiral_marks.update_layout(
660
+ xaxis=dict(scaleanchor="y"),
661
+ yaxis=dict(constrain="domain"),
662
+ )
663
+ # overlay line plot of phase sin and cos
664
+ fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data)
665
+ fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
666
+
667
+
668
+
669
+
670
  hist = px.histogram(df,
671
  x="jumps per second",
672
  template="plotly_dark",
 
687
  except FileNotFoundError:
688
  pass
689
 
690
+ return in_video, count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, bar
691
 
692
+ #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
693
  with gr.Blocks() as demo:
694
  with gr.Row():
695
  in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
 
726
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
727
  with gr.Row():
728
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
729
+ with gr.Row():
730
+ with gr.Column():
731
+ out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral')
732
+ with gr.Column():
733
+ out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks')
734
  with gr.Row():
735
  with gr.Column():
736
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
 
741
  demo_inference = partial(inference, count_only_api=False, api_key=None)
742
 
743
  run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
744
+ outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist])
745
  api_inference = partial(inference, api_call=True)
746
  api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
747
  outputs=[period_length], api_name='inference')
 
753
  ]
754
  gr.Examples(examples,
755
  inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
756
+ outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
757
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
758
 
759