Spaces:
Runtime error
Runtime error
from loguru import logger | |
import tensorrt as trt | |
import torch | |
from torch2trt import torch2trt | |
from yolox.exp import get_exp | |
import argparse | |
import os | |
import shutil | |
def make_parser(): | |
parser = argparse.ArgumentParser("YOLOX ncnn deploy") | |
parser.add_argument("-expn", "--experiment-name", type=str, default=None) | |
parser.add_argument("-n", "--name", type=str, default=None, help="model name") | |
parser.add_argument( | |
"-f", | |
"--exp_file", | |
default=None, | |
type=str, | |
help="pls input your expriment description file", | |
) | |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path") | |
return parser | |
def main(): | |
args = make_parser().parse_args() | |
exp = get_exp(args.exp_file, args.name) | |
if not args.experiment_name: | |
args.experiment_name = exp.exp_name | |
model = exp.get_model() | |
file_name = os.path.join(exp.output_dir, args.experiment_name) | |
os.makedirs(file_name, exist_ok=True) | |
if args.ckpt is None: | |
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") | |
else: | |
ckpt_file = args.ckpt | |
ckpt = torch.load(ckpt_file, map_location="cpu") | |
# load the model state dict | |
model.load_state_dict(ckpt["model"]) | |
logger.info("loaded checkpoint done.") | |
model.eval() | |
model.cuda() | |
model.head.decode_in_inference = False | |
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() | |
model_trt = torch2trt( | |
model, | |
[x], | |
fp16_mode=True, | |
log_level=trt.Logger.INFO, | |
max_workspace_size=(1 << 32), | |
) | |
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth")) | |
logger.info("Converted TensorRT model done.") | |
engine_file = os.path.join(file_name, "model_trt.engine") | |
engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine") | |
with open(engine_file, "wb") as f: | |
f.write(model_trt.engine.serialize()) | |
shutil.copyfile(engine_file, engine_file_demo) | |
logger.info("Converted TensorRT model engine file is saved for C++ inference.") | |
if __name__ == "__main__": | |
main() | |