dylanplummer commited on
Commit
c22a4b8
·
1 Parent(s): 4ca2ffd

update to mobilenetv3

Browse files
Files changed (1) hide show
  1. app.py +38 -39
app.py CHANGED
@@ -25,7 +25,7 @@ from hls_download import download_clips
25
 
26
  plt.style.use('dark_background')
27
 
28
- onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
  #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
  # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
  # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
@@ -46,7 +46,7 @@ else:
46
  ort_sess = ort.InferenceSession(onnx_file)
47
 
48
  print('Warmup...')
49
- dummy_input = torch.randn(4, 64, 3, 288, 288)
50
  ort_sess.run(None, {'video': dummy_input.numpy()})
51
  print('Done!')
52
 
@@ -73,7 +73,7 @@ def create_transform(img_size):
73
 
74
 
75
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
76
- img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
77
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
78
  api_call=False,
79
  progress=gr.Progress()):
@@ -327,48 +327,47 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
327
  range_y=[0, 1])
328
 
329
  return x, count_msg, fig, hist, bar
330
-
331
 
332
- with gr.Blocks() as demo:
333
- # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
334
- # width=400, height=400, interactive=True, container=True,
335
- # max_length=150)
336
- with gr.Row():
337
- in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
338
- with gr.Column():
339
- in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
340
- with gr.Column():
341
- in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
342
- with gr.Column(min_width=480):
343
- out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
344
-
345
- with gr.Row():
346
- run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
347
- api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
348
- count_only = gr.Checkbox(label='Count Only', visible=False)
349
- api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
350
 
351
- with gr.Column(elem_id='output-video-container'):
352
- with gr.Row():
353
- with gr.Column():
354
- out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
355
- period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
356
- periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
357
- with gr.Row():
358
- out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
359
  with gr.Row():
 
360
  with gr.Column():
361
- out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
362
  with gr.Column():
363
- out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
364
-
 
 
 
 
 
 
 
365
 
366
- demo_inference = partial(inference, count_only_api=False, api_key=None)
367
-
368
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
369
- api_inference = partial(inference, api_call=True)
370
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
 
 
 
 
 
 
 
 
 
371
 
 
 
 
 
 
372
 
373
- if __name__ == '__main__':
374
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
25
 
26
  plt.style.use('dark_background')
27
 
28
+ onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_mobilenetv3.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
  #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
  # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
  # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
 
46
  ort_sess = ort.InferenceSession(onnx_file)
47
 
48
  print('Warmup...')
49
+ dummy_input = torch.randn(4, 64, 3, 224, 224)
50
  ort_sess.run(None, {'video': dummy_input.numpy()})
51
  print('Done!')
52
 
 
73
 
74
 
75
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
76
+ img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
77
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
78
  api_call=False,
79
  progress=gr.Progress()):
 
327
  range_y=[0, 1])
328
 
329
  return x, count_msg, fig, hist, bar
 
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ if __name__ == '__main__':
333
+ with gr.Blocks() as demo:
334
+ # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
335
+ # width=400, height=400, interactive=True, container=True,
336
+ # max_length=150)
 
 
 
337
  with gr.Row():
338
+ in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
339
  with gr.Column():
340
+ in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
341
  with gr.Column():
342
+ in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
343
+ with gr.Column(min_width=480):
344
+ out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
345
+
346
+ with gr.Row():
347
+ run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
348
+ api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
349
+ count_only = gr.Checkbox(label='Count Only', visible=False)
350
+ api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
351
 
352
+ with gr.Column(elem_id='output-video-container'):
353
+ with gr.Row():
354
+ with gr.Column():
355
+ out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
356
+ period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
357
+ periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
358
+ with gr.Row():
359
+ out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
360
+ with gr.Row():
361
+ with gr.Column():
362
+ out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
363
+ with gr.Column():
364
+ out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
365
+
366
 
367
+ demo_inference = partial(inference, count_only_api=False, api_key=None)
368
+
369
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
370
+ api_inference = partial(inference, api_call=True)
371
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
372
 
 
373
  demo.queue(api_open=True, max_size=15).launch(share=False)