zyingt's picture
Upload 685 files
0d80816
raw
history blame contribute delete
No virus
18.9 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from __future__ import print_function
import argparse
import os
import copy
import sys
import torch
import yaml
import numpy as np
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
try:
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
except ImportError:
print("Please install onnx and onnxruntime!")
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description="export your script model")
parser.add_argument("--config", required=True, help="config file")
parser.add_argument("--checkpoint", required=True, help="checkpoint model")
parser.add_argument("--output_dir", required=True, help="output directory")
parser.add_argument(
"--chunk_size", required=True, type=int, help="decoding chunk size"
)
parser.add_argument(
"--num_decoding_left_chunks", required=True, type=int, help="cache chunks"
)
parser.add_argument(
"--reverse_weight",
default=0.5,
type=float,
help="reverse_weight in attention_rescoing",
)
args = parser.parse_args()
return args
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
def print_input_output_info(onnx_model, name, prefix="\t\t"):
input_names = [node.name for node in onnx_model.graph.input]
input_shapes = [
[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
]
output_names = [node.name for node in onnx_model.graph.output]
output_shapes = [
[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
]
print("{}{} inputs : {}".format(prefix, name, input_names))
print("{}{} input shapes : {}".format(prefix, name, input_shapes))
print("{}{} outputs: {}".format(prefix, name, output_names))
print("{}{} output shapes : {}".format(prefix, name, output_shapes))
def export_encoder(asr_model, args):
print("Stage-1: export encoder")
encoder = asr_model.encoder
encoder.forward = encoder.forward_chunk
encoder_outpath = os.path.join(args["output_dir"], "encoder.onnx")
print("\tStage-1.1: prepare inputs for encoder")
chunk = torch.randn((args["batch"], args["decoding_window"], args["feature_size"]))
offset = 0
# NOTE(xcsong): The uncertainty of `next_cache_start` only appears
# in the first few chunks, this is caused by dynamic att_cache shape, i,e
# (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
# chunks. One way to ease the ONNX export is to keep `next_cache_start`
# as a fixed value. To do this, for the **first** chunk, if
# left_chunks > 0, we feed real cache & real mask to the model, otherwise
# fake cache & fake mask. In this way, we get:
# 1. 16/-1 mode: next_cache_start == 0 for all chunks
# 2. 16/4 mode: next_cache_start == chunk_size for all chunks
# 3. 16/0 mode: next_cache_start == chunk_size for all chunks
# 4. -1/-1 mode: next_cache_start == 0 for all chunks
# NO MORE DYNAMIC CHANGES!!
#
# NOTE(Mddct): We retain the current design for the convenience of supporting some
# inference frameworks without dynamic shapes. If you're interested in all-in-one
# model that supports different chunks please see:
# https://github.com/wenet-e2e/wenet/pull/1174
if args["left_chunks"] > 0: # 16/4
required_cache_size = args["chunk_size"] * args["left_chunks"]
offset = required_cache_size
# Real cache
att_cache = torch.zeros(
(
args["num_blocks"],
args["head"],
required_cache_size,
args["output_size"] // args["head"] * 2,
)
)
# Real mask
att_mask = torch.ones(
(args["batch"], 1, required_cache_size + args["chunk_size"]),
dtype=torch.bool,
)
att_mask[:, :, :required_cache_size] = 0
elif args["left_chunks"] <= 0: # 16/-1, -1/-1, 16/0
required_cache_size = -1 if args["left_chunks"] < 0 else 0
# Fake cache
att_cache = torch.zeros(
(
args["num_blocks"],
args["head"],
0,
args["output_size"] // args["head"] * 2,
)
)
# Fake mask
att_mask = torch.ones((0, 0, 0), dtype=torch.bool)
cnn_cache = torch.zeros(
(
args["num_blocks"],
args["batch"],
args["output_size"],
args["cnn_module_kernel"] - 1,
)
)
inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache, att_mask)
print(
"\t\tchunk.size(): {}\n".format(chunk.size()),
"\t\toffset: {}\n".format(offset),
"\t\trequired_cache: {}\n".format(required_cache_size),
"\t\tatt_cache.size(): {}\n".format(att_cache.size()),
"\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()),
"\t\tatt_mask.size(): {}\n".format(att_mask.size()),
)
print("\tStage-1.2: torch.onnx.export")
dynamic_axes = {
"chunk": {1: "T"},
"att_cache": {2: "T_CACHE"},
"att_mask": {2: "T_ADD_T_CACHE"},
"output": {1: "T"},
"r_att_cache": {2: "T_CACHE"},
}
# NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
# to avoid padding the last chunk (which usually contains less
# frames than required). For users who want static axes, just pop
# out specific axis.
# if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
# dynamic_axes.pop('chunk')
# dynamic_axes.pop('output')
# if args['left_chunks'] >= 0: # 16/4, 16/0
# # NOTE(xsong): since we feed real cache & real mask into the
# # model when left_chunks > 0, the shape of cache will never
# # be changed.
# dynamic_axes.pop('att_cache')
# dynamic_axes.pop('r_att_cache')
torch.onnx.export(
encoder,
inputs,
encoder_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=[
"chunk",
"offset",
"required_cache_size",
"att_cache",
"cnn_cache",
"att_mask",
],
output_names=["output", "r_att_cache", "r_cnn_cache"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_encoder = onnx.load(encoder_outpath)
for k, v in args.items():
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_encoder)
onnx.helper.printable_graph(onnx_encoder.graph)
# NOTE(xcsong): to add those metadatas we need to reopen
# the file and resave it.
onnx.save(onnx_encoder, encoder_outpath)
print_input_output_info(onnx_encoder, "onnx_encoder")
# Dynamic quantization
model_fp32 = encoder_outpath
model_quant = os.path.join(args["output_dir"], "encoder.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_encoder, done! see {}".format(encoder_outpath))
print("\tStage-1.3: check onnx_encoder and torch_encoder")
torch_output = []
torch_chunk = copy.deepcopy(chunk)
torch_offset = copy.deepcopy(offset)
torch_required_cache_size = copy.deepcopy(required_cache_size)
torch_att_cache = copy.deepcopy(att_cache)
torch_cnn_cache = copy.deepcopy(cnn_cache)
torch_att_mask = copy.deepcopy(att_mask)
for i in range(10):
print(
"\t\ttorch chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i,
list(torch_chunk.size()),
torch_offset,
list(torch_att_cache.size()),
list(torch_cnn_cache.size()),
list(torch_att_mask.size()),
)
)
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args["left_chunks"] > 0: # 16/4
torch_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1
out, torch_att_cache, torch_cnn_cache = encoder(
torch_chunk,
torch_offset,
torch_required_cache_size,
torch_att_cache,
torch_cnn_cache,
torch_att_mask,
)
torch_output.append(out)
torch_offset += out.size(1)
torch_output = torch.cat(torch_output, dim=1)
onnx_output = []
onnx_chunk = to_numpy(chunk)
onnx_offset = np.array((offset)).astype(np.int64)
onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64)
onnx_att_cache = to_numpy(att_cache)
onnx_cnn_cache = to_numpy(cnn_cache)
onnx_att_mask = to_numpy(att_mask)
ort_session = onnxruntime.InferenceSession(encoder_outpath)
input_names = [node.name for node in onnx_encoder.graph.input]
for i in range(10):
print(
"\t\tonnx chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i,
onnx_chunk.shape,
onnx_offset,
onnx_att_cache.shape,
onnx_cnn_cache.shape,
onnx_att_mask.shape,
)
)
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args["left_chunks"] > 0: # 16/4
onnx_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1
ort_inputs = {
"chunk": onnx_chunk,
"offset": onnx_offset,
"required_cache_size": onnx_required_cache_size,
"att_cache": onnx_att_cache,
"cnn_cache": onnx_cnn_cache,
"att_mask": onnx_att_mask,
}
# NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
# will be hardcoded to 0 or chunk_size by ONNX, thus
# required_cache_size and att_mask are no more needed and they will
# be removed by ONNX automatically.
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
ort_outs = ort_session.run(None, ort_inputs)
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
onnx_output.append(ort_outs[0])
onnx_offset += ort_outs[0].shape[1]
onnx_output = np.concatenate(onnx_output, axis=1)
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-05
)
meta = ort_session.get_modelmeta()
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
print("\t\tCheck onnx_encoder, pass!")
def export_ctc(asr_model, args):
print("Stage-2: export ctc")
ctc = asr_model.ctc
ctc.forward = ctc.log_softmax
ctc_outpath = os.path.join(args["output_dir"], "ctc.onnx")
print("\tStage-2.1: prepare inputs for ctc")
hidden = torch.randn(
(
args["batch"],
args["chunk_size"] if args["chunk_size"] > 0 else 16,
args["output_size"],
)
)
print("\tStage-2.2: torch.onnx.export")
dynamic_axes = {"hidden": {1: "T"}, "probs": {1: "T"}}
torch.onnx.export(
ctc,
hidden,
ctc_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=["hidden"],
output_names=["probs"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_ctc = onnx.load(ctc_outpath)
for k, v in args.items():
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_ctc)
onnx.helper.printable_graph(onnx_ctc.graph)
onnx.save(onnx_ctc, ctc_outpath)
print_input_output_info(onnx_ctc, "onnx_ctc")
# Dynamic quantization
model_fp32 = ctc_outpath
model_quant = os.path.join(args["output_dir"], "ctc.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_ctc, done! see {}".format(ctc_outpath))
print("\tStage-2.3: check onnx_ctc and torch_ctc")
torch_output = ctc(hidden)
ort_session = onnxruntime.InferenceSession(ctc_outpath)
onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)})
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-05
)
print("\t\tCheck onnx_ctc, pass!")
def export_decoder(asr_model, args):
print("Stage-3: export decoder")
decoder = asr_model
# NOTE(lzhin): parameters of encoder will be automatically removed
# since they are not used during rescoring.
decoder.forward = decoder.forward_attention_decoder
decoder_outpath = os.path.join(args["output_dir"], "decoder.onnx")
print("\tStage-3.1: prepare inputs for decoder")
# hardcode time->200 nbest->10 len->20, they are dynamic axes.
encoder_out = torch.randn((1, 200, args["output_size"]))
hyps = torch.randint(low=0, high=args["vocab_size"], size=[10, 20])
hyps[:, 0] = args["vocab_size"] - 1 # <sos>
hyps_lens = torch.randint(low=15, high=21, size=[10])
print("\tStage-3.2: torch.onnx.export")
dynamic_axes = {
"hyps": {0: "NBEST", 1: "L"},
"hyps_lens": {0: "NBEST"},
"encoder_out": {1: "T"},
"score": {0: "NBEST", 1: "L"},
"r_score": {0: "NBEST", 1: "L"},
}
inputs = (hyps, hyps_lens, encoder_out, args["reverse_weight"])
torch.onnx.export(
decoder,
inputs,
decoder_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=["hyps", "hyps_lens", "encoder_out", "reverse_weight"],
output_names=["score", "r_score"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_decoder = onnx.load(decoder_outpath)
for k, v in args.items():
meta = onnx_decoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_decoder)
onnx.helper.printable_graph(onnx_decoder.graph)
onnx.save(onnx_decoder, decoder_outpath)
print_input_output_info(onnx_decoder, "onnx_decoder")
model_fp32 = decoder_outpath
model_quant = os.path.join(args["output_dir"], "decoder.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_decoder, done! see {}".format(decoder_outpath))
print("\tStage-3.3: check onnx_decoder and torch_decoder")
torch_score, torch_r_score = decoder(
hyps, hyps_lens, encoder_out, args["reverse_weight"]
)
ort_session = onnxruntime.InferenceSession(decoder_outpath)
input_names = [node.name for node in onnx_decoder.graph.input]
ort_inputs = {
"hyps": to_numpy(hyps),
"hyps_lens": to_numpy(hyps_lens),
"encoder_out": to_numpy(encoder_out),
"reverse_weight": np.array((args["reverse_weight"])),
}
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
onnx_output = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(
to_numpy(torch_score), onnx_output[0], rtol=1e-03, atol=1e-05
)
if args["is_bidirectional_decoder"] and args["reverse_weight"] > 0.0:
np.testing.assert_allclose(
to_numpy(torch_r_score), onnx_output[1], rtol=1e-03, atol=1e-05
)
print("\t\tCheck onnx_decoder, pass!")
def main():
torch.manual_seed(777)
args = get_args()
output_dir = args.output_dir
os.system("mkdir -p " + output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
print(model)
arguments = {}
arguments["output_dir"] = output_dir
arguments["batch"] = 1
arguments["chunk_size"] = args.chunk_size
arguments["left_chunks"] = args.num_decoding_left_chunks
arguments["reverse_weight"] = args.reverse_weight
arguments["output_size"] = configs["encoder_conf"]["output_size"]
arguments["num_blocks"] = configs["encoder_conf"]["num_blocks"]
arguments["cnn_module_kernel"] = configs["encoder_conf"].get("cnn_module_kernel", 1)
arguments["head"] = configs["encoder_conf"]["attention_heads"]
arguments["feature_size"] = configs["input_dim"]
arguments["vocab_size"] = configs["output_dim"]
# NOTE(xcsong): if chunk_size == -1, hardcode to 67
arguments["decoding_window"] = (
(args.chunk_size - 1) * model.encoder.embed.subsampling_rate
+ model.encoder.embed.right_context
+ 1
if args.chunk_size > 0
else 67
)
arguments["encoder"] = configs["encoder"]
arguments["decoder"] = configs["decoder"]
arguments["subsampling_rate"] = model.subsampling_rate()
arguments["right_context"] = model.right_context()
arguments["sos_symbol"] = model.sos_symbol()
arguments["eos_symbol"] = model.eos_symbol()
arguments["is_bidirectional_decoder"] = 1 if model.is_bidirectional_decoder() else 0
# NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
# not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
# streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
# want to use 16/-1 or any other streaming mode in `decoder_main`,
# please export onnx in the same config.
if arguments["left_chunks"] > 0:
assert arguments["chunk_size"] > 0 # -1/4 not supported
export_encoder(model, arguments)
export_ctc(model, arguments)
export_decoder(model, arguments)
if __name__ == "__main__":
main()