zhengrongzhang
commited on
Commit
•
5c2432b
1
Parent(s):
ef66c2c
update infer code for NCHW->NHWC
Browse files- infer_onnx.py +3 -1
infer_onnx.py
CHANGED
@@ -139,8 +139,10 @@ if __name__ == '__main__':
|
|
139 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
140 |
provider_options = None
|
141 |
session = ort.InferenceSession(args.model, providers=providers, provider_options=provider_options)
|
142 |
-
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
|
|
143 |
outputs = session.run(None, ort_inputs)
|
|
|
144 |
dets = postprocess(outputs, input_shape, ratio)
|
145 |
if dets is not None:
|
146 |
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
|
|
139 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
140 |
provider_options = None
|
141 |
session = ort.InferenceSession(args.model, providers=providers, provider_options=provider_options)
|
142 |
+
# ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
143 |
+
ort_inputs = {session.get_inputs()[0].name: np.transpose(img[None, :, :, :], (0, 2 ,3, 1))}
|
144 |
outputs = session.run(None, ort_inputs)
|
145 |
+
outputs = [np.transpose(out, (0, 3, 1, 2)) for out in outputs]
|
146 |
dets = postprocess(outputs, input_shape, ratio)
|
147 |
if dets is not None:
|
148 |
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|