oskarastrom commited on
Commit
1a4c886
1 Parent(s): c756bdd

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +29 -21
inference.py CHANGED
@@ -59,14 +59,16 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
59
  # Detect boxes in frames
60
  inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
61
 
 
 
62
  if config.associative_tracker == TrackerType.BYTETRACK:
63
 
64
  # Find low confidence detections
65
- low_outputs = do_suppression(inference, image_meter_width/width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
66
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
67
 
68
  # Find high confidence detections
69
- high_outputs = do_suppression(inference, image_meter_width/width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
70
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
71
 
72
  # Perform associative tracking (ByteTrack)
@@ -167,7 +169,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
167
 
168
  return inference, image_shapes, width, height
169
 
170
- def do_suppression(inference, pix2w, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
171
  """
172
  Args:
173
  frames_dir: a directory containing frames to be evaluated
@@ -186,7 +188,7 @@ def do_suppression(inference, pix2w, gp=None, batch_size=BATCH_SIZE, conf_thres=
186
  if gp: gp(batch_i / len(inference), pbar.__str__())
187
 
188
  with torch.no_grad():
189
- output = non_max_suppression(inf_out, pix2w, conf_thres=conf_thres, iou_thres=iou_thres)
190
 
191
 
192
  outputs.append(output)
@@ -392,13 +394,32 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
392
  with patch('json.encoder._make_iterencode', wraps=inner):
393
  return json.dump(some_object, open(out_path, 'w'), indent=2)
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  def non_max_suppression(
396
  prediction,
397
- pix2w,
398
  conf_thres=0.25,
399
  iou_thres=0.45,
400
- max_det=300,
401
- max_length=0.8
402
  ):
403
  """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
404
 
@@ -420,19 +441,6 @@ def non_max_suppression(
420
  prediction = prediction.cpu()
421
  bs = prediction.shape[0] # batch size
422
  xc = prediction[..., 4] > conf_thres # candidates
423
- print("pix2w", pix2w)
424
- width = prediction[..., 2]*pix2w
425
- wc = width < max_length
426
-
427
- print(0.1, sum(sum(width < 0.1)))
428
- print(0.8, sum(sum(width < 0.8)))
429
- print(1.5, sum(sum(width < 1.5)))
430
- print(10, sum(sum(width < 10 )))
431
- print(10, sum(sum(width < 10 )))
432
- print(prediction[1,512,:])
433
- print("mean", torch.mean(width))
434
-
435
-
436
 
437
 
438
  # Settings
@@ -446,7 +454,7 @@ def non_max_suppression(
446
 
447
 
448
  # Keep boxes that pass confidence threshold
449
- x = x[xc[xi] * wc[xi]] # confidence
450
 
451
  # If none remain process next image
452
  if not x.shape[0]:
 
59
  # Detect boxes in frames
60
  inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
61
 
62
+ inference = filter_detection_size(inference, image_meter_width, width, 1.5)
63
+
64
  if config.associative_tracker == TrackerType.BYTETRACK:
65
 
66
  # Find low confidence detections
67
+ low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
68
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
69
 
70
  # Find high confidence detections
71
+ high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
72
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
73
 
74
  # Perform associative tracking (ByteTrack)
 
169
 
170
  return inference, image_shapes, width, height
171
 
172
+ def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
173
  """
174
  Args:
175
  frames_dir: a directory containing frames to be evaluated
 
188
  if gp: gp(batch_i / len(inference), pbar.__str__())
189
 
190
  with torch.no_grad():
191
+ output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
192
 
193
 
194
  outputs.append(output)
 
394
  with patch('json.encoder._make_iterencode', wraps=inner):
395
  return json.dump(some_object, open(out_path, 'w'), indent=2)
396
 
397
+
398
+ def filter_detection_size(inference, image_meter_width, width, max_length):
399
+
400
+ outputs = []
401
+ for prediction in inference:
402
+
403
+ pix2width = image_meter_width/width
404
+ print("pix2w", pix2width)
405
+ width = prediction[..., 2]*pix2width
406
+ wc = width < max_length
407
+ bs = prediction.shape[0] # batches
408
+
409
+ output = [torch.zeros((0, 6), device=prediction.device)] * bs
410
+ for xi, x in enumerate(prediction):
411
+ x = x[wc[xi]] # confidence
412
+ output[xi] = x
413
+
414
+ outputs.append(output)
415
+
416
+ return outputs
417
+
418
  def non_max_suppression(
419
  prediction,
 
420
  conf_thres=0.25,
421
  iou_thres=0.45,
422
+ max_det=300
 
423
  ):
424
  """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
425
 
 
441
  prediction = prediction.cpu()
442
  bs = prediction.shape[0] # batch size
443
  xc = prediction[..., 4] > conf_thres # candidates
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
 
446
  # Settings
 
454
 
455
 
456
  # Keep boxes that pass confidence threshold
457
+ x = x[xc[xi]] # confidence
458
 
459
  # If none remain process next image
460
  if not x.shape[0]: