Spaces:
Running
Running
| import argparse | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| try: | |
| import tensorrt as trt | |
| except Exception: | |
| trt = None | |
| import warnings | |
| import numpy as np | |
| import torch | |
| warnings.filterwarnings(action='ignore', category=DeprecationWarning) | |
| class EngineBuilder: | |
| def __init__( | |
| self, | |
| checkpoint: Union[str, Path], | |
| opt_shape: Union[Tuple, List] = (1, 3, 640, 640), | |
| device: Optional[Union[str, int, torch.device]] = None) -> None: | |
| checkpoint = Path(checkpoint) if isinstance(checkpoint, | |
| str) else checkpoint | |
| assert checkpoint.exists() and checkpoint.suffix == '.onnx' | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| elif isinstance(device, int): | |
| device = torch.device(f'cuda:{device}') | |
| self.checkpoint = checkpoint | |
| self.opt_shape = np.array(opt_shape, dtype=np.float32) | |
| self.device = device | |
| def __build_engine(self, | |
| scale: Optional[List[List]] = None, | |
| fp16: bool = True, | |
| with_profiling: bool = True) -> None: | |
| logger = trt.Logger(trt.Logger.WARNING) | |
| trt.init_libnvinfer_plugins(logger, namespace='') | |
| builder = trt.Builder(logger) | |
| config = builder.create_builder_config() | |
| config.max_workspace_size = torch.cuda.get_device_properties( | |
| self.device).total_memory | |
| flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |
| network = builder.create_network(flag) | |
| parser = trt.OnnxParser(network, logger) | |
| if not parser.parse_from_file(str(self.checkpoint)): | |
| raise RuntimeError( | |
| f'failed to load ONNX file: {str(self.checkpoint)}') | |
| inputs = [network.get_input(i) for i in range(network.num_inputs)] | |
| outputs = [network.get_output(i) for i in range(network.num_outputs)] | |
| profile = None | |
| dshape = -1 in network.get_input(0).shape | |
| if dshape: | |
| profile = builder.create_optimization_profile() | |
| if scale is None: | |
| scale = np.array( | |
| [[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]], | |
| dtype=np.float32) | |
| scale = (self.opt_shape * scale).astype(np.int32) | |
| elif isinstance(scale, List): | |
| scale = np.array(scale, dtype=np.int32) | |
| assert scale.shape[0] == 3, 'Input a wrong scale list' | |
| else: | |
| raise NotImplementedError | |
| for inp in inputs: | |
| logger.log( | |
| trt.Logger.WARNING, | |
| f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') | |
| if dshape: | |
| profile.set_shape(inp.name, *scale) | |
| for out in outputs: | |
| logger.log( | |
| trt.Logger.WARNING, | |
| f'output "{out.name}" with shape{out.shape} {out.dtype}') | |
| if fp16 and builder.platform_has_fast_fp16: | |
| config.set_flag(trt.BuilderFlag.FP16) | |
| self.weight = self.checkpoint.with_suffix('.engine') | |
| if dshape: | |
| config.add_optimization_profile(profile) | |
| if with_profiling: | |
| config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED | |
| with builder.build_engine(network, config) as engine: | |
| self.weight.write_bytes(engine.serialize()) | |
| logger.log( | |
| trt.Logger.WARNING, f'Build tensorrt engine finish.\n' | |
| f'Save in {str(self.weight.absolute())}') | |
| def build(self, | |
| scale: Optional[List[List]] = None, | |
| fp16: bool = True, | |
| with_profiling=True): | |
| self.__build_engine(scale, fp16, with_profiling) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('checkpoint', help='Checkpoint file') | |
| parser.add_argument( | |
| '--img-size', | |
| nargs='+', | |
| type=int, | |
| default=[640, 640], | |
| help='Image size of height and width') | |
| parser.add_argument( | |
| '--device', type=str, default='cuda:0', help='TensorRT builder device') | |
| parser.add_argument( | |
| '--scales', | |
| type=str, | |
| default='[[1,3,640,640],[1,3,640,640],[1,3,640,640]]', | |
| help='Input scales for build dynamic input shape engine') | |
| parser.add_argument( | |
| '--fp16', action='store_true', help='Build model with fp16 mode') | |
| args = parser.parse_args() | |
| args.img_size *= 2 if len(args.img_size) == 1 else 1 | |
| return args | |
| def main(args): | |
| img_size = (1, 3, *args.img_size) | |
| try: | |
| scales = eval(args.scales) | |
| except Exception: | |
| print('Input scales is not a python variable') | |
| print('Set scales default None') | |
| scales = None | |
| builder = EngineBuilder(args.checkpoint, img_size, args.device) | |
| builder.build(scales, fp16=args.fp16) | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| main(args) | |