|
|
|
|
|
from typing import Tuple |
|
|
|
import argparse |
|
import onnxruntime |
|
import os |
|
import sys |
|
import time |
|
import torch |
|
import torchvision.datasets as datasets |
|
import torchvision.transforms as transforms |
|
|
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--onnx_model", default="model.onnx", help="Input onnx model") |
|
parser.add_argument( |
|
"--data_dir", |
|
default="/workspace/dataset/imagenet", |
|
help="Directory of dataset") |
|
parser.add_argument( |
|
"--batch_size", default=1, type=int, help="Evaluation batch size") |
|
parser.add_argument( |
|
"--ipu", |
|
action="store_true", |
|
help="Use IPU for inference.", |
|
) |
|
parser.add_argument( |
|
"--provider_config", |
|
type=str, |
|
default="vaip_config.json", |
|
help="Path of the config file for seting provider_options.", |
|
) |
|
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw") |
|
args = parser.parse_args() |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name, fmt=':f'): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
return fmtstr.format(**self.__dict__) |
|
|
|
def accuracy(output: torch.Tensor, |
|
target: torch.Tensor, |
|
topk: Tuple[int] = (1,)) -> Tuple[float]: |
|
"""Computes the accuracy over the k top predictions for the specified values of k. |
|
Args: |
|
output: Prediction of the model. |
|
target: Ground truth labels. |
|
topk: Topk accuracy to compute. |
|
|
|
Returns: |
|
Accuracy results according to 'topk'. |
|
""" |
|
|
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
def prepare_data_loader(data_dir: str, |
|
batch_size: int = 100, |
|
workers: int = 8) -> torch.utils.data.DataLoader: |
|
"""Returns a validation data loader of ImageNet by given `data_dir`. |
|
|
|
Args: |
|
data_dir: Directory where images stores. There must be a subdirectory named |
|
'validation' that stores the validation set of ImageNet. |
|
batch_size: Batch size of data loader. |
|
workers: How many subprocesses to use for data loading. |
|
|
|
Returns: |
|
An object of torch.utils.data.DataLoader. |
|
""" |
|
|
|
valdir = os.path.join(data_dir, 'validation') |
|
|
|
normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
val_dataset = datasets.ImageFolder( |
|
valdir, |
|
transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
normalize, |
|
])) |
|
|
|
return torch.utils.data.DataLoader( |
|
val_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=workers, |
|
pin_memory=True) |
|
|
|
def val_imagenet(): |
|
"""Validate ONNX model on ImageNet dataset.""" |
|
print(f'Current onnx model: {args.onnx_model}') |
|
|
|
if args.ipu: |
|
providers = ["VitisAIExecutionProvider"] |
|
provider_options = [{"config_file": args.provider_config}] |
|
else: |
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
provider_options = None |
|
ort_session = onnxruntime.InferenceSession( |
|
args.onnx_model, providers=providers, provider_options=provider_options) |
|
|
|
val_loader = prepare_data_loader(args.data_dir, args.batch_size) |
|
|
|
top1 = AverageMeter('Acc@1', ':6.2f') |
|
top5 = AverageMeter('Acc@5', ':6.2f') |
|
|
|
start_time = time.time() |
|
val_loader = tqdm(val_loader, file=sys.stdout) |
|
with torch.no_grad(): |
|
for batch_idx, (images, targets) in enumerate(val_loader): |
|
inputs, targets = images.numpy() if args.data_format == "nchw" else images.permute((0, 2, 3, 1)).numpy(), targets |
|
ort_inputs = {ort_session.get_inputs()[0].name: inputs} |
|
|
|
outputs = ort_session.run(None, ort_inputs) |
|
outputs = torch.from_numpy(outputs[0]) |
|
|
|
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) |
|
top1.update(acc1, images.size(0)) |
|
top5.update(acc5, images.size(0)) |
|
|
|
current_time = time.time() |
|
print('Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.format( |
|
float(top1.avg), float(top5.avg), (current_time - start_time))) |
|
|
|
return top1.avg, top5.avg |
|
|
|
if __name__ == '__main__': |
|
val_imagenet() |
|
|