AlexNet / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
09692c2
raw history blame
No virus
2.44 kB
import mxnet as mx
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
from mxnet.gluon.data.vision import transforms
from mxnet.contrib.onnx.onnx2mx.import_model import import_model
import os
import gradio as gr
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/alexnet/model/bvlcalexnet-7.onnx")
# Enter path to the ONNX model file
sym, arg_params, aux_params = import_model('bvlcalexnet-7.onnx')
Batch = namedtuple('Batch', ['data'])
def get_image(path, show=False):
img = mx.image.imread(path)
if img is None:
return None
if show:
plt.imshow(img.asnumpy())
plt.axis('off')
return img
def preprocess(img):
transform_fn = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform_fn(img)
img = img.expand_dims(axis=0)
return img
def predict(path):
img = get_image(path, show=True)
img = preprocess(img)
mod.forward(Batch([img]))
# Take softmax to generate probabilities
scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
# print the top-5 inferences class
scores = np.squeeze(scores)
a = np.argsort(scores)[::-1]
results = {}
for i in a[0:5]:
results[labels[i]] = float(scores[i])
return results
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)
# Load module
mod = mx.mod.Module(symbol=sym, context=ctx,data_names=['data_0'],label_names=None)
mod.bind(for_training=False, data_shapes=[('data_0', (1,3,224,224))],
label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
title="AlexNet"
description="AlexNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2012."
examples=[['catonnx.jpg']]
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True)