amd
/

Image Classification
ONNX
RyzenAI
efficientnet-es / infer_onnx.py
zhengrongzhang's picture
init model
0b5f4ac
raw
history blame
2.25 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright 2023 Advanced Micro Devices, Inc. on behalf of itself and its subsidiaries and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Megvii, Inc. and its affiliates.
import onnxruntime
import argparse
from PIL import Image
import torchvision.transforms as transforms
parser = argparse.ArgumentParser()
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
parser.add_argument('--image_path', type=str, required=True)
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
args = parser.parse_args()
def read_image():
# Read a PIL image
image = Image.open(args.image_path)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
normalize,
])
img_tensor = transform(image).unsqueeze(0)
return img_tensor.numpy()
def main():
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(
args.onnx_path, providers=providers, provider_options=provider_options)
ort_inputs = {
ort_session.get_inputs()[0].name: read_image()
}
output = ort_session.run(None, ort_inputs)[0]
print("class id =", output[0].argmax())
if __name__ == "__main__":
main()