Commit ·
c22a4b8
1
Parent(s): 4ca2ffd
update to mobilenetv3
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ from hls_download import download_clips
|
|
| 25 |
|
| 26 |
plt.style.use('dark_background')
|
| 27 |
|
| 28 |
-
onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='
|
| 29 |
#onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
|
| 30 |
# model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
|
| 31 |
# hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
|
|
@@ -46,7 +46,7 @@ else:
|
|
| 46 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 47 |
|
| 48 |
print('Warmup...')
|
| 49 |
-
dummy_input = torch.randn(4, 64, 3,
|
| 50 |
ort_sess.run(None, {'video': dummy_input.numpy()})
|
| 51 |
print('Done!')
|
| 52 |
|
|
@@ -73,7 +73,7 @@ def create_transform(img_size):
|
|
| 73 |
|
| 74 |
|
| 75 |
def inference(stream_url, start_time, end_time, count_only_api, api_key,
|
| 76 |
-
img_size=
|
| 77 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
|
| 78 |
api_call=False,
|
| 79 |
progress=gr.Progress()):
|
|
@@ -327,48 +327,47 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
|
|
| 327 |
range_y=[0, 1])
|
| 328 |
|
| 329 |
return x, count_msg, fig, hist, bar
|
| 330 |
-
|
| 331 |
|
| 332 |
-
with gr.Blocks() as demo:
|
| 333 |
-
# in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
|
| 334 |
-
# width=400, height=400, interactive=True, container=True,
|
| 335 |
-
# max_length=150)
|
| 336 |
-
with gr.Row():
|
| 337 |
-
in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
|
| 338 |
-
with gr.Column():
|
| 339 |
-
in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
|
| 340 |
-
with gr.Column():
|
| 341 |
-
in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
|
| 342 |
-
with gr.Column(min_width=480):
|
| 343 |
-
out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
|
| 344 |
-
|
| 345 |
-
with gr.Row():
|
| 346 |
-
run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
|
| 347 |
-
api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
|
| 348 |
-
count_only = gr.Checkbox(label='Count Only', visible=False)
|
| 349 |
-
api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
|
| 357 |
-
with gr.Row():
|
| 358 |
-
out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
|
| 359 |
with gr.Row():
|
|
|
|
| 360 |
with gr.Column():
|
| 361 |
-
|
| 362 |
with gr.Column():
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
if __name__ == '__main__':
|
| 374 |
demo.queue(api_open=True, max_size=15).launch(share=False)
|
|
|
|
| 25 |
|
| 26 |
plt.style.use('dark_background')
|
| 27 |
|
| 28 |
+
onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_mobilenetv3.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
|
| 29 |
#onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
|
| 30 |
# model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
|
| 31 |
# hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
|
|
|
|
| 46 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 47 |
|
| 48 |
print('Warmup...')
|
| 49 |
+
dummy_input = torch.randn(4, 64, 3, 224, 224)
|
| 50 |
ort_sess.run(None, {'video': dummy_input.numpy()})
|
| 51 |
print('Done!')
|
| 52 |
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def inference(stream_url, start_time, end_time, count_only_api, api_key,
|
| 76 |
+
img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
| 77 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
|
| 78 |
api_call=False,
|
| 79 |
progress=gr.Progress()):
|
|
|
|
| 327 |
range_y=[0, 1])
|
| 328 |
|
| 329 |
return x, count_msg, fig, hist, bar
|
|
|
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
+
if __name__ == '__main__':
|
| 333 |
+
with gr.Blocks() as demo:
|
| 334 |
+
# in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
|
| 335 |
+
# width=400, height=400, interactive=True, container=True,
|
| 336 |
+
# max_length=150)
|
|
|
|
|
|
|
|
|
|
| 337 |
with gr.Row():
|
| 338 |
+
in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
|
| 339 |
with gr.Column():
|
| 340 |
+
in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
|
| 341 |
with gr.Column():
|
| 342 |
+
in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
|
| 343 |
+
with gr.Column(min_width=480):
|
| 344 |
+
out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
|
| 345 |
+
|
| 346 |
+
with gr.Row():
|
| 347 |
+
run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
|
| 348 |
+
api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
|
| 349 |
+
count_only = gr.Checkbox(label='Count Only', visible=False)
|
| 350 |
+
api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
|
| 351 |
|
| 352 |
+
with gr.Column(elem_id='output-video-container'):
|
| 353 |
+
with gr.Row():
|
| 354 |
+
with gr.Column():
|
| 355 |
+
out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
|
| 356 |
+
period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
|
| 357 |
+
periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
|
| 358 |
+
with gr.Row():
|
| 359 |
+
out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
|
| 360 |
+
with gr.Row():
|
| 361 |
+
with gr.Column():
|
| 362 |
+
out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
|
| 363 |
+
with gr.Column():
|
| 364 |
+
out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
|
| 365 |
+
|
| 366 |
|
| 367 |
+
demo_inference = partial(inference, count_only_api=False, api_key=None)
|
| 368 |
+
|
| 369 |
+
run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
|
| 370 |
+
api_inference = partial(inference, api_call=True)
|
| 371 |
+
api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
|
| 372 |
|
|
|
|
| 373 |
demo.queue(api_open=True, max_size=15).launch(share=False)
|