glenn-jocher
commited on
Commit
•
8f17552
1
Parent(s):
3a5c532
ONNX export update
Browse files- models/onnx_export.py +7 -5
- models/yolo.py +29 -1
models/onnx_export.py
CHANGED
@@ -13,25 +13,27 @@ from models.common import *
|
|
13 |
|
14 |
if __name__ == '__main__':
|
15 |
parser = argparse.ArgumentParser()
|
16 |
-
parser.add_argument('--weights', type=str, default='./
|
17 |
-
parser.add_argument('--img-size', type=int, default=640, help='
|
18 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
19 |
opt = parser.parse_args()
|
20 |
print(opt)
|
21 |
|
22 |
# Parameters
|
23 |
f = opt.weights.replace('.pt', '.onnx') # onnx filename
|
24 |
-
img = torch.zeros((opt.batch_size, 3, opt.img_size
|
25 |
|
26 |
# Load pytorch model
|
27 |
google_utils.attempt_download(opt.weights)
|
28 |
model = torch.load(opt.weights)['model']
|
29 |
model.eval()
|
30 |
-
|
31 |
|
32 |
# Export to onnx
|
33 |
model.model[-1].export = True # set Detect() layer export=True
|
34 |
-
|
|
|
|
|
35 |
|
36 |
# Check onnx model
|
37 |
model = onnx.load(f) # load onnx model
|
|
|
13 |
|
14 |
if __name__ == '__main__':
|
15 |
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
|
17 |
+
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
|
18 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
19 |
opt = parser.parse_args()
|
20 |
print(opt)
|
21 |
|
22 |
# Parameters
|
23 |
f = opt.weights.replace('.pt', '.onnx') # onnx filename
|
24 |
+
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
|
25 |
|
26 |
# Load pytorch model
|
27 |
google_utils.attempt_download(opt.weights)
|
28 |
model = torch.load(opt.weights)['model']
|
29 |
model.eval()
|
30 |
+
model.fuse()
|
31 |
|
32 |
# Export to onnx
|
33 |
model.model[-1].export = True # set Detect() layer export=True
|
34 |
+
_ = model(img) # dry run
|
35 |
+
torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
|
36 |
+
output_names=['output']) # output_names=['classes', 'boxes']
|
37 |
|
38 |
# Check onnx model
|
39 |
model = onnx.load(f) # load onnx model
|
models/yolo.py
CHANGED
@@ -20,7 +20,7 @@ class Detect(nn.Module):
|
|
20 |
self.export = False # onnx export
|
21 |
|
22 |
def forward(self, x):
|
23 |
-
x = x.copy() # for profiling
|
24 |
z = [] # inference output
|
25 |
self.training |= self.export
|
26 |
for i in range(self.nl):
|
@@ -38,6 +38,34 @@ class Detect(nn.Module):
|
|
38 |
|
39 |
return x if self.training else (torch.cat(z, 1), x)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
@staticmethod
|
42 |
def _make_grid(nx=20, ny=20):
|
43 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
|
|
20 |
self.export = False # onnx export
|
21 |
|
22 |
def forward(self, x):
|
23 |
+
# x = x.copy() # for profiling
|
24 |
z = [] # inference output
|
25 |
self.training |= self.export
|
26 |
for i in range(self.nl):
|
|
|
38 |
|
39 |
return x if self.training else (torch.cat(z, 1), x)
|
40 |
|
41 |
+
def forward_(self, x):
|
42 |
+
if hasattr(self, 'nx'):
|
43 |
+
z = [] # inference output
|
44 |
+
for (y, gi, agi, si, nyi, nxi) in zip(x, self.grid, self.ag, self.stride, self.ny, self.nx):
|
45 |
+
m = self.na * nxi * nyi
|
46 |
+
y = y.view(1, self.na, self.no, nyi, nxi).permute(0, 1, 3, 4, 2).contiguous().view(m, self.no).sigmoid()
|
47 |
+
|
48 |
+
xy = (y[..., 0:2] * 2. - 0.5 + gi) * si # xy
|
49 |
+
wh = (y[..., 2:4] * 2) ** 2 * agi # wh
|
50 |
+
p_cls = y[:, 4:5] if self.nc == 1 else y[:, 5:self.no] * y[:, 4:5] # conf
|
51 |
+
z.append([p_cls, xy, wh])
|
52 |
+
|
53 |
+
z = [torch.cat(x, 0) for x in zip(*z)]
|
54 |
+
return z[0], torch.cat(z[1:3], 1) # scores, boxes: 3780x80, 3780x4
|
55 |
+
|
56 |
+
else: # dry run
|
57 |
+
self.nx = [0] * self.nl
|
58 |
+
self.ny = [0] * self.nl
|
59 |
+
self.ag = [0] * self.nl
|
60 |
+
for i in range(self.nl):
|
61 |
+
bs, _, ny, nx = x[i].shape
|
62 |
+
m = self.na * nx * ny
|
63 |
+
self.grid[i] = self._make_grid(nx, ny).repeat(1, self.na, 1, 1, 1).view(m, 2) / torch.tensor([[nx, ny]])
|
64 |
+
self.ag[i] = self.anchor_grid[i].repeat(1, 1, nx, ny, 1).view(m, 2)
|
65 |
+
self.nx[i] = nx
|
66 |
+
self.ny[i] = ny
|
67 |
+
return None
|
68 |
+
|
69 |
@staticmethod
|
70 |
def _make_grid(nx=20, ny=20):
|
71 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|