ResNet / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
7aaf95a
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
from mxnet.gluon.data.vision import transforms
from mxnet.contrib.onnx.onnx2mx.import_model import import_model
import os
import gradio as gr
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx")
# Enter path to the ONNX model file
sym, arg_params, aux_params = import_model('resnet18-v2-7.onnx')
Batch = namedtuple('Batch', ['data'])
def get_image(path, show=False):
img = mx.image.imread(path)
if img is None:
return None
if show:
plt.imshow(img.asnumpy())
plt.axis('off')
return img
def preprocess(img):
transform_fn = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform_fn(img)
img = img.expand_dims(axis=0)
return img
def predict(path):
img = get_image(path, show=True)
img = preprocess(img)
mod.forward(Batch([img]))
# Take softmax to generate probabilities
scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
# print the top-5 inferences class
scores = np.squeeze(scores)
a = np.argsort(scores)[::-1]
results = {}
for i in a[0:5]:
results[labels[i]] = float(scores[i])
return results
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)
# Load module
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
title="ResNet"
description="ResNet models perform image classification - they take images as input and classify the major object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which contains images from 1000 classes. ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when high accuracy of classification is required."
examples=[['catonnx.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True)