glenn-jocher commited on
Commit
8f17552
1 Parent(s): 3a5c532

ONNX export update

Browse files
Files changed (2) hide show
  1. models/onnx_export.py +7 -5
  2. models/yolo.py +29 -1
models/onnx_export.py CHANGED
@@ -13,25 +13,27 @@ from models.common import *
13
 
14
  if __name__ == '__main__':
15
  parser = argparse.ArgumentParser()
16
- parser.add_argument('--weights', type=str, default='./weights/yolov5s.pt', help='weights path')
17
- parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
18
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
19
  opt = parser.parse_args()
20
  print(opt)
21
 
22
  # Parameters
23
  f = opt.weights.replace('.pt', '.onnx') # onnx filename
24
- img = torch.zeros((opt.batch_size, 3, opt.img_size, opt.img_size)) # image size, (1, 3, 320, 192) iDetection
25
 
26
  # Load pytorch model
27
  google_utils.attempt_download(opt.weights)
28
  model = torch.load(opt.weights)['model']
29
  model.eval()
30
- # model.fuse()
31
 
32
  # Export to onnx
33
  model.model[-1].export = True # set Detect() layer export=True
34
- torch.onnx.export(model, img, f, verbose=False, opset_version=11)
 
 
35
 
36
  # Check onnx model
37
  model = onnx.load(f) # load onnx model
 
13
 
14
  if __name__ == '__main__':
15
  parser = argparse.ArgumentParser()
16
+ parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
17
+ parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
18
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
19
  opt = parser.parse_args()
20
  print(opt)
21
 
22
  # Parameters
23
  f = opt.weights.replace('.pt', '.onnx') # onnx filename
24
+ img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
25
 
26
  # Load pytorch model
27
  google_utils.attempt_download(opt.weights)
28
  model = torch.load(opt.weights)['model']
29
  model.eval()
30
+ model.fuse()
31
 
32
  # Export to onnx
33
  model.model[-1].export = True # set Detect() layer export=True
34
+ _ = model(img) # dry run
35
+ torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
36
+ output_names=['output']) # output_names=['classes', 'boxes']
37
 
38
  # Check onnx model
39
  model = onnx.load(f) # load onnx model
models/yolo.py CHANGED
@@ -20,7 +20,7 @@ class Detect(nn.Module):
20
  self.export = False # onnx export
21
 
22
  def forward(self, x):
23
- x = x.copy() # for profiling
24
  z = [] # inference output
25
  self.training |= self.export
26
  for i in range(self.nl):
@@ -38,6 +38,34 @@ class Detect(nn.Module):
38
 
39
  return x if self.training else (torch.cat(z, 1), x)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @staticmethod
42
  def _make_grid(nx=20, ny=20):
43
  yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
 
20
  self.export = False # onnx export
21
 
22
  def forward(self, x):
23
+ # x = x.copy() # for profiling
24
  z = [] # inference output
25
  self.training |= self.export
26
  for i in range(self.nl):
 
38
 
39
  return x if self.training else (torch.cat(z, 1), x)
40
 
41
+ def forward_(self, x):
42
+ if hasattr(self, 'nx'):
43
+ z = [] # inference output
44
+ for (y, gi, agi, si, nyi, nxi) in zip(x, self.grid, self.ag, self.stride, self.ny, self.nx):
45
+ m = self.na * nxi * nyi
46
+ y = y.view(1, self.na, self.no, nyi, nxi).permute(0, 1, 3, 4, 2).contiguous().view(m, self.no).sigmoid()
47
+
48
+ xy = (y[..., 0:2] * 2. - 0.5 + gi) * si # xy
49
+ wh = (y[..., 2:4] * 2) ** 2 * agi # wh
50
+ p_cls = y[:, 4:5] if self.nc == 1 else y[:, 5:self.no] * y[:, 4:5] # conf
51
+ z.append([p_cls, xy, wh])
52
+
53
+ z = [torch.cat(x, 0) for x in zip(*z)]
54
+ return z[0], torch.cat(z[1:3], 1) # scores, boxes: 3780x80, 3780x4
55
+
56
+ else: # dry run
57
+ self.nx = [0] * self.nl
58
+ self.ny = [0] * self.nl
59
+ self.ag = [0] * self.nl
60
+ for i in range(self.nl):
61
+ bs, _, ny, nx = x[i].shape
62
+ m = self.na * nx * ny
63
+ self.grid[i] = self._make_grid(nx, ny).repeat(1, self.na, 1, 1, 1).view(m, 2) / torch.tensor([[nx, ny]])
64
+ self.ag[i] = self.anchor_grid[i].repeat(1, 1, nx, ny, 1).view(m, 2)
65
+ self.nx[i] = nx
66
+ self.ny[i] = ny
67
+ return None
68
+
69
  @staticmethod
70
  def _make_grid(nx=20, ny=20):
71
  yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])