Spaces:
Running
on
T4
Running
on
T4
File size: 6,885 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import warnings
from collections import namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import onnxruntime
try:
import tensorrt as trt
except Exception:
trt = None
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class TRTWrapper(torch.nn.Module):
dtype_mapping = {}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix in ('.engine', '.plan')
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.stream = torch.cuda.Stream(device=device)
self.__update_mapping()
self.__init_engine()
self.__init_bindings()
def __update_mapping(self):
self.dtype_mapping.update({
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
})
def __init_engine(self):
logger = trt.Logger(trt.Logger.ERROR)
self.log = partial(logger.log, trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
self.logger = logger
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.is_dynamic = -1 in model.get_binding_shape(0)
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_bindings = num_inputs + num_outputs
self.bindings: List[int] = [0] * self.num_bindings
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
inputs_info.append(Binding(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
outputs_info.append(Binding(name, dtype, shape))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
if not self.is_dynamic:
self.output_tensor = [
torch.empty(o.shape, dtype=o.dtype, device=self.device)
for o in outputs_info
]
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
# create output tensors
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(
size=shape,
dtype=self.output_dtypes[i],
device=self.device)
else:
output = self.output_tensor[i]
outputs.append(output)
self.bindings[j] = output.data_ptr()
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs)
class ORTWrapper(torch.nn.Module):
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__init_session()
self.__init_bindings()
def __init_session(self):
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
session = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
self.session = session
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
self.is_dynamic = False
for i, tensor in enumerate(self.session.get_inputs()):
if any(not isinstance(i, int) for i in tensor.shape):
self.is_dynamic = True
inputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
for i, tensor in enumerate(self.session.get_outputs()):
outputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
self.num_inputs = len(inputs_info)
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[np.ndarray] = [
i.contiguous().cpu().numpy() for i in inputs
]
if not self.is_dynamic:
# make sure input shape is right for static input shape
for i in range(self.num_inputs):
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
outputs = self.session.run([o.name for o in self.outputs_info], {
j.name: contiguous_inputs[i]
for i, j in enumerate(self.inputs_info)
})
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)
|