File size: 4,805 Bytes
0b5f4ac 806b785 0b5f4ac 806b785 0b5f4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
from typing import Tuple
import argparse
import onnxruntime
import os
import sys
import time
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument(
"--onnx_model", default="model.onnx", help="Input onnx model")
parser.add_argument(
"--data_dir",
default="/workspace/dataset/imagenet",
help="Directory of dataset")
parser.add_argument(
"--batch_size", default=1, type=int, help="Evaluation batch size")
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
args = parser.parse_args()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output: torch.Tensor,
target: torch.Tensor,
topk: Tuple[int] = (1,)) -> Tuple[float]:
"""Computes the accuracy over the k top predictions for the specified values of k.
Args:
output: Prediction of the model.
target: Ground truth labels.
topk: Topk accuracy to compute.
Returns:
Accuracy results according to 'topk'.
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def prepare_data_loader(data_dir: str,
batch_size: int = 100,
workers: int = 8) -> torch.utils.data.DataLoader:
"""Returns a validation data loader of ImageNet by given `data_dir`.
Args:
data_dir: Directory where images stores. There must be a subdirectory named
'validation' that stores the validation set of ImageNet.
batch_size: Batch size of data loader.
workers: How many subprocesses to use for data loading.
Returns:
An object of torch.utils.data.DataLoader.
"""
valdir = os.path.join(data_dir, 'validation')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
return torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True)
def val_imagenet():
"""Validate ONNX model on ImageNet dataset."""
print(f'Current onnx model: {args.onnx_model}')
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(
args.onnx_model, providers=providers, provider_options=provider_options)
val_loader = prepare_data_loader(args.data_dir, args.batch_size)
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
start_time = time.time()
val_loader = tqdm(val_loader, file=sys.stdout)
with torch.no_grad():
for batch_idx, (images, targets) in enumerate(val_loader):
inputs, targets = images.numpy() if args.data_format == "nchw" else images.permute((0, 2, 3, 1)).numpy(), targets
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
outputs = ort_session.run(None, ort_inputs)
outputs = torch.from_numpy(outputs[0])
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
top1.update(acc1, images.size(0))
top5.update(acc5, images.size(0))
current_time = time.time()
print('Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.format(
float(top1.avg), float(top5.avg), (current_time - start_time)))
return top1.avg, top5.avg
if __name__ == '__main__':
val_imagenet()
|