Update eval_onnx.py
#2
by
zihengg
- opened
- eval_onnx.py +12 -2
eval_onnx.py
CHANGED
@@ -514,18 +514,28 @@ if __name__ == '__main__':
|
|
514 |
data_loader = data.getEvalDataloader()
|
515 |
# Load MoveNet model using ONNX runtime
|
516 |
model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
|
517 |
-
|
518 |
correct = 0
|
519 |
total = 0
|
520 |
# Loop through the data loader for evaluation
|
521 |
for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
|
|
|
522 |
if batch_idx%100 == 0:
|
523 |
print('Finish ',batch_idx)
|
|
|
524 |
imgs = imgs.detach().cpu().numpy()
|
525 |
-
|
|
|
|
|
|
|
|
|
|
|
526 |
pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
|
527 |
gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
|
|
|
|
|
528 |
acc = myAcc(pre, gt)
|
|
|
529 |
correct += sum(acc)
|
530 |
total += len(acc)
|
531 |
# Compute and print accuracy based on evaluated data
|
|
|
514 |
data_loader = data.getEvalDataloader()
|
515 |
# Load MoveNet model using ONNX runtime
|
516 |
model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
|
517 |
+
|
518 |
correct = 0
|
519 |
total = 0
|
520 |
# Loop through the data loader for evaluation
|
521 |
for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
|
522 |
+
|
523 |
if batch_idx%100 == 0:
|
524 |
print('Finish ',batch_idx)
|
525 |
+
|
526 |
imgs = imgs.detach().cpu().numpy()
|
527 |
+
imgs = imgs.transpose((0,2,3,1))
|
528 |
+
output = model.run(['1548_transpose','1607_transpose','1665_transpose','1723_transpose'],{'blob.1':imgs})
|
529 |
+
output[0] = output[0].transpose((0,3,1,2))
|
530 |
+
output[1] = output[1].transpose((0,3,1,2))
|
531 |
+
output[2] = output[2].transpose((0,3,1,2))
|
532 |
+
output[3] = output[3].transpose((0,3,1,2))
|
533 |
pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
|
534 |
gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
|
535 |
+
|
536 |
+
#n
|
537 |
acc = myAcc(pre, gt)
|
538 |
+
|
539 |
correct += sum(acc)
|
540 |
total += len(acc)
|
541 |
# Compute and print accuracy based on evaluated data
|