Commit
•
2e9787f
1
Parent(s):
452b074
Update infer_onnx.py (#4)
Browse files- Update infer_onnx.py (1332ee4c972acd15c2c7450db495dee19bb23a9f)
Co-authored-by: hang yang <hangyang-amd@users.noreply.huggingface.co>
- infer_onnx.py +3 -1
infer_onnx.py
CHANGED
@@ -36,7 +36,7 @@ parser.add_argument(
|
|
36 |
default="vaip_config.json",
|
37 |
help="Path of the config file for seting provider_options.",
|
38 |
)
|
39 |
-
|
40 |
args = parser.parse_args()
|
41 |
|
42 |
|
@@ -51,6 +51,8 @@ def read_image():
|
|
51 |
normalize,
|
52 |
])
|
53 |
img_tensor = transform(image).unsqueeze(0)
|
|
|
|
|
54 |
return img_tensor.numpy()
|
55 |
|
56 |
|
|
|
36 |
default="vaip_config.json",
|
37 |
help="Path of the config file for seting provider_options.",
|
38 |
)
|
39 |
+
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
|
40 |
args = parser.parse_args()
|
41 |
|
42 |
|
|
|
51 |
normalize,
|
52 |
])
|
53 |
img_tensor = transform(image).unsqueeze(0)
|
54 |
+
if args.data_format == "nhwc":
|
55 |
+
img_tensor = transform(image).unsqueeze(0).permute((0, 2, 3, 1))
|
56 |
return img_tensor.numpy()
|
57 |
|
58 |
|