Spaces:
Build error
Build error
| # -*- encoding: utf-8 -*- | |
| # @Author: SWHL | |
| # @Contact: liekkaskono@163.com | |
| import os | |
| import platform | |
| import traceback | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple, Union | |
| import numpy as np | |
| from onnxruntime import ( | |
| GraphOptimizationLevel, | |
| InferenceSession, | |
| SessionOptions, | |
| get_available_providers, | |
| get_device, | |
| ) | |
| from .logger import get_logger | |
| class EP(Enum): | |
| CPU_EP = "CPUExecutionProvider" | |
| CUDA_EP = "CUDAExecutionProvider" | |
| DIRECTML_EP = "DmlExecutionProvider" | |
| class OrtInferSession: | |
| def __init__(self, config: Dict[str, Any]): | |
| self.logger = get_logger("OrtInferSession") | |
| model_path = config.get("model_path", None) | |
| self._verify_model(model_path) | |
| self.cfg_use_cuda = config.get("use_cuda", None) | |
| self.cfg_use_dml = config.get("use_dml", None) | |
| self.had_providers: List[str] = get_available_providers() | |
| EP_list = self._get_ep_list() | |
| sess_opt = self._init_sess_opts(config) | |
| self.session = InferenceSession( | |
| model_path, | |
| sess_options=sess_opt, | |
| providers=EP_list, | |
| ) | |
| self._verify_providers() | |
| def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: | |
| sess_opt = SessionOptions() | |
| sess_opt.log_severity_level = 4 | |
| sess_opt.enable_cpu_mem_arena = False | |
| sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL | |
| cpu_nums = os.cpu_count() | |
| intra_op_num_threads = config.get("intra_op_num_threads", -1) | |
| if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: | |
| sess_opt.intra_op_num_threads = intra_op_num_threads | |
| inter_op_num_threads = config.get("inter_op_num_threads", -1) | |
| if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: | |
| sess_opt.inter_op_num_threads = inter_op_num_threads | |
| return sess_opt | |
| def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: | |
| cpu_provider_opts = { | |
| "arena_extend_strategy": "kSameAsRequested", | |
| } | |
| EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] | |
| cuda_provider_opts = { | |
| "device_id": 0, | |
| "arena_extend_strategy": "kNextPowerOfTwo", | |
| "cudnn_conv_algo_search": "EXHAUSTIVE", | |
| "do_copy_in_default_stream": True, | |
| } | |
| self.use_cuda = self._check_cuda() | |
| if self.use_cuda: | |
| EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) | |
| self.use_directml = self._check_dml() | |
| if self.use_directml: | |
| self.logger.info( | |
| "Windows 10 or above detected, try to use DirectML as primary provider" | |
| ) | |
| directml_options = ( | |
| cuda_provider_opts if self.use_cuda else cpu_provider_opts | |
| ) | |
| EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) | |
| return EP_list | |
| def _check_cuda(self) -> bool: | |
| if not self.cfg_use_cuda: | |
| return False | |
| cur_device = get_device() | |
| if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: | |
| return True | |
| self.logger.warning( | |
| "%s is not in available providers (%s). Use %s inference by default.", | |
| EP.CUDA_EP.value, | |
| self.had_providers, | |
| self.had_providers[0], | |
| ) | |
| self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") | |
| self.logger.info( | |
| "(For reference only) If you want to use GPU acceleration, you must do:" | |
| ) | |
| self.logger.info( | |
| "First, uninstall all onnxruntime pakcages in current environment." | |
| ) | |
| self.logger.info( | |
| "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." | |
| ) | |
| self.logger.info( | |
| "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." | |
| ) | |
| self.logger.info( | |
| "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" | |
| ) | |
| self.logger.info( | |
| "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", | |
| EP.CUDA_EP.value, | |
| ) | |
| return False | |
| def _check_dml(self) -> bool: | |
| if not self.cfg_use_dml: | |
| return False | |
| cur_os = platform.system() | |
| if cur_os != "Windows": | |
| self.logger.warning( | |
| "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", | |
| cur_os, | |
| self.had_providers[0], | |
| ) | |
| return False | |
| cur_window_version = int(platform.release().split(".")[0]) | |
| if cur_window_version < 10: | |
| self.logger.warning( | |
| "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", | |
| cur_window_version, | |
| self.had_providers[0], | |
| ) | |
| return False | |
| if EP.DIRECTML_EP.value in self.had_providers: | |
| return True | |
| self.logger.warning( | |
| "%s is not in available providers (%s). Use %s inference by default.", | |
| EP.DIRECTML_EP.value, | |
| self.had_providers, | |
| self.had_providers[0], | |
| ) | |
| self.logger.info("If you want to use DirectML acceleration, you must do:") | |
| self.logger.info( | |
| "First, uninstall all onnxruntime pakcages in current environment." | |
| ) | |
| self.logger.info( | |
| "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" | |
| ) | |
| self.logger.info( | |
| "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", | |
| EP.DIRECTML_EP.value, | |
| ) | |
| return False | |
| def _verify_providers(self): | |
| session_providers = self.session.get_providers() | |
| first_provider = session_providers[0] | |
| if self.use_cuda and first_provider != EP.CUDA_EP.value: | |
| self.logger.warning( | |
| "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", | |
| EP.CUDA_EP.value, | |
| first_provider, | |
| ) | |
| if self.use_directml and first_provider != EP.DIRECTML_EP.value: | |
| self.logger.warning( | |
| "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", | |
| EP.DIRECTML_EP.value, | |
| first_provider, | |
| ) | |
| def __call__(self, input_content: np.ndarray) -> np.ndarray: | |
| input_dict = dict(zip(self.get_input_names(), [input_content])) | |
| try: | |
| return self.session.run(self.get_output_names(), input_dict) | |
| except Exception as e: | |
| error_info = traceback.format_exc() | |
| raise ONNXRuntimeError(error_info) from e | |
| def get_input_names(self) -> List[str]: | |
| return [v.name for v in self.session.get_inputs()] | |
| def get_output_names(self) -> List[str]: | |
| return [v.name for v in self.session.get_outputs()] | |
| def get_character_list(self, key: str = "character") -> List[str]: | |
| meta_dict = self.session.get_modelmeta().custom_metadata_map | |
| return meta_dict[key].splitlines() | |
| def have_key(self, key: str = "character") -> bool: | |
| meta_dict = self.session.get_modelmeta().custom_metadata_map | |
| if key in meta_dict.keys(): | |
| return True | |
| return False | |
| def _verify_model(model_path: Union[str, Path, None]): | |
| if model_path is None: | |
| raise ValueError("model_path is None!") | |
| model_path = Path(model_path) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"{model_path} does not exists.") | |
| if not model_path.is_file(): | |
| raise FileExistsError(f"{model_path} is not a file.") | |
| class ONNXRuntimeError(Exception): | |
| pass | |