Spaces:
Runtime error
Runtime error
Commit
•
17fa97d
1
Parent(s):
b02b8a0
Hyperparameter settings
Browse files- app.py +29 -9
- gradio_scripts/upload_ui.py +15 -1
- inference.py +4 -4
- main.py +2 -5
- visualizer.py +1 -1
app.py
CHANGED
@@ -27,16 +27,24 @@ result = {}
|
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
-
def on_aris_input(file_list, model_id):
|
31 |
-
|
32 |
-
print(model_id)
|
33 |
-
print(models[model_id] if model_id in models else models['master'])
|
34 |
|
35 |
# Reset Result
|
36 |
reset_state(result, state)
|
37 |
state['files'] = file_list
|
38 |
state['total'] = len(file_list)
|
39 |
-
state['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Update loading_space to start inference on first file
|
42 |
return {
|
@@ -130,8 +138,10 @@ def infer_next(_, progress=gr.Progress()):
|
|
130 |
file_name = file_info[0].split("/")[-1]
|
131 |
bytes = file_info[1]
|
132 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
133 |
-
|
134 |
-
print(
|
|
|
|
|
135 |
|
136 |
# Check that the file was valid
|
137 |
if not valid:
|
@@ -143,8 +153,18 @@ def infer_next(_, progress=gr.Progress()):
|
|
143 |
# Send uploaded file to AWS
|
144 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
145 |
|
|
|
|
|
146 |
# Do inference
|
147 |
-
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
# Store result for that file
|
150 |
result['json_result'].append(json_result)
|
@@ -370,7 +390,7 @@ with demo:
|
|
370 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
371 |
|
372 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
373 |
-
input.upload(on_aris_input, [input
|
374 |
|
375 |
# When inference handler updates, tell result_handler to show the new result
|
376 |
# Also, add inference_handler as the output in order to have it display the progress
|
|
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
+
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age):
|
|
|
|
|
|
|
31 |
|
32 |
# Reset Result
|
33 |
reset_state(result, state)
|
34 |
state['files'] = file_list
|
35 |
state['total'] = len(file_list)
|
36 |
+
state['hyperparams'] = {
|
37 |
+
'model': models[model_id] if model_id in models else models['master'],
|
38 |
+
'conf_thresh': conf_thresh,
|
39 |
+
'iou_thresh': iou_thresh,
|
40 |
+
'min_hits': min_hits,
|
41 |
+
'max_age': max_age,
|
42 |
+
}
|
43 |
+
|
44 |
+
print(" ")
|
45 |
+
print("Running with:")
|
46 |
+
print(state['hyperparams'])
|
47 |
+
print(" ")
|
48 |
|
49 |
# Update loading_space to start inference on first file
|
50 |
return {
|
|
|
138 |
file_name = file_info[0].split("/")[-1]
|
139 |
bytes = file_info[1]
|
140 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
141 |
+
|
142 |
+
print("Directory: ", dir_name)
|
143 |
+
print("Aris input: ", file_path)
|
144 |
+
print(" ")
|
145 |
|
146 |
# Check that the file was valid
|
147 |
if not valid:
|
|
|
153 |
# Send uploaded file to AWS
|
154 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
155 |
|
156 |
+
hyperparams = state['hyperparams']
|
157 |
+
|
158 |
# Do inference
|
159 |
+
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
160 |
+
file_path,
|
161 |
+
weights = hyperparams['model'],
|
162 |
+
conf_thresh = hyperparams['conf_thresh'],
|
163 |
+
iou_thresh = hyperparams['iou_thresh'],
|
164 |
+
min_hits = hyperparams['min_hits'],
|
165 |
+
max_age = hyperparams['max_age'],
|
166 |
+
gradio_progress=set_progress
|
167 |
+
)
|
168 |
|
169 |
# Store result for that file
|
170 |
result['json_result'].append(json_result)
|
|
|
390 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
391 |
|
392 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
393 |
+
input.upload(on_aris_input, [input] + components['hyperparams'], inference_comps)
|
394 |
|
395 |
# When inference handler updates, tell result_handler to show the new result
|
396 |
# Also, add inference_handler as the output in order to have it display the progress
|
gradio_scripts/upload_ui.py
CHANGED
@@ -15,7 +15,21 @@ def Upload_Gradio(gradio_components):
|
|
15 |
|
16 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
#Input field for aris submission
|
21 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
|
|
15 |
|
16 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
17 |
|
18 |
+
with gr.Accordion("Advanced Settings", open=False):
|
19 |
+
settings = []
|
20 |
+
settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
|
21 |
+
|
22 |
+
gr.Markdown("Detection Parameters")
|
23 |
+
with gr.Row():
|
24 |
+
settings.append(gr.Slider(0, 1, value=0.05, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
|
25 |
+
settings.append(gr.Slider(0, 1, value=0.2, label="NMS IoU", info="IoU threshold for non-max suppression"))
|
26 |
+
|
27 |
+
gr.Markdown("Tracking Parameters")
|
28 |
+
with gr.Row():
|
29 |
+
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
30 |
+
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
31 |
+
|
32 |
+
gradio_components['hyperparams'] = settings
|
33 |
|
34 |
#Input field for aris submission
|
35 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
inference.py
CHANGED
@@ -48,7 +48,7 @@ def norm(bbox, w, h):
|
|
48 |
bb[3] /= h
|
49 |
return bb
|
50 |
|
51 |
-
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS):
|
52 |
|
53 |
model, device = setup_model(weights)
|
54 |
|
@@ -78,15 +78,15 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
78 |
return
|
79 |
|
80 |
|
81 |
-
outputs = do_suppression(inference, gp=gp)
|
82 |
|
83 |
#do_confidence_boost(inference, outputs, gp=gp)
|
84 |
|
85 |
-
#new_outputs = do_suppression(inference, gp=gp)
|
86 |
|
87 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
88 |
|
89 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
|
90 |
|
91 |
return results
|
92 |
|
|
|
48 |
bb[3] /= h
|
49 |
return bb
|
50 |
|
51 |
+
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, min_hits=MIN_HITS, max_age=MAX_AGE):
|
52 |
|
53 |
model, device = setup_model(weights)
|
54 |
|
|
|
78 |
return
|
79 |
|
80 |
|
81 |
+
outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
|
82 |
|
83 |
#do_confidence_boost(inference, outputs, gp=gp)
|
84 |
|
85 |
+
#new_outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
|
86 |
|
87 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
88 |
|
89 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=min_hits, max_age=max_age, gp=gp)
|
90 |
|
91 |
return results
|
92 |
|
main.py
CHANGED
@@ -7,9 +7,7 @@ from dataloader import create_dataloader_aris
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
13 |
"""
|
14 |
Main processing task to be run in gradio
|
15 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
@@ -25,7 +23,6 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
|
25 |
if (gradio_progress): gradio_progress(0, "In task...")
|
26 |
print("Cuda available in task?", torch.cuda.is_available())
|
27 |
|
28 |
-
print(filepath)
|
29 |
dirname = os.path.dirname(filepath)
|
30 |
filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
|
31 |
results_filepath = os.path.join(dirname, f"{filename}_results.json")
|
@@ -48,7 +45,7 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
|
|
48 |
frame_rate = dataset.didson.info['framerate']
|
49 |
|
50 |
# run detection + tracking
|
51 |
-
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights)
|
52 |
|
53 |
# re-index results if desired - this should be done before writing the file
|
54 |
results = prep_for_mm(results)
|
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
+
def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age, gradio_progress=None):
|
|
|
|
|
11 |
"""
|
12 |
Main processing task to be run in gradio
|
13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
|
23 |
if (gradio_progress): gradio_progress(0, "In task...")
|
24 |
print("Cuda available in task?", torch.cuda.is_available())
|
25 |
|
|
|
26 |
dirname = os.path.dirname(filepath)
|
27 |
filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
|
28 |
results_filepath = os.path.join(dirname, f"{filename}_results.json")
|
|
|
45 |
frame_rate = dataset.didson.info['framerate']
|
46 |
|
47 |
# run detection + tracking
|
48 |
+
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights, conf_thresh=conf_thresh, nms_iou=iou_thresh, min_hits=min_hits, max_age=max_age)
|
49 |
|
50 |
# re-index results if desired - this should be done before writing the file
|
51 |
results = prep_for_mm(results)
|
visualizer.py
CHANGED
@@ -110,7 +110,7 @@ def get_video_frames(frames, preds, frame_rate, image_meter_width=None, image_me
|
|
110 |
cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
111 |
cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
112 |
# cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
113 |
-
cv2.putText(frame, f'Frame: {i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
114 |
|
115 |
vid_frames.append(frame)
|
116 |
|
|
|
110 |
cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
111 |
cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
112 |
# cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
113 |
+
cv2.putText(frame, f'Frame: {start_frame+i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
|
114 |
|
115 |
vid_frames.append(frame)
|
116 |
|