File size: 587 Bytes
56b9061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b533667
56b9061
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np
import onnxruntime as rt

onnx_path = 'model/model.onnx'

def predict(img):
    session = rt.InferenceSession(onnx_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    img = np.array(img).astype(np.float32)
    img = img.reshape(1, 1, 256, 256)
    img = img / 255.0
    pred = session.run([output_name], {input_name: img})[0]
    pred = np.exp(pred) / np.sum(np.exp(pred), axis=1, keepdims=True)

    class_probs = {'No Substructure': float(pred[0][0]), 'Substructure': float(pred[0][1])}
    return class_probs