spuun commited on
Commit
da14404
1 Parent(s): 72d3376

Create onnx_.py

Browse files
Files changed (1) hide show
  1. onnx_.py +59 -0
onnx_.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ from functools import lru_cache
5
+ from typing import Optional
6
+
7
+ from hbutils.system import pip_install
8
+
9
+
10
+ def _ensure_onnxruntime():
11
+ try:
12
+ import onnxruntime
13
+ except (ImportError, ModuleNotFoundError):
14
+ logging.warning('Onnx runtime not installed, preparing to install ...')
15
+ if shutil.which('nvidia-smi'):
16
+ logging.info('Installing onnxruntime-gpu ...')
17
+ pip_install(['onnxruntime-gpu'], silent=True)
18
+ else:
19
+ logging.info('Installing onnxruntime (cpu) ...')
20
+ pip_install(['onnxruntime'], silent=True)
21
+
22
+
23
+ _ensure_onnxruntime()
24
+ from onnxruntime import get_available_providers, get_all_providers, InferenceSession, SessionOptions, \
25
+ GraphOptimizationLevel
26
+
27
+ alias = {
28
+ 'gpu': "CUDAExecutionProvider",
29
+ "trt": "TensorrtExecutionProvider",
30
+ }
31
+
32
+
33
+ def get_onnx_provider(provider: Optional[str] = None):
34
+ if not provider:
35
+ if "CUDAExecutionProvider" in get_available_providers():
36
+ return "CUDAExecutionProvider"
37
+ else:
38
+ return "CPUExecutionProvider"
39
+ elif provider.lower() in alias:
40
+ return alias[provider.lower()]
41
+ else:
42
+ for p in get_all_providers():
43
+ if provider.lower() == p.lower() or f'{provider}ExecutionProvider'.lower() == p.lower():
44
+ return p
45
+
46
+ raise ValueError(f'One of the {get_all_providers()!r} expected, '
47
+ f'but unsupported provider {provider!r} found.')
48
+
49
+
50
+ @lru_cache()
51
+ def _open_onnx_model(ckpt: str, provider: str = None) -> InferenceSession:
52
+ options = SessionOptions()
53
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
54
+ provider = provider or get_onnx_provider()
55
+ if provider == "CPUExecutionProvider":
56
+ options.intra_op_num_threads = os.cpu_count()
57
+
58
+ logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
59
+ return InferenceSession(ckpt, options, [provider])