import os import onnxruntime import torch import torch.utils.data import torchvision from torch import nn from torchvision.transforms.functional import InterpolationMode import utils def evaluate( criterion, data_loader, device, model=None, model_onnx_path=None, print_freq=100, log_suffix="", ): if model_onnx_path: session = onnxruntime.InferenceSession( model_onnx_path, providers=["CPUExecutionProvider"] ) input_name = session.get_inputs()[0].name metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" num_processed_samples = 0 with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, print_freq, header): target = target.to(device, non_blocking=True) image = image.to(device) if model_onnx_path: # from torch to numpy (ort) input_data = image.cpu().numpy() output_data = session.run([], {input_name: input_data})[0] # from numpy to torch output = torch.from_numpy(output_data).to(device) elif model: output = model(image) loss = criterion(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) num_processed_samples += batch_size # gather the stats from all processes metric_logger.synchronize_between_processes() print( f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}" ) return metric_logger.acc1.global_avg def load_data(valdir): # Data loading code print("Loading data") interpolation = InterpolationMode("bilinear") preprocessing = torchvision.transforms.Compose( [ torchvision.transforms.Resize(256, interpolation=interpolation), torchvision.transforms.CenterCrop(224), torchvision.transforms.PILToTensor(), torchvision.transforms.ConvertImageDtype(torch.float), torchvision.transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ), ] ) dataset_test = torchvision.datasets.ImageFolder( valdir, preprocessing, ) print("Creating data loaders") test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset_test, test_sampler def main(args): print(args) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") val_dir = os.path.join(args.data_path, "val") dataset_test, test_sampler = load_data(val_dir) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True, ) print("Creating model") criterion = nn.CrossEntropyLoss() model = None if args.model_ckpt: checkpoint = torch.load(args.model_ckpt, map_location="cpu") model = checkpoint["model_ckpt"] if "model_ema" in checkpoint: state_dict = {} for key, value in checkpoint["model_ema"].items(): if not "module." in key: continue state_dict[key.replace("module.", "")] = value model.load_state_dict(state_dict) model = model.to(device) model.eval() accuracy = evaluate( model=model, model_onnx_path=args.model_onnx, criterion=criterion, data_loader=data_loader_test, device=device, ) print(f"Model accuracy is: {accuracy}") def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser( description="PyTorch Classification Training", add_help=add_help ) parser.add_argument( "--data-path", default="datasets/imagenet", type=str, help="dataset path" ) parser.add_argument( "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size", ) parser.add_argument( "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)", ) parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument( "--model-onnx", default="", type=str, help="path of .onnx checkpoint" ) parser.add_argument( "--model-ckpt", default="", type=str, help="path of .pth checkpoint" ) return parser if __name__ == "__main__": args = get_args_parser().parse_args() main(args)