glenn-jocher
commited on
TorchScript single-output fix (#7261)
Browse files- export.py +12 -6
- models/common.py +4 -3
export.py
CHANGED
@@ -73,12 +73,18 @@ from utils.torch_utils import select_device
|
|
73 |
|
74 |
def export_formats():
|
75 |
# YOLOv5 export formats
|
76 |
-
x = [
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
|
83 |
|
84 |
|
|
|
73 |
|
74 |
def export_formats():
|
75 |
# YOLOv5 export formats
|
76 |
+
x = [
|
77 |
+
['PyTorch', '-', '.pt', True],
|
78 |
+
['TorchScript', 'torchscript', '.torchscript', True],
|
79 |
+
['ONNX', 'onnx', '.onnx', True],
|
80 |
+
['OpenVINO', 'openvino', '_openvino_model', False],
|
81 |
+
['TensorRT', 'engine', '.engine', True],
|
82 |
+
['CoreML', 'coreml', '.mlmodel', False],
|
83 |
+
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
|
84 |
+
['TensorFlow GraphDef', 'pb', '.pb', True],
|
85 |
+
['TensorFlow Lite', 'tflite', '.tflite', False],
|
86 |
+
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
|
87 |
+
['TensorFlow.js', 'tfjs', '_web_model', False],]
|
88 |
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
|
89 |
|
90 |
|
models/common.py
CHANGED
@@ -406,9 +406,10 @@ class DetectMultiBackend(nn.Module):
|
|
406 |
def forward(self, im, augment=False, visualize=False, val=False):
|
407 |
# YOLOv5 MultiBackend inference
|
408 |
b, ch, h, w = im.shape # batch, channel, height, width
|
409 |
-
if self.pt
|
410 |
-
y = self.model(im
|
411 |
-
|
|
|
412 |
elif self.dnn: # ONNX OpenCV DNN
|
413 |
im = im.cpu().numpy() # torch to numpy
|
414 |
self.net.setInput(im)
|
|
|
406 |
def forward(self, im, augment=False, visualize=False, val=False):
|
407 |
# YOLOv5 MultiBackend inference
|
408 |
b, ch, h, w = im.shape # batch, channel, height, width
|
409 |
+
if self.pt: # PyTorch
|
410 |
+
y = self.model(im, augment=augment, visualize=visualize)[0]
|
411 |
+
elif self.jit: # TorchScript
|
412 |
+
y = self.model(im)[0]
|
413 |
elif self.dnn: # ONNX OpenCV DNN
|
414 |
im = im.cpu().numpy() # torch to numpy
|
415 |
self.net.setInput(im)
|