|
import tensorrt as trt |
|
from itertools import tee |
|
|
|
from polygraphy.backend.trt import ( |
|
network_from_onnx_path, |
|
engine_from_network, |
|
save_engine, |
|
Profile, |
|
) |
|
|
|
from polygraphy.backend.trt import CreateConfig |
|
from tensorrt import PreviewFeature, MemoryPoolType |
|
|
|
batch_size = 1 |
|
max_length = 2048 |
|
opt_length = max_length // 2 |
|
|
|
|
|
profiles = [Profile().add( |
|
"input_ids", |
|
min=(batch_size, 1), |
|
opt=(batch_size, opt_length), |
|
max=(batch_size, max_length), |
|
).add( |
|
"position_ids", |
|
min=(batch_size, 2,1), |
|
opt=(batch_size, 2, opt_length), |
|
max=(batch_size, 2,max_length), |
|
).add( |
|
"attention_mask", |
|
min=(batch_size, 1,1,1), |
|
opt=(batch_size, 1,opt_length,opt_length), |
|
max=(batch_size, 1,max_length,max_length), |
|
)] |
|
|
|
|
|
|
|
|
|
|
|
def get_network_definition(network_definition): |
|
def pairwise(iterable): |
|
a, b = tee(iterable) |
|
next(b, None) |
|
return zip(a, b) |
|
|
|
indices = list(range(0, network_definition[1].num_layers)) |
|
for i, i_next in pairwise(indices): |
|
l = network_definition[1].get_layer(i) |
|
l_next = network_definition[1].get_layer(i_next) |
|
|
|
if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]): |
|
continue |
|
|
|
if l.get_output_type(0) != trt.float32: |
|
continue |
|
|
|
if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE: |
|
l.__class__ = getattr(trt, "IElementWiseLayer") |
|
if l.op == trt.ElementWiseOperation.POW: |
|
l.precision = trt.float32 |
|
l.set_output_type(0, trt.float32) |
|
|
|
l_next.precision = trt.float32 |
|
l_next.set_output_type(0, trt.float32) |
|
|
|
return network_definition |
|
|
|
|
|
input_fpath = "./model6b_onnx_pkv/model.onnx" |
|
|
|
|
|
preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805] |
|
|
|
|
|
|
|
trt_inference_config = CreateConfig( |
|
fp16=True, |
|
memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024}, |
|
profiles=profiles, |
|
precision_constraints=("obey"), |
|
preview_features=preview_features |
|
) |
|
|
|
|
|
onnx_network = network_from_onnx_path(input_fpath) |
|
|
|
|
|
network_definition = get_network_definition(onnx_network) |
|
print(network_definition) |
|
print(trt_inference_config) |
|
|
|
trt_engine = engine_from_network(network_definition, trt_inference_config) |
|
print(trt_engine) |
|
|
|
output_fpath = "./model6b_trt_pkv/out.engine" |
|
save_engine(trt_engine, output_fpath) |
|
|