|
import os |
|
|
|
import onnxruntime |
|
import torch |
|
import torch.utils.data |
|
import torchvision |
|
from torch import nn |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
import utils |
|
|
|
|
|
def evaluate( |
|
criterion, |
|
data_loader, |
|
device, |
|
model=None, |
|
model_onnx_path=None, |
|
print_freq=100, |
|
log_suffix="", |
|
): |
|
if model_onnx_path: |
|
session = onnxruntime.InferenceSession( |
|
model_onnx_path, providers=["CPUExecutionProvider"] |
|
) |
|
input_name = session.get_inputs()[0].name |
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = f"Test: {log_suffix}" |
|
|
|
num_processed_samples = 0 |
|
with torch.inference_mode(): |
|
for image, target in metric_logger.log_every(data_loader, print_freq, header): |
|
target = target.to(device, non_blocking=True) |
|
image = image.to(device) |
|
|
|
if model_onnx_path: |
|
|
|
input_data = image.cpu().numpy() |
|
|
|
output_data = session.run([], {input_name: input_data})[0] |
|
|
|
|
|
output = torch.from_numpy(output_data).to(device) |
|
elif model: |
|
output = model(image) |
|
|
|
loss = criterion(output, target) |
|
|
|
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) |
|
|
|
|
|
batch_size = image.shape[0] |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) |
|
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) |
|
num_processed_samples += batch_size |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
|
print( |
|
f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}" |
|
) |
|
return metric_logger.acc1.global_avg |
|
|
|
|
|
def load_data(valdir): |
|
|
|
print("Loading data") |
|
interpolation = InterpolationMode("bilinear") |
|
|
|
preprocessing = torchvision.transforms.Compose( |
|
[ |
|
torchvision.transforms.Resize(256, interpolation=interpolation), |
|
torchvision.transforms.CenterCrop(224), |
|
torchvision.transforms.PILToTensor(), |
|
torchvision.transforms.ConvertImageDtype(torch.float), |
|
torchvision.transforms.Normalize( |
|
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) |
|
), |
|
] |
|
) |
|
|
|
dataset_test = torchvision.datasets.ImageFolder( |
|
valdir, |
|
preprocessing, |
|
) |
|
|
|
print("Creating data loaders") |
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
|
return dataset_test, test_sampler |
|
|
|
|
|
def main(args): |
|
print(args) |
|
|
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
val_dir = os.path.join(args.data_path, "val") |
|
dataset_test, test_sampler = load_data(val_dir) |
|
|
|
data_loader_test = torch.utils.data.DataLoader( |
|
dataset_test, |
|
batch_size=args.batch_size, |
|
sampler=test_sampler, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
) |
|
|
|
print("Creating model") |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
model = None |
|
if args.model_ckpt: |
|
checkpoint = torch.load(args.model_ckpt, map_location="cpu") |
|
model = checkpoint["model_ckpt"] |
|
if "model_ema" in checkpoint: |
|
state_dict = {} |
|
for key, value in checkpoint["model_ema"].items(): |
|
if not "module." in key: |
|
continue |
|
state_dict[key.replace("module.", "")] = value |
|
model.load_state_dict(state_dict) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
accuracy = evaluate( |
|
model=model, |
|
model_onnx_path=args.model_onnx, |
|
criterion=criterion, |
|
data_loader=data_loader_test, |
|
device=device, |
|
) |
|
print(f"Model accuracy is: {accuracy}") |
|
|
|
|
|
def get_args_parser(add_help=True): |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser( |
|
description="PyTorch Classification Training", add_help=add_help |
|
) |
|
|
|
parser.add_argument( |
|
"--data-path", default="datasets/imagenet", type=str, help="dataset path" |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--batch-size", |
|
default=32, |
|
type=int, |
|
help="images per gpu, the total batch size is $NGPU x batch_size", |
|
) |
|
parser.add_argument( |
|
"-j", |
|
"--workers", |
|
default=16, |
|
type=int, |
|
metavar="N", |
|
help="number of data loading workers (default: 16)", |
|
) |
|
parser.add_argument("--print-freq", default=10, type=int, help="print frequency") |
|
parser.add_argument( |
|
"--model-onnx", default="", type=str, help="path of .onnx checkpoint" |
|
) |
|
parser.add_argument( |
|
"--model-ckpt", default="", type=str, help="path of .pth checkpoint" |
|
) |
|
|
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser().parse_args() |
|
main(args) |
|
|