|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import megengine as mge |
|
import numpy as np |
|
from megengine import jit |
|
|
|
from build import build_and_load |
|
|
|
|
|
def make_parser(): |
|
parser = argparse.ArgumentParser("YOLOX Demo Dump") |
|
parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name") |
|
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") |
|
parser.add_argument( |
|
"--dump_path", default="model.mge", help="path to save the dumped model" |
|
) |
|
return parser |
|
|
|
|
|
def dump_static_graph(model, graph_name="model.mge"): |
|
model.eval() |
|
model.head.decode_in_inference = False |
|
|
|
data = mge.Tensor(np.random.random((1, 3, 640, 640))) |
|
|
|
@jit.trace(capture_as_const=True) |
|
def pred_func(data): |
|
outputs = model(data) |
|
return outputs |
|
|
|
pred_func(data) |
|
pred_func.dump( |
|
graph_name, |
|
arg_names=["data"], |
|
optimize_for_inference=True, |
|
enable_fuse_conv_bias_nonlinearity=True, |
|
) |
|
|
|
|
|
def main(args): |
|
model = build_and_load(args.ckpt, name=args.name) |
|
dump_static_graph(model, args.dump_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = make_parser().parse_args() |
|
main(args) |
|
|