Update app.py
Browse files
app.py
CHANGED
|
@@ -30,11 +30,14 @@ from hls_download import download_clips
|
|
| 30 |
#plt.style.use('dark_background')
|
| 31 |
|
| 32 |
LOCAL = False
|
| 33 |
-
IMG_SIZE =
|
| 34 |
CACHE_API_CALLS = True
|
| 35 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 36 |
|
| 37 |
-
onnx_file = hf_hub_download(repo_id="
|
|
|
|
|
|
|
|
|
|
| 38 |
if torch.cuda.is_available():
|
| 39 |
print("Using CUDA")
|
| 40 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
|
@@ -234,8 +237,7 @@ def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, be
|
|
| 234 |
|
| 235 |
|
| 236 |
def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
|
| 237 |
-
count_only_api, api_key,
|
| 238 |
-
img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
| 239 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
|
| 240 |
api_call=False,
|
| 241 |
progress=gr.Progress()):
|
|
@@ -269,7 +271,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 269 |
seconds = length / fps
|
| 270 |
all_frames = []
|
| 271 |
frame_i = 0
|
| 272 |
-
resize_amount = max((
|
| 273 |
while cap.isOpened():
|
| 274 |
frame_i += 1
|
| 275 |
|
|
@@ -286,18 +288,21 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 286 |
frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
|
| 287 |
frame_center_x = frame.shape[1] // 2
|
| 288 |
frame_center_y = frame.shape[0] // 2
|
| 289 |
-
crop_x = frame_center_x -
|
| 290 |
-
crop_y = frame_center_y -
|
| 291 |
-
frame = frame[crop_y:crop_y+
|
| 292 |
all_frames.append(frame)
|
| 293 |
|
| 294 |
cap.release()
|
| 295 |
|
| 296 |
length = len(all_frames)
|
| 297 |
period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
|
|
|
|
| 298 |
periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 299 |
full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 300 |
event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
|
|
|
|
|
|
|
| 301 |
period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 302 |
event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
|
| 303 |
for _ in range(seq_len + stride_length): # pad full sequence
|
|
@@ -309,7 +314,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 309 |
for i in range(0, length + stride_length - stride_pad, stride_length):
|
| 310 |
batch = all_frames[i:i + seq_len]
|
| 311 |
Xlist = []
|
| 312 |
-
preprocess_tasks = [(idx, executor.submit(preprocess_image, img,
|
| 313 |
for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
|
| 314 |
Xlist.append(future.result())
|
| 315 |
|
|
@@ -342,23 +347,35 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 342 |
y2_out = outputs[1]
|
| 343 |
y3_out = outputs[2]
|
| 344 |
y4_out = outputs[3]
|
| 345 |
-
|
|
|
|
|
|
|
| 346 |
periodLength = y1.squeeze()
|
| 347 |
periodicity = y2.squeeze()
|
| 348 |
marks = y3.squeeze()
|
| 349 |
event_type = y4.squeeze()
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
| 351 |
periodicities[idx:idx+seq_len] += periodicity
|
| 352 |
full_marks[idx:idx+seq_len] += marks
|
| 353 |
event_type_logits[idx:idx+seq_len] += event_type
|
|
|
|
|
|
|
| 354 |
period_length_overlaps[idx:idx+seq_len] += 1
|
| 355 |
event_type_logit_overlaps[idx:idx+seq_len] += 1
|
| 356 |
del y1_out, y2_out, y3_out, y4_out # free up memory
|
| 357 |
|
| 358 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
|
|
|
| 359 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 360 |
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 361 |
per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
|
| 363 |
# softmax of event type logits
|
| 364 |
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
|
|
@@ -469,33 +486,44 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 469 |
|
| 470 |
|
| 471 |
|
| 472 |
-
fig, axs = plt.subplots(
|
| 473 |
|
| 474 |
-
# Ensure data exists before plotting
|
| 475 |
-
axs[0].plot(periodLength)
|
| 476 |
-
axs[0].
|
|
|
|
|
|
|
| 477 |
|
| 478 |
-
axs[1].plot(periodicity)
|
| 479 |
-
axs[1].set_title("Stream 0 - Periodicity")
|
| 480 |
-
axs[1].set_ylim(0, 1)
|
| 481 |
-
axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
|
| 482 |
|
| 483 |
|
| 484 |
-
axs[2].plot(full_marks, label='Raw Marks')
|
| 485 |
-
marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
|
| 486 |
-
axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
|
| 487 |
-
axs[2].set_title("Stream 0 - Marks")
|
| 488 |
-
axs[2].set_ylim(0, 1)
|
| 489 |
-
axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
-
axs[3].plot(count)
|
| 493 |
-
axs[3].set_title("Stream 0 - Calculated Count")
|
| 494 |
|
| 495 |
-
|
|
|
|
| 496 |
|
| 497 |
-
plt.
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
|
| 501 |
jumping_speed = np.copy(jumps_per_second)
|
|
@@ -508,6 +536,8 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 508 |
'jumping speed': jumping_speed,
|
| 509 |
'jumps per second': jumps_per_second,
|
| 510 |
'periodicity': periodicity,
|
|
|
|
|
|
|
| 511 |
'miss': misses,
|
| 512 |
'frame_type': frame_type,
|
| 513 |
'event_type': per_frame_event_types,
|
|
@@ -569,6 +599,74 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 569 |
title='event type'
|
| 570 |
))
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
hist = px.histogram(df,
|
| 573 |
x="jumps per second",
|
| 574 |
template="plotly_dark",
|
|
@@ -589,9 +687,9 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 589 |
except FileNotFoundError:
|
| 590 |
pass
|
| 591 |
|
| 592 |
-
return in_video, count_msg, fig, hist, bar
|
| 593 |
|
| 594 |
-
|
| 595 |
with gr.Blocks() as demo:
|
| 596 |
with gr.Row():
|
| 597 |
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
|
|
@@ -628,6 +726,11 @@ with gr.Blocks() as demo:
|
|
| 628 |
periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
|
| 629 |
with gr.Row():
|
| 630 |
out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
with gr.Row():
|
| 632 |
with gr.Column():
|
| 633 |
out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
|
|
@@ -638,7 +741,7 @@ with gr.Blocks() as demo:
|
|
| 638 |
demo_inference = partial(inference, count_only_api=False, api_key=None)
|
| 639 |
|
| 640 |
run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 641 |
-
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
|
| 642 |
api_inference = partial(inference, api_call=True)
|
| 643 |
api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
|
| 644 |
outputs=[period_length], api_name='inference')
|
|
@@ -650,7 +753,7 @@ with gr.Blocks() as demo:
|
|
| 650 |
]
|
| 651 |
gr.Examples(examples,
|
| 652 |
inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 653 |
-
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
|
| 654 |
fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 655 |
|
| 656 |
|
|
|
|
| 30 |
#plt.style.use('dark_background')
|
| 31 |
|
| 32 |
LOCAL = False
|
| 33 |
+
IMG_SIZE = 192
|
| 34 |
CACHE_API_CALLS = True
|
| 35 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 36 |
|
| 37 |
+
onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename="nextjump_192.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 38 |
+
|
| 39 |
+
#onnx_file = 'nextjump.onnx'
|
| 40 |
+
|
| 41 |
if torch.cuda.is_available():
|
| 42 |
print("Using CUDA")
|
| 43 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
|
|
|
| 237 |
|
| 238 |
|
| 239 |
def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
|
| 240 |
+
count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
|
|
|
| 241 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
|
| 242 |
api_call=False,
|
| 243 |
progress=gr.Progress()):
|
|
|
|
| 271 |
seconds = length / fps
|
| 272 |
all_frames = []
|
| 273 |
frame_i = 0
|
| 274 |
+
resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height)
|
| 275 |
while cap.isOpened():
|
| 276 |
frame_i += 1
|
| 277 |
|
|
|
|
| 288 |
frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
|
| 289 |
frame_center_x = frame.shape[1] // 2
|
| 290 |
frame_center_y = frame.shape[0] // 2
|
| 291 |
+
crop_x = frame_center_x - IMG_SIZE // 2
|
| 292 |
+
crop_y = frame_center_y - IMG_SIZE // 2
|
| 293 |
+
frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
|
| 294 |
all_frames.append(frame)
|
| 295 |
|
| 296 |
cap.release()
|
| 297 |
|
| 298 |
length = len(all_frames)
|
| 299 |
period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 300 |
+
period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 301 |
periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 302 |
full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 303 |
event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
|
| 304 |
+
phase_sin = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 305 |
+
phase_cos = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 306 |
period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
|
| 307 |
event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
|
| 308 |
for _ in range(seq_len + stride_length): # pad full sequence
|
|
|
|
| 314 |
for i in range(0, length + stride_length - stride_pad, stride_length):
|
| 315 |
batch = all_frames[i:i + seq_len]
|
| 316 |
Xlist = []
|
| 317 |
+
preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)]
|
| 318 |
for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
|
| 319 |
Xlist.append(future.result())
|
| 320 |
|
|
|
|
| 347 |
y2_out = outputs[1]
|
| 348 |
y3_out = outputs[2]
|
| 349 |
y4_out = outputs[3]
|
| 350 |
+
y5_out = outputs[4]
|
| 351 |
+
y6_out = outputs[5]
|
| 352 |
+
for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
|
| 353 |
periodLength = y1.squeeze()
|
| 354 |
periodicity = y2.squeeze()
|
| 355 |
marks = y3.squeeze()
|
| 356 |
event_type = y4.squeeze()
|
| 357 |
+
foot_type = y5.squeeze()
|
| 358 |
+
phase = y6.squeeze()
|
| 359 |
+
period_lengths[idx:idx+seq_len] += periodLength[:, 0]
|
| 360 |
+
period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
|
| 361 |
periodicities[idx:idx+seq_len] += periodicity
|
| 362 |
full_marks[idx:idx+seq_len] += marks
|
| 363 |
event_type_logits[idx:idx+seq_len] += event_type
|
| 364 |
+
phase_sin[idx:idx+seq_len] += phase[:, 1]
|
| 365 |
+
phase_cos[idx:idx+seq_len] += phase[:, 0]
|
| 366 |
period_length_overlaps[idx:idx+seq_len] += 1
|
| 367 |
event_type_logit_overlaps[idx:idx+seq_len] += 1
|
| 368 |
del y1_out, y2_out, y3_out, y4_out # free up memory
|
| 369 |
|
| 370 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 371 |
+
periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 372 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 373 |
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 374 |
per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
|
| 375 |
+
phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 376 |
+
# negate sin to make the bottom of the plot the start of the jump
|
| 377 |
+
phase_sin = -phase_sin
|
| 378 |
+
phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 379 |
event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
|
| 380 |
# softmax of event type logits
|
| 381 |
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
|
|
|
|
| 486 |
|
| 487 |
|
| 488 |
|
| 489 |
+
# fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count
|
| 490 |
|
| 491 |
+
# # Ensure data exists before plotting
|
| 492 |
+
# axs[0].plot(periodLength, label='Period Length')
|
| 493 |
+
# axs[0].plot(periodLength_rope, label='Period Length (Rope)')
|
| 494 |
+
# axs[0].set_title(f"Stream 0 - Period Length")
|
| 495 |
+
# axs[0].legend()
|
| 496 |
|
| 497 |
+
# axs[1].plot(periodicity)
|
| 498 |
+
# axs[1].set_title("Stream 0 - Periodicity")
|
| 499 |
+
# axs[1].set_ylim(0, 1)
|
| 500 |
+
# axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
|
| 501 |
|
| 502 |
|
| 503 |
+
# axs[2].plot(full_marks, label='Raw Marks')
|
| 504 |
+
# marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
|
| 505 |
+
# axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
|
| 506 |
+
# axs[2].set_title("Stream 0 - Marks")
|
| 507 |
+
# axs[2].set_ylim(0, 1)
|
| 508 |
+
# axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
|
| 509 |
|
| 510 |
+
# # plot phase
|
| 511 |
+
# axs[3].plot(phase_sin, label='Phase Sin')
|
| 512 |
+
# axs[3].plot(phase_cos, label='Phase Cos')
|
| 513 |
+
# axs[3].set_title("Stream 0 - Phase")
|
| 514 |
+
# axs[3].set_ylim(-1, 1)
|
| 515 |
+
# axs[3].axhline(0, color='r', linestyle=':', label='Zero Line')
|
| 516 |
+
# axs[3].legend()
|
| 517 |
|
|
|
|
|
|
|
| 518 |
|
| 519 |
+
# axs[4].plot(count)
|
| 520 |
+
# axs[4].set_title("Stream 0 - Calculated Count")
|
| 521 |
|
| 522 |
+
# plt.tight_layout()
|
| 523 |
+
|
| 524 |
+
# plt.savefig('plot.png')
|
| 525 |
+
# plt.close()
|
| 526 |
+
|
| 527 |
|
| 528 |
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
|
| 529 |
jumping_speed = np.copy(jumps_per_second)
|
|
|
|
| 536 |
'jumping speed': jumping_speed,
|
| 537 |
'jumps per second': jumps_per_second,
|
| 538 |
'periodicity': periodicity,
|
| 539 |
+
'phase sin': phase_sin,
|
| 540 |
+
'phase cos': phase_cos,
|
| 541 |
'miss': misses,
|
| 542 |
'frame_type': frame_type,
|
| 543 |
'event_type': per_frame_event_types,
|
|
|
|
| 599 |
title='event type'
|
| 600 |
))
|
| 601 |
|
| 602 |
+
|
| 603 |
+
# -pi/2 phase offset to make the bottom of the plot the start of the jump
|
| 604 |
+
# phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
|
| 605 |
+
# phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
|
| 606 |
+
|
| 607 |
+
# plot phase spiral using plotly
|
| 608 |
+
fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin,
|
| 609 |
+
color=jumps_per_second,
|
| 610 |
+
color_continuous_scale='plasma',
|
| 611 |
+
title="Phase Spiral (speed)",
|
| 612 |
+
template="plotly_dark")
|
| 613 |
+
fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5))
|
| 614 |
+
fig_phase_spiral.update_layout(
|
| 615 |
+
xaxis_title="Phase Cos",
|
| 616 |
+
yaxis_title="Phase Sin",
|
| 617 |
+
xaxis=dict(range=[-1, 1]),
|
| 618 |
+
yaxis=dict(range=[-1, 1]),
|
| 619 |
+
showlegend=False,
|
| 620 |
+
)
|
| 621 |
+
# label colorbar as time
|
| 622 |
+
fig_phase_spiral.update_coloraxes(colorbar=dict(
|
| 623 |
+
title="Jumps per second"))
|
| 624 |
+
# make axes equal
|
| 625 |
+
fig_phase_spiral.update_layout(
|
| 626 |
+
xaxis=dict(scaleanchor="y"),
|
| 627 |
+
yaxis=dict(constrain="domain"),
|
| 628 |
+
)
|
| 629 |
+
# overlay line plot of phase sin and cos
|
| 630 |
+
fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data)
|
| 631 |
+
fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
|
| 632 |
+
|
| 633 |
+
# plot phase consistency (sin^2 + cos^2 = 1) as a line plot
|
| 634 |
+
# phase_consistency = phase_sin**2 + phase_cos**2
|
| 635 |
+
# #phase_consistency = medfilt(phase_consistency, 5)
|
| 636 |
+
# fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency,
|
| 637 |
+
# title="Phase Consistency (sin^2 + cos^2)",
|
| 638 |
+
# labels={'x': 'Frame', 'y': 'Phase Consistency'},
|
| 639 |
+
# template="plotly_dark")
|
| 640 |
+
|
| 641 |
+
# plot phase spiral colored by mark_preds
|
| 642 |
+
fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin,
|
| 643 |
+
color=full_marks,
|
| 644 |
+
color_continuous_scale='Jet',
|
| 645 |
+
title="Phase Spiral (marks)",
|
| 646 |
+
template="plotly_dark")
|
| 647 |
+
fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5))
|
| 648 |
+
fig_phase_spiral_marks.update_layout(
|
| 649 |
+
xaxis_title="Phase Cos",
|
| 650 |
+
yaxis_title="Phase Sin",
|
| 651 |
+
xaxis=dict(range=[-1, 1]),
|
| 652 |
+
yaxis=dict(range=[-1, 1]),
|
| 653 |
+
showlegend=False,
|
| 654 |
+
)
|
| 655 |
+
# label colorbar as time
|
| 656 |
+
fig_phase_spiral_marks.update_coloraxes(colorbar=dict(
|
| 657 |
+
title="Marks"))
|
| 658 |
+
# make axes equal
|
| 659 |
+
fig_phase_spiral_marks.update_layout(
|
| 660 |
+
xaxis=dict(scaleanchor="y"),
|
| 661 |
+
yaxis=dict(constrain="domain"),
|
| 662 |
+
)
|
| 663 |
+
# overlay line plot of phase sin and cos
|
| 664 |
+
fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data)
|
| 665 |
+
fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
|
| 670 |
hist = px.histogram(df,
|
| 671 |
x="jumps per second",
|
| 672 |
template="plotly_dark",
|
|
|
|
| 687 |
except FileNotFoundError:
|
| 688 |
pass
|
| 689 |
|
| 690 |
+
return in_video, count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, bar
|
| 691 |
|
| 692 |
+
#css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
|
| 693 |
with gr.Blocks() as demo:
|
| 694 |
with gr.Row():
|
| 695 |
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
|
|
|
|
| 726 |
periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
|
| 727 |
with gr.Row():
|
| 728 |
out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
|
| 729 |
+
with gr.Row():
|
| 730 |
+
with gr.Column():
|
| 731 |
+
out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral')
|
| 732 |
+
with gr.Column():
|
| 733 |
+
out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks')
|
| 734 |
with gr.Row():
|
| 735 |
with gr.Column():
|
| 736 |
out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
|
|
|
|
| 741 |
demo_inference = partial(inference, count_only_api=False, api_key=None)
|
| 742 |
|
| 743 |
run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 744 |
+
outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist])
|
| 745 |
api_inference = partial(inference, api_call=True)
|
| 746 |
api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
|
| 747 |
outputs=[period_length], api_name='inference')
|
|
|
|
| 753 |
]
|
| 754 |
gr.Examples(examples,
|
| 755 |
inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 756 |
+
outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
|
| 757 |
fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 758 |
|
| 759 |
|