# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # from __future__ import print_function import argparse import copy import logging import os import sys import torch import yaml from torch.utils.data import DataLoader from wenet.dataset.dataset import Dataset from wenet.paraformer.search.beam_search import build_beam_search from wenet.utils.checkpoint import load_checkpoint from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols from wenet.utils.config import override_config from wenet.utils.init_model import init_model def get_args(): parser = argparse.ArgumentParser(description="recognize with your model") parser.add_argument("--config", required=True, help="config file") parser.add_argument("--test_data", required=True, help="test data file") parser.add_argument( "--data_type", default="raw", choices=["raw", "shard"], help="train and cv data type", ) parser.add_argument( "--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" ) parser.add_argument("--checkpoint", required=True, help="checkpoint model") parser.add_argument("--dict", required=True, help="dict file") parser.add_argument( "--non_lang_syms", help="non-linguistic symbol file. One symbol per line." ) parser.add_argument( "--beam_size", type=int, default=10, help="beam size for search" ) parser.add_argument("--penalty", type=float, default=0.0, help="length penalty") parser.add_argument("--result_file", required=True, help="asr result file") parser.add_argument("--batch_size", type=int, default=16, help="asr result file") parser.add_argument( "--mode", choices=[ "attention", "ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring", "rnnt_greedy_search", "rnnt_beam_search", "rnnt_beam_attn_rescoring", "ctc_beam_td_attn_rescoring", "hlg_onebest", "hlg_rescore", "paraformer_greedy_search", "paraformer_beam_search", ], default="attention", help="decoding mode", ) parser.add_argument( "--search_ctc_weight", type=float, default=1.0, help="ctc weight for nbest generation", ) parser.add_argument( "--search_transducer_weight", type=float, default=0.0, help="transducer weight for nbest generation", ) parser.add_argument( "--ctc_weight", type=float, default=0.0, help="ctc weight for rescoring weight in \ attention rescoring decode mode \ ctc weight for rescoring weight in \ transducer attention rescore decode mode", ) parser.add_argument( "--transducer_weight", type=float, default=0.0, help="transducer weight for rescoring weight in " "transducer attention rescore mode", ) parser.add_argument( "--attn_weight", type=float, default=0.0, help="attention weight for rescoring weight in " "transducer attention rescore mode", ) parser.add_argument( "--decoding_chunk_size", type=int, default=-1, help="""decoding chunk size, <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here""", ) parser.add_argument( "--num_decoding_left_chunks", type=int, default=-1, help="number of left chunks for decoding", ) parser.add_argument( "--simulate_streaming", action="store_true", help="simulate streaming inference" ) parser.add_argument( "--reverse_weight", type=float, default=0.0, help="""right to left weight for attention rescoring decode mode""", ) parser.add_argument( "--bpe_model", default=None, type=str, help="bpe model for english part" ) parser.add_argument( "--override_config", action="append", default=[], help="override yaml config" ) parser.add_argument( "--connect_symbol", default="", type=str, help="used to connect the output characters", ) parser.add_argument( "--word", default="", type=str, help="word file, only used for hlg decode" ) parser.add_argument( "--hlg", default="", type=str, help="hlg file, only used for hlg decode" ) parser.add_argument( "--lm_scale", type=float, default=0.0, help="lm scale for hlg attention rescore decode", ) parser.add_argument( "--decoder_scale", type=float, default=0.0, help="lm scale for hlg attention rescore decode", ) parser.add_argument( "--r_decoder_scale", type=float, default=0.0, help="lm scale for hlg attention rescore decode", ) args = parser.parse_args() print(args) return args def main(): args = get_args() logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" ) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) if ( args.mode in [ "ctc_prefix_beam_search", "attention_rescoring", "paraformer_beam_search", ] and args.batch_size > 1 ): logging.fatal( "decoding mode {} must be running with batch_size == 1".format(args.mode) ) sys.exit(1) with open(args.config, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if len(args.override_config) > 0: configs = override_config(configs, args.override_config) symbol_table = read_symbol_table(args.dict) test_conf = copy.deepcopy(configs["dataset_conf"]) test_conf["filter_conf"]["max_length"] = 102400 test_conf["filter_conf"]["min_length"] = 0 test_conf["filter_conf"]["token_max_length"] = 102400 test_conf["filter_conf"]["token_min_length"] = 0 test_conf["filter_conf"]["max_output_input_ratio"] = 102400 test_conf["filter_conf"]["min_output_input_ratio"] = 0 test_conf["speed_perturb"] = False test_conf["spec_aug"] = False test_conf["spec_sub"] = False test_conf["spec_trim"] = False test_conf["shuffle"] = False test_conf["sort"] = False if "fbank_conf" in test_conf: test_conf["fbank_conf"]["dither"] = 0.0 elif "mfcc_conf" in test_conf: test_conf["mfcc_conf"]["dither"] = 0.0 test_conf["batch_conf"]["batch_type"] = "static" test_conf["batch_conf"]["batch_size"] = args.batch_size non_lang_syms = read_non_lang_symbols(args.non_lang_syms) test_dataset = Dataset( args.data_type, args.test_data, symbol_table, test_conf, args.bpe_model, non_lang_syms, partition=False, ) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs model = init_model(configs) # Load dict char_dict = {v: k for k, v in symbol_table.items()} eos = len(char_dict) - 1 load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") model = model.to(device) model.eval() # Build BeamSearchCIF object if args.mode == "paraformer_beam_search": paraformer_beam_search = build_beam_search(model, args, device) else: paraformer_beam_search = None with torch.no_grad(), open(args.result_file, "w") as fout: for batch_idx, batch in enumerate(test_data_loader): keys, feats, target, feats_lengths, target_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) if args.mode == "attention": hyps, _ = model.recognize( feats, feats_lengths, beam_size=args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) hyps = [hyp.tolist() for hyp in hyps] elif args.mode == "ctc_greedy_search": hyps, _ = model.ctc_greedy_search( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) elif args.mode == "rnnt_greedy_search": assert feats.size(0) == 1 assert "predictor" in configs hyps = model.greedy_search( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) elif args.mode == "rnnt_beam_search": assert feats.size(0) == 1 assert "predictor" in configs hyps = model.beam_search( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, beam_size=args.beam_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ctc_weight=args.search_ctc_weight, transducer_weight=args.search_transducer_weight, ) elif args.mode == "rnnt_beam_attn_rescoring": assert feats.size(0) == 1 assert "predictor" in configs hyps = model.transducer_attention_rescoring( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, beam_size=args.beam_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ctc_weight=args.ctc_weight, transducer_weight=args.transducer_weight, attn_weight=args.attn_weight, reverse_weight=args.reverse_weight, search_ctc_weight=args.search_ctc_weight, search_transducer_weight=args.search_transducer_weight, ) elif args.mode == "ctc_beam_td_attn_rescoring": assert feats.size(0) == 1 assert "predictor" in configs hyps = model.transducer_attention_rescoring( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, beam_size=args.beam_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ctc_weight=args.ctc_weight, transducer_weight=args.transducer_weight, attn_weight=args.attn_weight, reverse_weight=args.reverse_weight, search_ctc_weight=args.search_ctc_weight, search_transducer_weight=args.search_transducer_weight, beam_search_type="ctc", ) # ctc_prefix_beam_search and attention_rescoring only return one # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif args.mode == "ctc_prefix_beam_search": assert feats.size(0) == 1 hyp, _ = model.ctc_prefix_beam_search( feats, feats_lengths, args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) hyps = [hyp] elif args.mode == "attention_rescoring": assert feats.size(0) == 1 hyp, _ = model.attention_rescoring( feats, feats_lengths, args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, ctc_weight=args.ctc_weight, simulate_streaming=args.simulate_streaming, reverse_weight=args.reverse_weight, ) hyps = [hyp] elif args.mode == "hlg_onebest": hyps = model.hlg_onebest( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, hlg=args.hlg, word=args.word, symbol_table=symbol_table, ) elif args.mode == "hlg_rescore": hyps = model.hlg_rescore( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, lm_scale=args.lm_scale, decoder_scale=args.decoder_scale, r_decoder_scale=args.r_decoder_scale, hlg=args.hlg, word=args.word, symbol_table=symbol_table, ) elif args.mode == "paraformer_beam_search": hyps = model.paraformer_beam_search( feats, feats_lengths, beam_search=paraformer_beam_search, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) elif args.mode == "paraformer_greedy_search": hyps = model.paraformer_greedy_search( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming, ) for i, key in enumerate(keys): content = [] for w in hyps[i]: if w == eos: break content.append(char_dict[w]) logging.info("{} {}".format(key, args.connect_symbol.join(content))) fout.write("{} {}\n".format(key, args.connect_symbol.join(content))) if __name__ == "__main__": main()