File size: 3,988 Bytes
89cf463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -- coding: utf-8 --
# @Time : 2021/11/29
# @Author : ykk648
# @Project : https://github.com/ykk648/AI_power
# -*-coding: utf-8 -*-

import onnxruntime
import numpy as np
from cv2box import MyFpsCounter


def init_session(onnx_path, provider='gpu'):
    if provider == 'gpu':
        providers = (
            "CUDAExecutionProvider",
            {'device_id': 0, }
        )
    elif provider == 'trt':
        providers = (
            'TensorrtExecutionProvider',
            {'trt_engine_cache_enable': True, 'trt_fp16_enable': False, }
        )
    elif provider == 'trt16':
        providers = (
            'TensorrtExecutionProvider',
            {'trt_engine_cache_enable': True, 'trt_fp16_enable': True, }
        )
    elif provider == 'trt8':
        providers = (
            'TensorrtExecutionProvider',
            {'trt_engine_cache_enable': True, 'trt_int8_enable': True, }
        )
    else:
        providers = "CPUExecutionProvider"

    # onnxruntime.set_default_logger_severity(3)
    session_options = onnxruntime.SessionOptions()
    session_options.log_severity_level = 3
    onnx_session = onnxruntime.InferenceSession(onnx_path, session_options, providers=[providers])
    return onnx_session


class OnnxModelPickable:  # This is a wrapper to make the current InferenceSession class pickable.
    def __init__(self, onnx_path, provider='gpu'):
        self.onnx_path = onnx_path
        self.provider = provider
        # self.onnx_session = init_session(self.onnx_path, self.provider)

    def get_input_feed(self, input_name, image_tensor):
        """
        input_feed={self.input_name: image_tensor}
        :param input_name:
        :param image_tensor:
        :return:
        """
        input_feed = {}
        for name in input_name:
            input_feed[name] = image_tensor
        return input_feed

    def get_output_info(self, onnx_session):
        """
        output_name = onnx_session.get_outputs()[0].name
        :param onnx_session:
        :return:
        """
        output_name = []
        output_shape = []
        for node in onnx_session.get_outputs():
            output_name.append(node.name)
            output_shape.append(node.shape)
        return output_name, output_shape

    def get_input_info(self, onnx_session):
        """
        input_name = onnx_session.get_inputs()[0].name
        :param onnx_session:
        :return:
        """
        input_name = []
        input_shape = []
        for node in onnx_session.get_inputs():
            input_name.append(node.name)
            input_shape.append(node.shape)
        return input_name, input_shape

    def forward(self, image_tensor, trans=False):
        '''
        image_tensor = image.transpose(2, 0, 1)
        image_tensor = image_tensor[np.newaxis, :]
        onnx_session.run([output_name], {input_name: x})
        :param image_tensor:
        :return:
        '''
        # 输入数据的类型必须与模型一致,以下三种写法都是可以的
        # scores, boxes = self.onnx_session.run(None, {self.input_name: image_tensor})
        # scores, boxes = self.onnx_session.run(self.output_name, input_feed={self.input_name: image_tensor})
        if trans:
            image_tensor = image_tensor.transpose(2, 0, 1)
            image_tensor = image_tensor[np.newaxis, :]
        image_tensor = np.ascontiguousarray(image_tensor)
        input_name, _ = self.get_input_info(self.onnx_session)
        output_name, _ = self.get_output_info(self.onnx_session)
        input_feed = self.get_input_feed(input_name, image_tensor)
        return self.onnx_session.run(output_name, input_feed=input_feed)

    def __getstate__(self):
        return {
            'onnx_path': self.onnx_path,
            'provider': self.provider,
        }

    def __setstate__(self, values):
        self.onnx_path = values['onnx_path']
        self.provider = values['provider']
        self.onnx_session = init_session(self.onnx_path, self.provider)