glenn-jocher commited on
Commit
eb97b2e
1 Parent(s): d97d31e

NMS fast mode

Browse files
Files changed (4) hide show
  1. detect.py +1 -1
  2. test.py +2 -2
  3. train.py +11 -11
  4. utils/utils.py +12 -8
detect.py CHANGED
@@ -76,7 +76,7 @@ def detect(save_img=False):
76
 
77
  # Apply NMS
78
  pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
79
- multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms)
80
 
81
  # Apply Classifier
82
  if classify:
 
76
 
77
  # Apply NMS
78
  pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
79
+ fast=True, classes=opt.classes, agnostic=opt.agnostic_nms)
80
 
81
  # Apply Classifier
82
  if classify:
test.py CHANGED
@@ -19,7 +19,7 @@ def test(data,
19
  augment=False,
20
  model=None,
21
  dataloader=None,
22
- multi_label=True,
23
  verbose=False): # 0 fast, 1 accurate
24
  # Initialize/load model and set device
25
  if model is None:
@@ -92,7 +92,7 @@ def test(data,
92
 
93
  # Run NMS
94
  t = torch_utils.time_synchronized()
95
- output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, multi_label=multi_label)
96
  t1 += torch_utils.time_synchronized() - t
97
 
98
  # Statistics per image
 
19
  augment=False,
20
  model=None,
21
  dataloader=None,
22
+ fast=False,
23
  verbose=False): # 0 fast, 1 accurate
24
  # Initialize/load model and set device
25
  if model is None:
 
92
 
93
  # Run NMS
94
  t = torch_utils.time_synchronized()
95
+ output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast)
96
  t1 += torch_utils.time_synchronized() - t
97
 
98
  # Statistics per image
train.py CHANGED
@@ -293,13 +293,13 @@ def train(hyp):
293
  final_epoch = epoch + 1 == epochs
294
  if not opt.notest or final_epoch: # Calculate mAP
295
  results, maps, times = test.test(opt.data,
296
- batch_size=batch_size,
297
- imgsz=imgsz_test,
298
- save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
299
- model=ema.ema,
300
- single_cls=opt.single_cls,
301
- dataloader=testloader,
302
- multi_label=ni > n_burn)
303
 
304
  # Write
305
  with open(results_file, 'a') as f:
@@ -325,10 +325,10 @@ def train(hyp):
325
  if save:
326
  with open(results_file, 'r') as f: # create checkpoint
327
  ckpt = {'epoch': epoch,
328
- 'best_fitness': best_fitness,
329
- 'training_results': f.read(),
330
- 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
331
- 'optimizer': None if final_epoch else optimizer.state_dict()}
332
 
333
  # Save last, best and delete
334
  torch.save(ckpt, last)
 
293
  final_epoch = epoch + 1 == epochs
294
  if not opt.notest or final_epoch: # Calculate mAP
295
  results, maps, times = test.test(opt.data,
296
+ batch_size=batch_size,
297
+ imgsz=imgsz_test,
298
+ save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
299
+ model=ema.ema,
300
+ single_cls=opt.single_cls,
301
+ dataloader=testloader,
302
+ fast=ni > n_burn)
303
 
304
  # Write
305
  with open(results_file, 'a') as f:
 
325
  if save:
326
  with open(results_file, 'r') as f: # create checkpoint
327
  ckpt = {'epoch': epoch,
328
+ 'best_fitness': best_fitness,
329
+ 'training_results': f.read(),
330
+ 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
331
+ 'optimizer': None if final_epoch else optimizer.state_dict()}
332
 
333
  # Save last, best and delete
334
  torch.save(ckpt, last)
utils/utils.py CHANGED
@@ -19,7 +19,7 @@ import torchvision
19
  from scipy.signal import butter, filtfilt
20
  from tqdm import tqdm
21
 
22
- from . import torch_utils, google_utils # torch_utils, google_utils
23
 
24
  # Set printoptions
25
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
@@ -460,29 +460,33 @@ def build_targets(p, targets, model):
460
 
461
  return tcls, tbox, indices, anch
462
 
463
-
464
- def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
465
  """
466
  Performs Non-Maximum Suppression on inference results
467
  Returns detections with shape:
468
  nx6 (x1, y1, x2, y2, conf, cls)
469
  """
 
470
 
471
  # Settings
472
- merge = True # merge for best mAP
473
  min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
474
  max_det = 300 # maximum number of detections per image
475
  time_limit = 10.0 # seconds to quit after
476
- redundant = conf_thres == 0.001 # require redundant detections
 
 
 
 
 
 
 
477
 
478
  t = time.time()
479
- nc = prediction[0].shape[1] - 5 # number of classes
480
- multi_label &= nc > 1 # multiple labels per box
481
  output = [None] * prediction.shape[0]
482
  for xi, x in enumerate(prediction): # image index, image inference
483
  # Apply constraints
 
484
  x = x[x[:, 4] > conf_thres] # confidence
485
- # x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
486
 
487
  # If none remain process next image
488
  if not x.shape[0]:
 
19
  from scipy.signal import butter, filtfilt
20
  from tqdm import tqdm
21
 
22
+ from . import torch_utils, google_utils #  torch_utils, google_utils
23
 
24
  # Set printoptions
25
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
 
460
 
461
  return tcls, tbox, indices, anch
462
 
463
+ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
 
464
  """
465
  Performs Non-Maximum Suppression on inference results
466
  Returns detections with shape:
467
  nx6 (x1, y1, x2, y2, conf, cls)
468
  """
469
+ nc = prediction[0].shape[1] - 5 # number of classes
470
 
471
  # Settings
 
472
  min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
473
  max_det = 300 # maximum number of detections per image
474
  time_limit = 10.0 # seconds to quit after
475
+ redundant = True # require redundant detections
476
+ fast |= conf_thres > 0.001 # fast mode
477
+ if fast:
478
+ merge = False
479
+ multi_label = False
480
+ else:
481
+ merge = True # merge for best mAP (adds 0.5ms/img)
482
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
483
 
484
  t = time.time()
 
 
485
  output = [None] * prediction.shape[0]
486
  for xi, x in enumerate(prediction): # image index, image inference
487
  # Apply constraints
488
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
489
  x = x[x[:, 4] > conf_thres] # confidence
 
490
 
491
  # If none remain process next image
492
  if not x.shape[0]: