oskarastrom commited on
Commit
2a572c2
1 Parent(s): 99331d9

Config file

Browse files
Files changed (7) hide show
  1. InferenceConfig.py +73 -0
  2. app.py +17 -17
  3. aris.py +1 -0
  4. inference.py +15 -28
  5. lib/fish_eye/tracker.py +1 -0
  6. main.py +3 -3
  7. multipage_pdf.pdf +0 -0
InferenceConfig.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ ### Configuration options
4
+ WEIGHTS = 'models/v5m_896_300best.pt'
5
+ # will need to configure these based on GPU hardware
6
+ BATCH_SIZE = 32
7
+
8
+ CONF_THRES = 0.05 # detection
9
+ NMS_IOU = 0.2 # NMS IOU
10
+ MAX_AGE = 14 # time until missing fish get's new id
11
+ MIN_HITS = 16 # minimum number of frames with a specific fish for it to count
12
+ MIN_LENGTH = 0.3 # minimum fish length, in meters
13
+ IOU_THRES = 0.01 # IOU threshold for tracking
14
+ MIN_TRAVEL = -1 # Minimum distance a track has to travel
15
+
16
+ class TrackerType(Enum):
17
+ NONE = 0
18
+ CONF_BOOST = 1
19
+ BYTETRACK = 2
20
+
21
+ class InferenceConfig:
22
+ def __init__(self,
23
+ weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
24
+ min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
25
+ self.weights = weights
26
+ self.conf_thresh = conf_thresh
27
+ self.nms_iou = nms_iou
28
+ self.min_hits = min_hits
29
+ self.max_age = max_age
30
+ self.min_length = min_length
31
+ self.min_travel = min_travel
32
+
33
+ self.associative_tracker = TrackerType.NONE
34
+ self.boost_power = 1
35
+ self.boost_decay = 1
36
+ self.byte_low_conf = 1
37
+ self.byte_high_conf = 1
38
+
39
+ def enable_conf_boost(self, power, decay):
40
+ self.associative_tracker = TrackerType.CONF_BOOST
41
+ self.boost_power = power
42
+ self.boost_decay = decay
43
+
44
+ def enable_byte_track(self, low, high):
45
+ self.associative_tracker = TrackerType.BYTETRACK
46
+ self.byte_low_conf = low
47
+ self.byte_high_conf = high
48
+
49
+ def to_dict(self):
50
+ dict = {
51
+ 'weights': self.weights,
52
+ 'nms_iou': self.nms_iou,
53
+ 'min_hits': self.min_hits,
54
+ 'max_age': self.max_age,
55
+ 'min_length': self.min_length,
56
+ 'min_travel': self.min_travel,
57
+ }
58
+
59
+ # Add tracker specific parameters
60
+ if (self.associative_tracker == TrackerType.BYTETRACK):
61
+ dict['tracker'] = "ByteTrack"
62
+ dict['byte_low_conf'] = self.byte_low_conf
63
+ dict['byte_high_conf'] = self.byte_high_conf
64
+ elif (self.associative_tracker == TrackerType.CONF_BOOST):
65
+ dict['tracker'] = "Confidence Boost"
66
+ dict['conf_thresh'] = self.conf_thresh
67
+ dict['boost_power'] = self.boost_power
68
+ dict['boost_decay'] = self.boost_decay
69
+ elif (self.associative_tracker == TrackerType.NONE):
70
+ dict['tracker'] = "None"
71
+ dict['conf_thresh'] = self.conf_thresh
72
+
73
+ return dict
app.py CHANGED
@@ -13,6 +13,7 @@ from gradio_scripts.upload_ui import Upload_Gradio, models
13
  from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
14
  from dataloader import create_dataloader_aris
15
  from aris import BEAM_WIDTH_DIR
 
16
 
17
  WEBAPP_VERSION = "1.0"
18
 
@@ -23,7 +24,7 @@ state = {
23
  'total': 1,
24
  'annotation_index': -1,
25
  'frame_index': 0,
26
- 'hyperparams': {}
27
  }
28
  result = {}
29
 
@@ -36,26 +37,25 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
36
  state['files'] = file_list
37
  state['total'] = len(file_list)
38
  state['version'] = WEBAPP_VERSION
39
- state['hyperparams'] = {
40
- 'model': models[model_id] if model_id in models else models['master'],
41
- 'conf_thresh': conf_thresh,
42
- 'iou_thresh': iou_thresh,
43
- 'min_hits': min_hits,
44
- 'max_age': max_age,
45
- 'min_length': min_length,
46
- 'min_travel': min_travel,
47
- 'associative_tracker': associative_tracker,
48
- }
 
49
  if (associative_tracker == "Confidence Boost"):
50
- state['hyperparams']['boost_power'] = boost_power
51
- state['hyperparams']['boost_decay'] = boost_decay
52
  elif (associative_tracker == "ByteTrack"):
53
- state['hyperparams']['byte_low_conf'] = byte_low_conf
54
- state['hyperparams']['byte_high_conf'] = byte_high_conf
55
 
56
  print(" ")
57
  print("Running with:")
58
- print(state['hyperparams'])
59
  print(" ")
60
 
61
  # Update loading_space to start inference on first file
@@ -169,7 +169,7 @@ def infer_next(_, progress=gr.Progress()):
169
  # Do inference
170
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
171
  file_path,
172
- hyperparams = state['hyperparams'],
173
  gradio_progress = set_progress
174
  )
