zyingt's picture
Upload 685 files
0d80816
raw
history blame
No virus
2.59 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from __future__ import print_function
import argparse
import os
import torch
import yaml
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
def get_args():
parser = argparse.ArgumentParser(description="export your script model")
parser.add_argument("--config", required=True, help="config file")
parser.add_argument("--checkpoint", required=True, help="checkpoint model")
parser.add_argument("--output_file", default=None, help="output file")
parser.add_argument(
"--output_quant_file", default=None, help="output quantized model file"
)
args = parser.parse_args()
return args
def main():
args = get_args()
# No need gpu for model export
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
print(model)
load_checkpoint(model, args.checkpoint)
# Export jit torch script model
if args.output_file:
script_model = torch.jit.script(model)
script_model.save(args.output_file)
print("Export model successfully, see {}".format(args.output_file))
# Export quantized jit torch script model
if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file)
print(
"Export quantized model successfully, "
"see {}".format(args.output_quant_file)
)
if __name__ == "__main__":
main()