Hodovo glenn-jocher commited on
Commit
e2a80c6
1 Parent(s): 31ee54c

Add support for FP16 (half) to export.py (#3010)

Browse files

* Added support for fp16 (half) to export.py

* minimize code additions

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. models/export.py +4 -0
models/export.py CHANGED
@@ -28,6 +28,7 @@ if __name__ == '__main__':
28
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
29
  parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
30
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
 
31
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
32
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
33
  opt = parser.parse_args()
@@ -44,11 +45,14 @@ if __name__ == '__main__':
44
  # Checks
45
  gs = int(max(model.stride)) # grid size (max stride)
46
  opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
 
47
 
48
  # Input
49
  img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
50
 
51
  # Update model
 
 
52
  for k, m in model.named_modules():
53
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
54
  if isinstance(m, models.common.Conv): # assign export-friendly activations
 
28
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
29
  parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
30
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
31
+ parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
32
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
33
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
34
  opt = parser.parse_args()
 
45
  # Checks
46
  gs = int(max(model.stride)) # grid size (max stride)
47
  opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
48
+ assert not (opt.device.lower() == "cpu" and opt.half), '--half only compatible with GPU export, i.e. use --device 0'
49
 
50
  # Input
51
  img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
52
 
53
  # Update model
54
+ if opt.half:
55
+ img, model = img.half(), model.half() # to FP16
56
  for k, m in model.named_modules():
57
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
58
  if isinstance(m, models.common.Conv): # assign export-friendly activations