File size: 1,980 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import onnxruntime
import torch

providers = [
    ('TensorrtExecutionProvider', {
        'device_id': 0,
        'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
        'trt_fp16_enable': True,
        'trt_engine_cache_enable': True,
    }),
    ('CUDAExecutionProvider', {
        'device_id': 0,
        'arena_extend_strategy': 'kSameAsRequested',
        'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
        'cudnn_conv_algo_search': 'HEURISTIC',
    })
]

def load_onnx(file_path: str):
    assert file_path.endswith(".onnx")
    sess_opt = onnxruntime.SessionOptions()
    ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
    return ort_session


def load_onnx_caller(file_path: str, single_output=False):
    ort_session = load_onnx(file_path)
    def caller(*args):
        torch_input = isinstance(args[0], torch.Tensor)
        if torch_input:
            torch_input_dtype = args[0].dtype
            torch_input_device = args[0].device
            # check all are torch.Tensor and have same dtype and device
            assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
            assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
            assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
            args = [arg.cpu().float().numpy() for arg in args]
        
        ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
        ort_outs = ort_session.run(None, ort_inputs)
        
        if torch_input:
            ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
        
        if single_output:
            return ort_outs[0]
        return ort_outs
    return caller