Spaces:
Running
Running
import onnxruntime | |
import numpy as np | |
class OnnxModel(object): | |
def __init__(self, model_path): | |
sess_options = onnxruntime.SessionOptions() | |
# # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization. | |
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | |
# # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs. | |
# sess_options.intra_op_num_threads = multiprocessing.cpu_count() | |
onnx_gpu = (onnxruntime.get_device() == 'GPU') | |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider'] | |
self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers) | |
self._input_names = [item.name for item in self.sess.get_inputs()] | |
self._output_names = [item.name for item in self.sess.get_outputs()] | |
def input_names(self): | |
return self._input_names | |
def output_names(self): | |
return self._output_names | |
def forward(self, inputs): | |
to_list_flag = False | |
if not isinstance(inputs, (tuple, list)): | |
inputs = [inputs] | |
to_list_flag = True | |
input_feed = {name: input for name, input in zip(self.input_names, inputs)} | |
outputs = self.sess.run(self.output_names, input_feed) | |
if (len(self.output_names) == 1) and to_list_flag: | |
return outputs[0] | |
else: | |
return outputs | |
def check_image_dtype_and_shape(image): | |
if not isinstance(image, np.ndarray): | |
raise Exception(f'image is not np.ndarray!') | |
if isinstance(image.dtype, (np.uint8, np.uint16)): | |
raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!') | |
if image.ndim not in {2, 3}: | |
raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!') | |
if image.ndim == 3: | |
num_channels = image.shape[-1] | |
if num_channels not in {1, 3, 4}: | |
raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!') | |