Spaces:
Configuration error
Configuration error
# | |
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import os | |
import pdb | |
import sys | |
import logging | |
import argparse | |
import platform | |
import tensorrt as trt | |
import ctypes | |
import numpy as np | |
logging.basicConfig(level=logging.INFO) | |
logging.getLogger("EngineBuilder").setLevel(logging.INFO) | |
log = logging.getLogger("EngineBuilder") | |
def load_plugins(logger: trt.Logger): | |
# 加载插件库 | |
if platform.system().lower() == 'linux': | |
ctypes.CDLL("./checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL) | |
else: | |
ctypes.CDLL("./checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL, winmode=0) | |
# 初始化TensorRT的插件库 | |
trt.init_libnvinfer_plugins(logger, "") | |
class EngineBuilder: | |
""" | |
Parses an ONNX graph and builds a TensorRT engine from it. | |
""" | |
def __init__(self, verbose=False): | |
""" | |
:param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger. | |
""" | |
self.trt_logger = trt.Logger(trt.Logger.INFO) | |
if verbose: | |
self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE | |
trt.init_libnvinfer_plugins(self.trt_logger, namespace="") | |
self.builder = trt.Builder(self.trt_logger) | |
self.config = self.builder.create_builder_config() | |
self.config.max_workspace_size = 12 * (2 ** 30) # 12 GB | |
profile = self.builder.create_optimization_profile() | |
# for face_2dpose_106.onnx | |
# profile.set_shape("data", (1, 3, 192, 192), (1, 3, 192, 192), (1, 3, 192, 192)) | |
# for retinaface_det.onnx | |
# profile.set_shape("input.1", (1, 3, 512, 512), (1, 3, 512, 512), (1, 3, 512, 512)) | |
self.config.add_optimization_profile(profile) | |
# 严格类型约束 | |
self.config.set_flag(trt.BuilderFlag.STRICT_TYPES) | |
self.batch_size = None | |
self.network = None | |
self.parser = None | |
# 加载自定义插件 | |
load_plugins(self.trt_logger) | |
def create_network(self, onnx_path): | |
""" | |
Parse the ONNX graph and create the corresponding TensorRT network definition. | |
:param onnx_path: The path to the ONNX graph to load. | |
""" | |
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | |
self.network = self.builder.create_network(network_flags) | |
self.parser = trt.OnnxParser(self.network, self.trt_logger) | |
onnx_path = os.path.realpath(onnx_path) | |
with open(onnx_path, "rb") as f: | |
if not self.parser.parse(f.read()): | |
log.error("Failed to load ONNX file: {}".format(onnx_path)) | |
for error in range(self.parser.num_errors): | |
log.error(self.parser.get_error(error)) | |
sys.exit(1) | |
inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] | |
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] | |
log.info("Network Description") | |
for input in inputs: | |
self.batch_size = input.shape[0] | |
log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) | |
for output in outputs: | |
log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) | |
# assert self.batch_size > 0 | |
self.builder.max_batch_size = 1 | |
def create_engine( | |
self, | |
engine_path, | |
precision | |
): | |
""" | |
Build the TensorRT engine and serialize it to disk. | |
:param engine_path: The path where to serialize the engine to. | |
:param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. | |
""" | |
engine_path = os.path.realpath(engine_path) | |
engine_dir = os.path.dirname(engine_path) | |
os.makedirs(engine_dir, exist_ok=True) | |
log.info("Building {} Engine in {}".format(precision, engine_path)) | |
if precision == "fp16": | |
if not self.builder.platform_has_fast_fp16: | |
log.warning("FP16 is not supported natively on this platform/device") | |
else: | |
self.config.set_flag(trt.BuilderFlag.FP16) | |
with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: | |
log.info("Serializing engine to file: {:}".format(engine_path)) | |
f.write(engine.serialize()) | |
def main(args): | |
builder = EngineBuilder(args.verbose) | |
builder.create_network(args.onnx) | |
builder.create_engine( | |
args.engine, | |
args.precision | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load") | |
parser.add_argument("-e", "--engine", help="The output path for the TRT engine") | |
parser.add_argument( | |
"-p", | |
"--precision", | |
default="fp16", | |
choices=["fp32", "fp16", "int8"], | |
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'", | |
) | |
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output") | |
args = parser.parse_args() | |
if args.engine is None: | |
args.engine = args.onnx.replace(".onnx", ".trt") | |
main(args) | |