Spaces:
Runtime error
Runtime error
oskarastrom
commited on
Commit
•
2a572c2
1
Parent(s):
99331d9
Config file
Browse files- InferenceConfig.py +73 -0
- app.py +17 -17
- aris.py +1 -0
- inference.py +15 -28
- lib/fish_eye/tracker.py +1 -0
- main.py +3 -3
- multipage_pdf.pdf +0 -0
InferenceConfig.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
### Configuration options
|
4 |
+
WEIGHTS = 'models/v5m_896_300best.pt'
|
5 |
+
# will need to configure these based on GPU hardware
|
6 |
+
BATCH_SIZE = 32
|
7 |
+
|
8 |
+
CONF_THRES = 0.05 # detection
|
9 |
+
NMS_IOU = 0.2 # NMS IOU
|
10 |
+
MAX_AGE = 14 # time until missing fish get's new id
|
11 |
+
MIN_HITS = 16 # minimum number of frames with a specific fish for it to count
|
12 |
+
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
13 |
+
IOU_THRES = 0.01 # IOU threshold for tracking
|
14 |
+
MIN_TRAVEL = -1 # Minimum distance a track has to travel
|
15 |
+
|
16 |
+
class TrackerType(Enum):
|
17 |
+
NONE = 0
|
18 |
+
CONF_BOOST = 1
|
19 |
+
BYTETRACK = 2
|
20 |
+
|
21 |
+
class InferenceConfig:
|
22 |
+
def __init__(self,
|
23 |
+
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
|
24 |
+
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
|
25 |
+
self.weights = weights
|
26 |
+
self.conf_thresh = conf_thresh
|
27 |
+
self.nms_iou = nms_iou
|
28 |
+
self.min_hits = min_hits
|
29 |
+
self.max_age = max_age
|
30 |
+
self.min_length = min_length
|
31 |
+
self.min_travel = min_travel
|
32 |
+
|
33 |
+
self.associative_tracker = TrackerType.NONE
|
34 |
+
self.boost_power = 1
|
35 |
+
self.boost_decay = 1
|
36 |
+
self.byte_low_conf = 1
|
37 |
+
self.byte_high_conf = 1
|
38 |
+
|
39 |
+
def enable_conf_boost(self, power, decay):
|
40 |
+
self.associative_tracker = TrackerType.CONF_BOOST
|
41 |
+
self.boost_power = power
|
42 |
+
self.boost_decay = decay
|
43 |
+
|
44 |
+
def enable_byte_track(self, low, high):
|
45 |
+
self.associative_tracker = TrackerType.BYTETRACK
|
46 |
+
self.byte_low_conf = low
|
47 |
+
self.byte_high_conf = high
|
48 |
+
|
49 |
+
def to_dict(self):
|
50 |
+
dict = {
|
51 |
+
'weights': self.weights,
|
52 |
+
'nms_iou': self.nms_iou,
|
53 |
+
'min_hits': self.min_hits,
|
54 |
+
'max_age': self.max_age,
|
55 |
+
'min_length': self.min_length,
|
56 |
+
'min_travel': self.min_travel,
|
57 |
+
}
|
58 |
+
|
59 |
+
# Add tracker specific parameters
|
60 |
+
if (self.associative_tracker == TrackerType.BYTETRACK):
|
61 |
+
dict['tracker'] = "ByteTrack"
|
62 |
+
dict['byte_low_conf'] = self.byte_low_conf
|
63 |
+
dict['byte_high_conf'] = self.byte_high_conf
|
64 |
+
elif (self.associative_tracker == TrackerType.CONF_BOOST):
|
65 |
+
dict['tracker'] = "Confidence Boost"
|
66 |
+
dict['conf_thresh'] = self.conf_thresh
|
67 |
+
dict['boost_power'] = self.boost_power
|
68 |
+
dict['boost_decay'] = self.boost_decay
|
69 |
+
elif (self.associative_tracker == TrackerType.NONE):
|
70 |
+
dict['tracker'] = "None"
|
71 |
+
dict['conf_thresh'] = self.conf_thresh
|
72 |
+
|
73 |
+
return dict
|
app.py
CHANGED
@@ -13,6 +13,7 @@ from gradio_scripts.upload_ui import Upload_Gradio, models
|
|
13 |
from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
|
|
16 |
|
17 |
WEBAPP_VERSION = "1.0"
|
18 |
|
@@ -23,7 +24,7 @@ state = {
|
|
23 |
'total': 1,
|
24 |
'annotation_index': -1,
|
25 |
'frame_index': 0,
|
26 |
-
'
|
27 |
}
|
28 |
result = {}
|
29 |
|
@@ -36,26 +37,25 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
|
|
36 |
state['files'] = file_list
|
37 |
state['total'] = len(file_list)
|
38 |
state['version'] = WEBAPP_VERSION
|
39 |
-
state['
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
if (associative_tracker == "Confidence Boost"):
|
50 |
-
state['
|
51 |
-
state['hyperparams']['boost_decay'] = boost_decay
|
52 |
elif (associative_tracker == "ByteTrack"):
|
53 |
-
state['
|
54 |
-
state['hyperparams']['byte_high_conf'] = byte_high_conf
|
55 |
|
56 |
print(" ")
|
57 |
print("Running with:")
|
58 |
-
print(state['
|
59 |
print(" ")
|
60 |
|
61 |
# Update loading_space to start inference on first file
|
@@ -169,7 +169,7 @@ def infer_next(_, progress=gr.Progress()):
|
|
169 |
# Do inference
|
170 |
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
171 |
file_path,
|
172 |
-
|
173 |
gradio_progress = set_progress
|
174 |
)
|
175 |
|
|
|
13 |
from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
16 |
+
from InferenceConfig import InferenceConfig
|
17 |
|
18 |
WEBAPP_VERSION = "1.0"
|
19 |
|
|
|
24 |
'total': 1,
|
25 |
'annotation_index': -1,
|
26 |
'frame_index': 0,
|
27 |
+
'config': None
|
28 |
}
|
29 |
result = {}
|
30 |
|
|
|
37 |
state['files'] = file_list
|
38 |
state['total'] = len(file_list)
|
39 |
state['version'] = WEBAPP_VERSION
|
40 |
+
state['config'] = InferenceConfig(
|
41 |
+
weights = models[model_id] if model_id in models else models['master'],
|
42 |
+
conf_thresh = conf_thresh,
|
43 |
+
nms_iou = iou_thresh,
|
44 |
+
min_hits = min_hits,
|
45 |
+
max_age = max_age,
|
46 |
+
min_length = min_length,
|
47 |
+
min_travel = min_travel,
|
48 |
+
)
|
49 |
+
|
50 |
+
# Enable tracker if specified
|
51 |
if (associative_tracker == "Confidence Boost"):
|
52 |
+
state['config'].enable_conf_boost(boost_power, boost_decay)
|
|
|
53 |
elif (associative_tracker == "ByteTrack"):
|
54 |
+
state['config'].enable_byte_track(byte_low_conf, byte_high_conf)
|
|
|
55 |
|
56 |
print(" ")
|
57 |
print("Running with:")
|
58 |
+
print(state['config'].to_dict())
|
59 |
print(" ")
|
60 |
|
61 |
# Update loading_space to start inference on first file
|
|
|
169 |
# Do inference
|
170 |
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
171 |
file_path,
|
172 |
+
config = state['config'],
|
173 |
gradio_progress = set_progress
|
174 |
)
|
175 |
|
aris.py
CHANGED
@@ -395,6 +395,7 @@ def add_metadata_to_result(aris_fp, json_data, beam_width_dir=BEAM_WIDTH_DIR):
|
|
395 |
fish_entry['FRAME_NUM'] = entry['frame_num']
|
396 |
fish_entry['START_FRAME'] = entry['start_frame_index']
|
397 |
fish_entry['END_FRAME'] = entry['end_frame_index']
|
|
|
398 |
fish_entry['TRAVEL'] = entry['travel_dist']
|
399 |
fish_entry['DIR'] = upstream_motion_map[entry['direction']]
|
400 |
fish_entry['R'] = bin_num * pixel_meter_size + frame.windowstart
|
|
|
395 |
fish_entry['FRAME_NUM'] = entry['frame_num']
|
396 |
fish_entry['START_FRAME'] = entry['start_frame_index']
|
397 |
fish_entry['END_FRAME'] = entry['end_frame_index']
|
398 |
+
fish_entry['NBR_FRAMES'] = entry['end_frame_index'] + 1 - entry['start_frame_index']
|
399 |
fish_entry['TRAVEL'] = entry['travel_dist']
|
400 |
fish_entry['DIR'] = upstream_motion_map[entry['direction']]
|
401 |
fish_entry['R'] = bin_num * pixel_meter_size + frame.windowstart
|
inference.py
CHANGED
@@ -17,6 +17,7 @@ from lib.yolov5.utils.metrics import box_iou
|
|
17 |
import torch
|
18 |
import torchvision
|
19 |
|
|
|
20 |
from lib.fish_eye.tracker import Tracker
|
21 |
from lib.fish_eye.associative import Associate
|
22 |
|
@@ -50,18 +51,9 @@ def norm(bbox, w, h):
|
|
50 |
bb[3] /= h
|
51 |
return bb
|
52 |
|
53 |
-
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None,
|
54 |
-
|
55 |
-
# Load hyperparameters
|
56 |
-
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
57 |
-
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
58 |
-
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
59 |
-
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
60 |
-
if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
|
61 |
-
if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
|
62 |
-
if 'associative_tracker' not in hyperparams: hyperparams['associative_tracker'] = "None"
|
63 |
|
64 |
-
model, device = setup_model(
|
65 |
|
66 |
load = False
|
67 |
save = False
|
@@ -89,43 +81,38 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
89 |
return
|
90 |
|
91 |
|
92 |
-
outputs = do_suppression(inference, conf_thres=
|
93 |
|
94 |
-
if
|
95 |
-
if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
|
96 |
-
if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
|
97 |
|
98 |
-
low_outputs = do_suppression(inference, conf_thres=
|
99 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
100 |
|
101 |
-
high_outputs = do_suppression(inference, conf_thres=
|
102 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
103 |
|
104 |
results = do_associative_tracking(
|
105 |
low_preds, high_preds, image_meter_width, image_meter_height,
|
106 |
-
reverse=False, min_length=
|
107 |
-
max_age=
|
108 |
gp=gp)
|
109 |
else:
|
110 |
|
111 |
-
if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
|
112 |
|
113 |
-
outputs = do_suppression(inference, conf_thres=
|
114 |
|
115 |
-
if
|
116 |
-
if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
|
117 |
-
if 'boost_decay' not in hyperparams: hyperparams['boost_decay'] = 1
|
118 |
|
119 |
-
do_confidence_boost(inference, outputs, boost_power=
|
120 |
|
121 |
-
outputs = do_suppression(inference, conf_thres=
|
122 |
|
123 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
124 |
|
125 |
results = do_tracking(
|
126 |
all_preds, image_meter_width, image_meter_height,
|
127 |
-
min_length=
|
128 |
-
max_age=
|
129 |
gp=gp)
|
130 |
|
131 |
return results
|
|
|
17 |
import torch
|
18 |
import torchvision
|
19 |
|
20 |
+
from InferenceConfig import InferenceConfig, TrackerType
|
21 |
from lib.fish_eye.tracker import Tracker
|
22 |
from lib.fish_eye.associative import Associate
|
23 |
|
|
|
51 |
bb[3] /= h
|
52 |
return bb
|
53 |
|
54 |
+
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig()):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
model, device = setup_model(config.weights)
|
57 |
|
58 |
load = False
|
59 |
save = False
|
|
|
81 |
return
|
82 |
|
83 |
|
84 |
+
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
85 |
|
86 |
+
if config.associative_tracker == TrackerType.BYTETRACK:
|
|
|
|
|
87 |
|
88 |
+
low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
89 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
90 |
|
91 |
+
high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
92 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
93 |
|
94 |
results = do_associative_tracking(
|
95 |
low_preds, high_preds, image_meter_width, image_meter_height,
|
96 |
+
reverse=False, min_length=config.min_length, min_travel=config.min_travel,
|
97 |
+
max_age=config.max_age, min_hits=config.min_hits,
|
98 |
gp=gp)
|
99 |
else:
|
100 |
|
|
|
101 |
|
102 |
+
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
103 |
|
104 |
+
if config.associative_tracker == TrackerType.CONF_BOOST:
|
|
|
|
|
105 |
|
106 |
+
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
107 |
|
108 |
+
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
109 |
|
110 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
111 |
|
112 |
results = do_tracking(
|
113 |
all_preds, image_meter_width, image_meter_height,
|
114 |
+
min_length=config.min_length, min_travel=config.min_travel,
|
115 |
+
max_age=config.max_age, min_hits=config.min_hits,
|
116 |
gp=gp)
|
117 |
|
118 |
return results
|
lib/fish_eye/tracker.py
CHANGED
@@ -123,6 +123,7 @@ class Tracker:
|
|
123 |
|
124 |
fish_entry['travel_dist'] = Tracker.get_travel_distance(start_bbox, end_bbox, json_data['image_meter_width'], json_data['image_meter_height'])
|
125 |
|
|
|
126 |
fish_entry['start_frame_index'] = boxes[0][1]
|
127 |
fish_entry['end_frame_index'] = boxes[-1][1]
|
128 |
fish_entry['color'] = Tracker.selectColor(track_id)
|
|
|
123 |
|
124 |
fish_entry['travel_dist'] = Tracker.get_travel_distance(start_bbox, end_bbox, json_data['image_meter_width'], json_data['image_meter_height'])
|
125 |
|
126 |
+
print(boxes[0])
|
127 |
fish_entry['start_frame_index'] = boxes[0][1]
|
128 |
fish_entry['end_frame_index'] = boxes[-1][1]
|
129 |
fish_entry['color'] = Tracker.selectColor(track_id)
|
main.py
CHANGED
@@ -7,7 +7,7 @@ from dataloader import create_dataloader_aris
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
-
def predict_task(filepath,
|
11 |
"""
|
12 |
Main processing task to be run in gradio
|
13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
@@ -45,12 +45,12 @@ def predict_task(filepath, hyperparams, gradio_progress=None):
|
|
45 |
frame_rate = dataset.didson.info['framerate']
|
46 |
|
47 |
# run detection + tracking
|
48 |
-
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress,
|
49 |
|
50 |
# re-index results if desired - this should be done before writing the file
|
51 |
results = prep_for_mm(results)
|
52 |
results = add_metadata_to_result(filepath, results)
|
53 |
-
results['metadata']['hyperparameters'] =
|
54 |
|
55 |
# write output to disk
|
56 |
json_dump_round_float(results, results_filepath)
|
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
+
def predict_task(filepath, config, gradio_progress=None):
|
11 |
"""
|
12 |
Main processing task to be run in gradio
|
13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
|
45 |
frame_rate = dataset.didson.info['framerate']
|
46 |
|
47 |
# run detection + tracking
|
48 |
+
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, config=config)
|
49 |
|
50 |
# re-index results if desired - this should be done before writing the file
|
51 |
results = prep_for_mm(results)
|
52 |
results = add_metadata_to_result(filepath, results)
|
53 |
+
results['metadata']['hyperparameters'] = config.to_dict()
|
54 |
|
55 |
# write output to disk
|
56 |
json_dump_round_float(results, results_filepath)
|
multipage_pdf.pdf
CHANGED
Binary files a/multipage_pdf.pdf and b/multipage_pdf.pdf differ
|
|