savchenkoyana's picture
add ViT x4.8 ONNX, small fixes in test.py, and allow measuring macs on ONNX
061bac4
import os
import onnxruntime
import torch
import torch.utils.data
import torchvision
from torch import nn
from torchvision.transforms.functional import InterpolationMode
import utils
def evaluate(
criterion,
data_loader,
device,
model=None,
model_onnx_path=None,
print_freq=100,
log_suffix="",
):
if model_onnx_path:
session = onnxruntime.InferenceSession(
model_onnx_path, providers=["CPUExecutionProvider"]
)
input_name = session.get_inputs()[0].name
metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}"
num_processed_samples = 0
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
target = target.to(device, non_blocking=True)
image = image.to(device)
if model_onnx_path:
# from torch to numpy (ort)
input_data = image.cpu().numpy()
output_data = session.run([], {input_name: input_data})[0]
# from numpy to torch
output = torch.from_numpy(output_data).to(device)
elif model:
output = model(image)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(
f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}"
)
return metric_logger.acc1.global_avg
def load_data(valdir):
# Data loading code
print("Loading data")
interpolation = InterpolationMode("bilinear")
preprocessing = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(256, interpolation=interpolation),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.PILToTensor(),
torchvision.transforms.ConvertImageDtype(torch.float),
torchvision.transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
print("Creating data loaders")
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset_test, test_sampler
def main(args):
print(args)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
val_dir = os.path.join(args.data_path, "val")
dataset_test, test_sampler = load_data(val_dir)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=test_sampler,
num_workers=args.workers,
pin_memory=True,
)
print("Creating model")
criterion = nn.CrossEntropyLoss()
model = None
if args.model_ckpt:
checkpoint = torch.load(args.model_ckpt, map_location="cpu")
model = checkpoint["model_ckpt"]
if "model_ema" in checkpoint:
state_dict = {}
for key, value in checkpoint["model_ema"].items():
if not "module." in key:
continue
state_dict[key.replace("module.", "")] = value
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
accuracy = evaluate(
model=model,
model_onnx_path=args.model_onnx,
criterion=criterion,
data_loader=data_loader_test,
device=device,
)
print(f"Model accuracy is: {accuracy}")
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(
description="PyTorch Classification Training", add_help=add_help
)
parser.add_argument(
"--data-path", default="datasets/imagenet", type=str, help="dataset path"
)
parser.add_argument(
"-b",
"--batch-size",
default=32,
type=int,
help="images per gpu, the total batch size is $NGPU x batch_size",
)
parser.add_argument(
"-j",
"--workers",
default=16,
type=int,
metavar="N",
help="number of data loading workers (default: 16)",
)
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument(
"--model-onnx", default="", type=str, help="path of .onnx checkpoint"
)
parser.add_argument(
"--model-ckpt", default="", type=str, help="path of .pth checkpoint"
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)