|
import time |
|
from urllib.request import urlopen |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import onnxruntime as ort |
|
import torch |
|
from PIL import Image |
|
|
|
from imagenet_classes import IMAGENET2012_CLASSES |
|
|
|
|
|
def read_image(image: Image.Image): |
|
image = image.convert("RGB") |
|
img_numpy = np.array(image).astype(np.float32) |
|
img_numpy = img_numpy.transpose(2, 0, 1) |
|
img_numpy = np.expand_dims(img_numpy, axis=0) |
|
return img_numpy |
|
|
|
|
|
providers = ["CPUExecutionProvider"] |
|
|
|
session = ort.InferenceSession("merged_model_compose.onnx", providers=providers) |
|
|
|
input_name = session.get_inputs()[0].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
|
|
def predict(img): |
|
output = session.run([output_name], {input_name: read_image(img)}) |
|
output = torch.from_numpy(output[0]) |
|
|
|
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5) |
|
|
|
im_classes = list(IMAGENET2012_CLASSES.values()) |
|
class_names = [im_classes[i] for i in top5_class_indices[0]] |
|
|
|
results = { |
|
name: float(prob) for name, prob in zip(class_names, top5_probabilities[0]) |
|
} |
|
return results |
|
|
|
|
|
|
|
example_image = "beignets-task-guide.png" |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="Image Classification with ONNX using EVA02 model", |
|
description="Blog post: https://dicksonneoh.com/portfolio/supercharge_your_pytorch_image_models/", |
|
examples=[example_image], |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|