glenn-jocher commited on
Commit
26bfd44
1 Parent(s): 7a2a118

Adjust NMS time limit warning to batch size (#7156)

Browse files
Files changed (1) hide show
  1. utils/general.py +7 -4
utils/general.py CHANGED
@@ -709,6 +709,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
709
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
710
  """
711
 
 
712
  nc = prediction.shape[2] - 5 # number of classes
713
  xc = prediction[..., 4] > conf_thres # candidates
714
 
@@ -719,13 +720,13 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
719
  # Settings
720
  min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
721
  max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
722
- time_limit = 10.0 # seconds to quit after
723
  redundant = True # require redundant detections
724
  multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
725
  merge = False # use merge-NMS
726
 
727
- t = time.time()
728
- output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
729
  for xi, x in enumerate(prediction): # image index, image inference
730
  # Apply constraints
731
  x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
@@ -789,7 +790,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
789
 
790
  output[xi] = x[i]
791
  if (time.time() - t) > time_limit:
792
- LOGGER.warning(f'WARNING: NMS time limit {time_limit}s exceeded')
 
 
793
  break # time limit exceeded
794
 
795
  return output
 
709
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
710
  """
711
 
712
+ bs = prediction.shape[0] # batch size
713
  nc = prediction.shape[2] - 5 # number of classes
714
  xc = prediction[..., 4] > conf_thres # candidates
715
 
 
720
  # Settings
721
  min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
722
  max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
723
+ time_limit = 0.030 * bs # seconds to quit after
724
  redundant = True # require redundant detections
725
  multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
726
  merge = False # use merge-NMS
727
 
728
+ t, warn_time = time.time(), True
729
+ output = [torch.zeros((0, 6), device=prediction.device)] * bs
730
  for xi, x in enumerate(prediction): # image index, image inference
731
  # Apply constraints
732
  x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
 
790
 
791
  output[xi] = x[i]
792
  if (time.time() - t) > time_limit:
793
+ if warn_time:
794
+ LOGGER.warning(f'WARNING: NMS time limit {time_limit:3f}s exceeded')
795
+ warn_time = False
796
  break # time limit exceeded
797
 
798
  return output