insecta / insectid /base.py
admin
sync
9d1ee0a
raw
history blame
No virus
2.22 kB
import onnxruntime
import numpy as np
class OnnxModel(object):
def __init__(self, model_path):
sess_options = onnxruntime.SessionOptions()
# # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
# sess_options.intra_op_num_threads = multiprocessing.cpu_count()
onnx_gpu = (onnxruntime.get_device() == 'GPU')
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
self._input_names = [item.name for item in self.sess.get_inputs()]
self._output_names = [item.name for item in self.sess.get_outputs()]
@property
def input_names(self):
return self._input_names
@property
def output_names(self):
return self._output_names
def forward(self, inputs):
to_list_flag = False
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
to_list_flag = True
input_feed = {name: input for name, input in zip(self.input_names, inputs)}
outputs = self.sess.run(self.output_names, input_feed)
if (len(self.output_names) == 1) and to_list_flag:
return outputs[0]
else:
return outputs
def check_image_dtype_and_shape(image):
if not isinstance(image, np.ndarray):
raise Exception(f'image is not np.ndarray!')
if isinstance(image.dtype, (np.uint8, np.uint16)):
raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
if image.ndim not in {2, 3}:
raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
if image.ndim == 3:
num_channels = image.shape[-1]
if num_channels not in {1, 3, 4}:
raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')