Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,33 @@ from mxnet.contrib.onnx.onnx2mx.import_model import import_model
|
|
7 |
import os
|
8 |
import gradio as gr
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
|
11 |
|
12 |
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
|
@@ -15,56 +42,23 @@ with open('synset.txt', 'r') as f:
|
|
15 |
|
16 |
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/alexnet/model/bvlcalexnet-7.onnx")
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
sym, arg_params, aux_params = import_model('bvlcalexnet-7.onnx')
|
21 |
|
22 |
-
Batch = namedtuple('Batch', ['data'])
|
23 |
-
def get_image(path, show=False):
|
24 |
-
img = mx.image.imread(path)
|
25 |
-
if img is None:
|
26 |
-
return None
|
27 |
-
if show:
|
28 |
-
plt.imshow(img.asnumpy())
|
29 |
-
plt.axis('off')
|
30 |
-
return img
|
31 |
-
|
32 |
-
def preprocess(img):
|
33 |
-
transform_fn = transforms.Compose([
|
34 |
-
transforms.Resize(256),
|
35 |
-
transforms.CenterCrop(224),
|
36 |
-
transforms.ToTensor(),
|
37 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
38 |
-
])
|
39 |
-
img = transform_fn(img)
|
40 |
-
img = img.expand_dims(axis=0)
|
41 |
-
return img
|
42 |
|
43 |
def predict(path):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
a = np.argsort(
|
52 |
results = {}
|
53 |
for i in a[0:5]:
|
54 |
-
results[
|
55 |
return results
|
56 |
-
|
57 |
-
# Determine and set context
|
58 |
-
if len(mx.test_utils.list_gpus())==0:
|
59 |
-
ctx = mx.cpu()
|
60 |
-
else:
|
61 |
-
ctx = mx.gpu(0)
|
62 |
-
# Load module
|
63 |
-
|
64 |
-
mod = mx.mod.Module(symbol=sym, context=ctx,data_names=['data_0'],label_names=None)
|
65 |
-
mod.bind(for_training=False, data_shapes=[('data_0', (1,3,224,224))],
|
66 |
-
label_shapes=mod._label_shapes)
|
67 |
-
mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
|
68 |
|
69 |
title="AlexNet"
|
70 |
description="AlexNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2012."
|
|
|
7 |
import os
|
8 |
import gradio as gr
|
9 |
|
10 |
+
from PIL import Image
|
11 |
+
import imageio
|
12 |
+
import onnxruntime as ort
|
13 |
+
|
14 |
+
def get_image(path):
|
15 |
+
'''
|
16 |
+
Using path to image, return the RGB load image
|
17 |
+
'''
|
18 |
+
img = imageio.imread(path, pilmode='RGB')
|
19 |
+
return img
|
20 |
+
|
21 |
+
# Pre-processing function for ImageNet models using numpy
|
22 |
+
def preprocess(img):
|
23 |
+
'''
|
24 |
+
Preprocessing required on the images for inference with mxnet gluon
|
25 |
+
The function takes loaded image and returns processed tensor
|
26 |
+
'''
|
27 |
+
img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
|
28 |
+
img[:, :, 0] -= 123.68
|
29 |
+
img[:, :, 1] -= 116.779
|
30 |
+
img[:, :, 2] -= 103.939
|
31 |
+
img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
|
32 |
+
img = img.transpose((2, 0, 1))
|
33 |
+
img = np.expand_dims(img, axis=0)
|
34 |
+
|
35 |
+
return img
|
36 |
+
|
37 |
mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
|
38 |
|
39 |
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
|
|
|
42 |
|
43 |
os.system("wget https://github.com/AK391/models/raw/main/vision/classification/alexnet/model/bvlcalexnet-7.onnx")
|
44 |
|
45 |
+
ort_session = ort.InferenceSession("bvlcalexnet-7.onnx")
|
|
|
|
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def predict(path):
|
49 |
+
img_batch = preprocess(get_image(path))
|
50 |
+
|
51 |
+
outputs = ort_session.run(
|
52 |
+
None,
|
53 |
+
{"data_0": img_batch.astype(np.float32)},
|
54 |
+
)
|
55 |
+
|
56 |
+
a = np.argsort(-outputs[0].flatten())
|
57 |
results = {}
|
58 |
for i in a[0:5]:
|
59 |
+
results[label[i]]=outputs[0][0][i]
|
60 |
return results
|
61 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
title="AlexNet"
|
64 |
description="AlexNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2012."
|