laurenok24 commited on
Commit
3e16b6a
1 Parent(s): 2af1b18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +366 -0
app.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import cv2
4
+ import gradio as gr
5
+ print(gr.__version__)
6
+ from tempSegAndAllErrorsForAllFrames import getAllErrorsAndSegmentation
7
+ from models.detectron2.platform_detector_setup import get_platform_detector
8
+ from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation
9
+ from models.detectron2.diver_detector_setup import get_diver_detector
10
+ from models.pose_estimator.pose_estimator_model_setup import get_pose_model
11
+ from models.detectron2.splash_detector_setup import get_splash_detector
12
+ from scoring_functions import *
13
+ from generate_reports import *
14
+ from tempSegAndAllErrorsForAllFrames_newVids import getAllErrorsAndSegmentation_newVids, abstractSymbols
15
+
16
+ from jinja2 import Environment, FileSystemLoader
17
+ from PIL import Image, ImageDraw
18
+ from io import BytesIO
19
+ import base64
20
+
21
+ platform_detector = get_platform_detector()
22
+ splash_detector = get_splash_detector()
23
+ diver_detector = get_diver_detector()
24
+ pose_model = get_pose_model()
25
+ template_path = 'report_template_tables.html'
26
+ dive_data = {}
27
+
28
+ with open('./segmentation_error_data.pkl', 'rb') as f:
29
+ dive_data_precomputed = pickle.load(f)
30
+
31
+ import sys
32
+ import csv
33
+
34
+ csv.field_size_limit(sys.maxsize)
35
+
36
+ with open('FineDiving/Annotations/fine-grained_annotation_aqa.pkl', 'rb') as f:
37
+ dive_annotation_data = pickle.load(f)
38
+
39
+ def extract_frames(video_path):
40
+ cap = cv2.VideoCapture(video_path)
41
+ # Check if the video file is opened successfully
42
+ if not cap.isOpened():
43
+ print("Error: Couldn't open video file.")
44
+ exit()
45
+ # a variable to set how many frames you want to skip
46
+ frame_skip = 1
47
+ # a variable to keep track of the frame to be saved
48
+ frame_count = 0
49
+ frames = []
50
+ i = 0
51
+ while True:
52
+ ret, frame = cap.read()
53
+ if not ret:
54
+ break
55
+ if i > frame_skip - 1:
56
+ frame_count += 1
57
+ # print("frame.shape:", frame.shape)
58
+ # resize takes argument (width, height)
59
+ frame = cv2.resize(frame, (455, 256))
60
+ frames.append(frame)
61
+ i = 0
62
+ continue
63
+ # cv2.imwrite("./tempdata/{}.jpg".format(i), frame)
64
+ i += 1
65
+ cap.release()
66
+ print("frame_count", frame_count)
67
+ return frames
68
+
69
+ def get_key_from_videopath(video):
70
+ try:
71
+ video_name = video.split('/')[-1]
72
+ first_folder = video_name.split('_')[1]
73
+ second_folder = video_name.split('_')[2].split('.')[0]
74
+ return (first_folder, int(second_folder))
75
+ except:
76
+ return None
77
+
78
+ def get_abstracted_symbols_precomputed(video, key, progress=gr.Progress()):
79
+ progress(0, desc="Abstracting Symbols")
80
+ if video is None:
81
+ raise gr.Error("input a video!!")
82
+ local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1])
83
+ directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1])
84
+ # dive_data = abstractSymbols(frames, progress=progress, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
85
+ # dive_data['frames'] = frames
86
+ global dive_data_precomputed
87
+ dive_data = dive_data_precomputed[key]
88
+ html_intermediate = generate_symbols_report_precomputed("intermediate_steps.html", dive_data, local_directory, progress=progress)
89
+ progress(0.95, desc="Abstracting Symbols")
90
+ return html_intermediate
91
+
92
+ def get_abstracted_symbols_calculated(video, progress=gr.Progress()):
93
+ progress(0, desc="Abstracting Symbols")
94
+ frames = extract_frames(video)
95
+ global dive_data
96
+ dive_data = abstractSymbols(frames, progress=progress, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
97
+ dive_data['frames'] = frames
98
+ html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames)
99
+ return html_intermediate
100
+
101
+ def get_abstracted_symbols(video, progress=gr.Progress()):
102
+ if video is None:
103
+ raise gr.Error("input a video!!")
104
+ key = get_key_from_videopath(video)
105
+ if key is None:
106
+ return get_abstracted_symbols_calculated(video, progress=progress)
107
+ else:
108
+ return get_abstracted_symbols_precomputed(video, key, progress=progress)
109
+
110
+ def get_score_report_precomputed(video, key, progress=gr.Progress(), diveNum=""):
111
+ progress(0, desc="Calculating Dive Errors")
112
+ if video is None:
113
+ raise gr.Error("input a video!!")
114
+ global dive_data_precomputed
115
+ dive_data = dive_data_precomputed[key]
116
+ local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1])
117
+ directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1])
118
+
119
+ intermediate_scores_dict = get_all_report_scores(dive_data)
120
+ progress(0.75, desc="Generating Score Report")
121
+ print('getting html...')
122
+ html = generate_report(template_path, intermediate_scores_dict, directory, local_directory, progress=progress)
123
+ progress(0.9, desc="Generating Score Report")
124
+ html = (
125
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
126
+ + html
127
+ + "</div>")
128
+ print("returning...")
129
+ return html
130
+
131
+ def get_score_report_calculated(video, progress=gr.Progress(), diveNum=""):
132
+ progress(0, desc="Calculating Dive Errors")
133
+ global dive_data
134
+ frames = extract_frames(video)
135
+ dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, progress=progress, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
136
+ intermediate_scores_dict = get_all_report_scores(dive_data)
137
+ progress(0.75, desc="Generating Score Report")
138
+ print('getting html...')
139
+ html = generate_report_from_frames(template_path, intermediate_scores_dict, frames)
140
+ html = (
141
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
142
+ + html
143
+ + "</div>")
144
+ print("returning...")
145
+ progress(8/8, desc="Generating Score Report")
146
+ return html
147
+
148
+ def get_score_report(video, progress=gr.Progress(), diveNum=""):
149
+ if video is None:
150
+ raise gr.Error("input a video!!")
151
+ key = get_key_from_videopath(video)
152
+ if key is None:
153
+ return get_score_report_calculated(video, progress=progress)
154
+ else:
155
+ return get_score_report_precomputed(video, key, progress=progress)
156
+
157
+
158
+ def get_html_from_video(video, diveNum=""):
159
+ if video is None:
160
+ raise gr.Error("input a video!!")
161
+ frames = extract_frames(video)
162
+ dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
163
+ dive_data['frames'] = frames.copy()
164
+ html_intermediate = generate_symbols_report("intermediate_steps.html", dive_data, frames)
165
+ yield html_intermediate
166
+ dive_data = getAllErrorsAndSegmentation_newVids(frames, dive_data, diveNum=diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
167
+ intermediate_scores_dict = get_all_report_scores(dive_data)
168
+ print('getting html...')
169
+ html = generate_report_from_frames(template_path, intermediate_scores_dict, frames)
170
+ html = (
171
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
172
+ + html_intermediate
173
+ + html
174
+ + "</div>")
175
+ print("returning...")
176
+ yield html
177
+
178
+ def get_html_from_finedivingkey(first_folder, second_folder):
179
+ board_side = "left" # change!!!
180
+ key = (first_folder, int(second_folder))
181
+ local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1])
182
+ directory = "file:///Users/lokamoto/Comprehensive_AQA/FineDiving/datasets/FINADiving_MTL_256s/{}/{}".format(key[0], key[1])
183
+ print("key:", key)
184
+ diveNum = dive_annotation_data[key][0]
185
+ pose_preds, takeoff, twist, som, entry, distance_from_board, position_tightness, feet_apart, over_under_rotation, splash, above_boards, on_boards, som_counts, twist_counts, board_end_coords, diver_boxes = getAllErrorsAndSegmentation(first_folder, second_folder, diveNum, board_side=board_side, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
186
+ dive_data['pose_pred'] = pose_preds
187
+ dive_data['takeoff'] = takeoff
188
+ dive_data['twist'] = twist
189
+ dive_data['som'] = som
190
+ dive_data['entry'] = entry
191
+ dive_data['distance_from_board'] = distance_from_board
192
+ dive_data['position_tightness'] = position_tightness
193
+ dive_data['feet_apart'] = feet_apart
194
+ dive_data['over_under_rotation'] = over_under_rotation
195
+ dive_data['splash'] = splash
196
+ dive_data['above_boards'] = above_boards
197
+ dive_data['on_boards'] = on_boards
198
+ dive_data['som_counts'] = som_counts
199
+ dive_data['twist_counts'] = twist_counts
200
+ dive_data['board_end_coords'] = board_end_coords
201
+ dive_data['diver_boxes'] = diver_boxes
202
+ dive_data['diveNum'] = diveNum
203
+ dive_data['board_side'] = board_side
204
+
205
+ intermediate_scores_dict = get_all_report_scores(dive_data)
206
+ html = generate_report(template_path, intermediate_scores_dict, directory, local_directory)
207
+ html = (
208
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
209
+ + html
210
+ + "</div>")
211
+
212
+ return html
213
+
214
+ ## gradio where we input a video ###
215
+ def enable_get_score_btn(get_score_btn):
216
+ return gr.Button.update(interactive=True, variant="primary")
217
+
218
+ def disable_get_score_btn(get_score_btn):
219
+ return gr.Button.update(interactive=False, variant="secondary")
220
+
221
+ with gr.Blocks() as demo_new:
222
+ gr.Markdown(
223
+ """
224
+ # NS-AQA
225
+ This system takes in a diving video, and outputs a detailed report summarizing each component of the dive and how we evaluated it. We first abstract the necessary symbols, and then proceed to score the dive.\n
226
+ Paper: *insert link to paper* \n
227
+ Code: *insert github link*
228
+ """)
229
+
230
+ with gr.Row():
231
+ with gr.Column():
232
+ gr.Markdown(
233
+ """
234
+ ## Step 1: Abstract Symbols
235
+ We first abstract the necessary visual elements from the provided diving video. This includes the platform, splash, and the pose estimation of the diver.
236
+ """
237
+ )
238
+ video = gr.Video(label="Video", format="mp4", include_audio=False)
239
+ abstract_symbols_btn = gr.Button("Abstract Symbols", variant='primary')
240
+ symbol_output = gr.HTML(label="Output")
241
+ examples = gr.Examples(examples = [['01_10.mp4'], ['01_11.mp4'], ['01_16.mp4'], ['01_33.mp4'], ['01_140.mp4']], inputs=[video])
242
+
243
+ with gr.Row():
244
+ gr.Markdown(
245
+ """
246
+ ## Step 2: Calculate Logic-Based Errors and Generate Detailed Score Report
247
+ """
248
+ )
249
+ get_score_btn = gr.Button("Get Score", interactive=False, variant='secondary')
250
+ score_report = gr.HTML(label="Output")
251
+ # get_score_report_btn = gr.Button("Get Score Report")
252
+ # video.change(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
253
+ video.change(fn=disable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
254
+ video.change(fn=enable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn)
255
+ abstract_symbols_btn.click(fn=get_abstracted_symbols, inputs=video, outputs=symbol_output).success(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
256
+ symbol_output.change(fn=disable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn)
257
+ symbol_output.change(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
258
+ get_score_btn.click(fn=get_score_report, inputs=[video], outputs=score_report)
259
+
260
+
261
+ #### demo precomputed ########
262
+ with gr.Blocks() as demo_precomputed:
263
+ gr.Markdown(
264
+ """
265
+ # Neuro-Symbolic Olympic Diving Judge
266
+ This system not only scores an Olympic dive, and outputs a detailed report summarizing each component of the dive and how we evaluated it. We first abstract the necessary symbols, and then proceed to score the dive.\n
267
+ Paper: *insert link to paper* \n
268
+ Code: *insert github link*
269
+ """)
270
+
271
+ gr.Markdown(
272
+ """
273
+ ## Step 1: Abstract Symbols
274
+ We first abstract the necessary visual elements from the provided diving video. This includes the platform, splash, and the pose estimation of the diver.
275
+ """
276
+ )
277
+ # with gr.Row():
278
+ gr.HTML(
279
+ """
280
+ <table>
281
+ <tr>
282
+ <td>
283
+ Platform
284
+ <img src='file/platform.png' height='90'>
285
+ </td>
286
+ <td>
287
+ The location of the platform, especially the position of its edge facing the pool, is crucial to determine when the diver leaves the platform, thus starting their dive.
288
+ The platform location is also important to assess how close the diver comes to its edge, which is relevant to scoring.
289
+ </td>
290
+ <td>
291
+ Pose Estimation of Diver
292
+ <img src='file/pose_estimation.png' height='70'>
293
+ </td>
294
+ <td>
295
+ The pose of the diver in the sequence of video frames is critical to understanding and assessing the dive.
296
+ We obtain 2D pose data with locations of various body parts, including the head, thorax, pelvis, shoulders, elbows, wrists, hips, knees, and ankles.
297
+ With this, we can recognize sub-actions being performed by the diver, such as a somersault, a twist, or an entry, and also assess the quality of that sub-action.
298
+ </td>
299
+ <td>
300
+ Splash
301
+ <img src='file/splash.png' height='90'>
302
+ </td>
303
+ <td>
304
+ Splash at entry into the pool is a conspicuous visual feature of a dive.
305
+ The size of the splash is an important element in traditional scoring of dives.
306
+ A large splash mars the end of a dive and also likely indicates a flaw in form at water entry.
307
+ </td>
308
+ </tr>
309
+ </table>
310
+ """
311
+ )
312
+ gr.Markdown(
313
+ """
314
+ 1. Select one of the example diving videos.
315
+ 2. Hit the **Abstract Symbols** button.
316
+ """
317
+ )
318
+
319
+ with gr.Row(variant='panel'):
320
+ with gr.Column():
321
+ video = gr.Video(label="Video", format="mp4", include_audio=False)
322
+ abstract_symbols_btn = gr.Button("Abstract Symbols", variant='primary')
323
+ symbol_output = gr.HTML(label="Output")
324
+ examples = gr.Examples(examples = [['01_10.mp4'], ['01_11.mp4'], ['01_16.mp4'], ['01_33.mp4'], ['01_76.mp4'], ['01_140.mp4']], inputs=[video])
325
+
326
+ gr.Markdown(
327
+ """
328
+ ## Step 2: Calculate Logic-Based Errors and Generate Detailed Score Report
329
+
330
+ Using the abstracted symbols, we calculate different "errors" of the dive.
331
+ These errors are: **feet apart; height off board; distance from board; somersault position tightness; knee straightness; twist position straightness; over/under rotation; straightness of body during entry; and splash size.**
332
+ Each error is scored on a scale of 0-10, and are then averaged to reach a final score for the dive.
333
+
334
+ We then programmatically generate a detailed performance report containing different aspects of the dive, their percentile scores, and visual evidence.
335
+ This report can be seen as a compact, but highly detailed representation of quality of the dive performed.
336
+ It can be helpful for a number of reasons including as a support to human judges and as an educational tool to teach coaches, athletes, and judges how to score.
337
+
338
+ 1. Click the **Get Score** button. The Score Report will be generated below. (Abstract Symbols first if you haven't already!)
339
+ """
340
+ )
341
+
342
+ # with gr.Row():
343
+ get_score_btn = gr.Button("Get Score", interactive=False)
344
+ score_report = gr.HTML(label="Report")
345
+ # get_score_report_btn = gr.Button("Get Score Report")
346
+ video.change(fn=disable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
347
+ video.change(fn=enable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn)
348
+ abstract_symbols_btn.click(fn=get_abstracted_symbols, inputs=video, outputs=symbol_output).success(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
349
+ symbol_output.change(fn=disable_get_score_btn, inputs=abstract_symbols_btn, outputs=abstract_symbols_btn)
350
+ symbol_output.change(fn=enable_get_score_btn, inputs=get_score_btn, outputs=get_score_btn)
351
+ get_score_btn.click(fn=get_score_report, inputs=video, outputs=score_report)
352
+
353
+
354
+ ############################################################################################################################################
355
+
356
+
357
+ demo_precomputed.queue()
358
+ demo_precomputed.launch(share=True)
359
+ ######### gradio where we input first and second folder ##
360
+ # demo = gr.Interface(
361
+ # fn=get_html_from_finedivingkey,
362
+ # inputs=["text", "text"],
363
+ # outputs=["html"],
364
+ # )
365
+
366
+ # demo.launch(share=True, enable_queue=True,)