Spaces:
Sleeping
Sleeping
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
from __future__ import print_function | |
import argparse | |
import os | |
import torch | |
import yaml | |
from wenet.utils.checkpoint import load_checkpoint | |
from wenet.utils.init_model import init_model | |
def get_args(): | |
parser = argparse.ArgumentParser(description="export your script model") | |
parser.add_argument("--config", required=True, help="config file") | |
parser.add_argument("--checkpoint", required=True, help="checkpoint model") | |
parser.add_argument("--output_file", default=None, help="output file") | |
parser.add_argument( | |
"--output_quant_file", default=None, help="output quantized model file" | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
# No need gpu for model export | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
with open(args.config, "r") as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
model = init_model(configs) | |
print(model) | |
load_checkpoint(model, args.checkpoint) | |
# Export jit torch script model | |
if args.output_file: | |
script_model = torch.jit.script(model) | |
script_model.save(args.output_file) | |
print("Export model successfully, see {}".format(args.output_file)) | |
# Export quantized jit torch script model | |
if args.output_quant_file: | |
quantized_model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
print(quantized_model) | |
script_quant_model = torch.jit.script(quantized_model) | |
script_quant_model.save(args.output_quant_file) | |
print( | |
"Export quantized model successfully, " | |
"see {}".format(args.output_quant_file) | |
) | |
if __name__ == "__main__": | |
main() | |