zhengrongzhang commited on
Commit
5c2432b
1 Parent(s): ef66c2c

update infer code for NCHW->NHWC

Browse files
Files changed (1) hide show
  1. infer_onnx.py +3 -1
infer_onnx.py CHANGED
@@ -139,8 +139,10 @@ if __name__ == '__main__':
139
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
140
  provider_options = None
141
  session = ort.InferenceSession(args.model, providers=providers, provider_options=provider_options)
142
- ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
 
143
  outputs = session.run(None, ort_inputs)
 
144
  dets = postprocess(outputs, input_shape, ratio)
145
  if dets is not None:
146
  final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
 
139
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
140
  provider_options = None
141
  session = ort.InferenceSession(args.model, providers=providers, provider_options=provider_options)
142
+ # ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
143
+ ort_inputs = {session.get_inputs()[0].name: np.transpose(img[None, :, :, :], (0, 2 ,3, 1))}
144
  outputs = session.run(None, ort_inputs)
145
+ outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
146
  dets = postprocess(outputs, input_shape, ratio)
147
  if dets is not None:
148
  final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]