from loguru import logger import tensorrt as trt import torch from torch2trt import torch2trt from yolox.exp import get_exp import argparse import os import shutil def make_parser(): parser = argparse.ArgumentParser("YOLOX ncnn deploy") parser.add_argument("-expn", "--experiment-name", type=str, default=None) parser.add_argument("-n", "--name", type=str, default=None, help="model name") parser.add_argument( "-f", "--exp_file", default=None, type=str, help="pls input your expriment description file", ) parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path") return parser @logger.catch def main(): args = make_parser().parse_args() exp = get_exp(args.exp_file, args.name) if not args.experiment_name: args.experiment_name = exp.exp_name model = exp.get_model() file_name = os.path.join(exp.output_dir, args.experiment_name) os.makedirs(file_name, exist_ok=True) if args.ckpt is None: ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") else: ckpt_file = args.ckpt ckpt = torch.load(ckpt_file, map_location="cpu") # load the model state dict model.load_state_dict(ckpt["model"]) logger.info("loaded checkpoint done.") model.eval() model.cuda() model.head.decode_in_inference = False x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() model_trt = torch2trt( model, [x], fp16_mode=True, log_level=trt.Logger.INFO, max_workspace_size=(1 << 32), ) torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth")) logger.info("Converted TensorRT model done.") engine_file = os.path.join(file_name, "model_trt.engine") engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine") with open(engine_file, "wb") as f: f.write(model_trt.engine.serialize()) shutil.copyfile(engine_file, engine_file_demo) logger.info("Converted TensorRT model engine file is saved for C++ inference.") if __name__ == "__main__": main()