vumichien's picture
Update app.py
ef25c28
import torch
import torch.nn.functional as F
from torch import optim
from torch.nn import Module
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import numpy as np
import onnxruntime
import gradio as gr
import json
def get_image(x):
return x.split(', ')[0]
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# Transform image to ToTensor
def transform_image(myarray):
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image = Image.fromarray(np.uint8(myarray)).convert('RGB')
image = transform(image).unsqueeze(0)
return image
f = open('imagenet_label.json',)
label_map=json.load(f)
f.close()
# Load list of images for similarity
sub_test_list = open('img_list.txt', 'r')
sub_test_list = [i.strip() for i in sub_test_list]
# Load images embedding for similarity
embeddings = torch.load('embeddings.pt')
# Configure
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
# Load model
PATH = 'model_onnx.onnx'
ort_session = onnxruntime.InferenceSession(PATH, sess_options=options)
input_name = ort_session.get_inputs()[0].name
# predict multi-level classification
def get_classification(img):
image_tensor = transform_image(img)
ort_inputs = {input_name: to_numpy(image_tensor)}
x = ort_session.run(None, ort_inputs)
predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist()
result = {}
for i in predictions:
label = label_map[str(i)]
prob = x[0][0, i].item()
result[label] = prob
return result
iface = gr.Interface(
get_classification,
gr.inputs.Image(shape=(200, 200)),
outputs="label",
title = 'Universal Image Classification',
description = "Imagenet classification from Mobilenetv3 converting to ONNX runtime",
article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>.",
)
iface.launch()