glenn-jocher commited on
Commit
fc171e2
1 Parent(s): 1f1917e

check_anchor_order() update

Browse files
Files changed (2) hide show
  1. models/yolo.py +2 -1
  2. utils/utils.py +13 -1
models/yolo.py CHANGED
@@ -61,8 +61,9 @@ class Model(nn.Module):
61
 
62
  # Build strides, anchors
63
  m = self.model[-1] # Detect()
64
- m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward
65
  m.anchors /= m.stride.view(-1, 1, 1)
 
66
  self.stride = m.stride
67
 
68
  # Init weights, biases
 
61
 
62
  # Build strides, anchors
63
  m = self.model[-1] # Detect()
64
+ m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
65
  m.anchors /= m.stride.view(-1, 1, 1)
66
+ check_anchor_order(m)
67
  self.stride = m.stride
68
 
69
  # Init weights, biases
utils/utils.py CHANGED
@@ -58,7 +58,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
58
  print('\nAnalyzing anchors... ', end='')
59
  m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
60
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
61
- wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
 
62
 
63
  def metric(k): # compute metric
64
  r = wh[:, None] / k[None]
@@ -77,12 +78,23 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
77
  new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
78
  m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
79
  m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
 
80
  print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
81
  else:
82
  print('Original anchors better than new anchors. Proceeding with original anchors.')
83
  print('') # newline
84
 
85
 
 
 
 
 
 
 
 
 
 
 
86
  def check_file(file):
87
  # Searches for file if not found locally
88
  if os.path.isfile(file):
 
58
  print('\nAnalyzing anchors... ', end='')
59
  m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
60
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
61
+ scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
62
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
63
 
64
  def metric(k): # compute metric
65
  r = wh[:, None] / k[None]
 
78
  new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
79
  m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
80
  m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
81
+ check_anchor_order(m)
82
  print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
83
  else:
84
  print('Original anchors better than new anchors. Proceeding with original anchors.')
85
  print('') # newline
86
 
87
 
88
+ def check_anchor_order(m):
89
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
90
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
91
+ da = a[-1] - a[0] # delta a
92
+ ds = m.stride[-1] - m.stride[0] # delta s
93
+ if da.sign() != ds.sign(): # same order
94
+ m.anchors[:] = m.anchors.flip(0)
95
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
96
+
97
+
98
  def check_file(file):
99
  # Searches for file if not found locally
100
  if os.path.isfile(file):