Spaces:
Runtime error
Runtime error
oskarastrom
commited on
Commit
•
711b619
1
Parent(s):
31f0f25
Boost parameters
Browse files- inference.py +5 -5
- scripts/track_detection.py +3 -1
- scripts/track_eval.py +4 -0
inference.py
CHANGED
@@ -237,7 +237,7 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
|
|
237 |
|
238 |
return all_preds, real_width, real_height
|
239 |
|
240 |
-
def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, verbose=True):
|
241 |
"""
|
242 |
Args:
|
243 |
frames_dir: a directory containing frames to be evaluated
|
@@ -275,7 +275,7 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, v
|
|
275 |
next_frame = inference[batch_i + 1][0]
|
276 |
|
277 |
if next_frame != None:
|
278 |
-
boost_frame(safe_frame, next_frame, 1)
|
279 |
|
280 |
prev_frame = None
|
281 |
if i-1 >= 0:
|
@@ -284,18 +284,18 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, v
|
|
284 |
prev_frame = inference[batch_i - 1][len(inference[batch_i - 1]) - 1]
|
285 |
|
286 |
if prev_frame != None:
|
287 |
-
boost_frame(safe_frame, prev_frame, -1)
|
288 |
|
289 |
pbar.update(1*batch_size)
|
290 |
|
291 |
|
292 |
-
def boost_frame(safe_frame, base_frame, dt, decay=1):
|
293 |
safe_boxes = safe_frame[:, :4]
|
294 |
boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
295 |
ious = box_iou(boxes, safe_boxes)
|
296 |
score = torch.matmul(ious, safe_frame[:, 4])
|
297 |
# score = iou(safe_box, base_box) * confidence(safe_box)
|
298 |
-
base_frame[:, 4] *= 1 + (score)*math.exp(-decay*dt*dt)
|
299 |
return base_frame
|
300 |
|
301 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
|
|
237 |
|
238 |
return all_preds, real_width, real_height
|
239 |
|
240 |
+
def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, conf_power=1, conf_decay=1, verbose=True):
|
241 |
"""
|
242 |
Args:
|
243 |
frames_dir: a directory containing frames to be evaluated
|
|
|
275 |
next_frame = inference[batch_i + 1][0]
|
276 |
|
277 |
if next_frame != None:
|
278 |
+
boost_frame(safe_frame, next_frame, 1, decay=conf_decay)
|
279 |
|
280 |
prev_frame = None
|
281 |
if i-1 >= 0:
|
|
|
284 |
prev_frame = inference[batch_i - 1][len(inference[batch_i - 1]) - 1]
|
285 |
|
286 |
if prev_frame != None:
|
287 |
+
boost_frame(safe_frame, prev_frame, -1, power=conf_power, decay=conf_decay)
|
288 |
|
289 |
pbar.update(1*batch_size)
|
290 |
|
291 |
|
292 |
+
def boost_frame(safe_frame, base_frame, dt, power=1, decay=1):
|
293 |
safe_boxes = safe_frame[:, :4]
|
294 |
boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
295 |
ious = box_iou(boxes, safe_boxes)
|
296 |
score = torch.matmul(ious, safe_frame[:, 4])
|
297 |
# score = iou(safe_box, base_box) * confidence(safe_box)
|
298 |
+
base_frame[:, 4] *= 1 + power*(score)*math.exp(-decay*dt*dt)
|
299 |
return base_frame
|
300 |
|
301 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
scripts/track_detection.py
CHANGED
@@ -35,6 +35,8 @@ def main(args, config={}, verbose=True):
|
|
35 |
if "iou_threshold" not in config: config['iou_threshold'] = 0.01
|
36 |
if "min_hits" not in config: config['min_hits'] = 11
|
37 |
if "use_associative" not in config: config['use_associative'] = False
|
|
|
|
|
38 |
|
39 |
print(config)
|
40 |
|
@@ -102,7 +104,7 @@ def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose):
|
|
102 |
|
103 |
if config['use_associative']:
|
104 |
|
105 |
-
do_confidence_boost(inference, outputs, verbose=verbose)
|
106 |
|
107 |
outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
|
108 |
|
|
|
35 |
if "iou_threshold" not in config: config['iou_threshold'] = 0.01
|
36 |
if "min_hits" not in config: config['min_hits'] = 11
|
37 |
if "use_associative" not in config: config['use_associative'] = False
|
38 |
+
if "boost_power" not in config: config['boost_power'] = 1
|
39 |
+
if "boost_decay" not in config: config['boost_decay'] = 1
|
40 |
|
41 |
print(config)
|
42 |
|
|
|
104 |
|
105 |
if config['use_associative']:
|
106 |
|
107 |
+
do_confidence_boost(inference, outputs, conf_power=config['boost_power'], conf_decay=config['boost_decay'], verbose=verbose)
|
108 |
|
109 |
outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
|
110 |
|
scripts/track_eval.py
CHANGED
@@ -25,6 +25,8 @@ def main(args):
|
|
25 |
'max_age': int(args.max_age),
|
26 |
'iou_threshold': float(args.iou_threshold),
|
27 |
'min_hits': int(args.min_hits),
|
|
|
|
|
28 |
'use_associative': args.use_associative
|
29 |
}
|
30 |
|
@@ -42,6 +44,8 @@ def argument_parser():
|
|
42 |
parser.add_argument("--max_age", default=20, help="Config object. Required.")
|
43 |
parser.add_argument("--iou_threshold", default=0.01, help="Config object. Required.")
|
44 |
parser.add_argument("--min_hits", default=11, help="Config object. Required.")
|
|
|
|
|
45 |
parser.add_argument("--use_associative", action='store_true', help="Config object. Required.")
|
46 |
return parser
|
47 |
|
|
|
25 |
'max_age': int(args.max_age),
|
26 |
'iou_threshold': float(args.iou_threshold),
|
27 |
'min_hits': int(args.min_hits),
|
28 |
+
'boost_power': int(args.boost_power),
|
29 |
+
'boost_decay': int(args.boost_decay),
|
30 |
'use_associative': args.use_associative
|
31 |
}
|
32 |
|
|
|
44 |
parser.add_argument("--max_age", default=20, help="Config object. Required.")
|
45 |
parser.add_argument("--iou_threshold", default=0.01, help="Config object. Required.")
|
46 |
parser.add_argument("--min_hits", default=11, help="Config object. Required.")
|
47 |
+
parser.add_argument("--boost_power", default=1, help="Config object. Required.")
|
48 |
+
parser.add_argument("--boost_decay", default=1, help="Config object. Required.")
|
49 |
parser.add_argument("--use_associative", action='store_true', help="Config object. Required.")
|
50 |
return parser
|
51 |
|