oskarastrom commited on
Commit
cef04ce
1 Parent(s): a77f5fd

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -6
inference.py CHANGED
@@ -52,7 +52,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
52
 
53
  # Load hyperparameters
54
  if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
55
- if 'conf_thresh' not in hyperparams: hyperparams['conf_tresh'] = CONF_THRES
56
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
57
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
58
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
@@ -87,13 +87,13 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
87
  return
88
 
89
 
90
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
91
 
92
  if hyperparams['use_associative_tracking']:
93
-
94
  do_confidence_boost(inference, outputs, gp=gp)
95
 
96
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
97
 
98
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
99
 
@@ -288,13 +288,13 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, v
288
  pbar.update(1*batch_size)
289
 
290
 
291
- def boost_frame(safe_frame, base_frame, dt):
292
  safe_boxes = safe_frame[:, :4]
293
  boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
294
  ious = box_iou(boxes, safe_boxes)
295
  score = torch.matmul(ious, safe_frame[:, 4])
296
  # score = iou(safe_box, base_box) * confidence(safe_box)
297
- base_frame[:, 4] *= 1 + (score)*math.exp(-dt*dt)
298
  return base_frame
299
 
300
  def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
 
52
 
53
  # Load hyperparameters
54
  if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
55
+ if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
56
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
57
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
58
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
 
87
  return
88
 
89
 
90
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
91
 
92
  if hyperparams['use_associative_tracking']:
93
+
94
  do_confidence_boost(inference, outputs, gp=gp)
95
 
96
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
97
 
98
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
99
 
 
288
  pbar.update(1*batch_size)
289
 
290
 
291
+ def boost_frame(safe_frame, base_frame, dt, decay=1):
292
  safe_boxes = safe_frame[:, :4]
293
  boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
294
  ious = box_iou(boxes, safe_boxes)
295
  score = torch.matmul(ious, safe_frame[:, 4])
296
  # score = iou(safe_box, base_box) * confidence(safe_box)
297
+ base_frame[:, 4] *= 1 + (score)*math.exp(-decay*dt*dt)
298
  return base_frame
299
 
300
  def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):