jebastin-nadar glenn-jocher commited on
Commit
9d75e42
1 Parent(s): 153873e

Refactor `Detect()` anchors for ONNX <> OpenCV DNN compatibility (#4833)

Browse files

* refactor anchors and anchor_grid in Detect Layer

* fix CI failures by adding compatibility

* fix tf failure

* fix different devices errors

* Cleanup

* fix anchors overwriting issue

* better refactoring

* Remove self.anchor_grid shape check (redundant with self.grid check)

Also PEP8 / 120 line width

* Convert _make_grid() from static to dynamic method

* Remove anchor_grid.to(device)

clone() should already clone to same device as self.anchors

* fix different devices error

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

models/common.py CHANGED
@@ -295,6 +295,8 @@ class AutoShape(nn.Module):
295
  m = self.model.model[-1] # Detect()
296
  m.stride = fn(m.stride)
297
  m.grid = list(map(fn, m.grid))
 
 
298
  return self
299
 
300
  @torch.no_grad()
 
295
  m = self.model.model[-1] # Detect()
296
  m.stride = fn(m.stride)
297
  m.grid = list(map(fn, m.grid))
298
+ if isinstance(m.anchor_grid, list):
299
+ m.anchor_grid = list(map(fn, m.anchor_grid))
300
  return self
301
 
302
  @torch.no_grad()
models/experimental.py CHANGED
@@ -102,6 +102,10 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
102
  for m in model.modules():
103
  if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
104
  m.inplace = inplace # pytorch 1.7.0 compatibility
 
 
 
 
105
  elif type(m) is Conv:
106
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
107
 
 
102
  for m in model.modules():
103
  if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
104
  m.inplace = inplace # pytorch 1.7.0 compatibility
105
+ if type(m) is Detect:
106
+ if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
107
+ delattr(m, 'anchor_grid')
108
+ setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
109
  elif type(m) is Conv:
110
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
111
 
models/tf.py CHANGED
@@ -193,7 +193,7 @@ class TFDetect(keras.layers.Layer):
193
  self.na = len(anchors[0]) // 2 # number of anchors
194
  self.grid = [tf.zeros(1)] * self.nl # init grid
195
  self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
196
- self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
197
  [self.nl, 1, -1, 1, 2])
198
  self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
199
  self.training = False # set to False after building model
 
193
  self.na = len(anchors[0]) // 2 # number of anchors
194
  self.grid = [tf.zeros(1)] * self.nl # init grid
195
  self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
196
+ self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
197
  [self.nl, 1, -1, 1, 2])
198
  self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
199
  self.training = False # set to False after building model
models/yolo.py CHANGED
@@ -44,9 +44,8 @@ class Detect(nn.Module):
44
  self.nl = len(anchors) # number of detection layers
45
  self.na = len(anchors[0]) // 2 # number of anchors
46
  self.grid = [torch.zeros(1)] * self.nl # init grid
47
- a = torch.tensor(anchors).float().view(self.nl, -1, 2)
48
- self.register_buffer('anchors', a) # shape(nl,na,2)
49
- self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
50
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
51
  self.inplace = inplace # use in-place ops (e.g. slice assignment)
52
 
@@ -59,7 +58,7 @@ class Detect(nn.Module):
59
 
60
  if not self.training: # inference
61
  if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
62
- self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
63
 
64
  y = x[i].sigmoid()
65
  if self.inplace:
@@ -67,16 +66,19 @@ class Detect(nn.Module):
67
  y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
68
  else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
69
  xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
70
- wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
71
  y = torch.cat((xy, wh, y[..., 4:]), -1)
72
  z.append(y.view(bs, -1, self.no))
73
 
74
  return x if self.training else (torch.cat(z, 1), x)
75
 
76
- @staticmethod
77
- def _make_grid(nx=20, ny=20):
78
- yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
79
- return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
 
 
 
80
 
81
 
82
  class Model(nn.Module):
@@ -239,6 +241,8 @@ class Model(nn.Module):
239
  if isinstance(m, Detect):
