oskarastrom commited on
Commit
17fa97d
1 Parent(s): b02b8a0

Hyperparameter settings

Browse files
Files changed (5) hide show
  1. app.py +29 -9
  2. gradio_scripts/upload_ui.py +15 -1
  3. inference.py +4 -4
  4. main.py +2 -5
  5. visualizer.py +1 -1
app.py CHANGED
@@ -27,16 +27,24 @@ result = {}
27
 
28
 
29
  # Called when an Aris file is uploaded for inference
30
- def on_aris_input(file_list, model_id):
31
-
32
- print(model_id)
33
- print(models[model_id] if model_id in models else models['master'])
34
 
35
  # Reset Result
36
  reset_state(result, state)
37
  state['files'] = file_list
38
  state['total'] = len(file_list)
39
- state['model'] = models[model_id] if model_id in models else models['master']
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Update loading_space to start inference on first file
42
  return {
@@ -130,8 +138,10 @@ def infer_next(_, progress=gr.Progress()):
130
  file_name = file_info[0].split("/")[-1]
131
  bytes = file_info[1]
132
  valid, file_path, dir_name = save_data(bytes, file_name)
133
- print(dir_name)
134
- print(file_path)
 
 
135
 
136
  # Check that the file was valid
137
  if not valid:
@@ -143,8 +153,18 @@ def infer_next(_, progress=gr.Progress()):
143
  # Send uploaded file to AWS
144
  upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
145
 
 
 
146
  # Do inference
147
- json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(file_path, weights=state['model'], gradio_progress=set_progress)
 
 
 
 
 
 
 
 
148
 
149
  # Store result for that file
150
  result['json_result'].append(json_result)
@@ -370,7 +390,7 @@ with demo:
370
  inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
371
 
372
  # When a file is uploaded to the input, tell the inference_handler to start inference
373
- input.upload(on_aris_input, [input, components['model_select']], inference_comps)
374
 
375
  # When inference handler updates, tell result_handler to show the new result
376
  # Also, add inference_handler as the output in order to have it display the progress
 
27
 
28
 
29
  # Called when an Aris file is uploaded for inference
30
+ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age):
 
 
 
31
 
32
  # Reset Result
33
  reset_state(result, state)
34
  state['files'] = file_list
35
  state['total'] = len(file_list)
36
+ state['hyperparams'] = {
37
+ 'model': models[model_id] if model_id in models else models['master'],
38
+ 'conf_thresh': conf_thresh,
39
+ 'iou_thresh': iou_thresh,
40
+ 'min_hits': min_hits,
41
+ 'max_age': max_age,
42
+ }
43
+
44
+ print(" ")
45
+ print("Running with:")
46
+ print(state['hyperparams'])
47
+ print(" ")
48
 
49
  # Update loading_space to start inference on first file
50
  return {
 
138
  file_name = file_info[0].split("/")[-1]
139
  bytes = file_info[1]
140
  valid, file_path, dir_name = save_data(bytes, file_name)
141
+
142
+ print("Directory: ", dir_name)
143
+ print("Aris input: ", file_path)
144
+ print(" ")
145
 
146
  # Check that the file was valid
147
  if not valid:
 
153
  # Send uploaded file to AWS
154
  upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
155
 
156
+ hyperparams = state['hyperparams']
157
+
158
  # Do inference
159
+ json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
160
+ file_path,
161
+ weights = hyperparams['model'],
162
+ conf_thresh = hyperparams['conf_thresh'],
163
+ iou_thresh = hyperparams['iou_thresh'],
164
+ min_hits = hyperparams['min_hits'],
165
+ max_age = hyperparams['max_age'],
166
+ gradio_progress=set_progress
167
+ )
168
 
169
  # Store result for that file
170
  result['json_result'].append(json_result)
 
390
  inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
391
 
392
  # When a file is uploaded to the input, tell the inference_handler to start inference
393
+ input.upload(on_aris_input, [input] + components['hyperparams'], inference_comps)
394
 
395
  # When inference handler updates, tell result_handler to show the new result
396
  # Also, add inference_handler as the output in order to have it display the progress
