|
import os |
|
import io |
|
import pickle |
|
import cv2 |
|
import gradio as gr |
|
print(gr.__version__) |
|
from tempSegAndAllErrorsForAllFrames import getAllErrorsAndSegmentation |
|
from models.detectron2.platform_detector_setup import get_platform_detector |
|
from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation |
|
from models.detectron2.diver_detector_setup import get_diver_detector |
|
from models.pose_estimator.pose_estimator_model_setup import get_pose_model |
|
from models.detectron2.splash_detector_setup import get_splash_detector |
|
from scoring_functions import * |
|
from generate_reports import * |
|
from tempSegAndAllErrorsForAllFrames_newVids import getAllErrorsAndSegmentation_newVids, abstractSymbols |
|
|
|
from jinja2 import Environment, FileSystemLoader |
|
from PIL import Image, ImageDraw |
|
from io import BytesIO |
|
import base64 |
|
|
|
|
|
|
|
|
|
|
|
template_path = 'report_template_tables.html' |
|
dive_data = {} |
|
|
|
class CPU_Unpickler(pickle.Unpickler): |
|
def find_class(self, module, name): |
|
if module == 'torch.storage' and name == '_load_from_bytes': |
|
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
|
else: return super().find_class(module, name) |
|
|
|
dive_data_precomputed = CPU_Unpickler(open('./segmentation_error_data.pkl', 'rb')).load() |
|
|
|
|
|
|
|
|
|
import sys |
|
import csv |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
with open('FineDiving/fine-grained_annotation_aqa.pkl', 'rb') as f: |
|
dive_annotation_data = pickle.load(f) |
|
|
|
def extract_frames(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
if not cap.isOpened(): |
|
print("Error: Couldn't open video file.") |
|
exit() |
|
|
|
frame_skip = 1 |
|
|
|
frame_count = 0 |
|
frames = [] |
|
i = 0 |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if i > frame_skip - 1: |
|
frame_count += 1 |
|
|
|
|
|
frame = cv2.resize(frame, (455, 256)) |
|
frames.append(frame) |
|
i = 0 |
|
continue |
|
|
|
i += 1 |
|
cap.release() |
|
print("frame_count", frame_count) |
|
return frames |
|
|
|
def get_key_from_videopath(video): |
|
try: |
|
video_name = video.split('/')[-1] |
|
first_folder = video_name.split('_')[1] |
|
second_folder = video_name.split('_')[2].split('.')[0] |
|
return (first_folder, int(second_folder)) |
|
except: |
|
return None |
|
|
|
def get_abstracted_symbols_precomputed(video, key, progress=gr.Progress()): |
|
progress(0, desc="Abstracting Symbols") |
|
if video is None: |
|
raise gr.Error("input a video!!") |
|
local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1]) |
|
directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
|
|
|
|
|
global dive_data_precomputed |
|
dive_data = dive_data_precomputed[key] |
|
html_intermediate = generate_symbols_report_precomputed("intermediate_steps.html", dive_data, local_directory, progress=progress) |
|
progress(0.95, desc="Abstracting Symbols") |
|
return html_intermediate |
|
|
|
def get_abstracted_symbols_calculated(video, progress=gr.Progress()): |
|
progress(0, desc="Abstracting Symbols") |
|
frames = extract_frames(video) |
|
global dive_data |
|
dive_data = abstractSymbols(frames, progress=progress, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
|
dive_data['frames'] = frames |
|
html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames) |
|
return html_intermediate |
|
|
|
def get_abstracted_symbols(video, progress=gr.Progress()): |
|
if video is None: |
|
raise gr.Error("Click on an example diving video first!") |
|
key = get_key_from_videopath(video) |
|
if key is None: |
|
return get_abstracted_symbols_calculated(video, progress=progress) |
|
else: |
|
return get_abstracted_symbols_precomputed(video, key, progress=progress) |
|
|
|
def get_score_report_precomputed(video, key, progress=gr.Progress(), diveNum=""): |
|
progress(0, desc="Calculating Dive Errors") |
|
if video is None: |
|
raise gr.Error("input a video!!") |
|
global dive_data_precomputed |
|
dive_data = dive_data_precomputed[key] |
|
local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1]) |
|
directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
|
|
|
intermediate_scores_dict = get_all_report_scores(dive_data) |
|
progress(0.75, desc="Generating Score Report") |
|
print('getting html...') |
|
html = generate_report(template_path, intermediate_scores_dict, directory, local_directory, progress=progress) |
|
progress(0.9, desc="Generating Score Report") |
|
html = ( |
|
"<div style='max-width:100%; max-height:360px; overflow:auto'>" |
|
+ html |
|
+ "</div>") |
|
print("returning...") |
|
return html |
|
|
|
def get_score_report_calculated(video, progress=gr.Progress(), diveNum=""): |
|
progress(0, desc="Calculating Dive Errors") |
|
global dive_data |
|
frames = extract_frames(video) |
|
dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, progress=progress, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
|
intermediate_scores_dict = get_all_report_scores(dive_data) |
|
progress(0.75, desc="Generating Score Report") |
|
print('getting html...') |
|
html = generate_report_from_frames(template_path, intermediate_scores_dict, frames) |
|
html = ( |
|
"<div style='max-width:100%; max-height:360px; overflow:auto'>" |
|
+ html |
|
+ "</div>") |
|
print("returning...") |
|
progress(8/8, desc="Generating Score Report") |
|
return html |
|
|
|
def get_score_report(video, progress=gr.Progress(), diveNum=""): |
|
if video is None: |
|
raise gr.Error("input a video!!") |
|
key = get_key_from_videopath(video) |
|
if key is None: |
|
return get_score_report_calculated(video, progress=progress) |
|
else: |
|
return get_score_report_precomputed(video, key, progress=progress) |
|
|
|
|
|
def get_html_from_video(video, diveNum=""): |
|
if video is None: |
|
raise gr.Error("input a video!!") |
|
frames = extract_frames(video) |
|
dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
|
dive_data['frames'] = frames.copy() |
|
html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames) |
|
yield html_intermediate |
|
dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
|
intermediate_scores_dict = get_all_report_scores(dive_data) |
|
print('getting html...') |
|
html = generate_report_from_frames(template_path, intermediate_scores_dict, frames) |
|
html = ( |
|
"<div style='max-width:100%; max-height:360px; overflow:auto'>" |
|
+ html_intermediate |
|
+ html |
|
+ "</div>") |
|
print("returning...") |
|
yield html |
|
|
|
def get_html_from_finedivingkey(first_folder, second_folder): |
|
board_side = "left" |
|
key = (first_folder, int(second_folder)) |
|
local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
|
directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1]) |
|
print("key:", key) |
|
diveNum = dive_annotation_data[key][0] |
|
pose_preds, takeoff, twist, som, entry, distance_from_board, position_tightness, feet_apart, over_under_rotation, splash, above_boards, on_boards, som_counts, twist_counts, board_end_coords, diver_boxes = getAllErrorsAndSegmentation(first_folder, second_folder, diveNum, board_side=board_side, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) |
|
dive_data['pose_pred'] = pose_preds |
|
dive_data['takeoff'] = takeoff |
|
dive_data['twist'] = twist |
|
dive_data['som'] = som |
|
dive_data['entry'] = entry |
|
dive_data['distance_from_board'] = distance_from_board |
|
dive_data['position_tightness'] = position_tightness |
|
dive_data['feet_apart'] = feet_apart |
|
dive_data['over_under_rotation'] = over_under_rotation |
|
dive_data['splash'] = splash |
|
dive_data['above_boards'] = above_boards |
|
dive_data['on_boards'] = on_boards |
|
dive_data['som_counts'] = som_counts |
|
dive_data['twist_counts'] = twist_counts |
|
dive_data['board_end_coords'] = board_end_coords |
|
dive_data['diver_boxes'] = diver_boxes |
|
dive_data['diveNum'] = diveNum |
|
dive_data['board_side'] = board_side |
|
|
|
intermediate_scores_dict = get_all_report_scores(dive_data) |
|
html = generate_report(template_path, intermediate_scores_dict, directory, local_directory) |
|
html = ( |
|
"<div style='max-width:100%; max-height:360px; overflow:auto'>" |
|
+ html |
|
+ "</div>") |
|
|
|
return html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enable_get_score_btn(get_score_btn): |
|
return gr.Button(interactive=True, variant="primary") |
|
|
|
def disable_get_score_btn(get_score_btn): |
|
return gr.Button(interactive=False, variant="secondary") |
|
|
|
|
|
with gr.Blocks() as demo_precomputed: |
|
gr.Markdown( |
|
""" |
|
# Neuro-Symbolic Olympic Diving Judge |
|
Authors: ... |
|
This system not only scores an Olympic dive, but outputs a detailed report summarizing each component of the dive and how we evaluated it. We first abstract the necessary symbols, and then proceed to score the dive.\n |
|
Paper: *insert link to paper* \n |
|
Code: *insert github link* |
|
""") |
|
|
|
gr.Markdown( |
|
""" |
|
## Step 1: Neural Symbol Abstraction |
|
We first abstract the necessary visual elements from the provided diving video. This includes the platform, splash, and the pose estimation of the diver. |
|
""" |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<table> |
|
<tr> |
|
<td> |
|
Platform |
|
<img src='file/platform.png' height='90'> |
|
</td> |
|
<td> |
|
The location of the platform is crucial to determine when the diver leaves the platform, thus starting their dive. |
|
It is also important to assess how close the diver comes to its edge, which is relevant to scoring. |
|
</td> |
|
<td> |
|
Pose Estimation of Diver |
|
<img src='file/pose_estimation.png' height='70'> |
|
</td> |
|
<td> |
|
The pose of the diver in the sequence of video frames is critical to understanding and assessing the dive. |
|
We obtain 2D pose data with locations of various body parts to recognize sub-actions being performed by the diver, such as a somersault, a twist, or an entry, and also assess the quality of that sub-action. |
|
</td> |
|
<td> |
|
Splash |
|
<img src='file/splash.png' height='90'> |
|
</td> |
|
<td> |
|
Splash at entry into the pool is a conspicuous visual feature of a dive. |
|
The size of the splash is an important element in traditional scoring of dives. |
|
A large splash mars the end of a dive and also likely indicates a flaw in form at water entry. |
|
</td> |
|
</tr> |
|
</table> |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
1. Select one of the example diving videos. |
|
2. Hit the **Abstract Symbols** button. The symbols abstracted will appear to the right of the diving video. |
|
""" |
|
) |
|
with gr.Row(variant='panel'): |
|
with gr.Column(): |
|
video = gr.Video(label="Video", format="mp4", include_audio=False, sources=["upload"], interactive=False) |
|
examples = gr.Examples(examples = [['01_10.mp4'], ['01_11.mp4'], ['01_16.mp4'], ['01_33.mp4'], ['01_76.mp4'], ['01_140.mp4']], inputs=[video], label="Click on one of the following diving videos") |
|
symbol_output = gr.HTML(label="Output") |
|
abstract_symbols_btn = gr.Button("Abstract Symbols", variant='secondary') |
|
gr.Markdown( |
|
""" |
|
## Step 2: Calculate Logic-Based Errors and Generate Detailed Score Report |
|
|
|
Using the abstracted symbols, we calculate different "errors" of the dive. |
|
These errors are: **feet apart; height off board; distance from board; somersault position tightness; knee straightness; twist position straightness; over/under rotation; straightness of body during entry; and splash size.** |
|
Each error is scored on a scale of 0-10, and are then averaged to reach a final score for the dive. |
|
|
|
We then programmatically generate a detailed performance report containing different aspects of the dive, their percentile scores, and visual evidence. |
|
It can be helpful for a number of reasons including as a support to human judges and as an educational tool to teach coaches, athletes, and judges how to score. |
|
|
|
1. Click the **Generate Score Report** button. The Score Report will be generated below. (Abstract Symbols first if you haven't already!) |
|
""" |
|
) |
|
|
|
|
|
get_score_btn = gr.Button("Generate Score Report", interactive=False) |
|
score_report = gr.HTML(label="Report") |
|
|
|
|
|
video.change(fn=disable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
|
video.change(fn=enable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn) |
|
abstract_symbols_btn.click(fn=get_abstracted_symbols, inputs=video, outputs=symbol_output).success(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
|
symbol_output.change(fn=disable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn) |
|
symbol_output.change(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn) |
|
get_score_btn.click(fn=get_score_report, inputs=video, outputs=score_report) |
|
|
|
|
|
|
|
|
|
|
|
demo_precomputed.queue() |
|
demo_precomputed.launch(share=True, allowed_paths=["."]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|