240
  m.stride = fn(m.stride)
241
  m.grid = list(map(fn, m.grid))
 
 
242
  return self
243
 
244
 
 
44
  self.nl = len(anchors) # number of detection layers
45
  self.na = len(anchors[0]) // 2 # number of anchors
46
  self.grid = [torch.zeros(1)] * self.nl # init grid
47
+ self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
48
+ self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
 
49
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
50
  self.inplace = inplace # use in-place ops (e.g. slice assignment)
51
 
 
58
 
59
  if not self.training: # inference
60
  if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
61
+ self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
62
 
63
  y = x[i].sigmoid()
64
  if self.inplace:
 
66
  y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
67
  else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
68
  xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
69
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
70
  y = torch.cat((xy, wh, y[..., 4:]), -1)
71
  z.append(y.view(bs, -1, self.no))
72
 
73
  return x if self.training else (torch.cat(z, 1), x)
74
 
75
+ def _make_grid(self, nx=20, ny=20, i=0):
76
+ d = self.anchors[i].device
77
+ yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
78
+ grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
79
+ anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
80
+ .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
81
+ return grid, anchor_grid
82
 
83
 
84
  class Model(nn.Module):
 
241
  if isinstance(m, Detect):
242
  m.stride = fn(m.stride)
243
  m.grid = list(map(fn, m.grid))
244
+ if isinstance(m.anchor_grid, list):
245
+ m.anchor_grid = list(map(fn, m.anchor_grid))
246
  return self
247
 
248
 
utils/autoanchor.py CHANGED
@@ -15,13 +15,12 @@ from utils.general import colorstr
15
 
16
  def check_anchor_order(m):
17
  # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
18
- a = m.anchor_grid.prod(-1).view(-1) # anchor area
19
  da = a[-1] - a[0] # delta a
20
  ds = m.stride[-1] - m.stride[0] # delta s
21
  if da.sign() != ds.sign(): # same order
22
  print('Reversing anchor order')
23
  m.anchors[:] = m.anchors.flip(0)
24
- m.anchor_grid[:] = m.anchor_grid.flip(0)
25
 
26
 
27
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
@@ -41,12 +40,12 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
41
  bpr = (best > 1. / thr).float().mean() # best possible recall
42
  return bpr, aat
43
 
44
- anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
45
- bpr, aat = metric(anchors)
46
  print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
47
  if bpr < 0.98: # threshold to recompute
48
  print('. Attempting to improve anchors, please wait...')
49
- na = m.anchor_grid.numel() // 2 # number of anchors
50
  try:
51
  anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
52
  except Exception as e:
@@ -54,7 +53,6 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
54
  new_bpr = metric(anchors)[0]
55
  if new_bpr > bpr: # replace anchors
56
  anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
57
- m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
58
  m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
59
  check_anchor_order(m)
60
  print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
 
15
 
16
  def check_anchor_order(m):
17
  # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
18
+ a = m.anchors.prod(-1).view(-1) # anchor area
19
  da = a[-1] - a[0] # delta a
20
  ds = m.stride[-1] - m.stride[0] # delta s
21
  if da.sign() != ds.sign(): # same order
22
  print('Reversing anchor order')
23
  m.anchors[:] = m.anchors.flip(0)
 
24
 
25
 
26
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
 
40
  bpr = (best > 1. / thr).float().mean() # best possible recall
41
  return bpr, aat
42
 
43
+ anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
44
+ bpr, aat = metric(anchors.cpu().view(-1, 2))
45
  print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
46
  if bpr < 0.98: # threshold to recompute
47
  print('. Attempting to improve anchors, please wait...')
48
+ na = m.anchors.numel() // 2 # number of anchors
49
  try:
50
  anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
51
  except Exception as e:
 
53
  new_bpr = metric(anchors)[0]
54
  if new_bpr > bpr: # replace anchors
55
  anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
 
56
  m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
57
  check_anchor_order(m)
58
  print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')