Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mxnet as mx
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from collections import namedtuple
|
5 |
+
from mxnet.gluon.data.vision import transforms
|
6 |
+
from mxnet.contrib.onnx.onnx2mx.import_model import import_model
|
7 |
+
import os
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
|
11 |
+
|
12 |
+
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
|
13 |
+
with open('synset.txt', 'r') as f:
|
14 |
+
labels = [l.rstrip() for l in f]
|
15 |
+
|
16 |
+
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/resnet/model/resnet50-v1-12-int8.onnx")
|
17 |
+
|
18 |
+
# Enter path to the ONNX model file
|
19 |
+
|
20 |
+
sym, arg_params, aux_params = import_model('resnet50-v1-12-int8.onnx')
|
21 |
+
|
22 |
+
Batch = namedtuple('Batch', ['data'])
|
23 |
+
def get_image(path, show=False):
|
24 |
+
img = mx.image.imread(path)
|
25 |
+
if img is None:
|
26 |
+
return None
|
27 |
+
if show:
|
28 |
+
plt.imshow(img.asnumpy())
|
29 |
+
plt.axis('off')
|
30 |
+
return img
|
31 |
+
|
32 |
+
def preprocess(img):
|
33 |
+
transform_fn = transforms.Compose([
|
34 |
+
transforms.Resize(256),
|
35 |
+
transforms.CenterCrop(224),
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
38 |
+
])
|
39 |
+
img = transform_fn(img)
|
40 |
+
img = img.expand_dims(axis=0)
|
41 |
+
return img
|
42 |
+
|
43 |
+
def predict(path):
|
44 |
+
img = get_image(path, show=True)
|
45 |
+
img = preprocess(img)
|
46 |
+
mod.forward(Batch([img]))
|
47 |
+
# Take softmax to generate probabilities
|
48 |
+
scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
|
49 |
+
# print the top-5 inferences class
|
50 |
+
scores = np.squeeze(scores)
|
51 |
+
a = np.argsort(scores)[::-1]
|
52 |
+
results = {}
|
53 |
+
for i in a[0:5]:
|
54 |
+
results[labels[i]] = float(scores[i])
|
55 |
+
return results
|
56 |
+
|
57 |
+
# Determine and set context
|
58 |
+
if len(mx.test_utils.list_gpus())==0:
|
59 |
+
ctx = mx.cpu()
|
60 |
+
else:
|
61 |
+
ctx = mx.gpu(0)
|
62 |
+
# Load module
|
63 |
+
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
64 |
+
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],
|
65 |
+
label_shapes=mod._label_shapes)
|
66 |
+
mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
|
67 |
+
|
68 |
+
title="ResNet"
|
69 |
+
description="ResNet models perform image classification - they take images as input and classify the major object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which contains images from 1000 classes. ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when high accuracy of classification is required."
|
70 |
+
|
71 |
+
examples=[['catonnx.jpg']]
|
72 |
+
gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True)
|