175
 
 
13
  from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
14
  from dataloader import create_dataloader_aris
15
  from aris import BEAM_WIDTH_DIR
16
+ from InferenceConfig import InferenceConfig
17
 
18
  WEBAPP_VERSION = "1.0"
19
 
 
24
  'total': 1,
25
  'annotation_index': -1,
26
  'frame_index': 0,
27
+ 'config': None
28
  }
29
  result = {}
30
 
 
37
  state['files'] = file_list
38
  state['total'] = len(file_list)
39
  state['version'] = WEBAPP_VERSION
40
+ state['config'] = InferenceConfig(
41
+ weights = models[model_id] if model_id in models else models['master'],
42
+ conf_thresh = conf_thresh,
43
+ nms_iou = iou_thresh,
44
+ min_hits = min_hits,
45
+ max_age = max_age,
46
+ min_length = min_length,
47
+ min_travel = min_travel,
48
+ )
49
+
50
+ # Enable tracker if specified
51
  if (associative_tracker == "Confidence Boost"):
52
+ state['config'].enable_conf_boost(boost_power, boost_decay)
 
53
  elif (associative_tracker == "ByteTrack"):
54
+ state['config'].enable_byte_track(byte_low_conf, byte_high_conf)
 
55
 
56
  print(" ")
57
  print("Running with:")
58
+ print(state['config'].to_dict())
59
  print(" ")
60
 
61
  # Update loading_space to start inference on first file
 
169
  # Do inference
170
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
171
  file_path,
172
+ config = state['config'],
173
  gradio_progress = set_progress
174
  )
175
 
aris.py CHANGED
@@ -395,6 +395,7 @@ def add_metadata_to_result(aris_fp, json_data, beam_width_dir=BEAM_WIDTH_DIR):
395
  fish_entry['FRAME_NUM'] = entry['frame_num']
396
  fish_entry['START_FRAME'] = entry['start_frame_index']
397
  fish_entry['END_FRAME'] = entry['end_frame_index']
 
398
  fish_entry['TRAVEL'] = entry['travel_dist']
399
  fish_entry['DIR'] = upstream_motion_map[entry['direction']]
400
  fish_entry['R'] = bin_num * pixel_meter_size + frame.windowstart
 
395
  fish_entry['FRAME_NUM'] = entry['frame_num']
396
  fish_entry['START_FRAME'] = entry['start_frame_index']
397
  fish_entry['END_FRAME'] = entry['end_frame_index']
398
+ fish_entry['NBR_FRAMES'] = entry['end_frame_index'] + 1 - entry['start_frame_index']
399
  fish_entry['TRAVEL'] = entry['travel_dist']
400
  fish_entry['DIR'] = upstream_motion_map[entry['direction']]
401
  fish_entry['R'] = bin_num * pixel_meter_size + frame.windowstart
inference.py CHANGED
@@ -17,6 +17,7 @@ from lib.yolov5.utils.metrics import box_iou
17
  import torch
18
  import torchvision
19
 
 
20
  from lib.fish_eye.tracker import Tracker
21
  from lib.fish_eye.associative import Associate
22
 
@@ -50,18 +51,9 @@ def norm(bbox, w, h):
50
  bb[3] /= h
51
  return bb
52
 
53
- def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, hyperparams={}):
54
-
55
- # Load hyperparameters
56
- if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
57
- if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
58
- if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
59
- if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
60
- if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
61
- if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
62
- if 'associative_tracker' not in hyperparams: hyperparams['associative_tracker'] = "None"
63
 
64
- model, device = setup_model(hyperparams['model'])
65
 
66
  load = False
67
  save = False
@@ -89,43 +81,38 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
89
  return
90
 
91
 
92
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
93
 
94
- if hyperparams['associative_tracker'] == "ByteTrack":
95
- if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
96
- if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
97
 
