zhengrongzhang yixionghuo commited on
Commit
a928d8f
1 Parent(s): 3db13c2

Update onnx_test.py (#2)

Browse files

- Update onnx_test.py (675b74940c2cb69557dca75d986203e75716d3a3)


Co-authored-by: yixiong huo <yixionghuo@users.noreply.huggingface.co>

Files changed (1) hide show
  1. onnx_test.py +5 -1
onnx_test.py CHANGED
@@ -962,8 +962,12 @@ def test(data,
962
  whwh = torch.Tensor([width, height, width, height]).to(device)
963
 
964
  if onnx_runtime:
 
 
965
  outputs = onnx_model.run(
966
- None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
 
 
967
  outputs = [torch.tensor(item).to(device) for item in outputs]
968
  inf_out, train_out = post_process(outputs)
969
 
 
962
  whwh = torch.Tensor([width, height, width, height]).to(device)
963
 
964
  if onnx_runtime:
965
+ # outputs = onnx_model.run(
966
+ # None, {onnx_model.get_inputs()[0].name: imgs.cpu().numpy()})
967
  outputs = onnx_model.run(
968
+ None, {onnx_model.get_inputs()[0].name: np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))})
969
+ outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
970
+
971
  outputs = [torch.tensor(item).to(device) for item in outputs]
972
  inf_out, train_out = post_process(outputs)
973