amd
/

Image Classification
ONNX
RyzenAI
zhengrongzhang hangyang-amd commited on
Commit
2e9787f
1 Parent(s): 452b074

Update infer_onnx.py (#4)

Browse files

- Update infer_onnx.py (1332ee4c972acd15c2c7450db495dee19bb23a9f)


Co-authored-by: hang yang <hangyang-amd@users.noreply.huggingface.co>

Files changed (1) hide show
  1. infer_onnx.py +3 -1
infer_onnx.py CHANGED
@@ -36,7 +36,7 @@ parser.add_argument(
36
  default="vaip_config.json",
37
  help="Path of the config file for seting provider_options.",
38
  )
39
-
40
  args = parser.parse_args()
41
 
42
 
@@ -51,6 +51,8 @@ def read_image():
51
  normalize,
52
  ])
53
  img_tensor = transform(image).unsqueeze(0)
 
 
54
  return img_tensor.numpy()
55
 
56
 
 
36
  default="vaip_config.json",
37
  help="Path of the config file for seting provider_options.",
38
  )
39
+ parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
40
  args = parser.parse_args()
41
 
42
 
 
51
  normalize,
52
  ])
53
  img_tensor = transform(image).unsqueeze(0)
54
+ if args.data_format == "nhwc":
55
+ img_tensor = transform(image).unsqueeze(0).permute((0, 2, 3, 1))
56
  return img_tensor.numpy()
57
 
58