Text-to-Speech / modules /wenet_extractor /bin /recognize_onnx_gpu.py
zyingt's picture
Upload 685 files
0d80816
raw
history blame
No virus
11.4 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.utils.common import IGNORE_ID
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.config import override_config
import onnxruntime as rt
import multiprocessing
import numpy as np
try:
from swig_decoders import (
map_batch,
ctc_beam_search_decoder_batch,
TrieVector,
PathTrie,
)
except ImportError:
print(
"Please install ctc decoders first by refering to\n"
+ "https://github.com/Slyne/ctc_decoder.git"
)
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description="recognize with your model")
parser.add_argument("--config", required=True, help="config file")
parser.add_argument("--test_data", required=True, help="test data file")
parser.add_argument(
"--data_type",
default="raw",
choices=["raw", "shard"],
help="train and cv data type",
)
parser.add_argument(
"--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu"
)
parser.add_argument("--dict", required=True, help="dict file")
parser.add_argument("--encoder_onnx", required=True, help="encoder onnx file")
parser.add_argument("--decoder_onnx", required=True, help="decoder onnx file")
parser.add_argument("--result_file", required=True, help="asr result file")
parser.add_argument("--batch_size", type=int, default=32, help="asr result file")
parser.add_argument(
"--mode",
choices=["ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring"],
default="attention_rescoring",
help="decoding mode",
)
parser.add_argument(
"--bpe_model", default=None, type=str, help="bpe model for english part"
)
parser.add_argument(
"--override_config", action="append", default=[], help="override yaml config"
)
parser.add_argument(
"--fp16",
action="store_true",
help="whether to export fp16 model, default false",
)
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s"
)
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(args.override_config) > 0:
configs = override_config(configs, args.override_config)
reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
symbol_table = read_symbol_table(args.dict)
test_conf = copy.deepcopy(configs["dataset_conf"])
test_conf["filter_conf"]["max_length"] = 102400
test_conf["filter_conf"]["min_length"] = 0
test_conf["filter_conf"]["token_max_length"] = 102400
test_conf["filter_conf"]["token_min_length"] = 0
test_conf["filter_conf"]["max_output_input_ratio"] = 102400
test_conf["filter_conf"]["min_output_input_ratio"] = 0
test_conf["speed_perturb"] = False
test_conf["spec_aug"] = False
test_conf["spec_sub"] = False
test_conf["spec_trim"] = False
test_conf["shuffle"] = False
test_conf["sort"] = False
test_conf["fbank_conf"]["dither"] = 0.0
test_conf["batch_conf"]["batch_type"] = "static"
test_conf["batch_conf"]["batch_size"] = args.batch_size
test_dataset = Dataset(
args.data_type,
args.test_data,
symbol_table,
test_conf,
args.bpe_model,
partition=False,
)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
# Init asr model from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
if use_cuda:
EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
EP_list = ["CPUExecutionProvider"]
encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list)
decoder_ort_session = None
if args.mode == "attention_rescoring":
decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list)
# Load dict
vocabulary = []
char_dict = {}
with open(args.dict, "r") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
eos = sos = len(char_dict) - 1
with torch.no_grad(), open(args.result_file, "w") as fout:
for _, batch in enumerate(test_data_loader):
keys, feats, _, feats_lengths, _ = batch
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
if args.fp16:
feats = feats.astype(np.float16)
ort_inputs = {
encoder_ort_session.get_inputs()[0].name: feats,
encoder_ort_session.get_inputs()[1].name: feats_lengths,
}
ort_outs = encoder_ort_session.run(None, ort_inputs)
(
encoder_out,
encoder_out_lens,
ctc_log_probs,
beam_log_probs,
beam_log_probs_idx,
) = ort_outs
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
num_processes = min(multiprocessing.cpu_count(), batch_size)
if args.mode == "ctc_greedy_search":
if beam_size != 1:
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0 : encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0)
elif args.mode in ("ctc_prefix_beam_search", "attention_rescoring"):
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent])
root_dict[i] = PathTrie()
batch_root.append(root_dict[i])
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(
batch_log_probs_seq,
batch_log_probs_ids,
batch_root,
batch_start,
beam_size,
num_processes,
0,
-2,
0.99999,
)
if args.mode == "ctc_prefix_beam_search":
hyps = []
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
if args.mode == "attention_rescoring":
ctc_score, all_hyps = [], []
max_len = 0
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < beam_size:
hyps += (beam_size - cur_len) * [(-float("INF"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
ctc_score.append(cur_ctc_score)
if args.fp16:
ctc_score = np.array(ctc_score, dtype=np.float16)
else:
ctc_score = np.array(ctc_score, dtype=np.float32)
hyps_pad_sos_eos = (
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64)
* IGNORE_ID
)
r_hyps_pad_sos_eos = (
np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64)
* IGNORE_ID
)
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32)
k = 0
for i in range(batch_size):
for j in range(beam_size):
cand = all_hyps[k]
l = len(cand) + 2
hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos]
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
decoder_ort_inputs = {
decoder_ort_session.get_inputs()[0].name: encoder_out,
decoder_ort_session.get_inputs()[1].name: encoder_out_lens,
decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos,
decoder_ort_session.get_inputs()[3].name: hyps_lens_sos,
decoder_ort_session.get_inputs()[-1].name: ctc_score,
}
if reverse_weight > 0:
r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name
decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos
best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0]
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k : k + beam_size][idx]
best_sents.append(cur_best_sent)
k += beam_size
hyps = map_batch(best_sents, vocabulary, num_processes)
for i, key in enumerate(keys):
content = hyps[i]
logging.info("{} {}".format(key, content))
fout.write("{} {}\n".format(key, content))
if __name__ == "__main__":
main()