98
- low_outputs = do_suppression(inference, conf_thres=hyperparams['byte_low_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
99
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
100
 
101
- high_outputs = do_suppression(inference, conf_thres=hyperparams['byte_high_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
102
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
103
 
104
  results = do_associative_tracking(
105
  low_preds, high_preds, image_meter_width, image_meter_height,
106
- reverse=False, min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
107
- max_age=hyperparams['max_age'], min_hits=hyperparams['min_hits'],
108
  gp=gp)
109
  else:
110
 
111
- if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
112
 
113
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
114
 
115
- if hyperparams['associative_tracker'] == "Confidence Boost":
116
- if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
117
- if 'boost_decay' not in hyperparams: hyperparams['boost_decay'] = 1
118
 
119
- do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
120
 
121
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
122
 
123
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
124
 
125
  results = do_tracking(
126
  all_preds, image_meter_width, image_meter_height,
127
- min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
128
- max_age=hyperparams['max_age'], iou_thres=hyperparams['iou_threshold'], min_hits=hyperparams['min_hits'],
129
  gp=gp)
130
 
131
  return results
 
17
  import torch
18
  import torchvision
19
 
20
+ from InferenceConfig import InferenceConfig, TrackerType
21
  from lib.fish_eye.tracker import Tracker
22
  from lib.fish_eye.associative import Associate
23
 
 
51
  bb[3] /= h
52
  return bb
53
 
54
+ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig()):
 
 
 
 
 
 
 
 
 
55
 
56
+ model, device = setup_model(config.weights)
57
 
58
  load = False
59
  save = False
 
81
  return
82
 
83
 
84
+ outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
85
 
86
+ if config.associative_tracker == TrackerType.BYTETRACK:
 
 
87
 
88
+ low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
89
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
90
 
91
+ high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
92
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
93
 
94
  results = do_associative_tracking(
95
  low_preds, high_preds, image_meter_width, image_meter_height,
96
+ reverse=False, min_length=config.min_length, min_travel=config.min_travel,
97
+ max_age=config.max_age, min_hits=config.min_hits,
98
  gp=gp)
99
  else:
100
 
 
101
 
102
+ outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
103
 
104
+ if config.associative_tracker == TrackerType.CONF_BOOST:
 
 
105
 
106
+ do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
107
 
108
+ outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
109
 
110
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
111
 
112
  results = do_tracking(
113
  all_preds, image_meter_width, image_meter_height,
114
+ min_length=config.min_length, min_travel=config.min_travel,
115
+ max_age=config.max_age, min_hits=config.min_hits,
116
  gp=gp)
117
 
118
  return results
lib/fish_eye/tracker.py CHANGED
@@ -123,6 +123,7 @@ class Tracker:
123
 
124
  fish_entry['travel_dist'] = Tracker.get_travel_distance(start_bbox, end_bbox, json_data['image_meter_width'], json_data['image_meter_height'])
125
 
 
126
  fish_entry['start_frame_index'] = boxes[0][1]
127
  fish_entry['end_frame_index'] = boxes[-1][1]
128
  fish_entry['color'] = Tracker.selectColor(track_id)
 
123
 
124
  fish_entry['travel_dist'] = Tracker.get_travel_distance(start_bbox, end_bbox, json_data['image_meter_width'], json_data['image_meter_height'])
125
 
126
+ print(boxes[0])
127
  fish_entry['start_frame_index'] = boxes[0][1]
128
  fish_entry['end_frame_index'] = boxes[-1][1]
129
  fish_entry['color'] = Tracker.selectColor(track_id)
main.py CHANGED
@@ -7,7 +7,7 @@ from dataloader import create_dataloader_aris
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
- def predict_task(filepath, hyperparams, gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
@@ -45,12 +45,12 @@ def predict_task(filepath, hyperparams, gradio_progress=None):
45
  frame_rate = dataset.didson.info['framerate']
46
 
47
  # run detection + tracking
48
- results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, hyperparams=hyperparams)
49
 
50
  # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
52
  results = add_metadata_to_result(filepath, results)
53
- results['metadata']['hyperparameters'] = hyperparams
54
 
55
  # write output to disk
56
  json_dump_round_float(results, results_filepath)
 
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
+ def predict_task(filepath, config, gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
 
45
  frame_rate = dataset.didson.info['framerate']
46
 
47
  # run detection + tracking
48
+ results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, config=config)
49
 
50
  # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
52
  results = add_metadata_to_result(filepath, results)
53
+ results['metadata']['hyperparameters'] = config.to_dict()
54
 
55
  # write output to disk
56
  json_dump_round_float(results, results_filepath)
multipage_pdf.pdf CHANGED
Binary files a/multipage_pdf.pdf and b/multipage_pdf.pdf differ