gradio_scripts/upload_ui.py CHANGED
@@ -15,7 +15,21 @@ def Upload_Gradio(gradio_components):
15
 
16
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
17
 
18
- gradio_components['model_select'] = gr.Dropdown(value="master", choices=list(models.keys()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  #Input field for aris submission
21
  gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
 
15
 
16
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
17
 
18
+ with gr.Accordion("Advanced Settings", open=False):
19
+ settings = []
20
+ settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
21
+
22
+ gr.Markdown("Detection Parameters")
23
+ with gr.Row():
24
+ settings.append(gr.Slider(0, 1, value=0.05, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
25
+ settings.append(gr.Slider(0, 1, value=0.2, label="NMS IoU", info="IoU threshold for non-max suppression"))
26
+
27
+ gr.Markdown("Tracking Parameters")
28
+ with gr.Row():
29
+ settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
30
+ settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
31
+
32
+ gradio_components['hyperparams'] = settings
33
 
34
  #Input field for aris submission
35
  gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
inference.py CHANGED
@@ -48,7 +48,7 @@ def norm(bbox, w, h):
48
  bb[3] /= h
49
  return bb
50
 
51
- def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS):
52
 
53
  model, device = setup_model(weights)
54
 
@@ -78,15 +78,15 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
78
  return
79
 
80
 
81
- outputs = do_suppression(inference, gp=gp)
82
 
83
  #do_confidence_boost(inference, outputs, gp=gp)
84
 
85
- #new_outputs = do_suppression(inference, gp=gp)
86
 
87
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
88
 
89
- results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
90
 
91
  return results
92
 
 
48
  bb[3] /= h
49
  return bb
50
 
51
+ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, min_hits=MIN_HITS, max_age=MAX_AGE):
52
 
53
  model, device = setup_model(weights)
54
 
 
78
  return
79
 
80
 
81
+ outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
82
 
83
  #do_confidence_boost(inference, outputs, gp=gp)
84
 
85
+ #new_outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
86
 
87
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
88
 
89
+ results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=min_hits, max_age=max_age, gp=gp)
90
 
91
  return results
92
 
main.py CHANGED
@@ -7,9 +7,7 @@ from dataloader import create_dataloader_aris
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
- WEIGHTS = 'models/v5m_896_300best.pt'
11
-
12
- def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
13
  """
14
  Main processing task to be run in gradio
15
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
@@ -25,7 +23,6 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
25
  if (gradio_progress): gradio_progress(0, "In task...")
26
  print("Cuda available in task?", torch.cuda.is_available())
27
 
28
- print(filepath)
29
  dirname = os.path.dirname(filepath)
30
  filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
31
  results_filepath = os.path.join(dirname, f"{filename}_results.json")
@@ -48,7 +45,7 @@ def predict_task(filepath, weights=WEIGHTS, gradio_progress=None):
48
  frame_rate = dataset.didson.info['framerate']
49
 
50
  # run detection + tracking
51
- results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights)
52
 
53
  # re-index results if desired - this should be done before writing the file
54
  results = prep_for_mm(results)
 
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
+ def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age, gradio_progress=None):
 
 
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
 
23
  if (gradio_progress): gradio_progress(0, "In task...")
24
  print("Cuda available in task?", torch.cuda.is_available())
25
 
 
26
  dirname = os.path.dirname(filepath)
27
  filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
28
  results_filepath = os.path.join(dirname, f"{filename}_results.json")
 
45
  frame_rate = dataset.didson.info['framerate']
46
 
47
  # run detection + tracking
48
+ results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights, conf_thresh=conf_thresh, nms_iou=iou_thresh, min_hits=min_hits, max_age=max_age)
49
 
50
  # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
visualizer.py CHANGED
@@ -110,7 +110,7 @@ def get_video_frames(frames, preds, frame_rate, image_meter_width=None, image_me
110
  cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
111
  cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
112
  # cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
113
- cv2.putText(frame, f'Frame: {i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
114
 
115
  vid_frames.append(frame)
116
 
 
110
  cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
111
  cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
112
  # cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
113
+ cv2.putText(frame, f'Frame: {start_frame+i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
114
 
115
  vid_frames.append(frame)
116