|
|
from loguru import logger
|
|
|
|
|
|
import tensorrt as trt
|
|
|
import torch
|
|
|
from torch2trt import torch2trt
|
|
|
|
|
|
from yolox.exp import get_exp
|
|
|
|
|
|
import argparse
|
|
|
import os
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
def make_parser():
|
|
|
parser = argparse.ArgumentParser("YOLOX ncnn deploy")
|
|
|
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
|
|
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
|
|
|
|
|
parser.add_argument(
|
|
|
"-f",
|
|
|
"--exp_file",
|
|
|
default=None,
|
|
|
type=str,
|
|
|
help="pls input your expriment description file",
|
|
|
)
|
|
|
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
|
|
|
return parser
|
|
|
|
|
|
|
|
|
@logger.catch
|
|
|
def main():
|
|
|
args = make_parser().parse_args()
|
|
|
exp = get_exp(args.exp_file, args.name)
|
|
|
if not args.experiment_name:
|
|
|
args.experiment_name = exp.exp_name
|
|
|
|
|
|
model = exp.get_model()
|
|
|
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
|
|
os.makedirs(file_name, exist_ok=True)
|
|
|
if args.ckpt is None:
|
|
|
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
|
|
|
else:
|
|
|
ckpt_file = args.ckpt
|
|
|
|
|
|
ckpt = torch.load(ckpt_file, map_location="cpu")
|
|
|
|
|
|
|
|
|
model.load_state_dict(ckpt["model"])
|
|
|
logger.info("loaded checkpoint done.")
|
|
|
model.eval()
|
|
|
model.cuda()
|
|
|
model.head.decode_in_inference = False
|
|
|
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
|
|
|
model_trt = torch2trt(
|
|
|
model,
|
|
|
[x],
|
|
|
fp16_mode=True,
|
|
|
log_level=trt.Logger.INFO,
|
|
|
max_workspace_size=(1 << 32),
|
|
|
)
|
|
|
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
|
|
|
logger.info("Converted TensorRT model done.")
|
|
|
engine_file = os.path.join(file_name, "model_trt.engine")
|
|
|
engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine")
|
|
|
with open(engine_file, "wb") as f:
|
|
|
f.write(model_trt.engine.serialize())
|
|
|
|
|
|
shutil.copyfile(engine_file, engine_file_demo)
|
|
|
|
|
|
logger.info("Converted TensorRT model engine file is saved for C++ inference.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|