Commit
•
a928d8f
1
Parent(s):
3db13c2
Update onnx_test.py (#2)
Browse files- Update onnx_test.py (675b74940c2cb69557dca75d986203e75716d3a3)
Co-authored-by: yixiong huo <yixionghuo@users.noreply.huggingface.co>
- onnx_test.py +5 -1
onnx_test.py
CHANGED
@@ -962,8 +962,12 @@ def test(data,
|
|
962 |
whwh = torch.Tensor([width, height, width, height]).to(device)
|
963 |
|
964 |
if onnx_runtime:
|
|
|
|
|
965 |
outputs = onnx_model.run(
|
966 |
-
None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
|
|
|
|
|
967 |
outputs = [torch.tensor(item).to(device) for item in outputs]
|
968 |
inf_out, train_out = post_process(outputs)
|
969 |
|
|
|
962 |
whwh = torch.Tensor([width, height, width, height]).to(device)
|
963 |
|
964 |
if onnx_runtime:
|
965 |
+
# outputs = onnx_model.run(
|
966 |
+
# None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
|
967 |
outputs = onnx_model.run(
|
968 |
+
None, {onnx_model.get_inputs()[0].name: np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))})
|
969 |
+
outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
|
970 |
+
|
971 |
outputs = [torch.tensor(item).to(device) for item in outputs]
|
972 |
inf_out, train_out = post_process(outputs)
|
973 |
|