glenn-jocher commited on
Commit
a8751e5
1 Parent(s): 5f07782

model.fuse() fix for export.py (#827)

Browse files
Files changed (2) hide show
  1. models/export.py +3 -1
  2. models/yolo.py +1 -1
models/export.py CHANGED
@@ -28,6 +28,9 @@ if __name__ == '__main__':
28
  attempt_download(opt.weights)
29
  model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
30
  model.eval()
 
 
 
31
  model.model[-1].export = True # set Detect() layer export=True
32
  y = model(img) # dry run
33
 
@@ -47,7 +50,6 @@ if __name__ == '__main__':
47
 
48
  print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
49
  f = opt.weights.replace('.pt', '.onnx') # filename
50
- model.fuse() # only for ONNX
51
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
52
  output_names=['classes', 'boxes'] if y is None else ['output'])
53
 
 
28
  attempt_download(opt.weights)
29
  model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
30
  model.eval()
31
+ model.fuse()
32
+
33
+ # Update model
34
  model.model[-1].export = True # set Detect() layer export=True
35
  y = model(img) # dry run
36
 
 
50
 
51
  print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
52
  f = opt.weights.replace('.pt', '.onnx') # filename
 
53
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
54
  output_names=['classes', 'boxes'] if y is None else ['output'])
55
 
models/yolo.py CHANGED
@@ -163,7 +163,7 @@ class Model(nn.Module):
163
  if type(m) is Conv:
164
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
165
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
166
- m.bn = None # remove batchnorm
167
  m.forward = m.fuseforward # update forward
168
  self.info()
169
  return self
 
163
  if type(m) is Conv:
164
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
165
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
166
+ delattr(m, 'bn') # remove batchnorm
167
  m.forward = m.fuseforward # update forward
168
  self.info()
169
  return self