jakepoz commited on
Commit
685091a
1 Parent(s): 37e13f8

Adding torchscript export

Browse files
Files changed (1) hide show
  1. models/torchscript_export.py +38 -0
models/torchscript_export.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Exports a pytorch *.pt model to *.torchscript format
2
+
3
+ Usage:
4
+ $ export PYTHONPATH="$PWD" && python models/torchscript_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
5
+ """
6
+
7
+ import argparse
8
+
9
+
10
+ from models.common import *
11
+ from utils import google_utils
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
16
+ parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
17
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
18
+ opt = parser.parse_args()
19
+ print(opt)
20
+
21
+ # Parameters
22
+ f = opt.weights.replace('.pt', '.torchscript') # onnx filename
23
+ img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
24
+
25
+ # Load pytorch model
26
+ google_utils.attempt_download(opt.weights)
27
+ model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
28
+ model.eval()
29
+
30
+ # Don't fuse layers, it won't work with torchscript exports
31
+ #model.fuse()
32
+
33
+ # Export to jit/torchscript
34
+ model.model[-1].export = True # set Detect() layer export=True
35
+ _ = model(img) # dry run
36
+
37
+ traced_script_module = torch.jit.trace(model, img)
38
+ traced_script_module.save(f)