glenn-jocher commited on
Commit
a62a1c2
1 Parent(s): f1d67f4

export.py update

Browse files
Files changed (2) hide show
  1. detect.py +4 -4
  2. models/export.py +5 -5
detect.py CHANGED
@@ -156,9 +156,9 @@ if __name__ == '__main__':
156
  print(opt)
157
 
158
  with torch.no_grad():
159
- detect()
160
 
161
  # Update all models
162
- # for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
163
- # detect()
164
- # create_pretrained(opt.weights, opt.weights)
 
156
  print(opt)
157
 
158
  with torch.no_grad():
159
+ # detect()
160
 
161
  # Update all models
162
+ for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
163
+ detect()
164
+ create_pretrained(opt.weights, opt.weights)
models/export.py CHANGED
@@ -6,8 +6,6 @@ Usage:
6
 
7
  import argparse
8
 
9
- import onnx
10
-
11
  from models.common import *
12
  from utils import google_utils
13
 
@@ -21,7 +19,7 @@ if __name__ == '__main__':
21
  print(opt)
22
 
23
  # Input
24
- img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
25
 
26
  # Load PyTorch model
27
  google_utils.attempt_download(opt.weights)
@@ -30,7 +28,7 @@ if __name__ == '__main__':
30
  model.model[-1].export = True # set Detect() layer export=True
31
  _ = model(img) # dry run
32
 
33
- # Export to TorchScript
34
  try:
35
  f = opt.weights.replace('.pt', '.torchscript') # filename
36
  ts = torch.jit.trace(model, img)
@@ -39,8 +37,10 @@ if __name__ == '__main__':
39
  except Exception as e:
40
  print('TorchScript export failed: %s' % e)
41
 
42
- # Export to ONNX
43
  try:
 
 
44
  f = opt.weights.replace('.pt', '.onnx') # filename
45
  model.fuse() # only for ONNX
46
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
 
6
 
7
  import argparse
8
 
 
 
9
  from models.common import *
10
  from utils import google_utils
11
 
 
19
  print(opt)
20
 
21
  # Input
22
+ img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
23
 
24
  # Load PyTorch model
25
  google_utils.attempt_download(opt.weights)
 
28
  model.model[-1].export = True # set Detect() layer export=True
29
  _ = model(img) # dry run
30
 
31
+ # TorchScript export
32
  try:
33
  f = opt.weights.replace('.pt', '.torchscript') # filename
34
  ts = torch.jit.trace(model, img)
 
37
  except Exception as e:
38
  print('TorchScript export failed: %s' % e)
39
 
40
+ # ONNX export
41
  try:
42
+ import onnx
43
+
44
  f = opt.weights.replace('.pt', '.onnx') # filename
45
  model.fuse() # only for ONNX
46
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],