akhaliq HF staff commited on
Commit
ef1897c
1 Parent(s): 825e175

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mxnet as mx
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from collections import namedtuple
5
+ from mxnet.gluon.data.vision import transforms
6
+ import os
7
+ import gradio as gr
8
+
9
+ from PIL import Image
10
+ import imageio
11
+ import onnxruntime as ort
12
+
13
+ def get_image(path):
14
+ '''
15
+ Using path to image, return the RGB load image
16
+ '''
17
+ img = imageio.imread(path, pilmode='RGB')
18
+ return img
19
+
20
+ # Pre-processing function for ImageNet models using numpy
21
+ def preprocess(img):
22
+ '''
23
+ Preprocessing required on the images for inference with mxnet gluon
24
+ The function takes loaded image and returns processed tensor
25
+ '''
26
+ img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
27
+ img[:, :, 0] -= 123.68
28
+ img[:, :, 1] -= 116.779
29
+ img[:, :, 2] -= 103.939
30
+ img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
31
+ img = img.transpose((2, 0, 1))
32
+ img = np.expand_dims(img, axis=0)
33
+
34
+ return img
35
+
36
+ mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
37
+
38
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
39
+ with open('synset.txt', 'r') as f:
40
+ labels = [l.rstrip() for l in f]
41
+
42
+ os.system("wget https://github.com/AK391/models/raw/main/vision/classification/caffenet/model/caffenet-12.onnx")
43
+
44
+ ort_session = ort.InferenceSession("caffenet-12.onnx")
45
+
46
+
47
+ def predict(path):
48
+ img_batch = preprocess(get_image(path))
49
+
50
+ outputs = ort_session.run(
51
+ None,
52
+ {"data_0": img_batch.astype(np.float32)},
53
+ )
54
+
55
+ a = np.argsort(-outputs[0].flatten())
56
+ results = {}
57
+ for i in a[0:5]:
58
+ results[labels[i]]=float(outputs[0][0][i])
59
+ return results
60
+
61
+
62
+ title="GoogleNet"
63
+ description="GoogLeNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014."
64
+
65
+ examples=[['catonnx.jpg']]
66
+ gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)