zhengrongzhang commited on
Commit
ef66c2c
1 Parent(s): 632a5be

update eval code for NCHW->NHWC

Browse files
Files changed (1) hide show
  1. eval_onnx.py +3 -1
eval_onnx.py CHANGED
@@ -93,7 +93,9 @@ class COCOEvaluator:
93
  is_time_record = cur_iter < len(self.dataloader) - 1
94
  if is_time_record:
95
  start = time.time()
96
- outputs = ort_sess.run(None, {input_name: imgs.numpy()})
 
 
97
  outputs = [torch.Tensor(out) for out in outputs]
98
  outputs = head_postprocess(outputs)
99
  if is_time_record:
 
93
  is_time_record = cur_iter < len(self.dataloader) - 1
94
  if is_time_record:
95
  start = time.time()
96
+ # outputs = ort_sess.run(None, {input_name: imgs.numpy()})
97
+ outputs = ort_sess.run(None, {input_name: np.transpose(imgs.numpy(), (0, 2, 3, 1))})
98
+ outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
99
  outputs = [torch.Tensor(out) for out in outputs]
100
  outputs = head_postprocess(outputs)
101
  if is_time_record: