glenn-jocher commited on
Commit
9fdb0fb
1 Parent(s): 8b6f582

AutoAnchor bug fix # 117

Browse files
Files changed (2) hide show
  1. train.py +3 -1
  2. utils/utils.py +8 -5
train.py CHANGED
@@ -200,7 +200,8 @@ def train(hyp):
200
  tb_writer.add_histogram('classes', c, 0)
201
 
202
  # Check anchors
203
- check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
 
204
 
205
  # Exponential moving average
206
  ema = torch_utils.ModelEMA(model)
@@ -374,6 +375,7 @@ if __name__ == '__main__':
374
  parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
375
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
376
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
 
377
  parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
378
  parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
379
  parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
 
200
  tb_writer.add_histogram('classes', c, 0)
201
 
202
  # Check anchors
203
+ if not opt.noautoanchor:
204
+ check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
205
 
206
  # Exponential moving average
207
  ema = torch_utils.ModelEMA(model)
 
375
  parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
376
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
377
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
378
+ parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
379
  parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
380
  parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
381
  parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
utils/utils.py CHANGED
@@ -56,7 +56,7 @@ def check_img_size(img_size, s=32):
56
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
57
  # Check anchor fit to data, recompute if necessary
58
  print('\nAnalyzing anchors... ', end='')
59
- anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
60
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
61
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
62
 
@@ -66,14 +66,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
66
  best = x.max(1)[0] # best_x
67
  return (best > 1. / thr).float().mean() #  best possible recall
68
 
69
- bpr = metric(anchors.clone().cpu().view(-1, 2))
70
  print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
71
  if bpr < 0.99: # threshold to recompute
72
  print('. Attempting to generate improved anchors, please wait...' % bpr)
73
- new_anchors = kmean_anchors(dataset, n=anchors.numel() // 2, img_size=imgsz, thr=thr, gen=1000, verbose=False)
 
74
  new_bpr = metric(new_anchors.reshape(-1, 2))
75
- if new_bpr > bpr:
76
- anchors[:] = torch.tensor(new_anchors).view_as(anchors).type_as(anchors)
 
 
77
  print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
78
  else:
79
  print('Original anchors better than new anchors. Proceeding with original anchors.')
 
56
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
57
  # Check anchor fit to data, recompute if necessary
58
  print('\nAnalyzing anchors... ', end='')
59
+ m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
60
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
61
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
62
 
 
66
  best = x.max(1)[0] # best_x
67
  return (best > 1. / thr).float().mean() #  best possible recall
68
 
69
+ bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
70
  print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
71
  if bpr < 0.99: # threshold to recompute
72
  print('. Attempting to generate improved anchors, please wait...' % bpr)
73
+ na = m.anchor_grid.numel() // 2 # number of anchors
74
+ new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
75
  new_bpr = metric(new_anchors.reshape(-1, 2))
76
+ if new_bpr > bpr: # replace anchors
77
+ new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
78
+ m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
79
+ m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
80
  print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
81
  else:
82
  print('Original anchors better than new anchors. Proceeding with original anchors.')