File size: 1,225 Bytes
f3832f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import json
import onnxruntime as rt

model_path = 'model/model.onnx'
class_path = 'model/birds_name_mapping.json'

normalise_means = [0.4914, 0.4822, 0.4465]
normalise_stds = [0.2023, 0.1994, 0.2010]

def normalise_image(image):
    image = image.copy()
    for i in range(3):
        image[:, i, :, :] = (image[:, i, :, :] - normalise_means[i]) / normalise_stds[i]
    return image

def load_class_names():
    with open(class_path, 'r') as f:
        class_names = json.load(f)
    return class_names

def predict(inp_image):
    
    class_names = load_class_names()

    image = inp_image
    image = image.transpose((2, 0, 1))

    image = image / 255.0
    image = np.expand_dims(image, axis=0)
    image = normalise_image(image)
    image = image.astype(np.float32)

    sess = rt.InferenceSession(model_path)

    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    output = sess.run([output_name], {input_name: image})[0]
    prob = np.exp(output) / np.sum(np.exp(output), axis=1, keepdims=True)

    top5 = np.argsort(prob[0])[-5:][::-1]

    class_probs = {class_names[str(i)]: float(prob[0][i]) for i in top5}
    print(class_probs)

    return class_probs