stevengrove
initial commit
186701e
raw
history blame
No virus
6.89 kB
import warnings
from collections import namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import onnxruntime
try:
import tensorrt as trt
except Exception:
trt = None
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class TRTWrapper(torch.nn.Module):
dtype_mapping = {}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix in ('.engine', '.plan')
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.stream = torch.cuda.Stream(device=device)
self.__update_mapping()
self.__init_engine()
self.__init_bindings()
def __update_mapping(self):
self.dtype_mapping.update({
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
})
def __init_engine(self):
logger = trt.Logger(trt.Logger.ERROR)
self.log = partial(logger.log, trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
self.logger = logger
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.is_dynamic = -1 in model.get_binding_shape(0)
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_bindings = num_inputs + num_outputs
self.bindings: List[int] = [0] * self.num_bindings
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
inputs_info.append(Binding(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
outputs_info.append(Binding(name, dtype, shape))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
if not self.is_dynamic:
self.output_tensor = [
torch.empty(o.shape, dtype=o.dtype, device=self.device)
for o in outputs_info
]
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
# create output tensors
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(
size=shape,
dtype=self.output_dtypes[i],
device=self.device)
else:
output = self.output_tensor[i]
outputs.append(output)
self.bindings[j] = output.data_ptr()
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs)
class ORTWrapper(torch.nn.Module):
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__init_session()
self.__init_bindings()
def __init_session(self):
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
session = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
self.session = session
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
self.is_dynamic = False
for i, tensor in enumerate(self.session.get_inputs()):
if any(not isinstance(i, int) for i in tensor.shape):
self.is_dynamic = True
inputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
for i, tensor in enumerate(self.session.get_outputs()):
outputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
self.num_inputs = len(inputs_info)
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[np.ndarray] = [
i.contiguous().cpu().numpy() for i in inputs
]
if not self.is_dynamic:
# make sure input shape is right for static input shape
for i in range(self.num_inputs):
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
outputs = self.session.run([o.name for o in self.outputs_info], {
j.name: contiguous_inputs[i]
for i, j in enumerate(self.inputs_info)
})
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)