glenn-jocher
commited on
Commit
·
a8751e5
1
Parent(s):
5f07782
model.fuse() fix for export.py (#827)
Browse files- models/export.py +3 -1
- models/yolo.py +1 -1
models/export.py
CHANGED
@@ -28,6 +28,9 @@ if __name__ == '__main__':
|
|
28 |
attempt_download(opt.weights)
|
29 |
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
|
30 |
model.eval()
|
|
|
|
|
|
|
31 |
model.model[-1].export = True # set Detect() layer export=True
|
32 |
y = model(img) # dry run
|
33 |
|
@@ -47,7 +50,6 @@ if __name__ == '__main__':
|
|
47 |
|
48 |
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
|
49 |
f = opt.weights.replace('.pt', '.onnx') # filename
|
50 |
-
model.fuse() # only for ONNX
|
51 |
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
|
52 |
output_names=['classes', 'boxes'] if y is None else ['output'])
|
53 |
|
|
|
28 |
attempt_download(opt.weights)
|
29 |
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
|
30 |
model.eval()
|
31 |
+
model.fuse()
|
32 |
+
|
33 |
+
# Update model
|
34 |
model.model[-1].export = True # set Detect() layer export=True
|
35 |
y = model(img) # dry run
|
36 |
|
|
|
50 |
|
51 |
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
|
52 |
f = opt.weights.replace('.pt', '.onnx') # filename
|
|
|
53 |
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
|
54 |
output_names=['classes', 'boxes'] if y is None else ['output'])
|
55 |
|
models/yolo.py
CHANGED
@@ -163,7 +163,7 @@ class Model(nn.Module):
|
|
163 |
if type(m) is Conv:
|
164 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
165 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
166 |
-
m
|
167 |
m.forward = m.fuseforward # update forward
|
168 |
self.info()
|
169 |
return self
|
|
|
163 |
if type(m) is Conv:
|
164 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
165 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
166 |
+
delattr(m, 'bn') # remove batchnorm
|
167 |
m.forward = m.fuseforward # update forward
|
168 |
self.info()
|
169 |
return self
|