oskarastrom commited on
Commit
93b9874
1 Parent(s): 6bfcb22

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -9
inference.py CHANGED
@@ -64,11 +64,11 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
64
  if config.associative_tracker == TrackerType.BYTETRACK:
65
 
66
  # Find low confidence detections
67
- low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
68
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
69
 
70
  # Find high confidence detections
71
- high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
72
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
73
 
74
  # Perform associative tracking (ByteTrack)
@@ -80,7 +80,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
80
  else:
81
 
82
  # Find confident detections
83
- outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
84
 
85
  if config.associative_tracker == TrackerType.CONF_BOOST:
86
 
@@ -88,7 +88,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
88
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
89
 
90
  # Find confident detections from boosted list
91
- outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
92
 
93
  # Format confident detections
94
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
@@ -169,7 +169,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
169
 
170
  return inference, image_shapes, width, height
171
 
172
- def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
173
  """
174
  Args:
175
  frames_dir: a directory containing frames to be evaluated
@@ -188,7 +188,7 @@ def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_TH
188
  if gp: gp(batch_i / len(inference), pbar.__str__())
189
 
190
  with torch.no_grad():
191
- output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
192
 
193
 
194
  outputs.append(output)
@@ -422,13 +422,13 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
422
  print(wc.shape)
423
  bs = batch.shape[0] # batches
424
 
425
- output = [torch.zeros((0, 6), device=batch.device)] * bs
426
  print("wc")
427
  print(batch.shape)
428
  for xi, x in enumerate(batch):
429
  x = x[wc[xi]] # confidence
430
  print(x.shape)
431
- output[xi] = x
432
 
433
  output = torch.tensor(output)
434
  print("output len", output.shape)
@@ -439,6 +439,9 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
439
 
440
  def non_max_suppression(
441
  prediction,
 
 
 
442
  conf_thres=0.25,
443
  iou_thres=0.45,
444
  max_det=300
@@ -463,6 +466,9 @@ def non_max_suppression(
463
  prediction = prediction.cpu()
464
  bs = prediction.shape[0] # batch size
465
  xc = prediction[..., 4] > conf_thres # candidates
 
 
 
466
 
467
 
468
  # Settings
@@ -476,7 +482,7 @@ def non_max_suppression(
476
 
477
 
478
  # Keep boxes that pass confidence threshold
479
- x = x[xc[xi]] # confidence
480
 
481
  # If none remain process next image
482
  if not x.shape[0]:
 
64
  if config.associative_tracker == TrackerType.BYTETRACK:
65
 
66
  # Find low confidence detections
67
+ low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
68
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
69
 
70
  # Find high confidence detections
71
+ high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
72
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
73
 
74
  # Perform associative tracking (ByteTrack)
 
80
  else:
81
 
82
  # Find confident detections
83
+ outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
84
 
85
  if config.associative_tracker == TrackerType.CONF_BOOST:
86
 
 
88
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
89
 
90
  # Find confident detections from boosted list
91
+ outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
92
 
93
  # Format confident detections
94
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
 
169
 
170
  return inference, image_shapes, width, height
171
 
172
+ def do_suppression(inference, image_meter_width, image_pixel_width, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, max_length=1.5, verbose=True):
173
  """
174
  Args:
175
  frames_dir: a directory containing frames to be evaluated
 
188
  if gp: gp(batch_i / len(inference), pbar.__str__())
189
 
190
  with torch.no_grad():
191
+ output = non_max_suppression(inf_out, image_meter_width, image_pixel_width, conf_thres=conf_thres, iou_thres=iou_thres, max_length=max_length)
192
 
193
 
194
  outputs.append(output)
 
422
  print(wc.shape)
423
  bs = batch.shape[0] # batches
424
 
425
+ output = torch.zeros((bs, 0, 6), device=batch.device)
426
  print("wc")
427
  print(batch.shape)
428
  for xi, x in enumerate(batch):
429
  x = x[wc[xi]] # confidence
430
  print(x.shape)
431
+ output[xi, :, :] = x
432
 
433
  output = torch.tensor(output)
434
  print("output len", output.shape)
 
439
 
440
  def non_max_suppression(
441
  prediction,
442
+ image_meter_width,
443
+ image_pixel_width,
444
+ max_length=1.5,
445
  conf_thres=0.25,
446
  iou_thres=0.45,
447
  max_det=300
 
466
  prediction = prediction.cpu()
467
  bs = prediction.shape[0] # batch size
468
  xc = prediction[..., 4] > conf_thres # candidates
469
+ pix2width = image_meter_width/width
470
+ width = prediction[..., 2]*pix2width
471
+ wc = width < max_length
472
 
473
 
474
  # Settings
 
482
 
483
 
484
  # Keep boxes that pass confidence threshold
485
+ x = x[xc[xi] * wc[xi]] # confidence
486
 
487
  # If none remain process next image
488
  if not x.shape[0]: