# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # import torch from modules.wenet_extractor.transducer.joint import TransducerJoint from modules.wenet_extractor.transducer.predictor import ( ConvPredictor, EmbeddingPredictor, RNNPredictor, ) from modules.wenet_extractor.transducer.transducer import Transducer from modules.wenet_extractor.transformer.asr_model import ASRModel from modules.wenet_extractor.transformer.cmvn import GlobalCMVN from modules.wenet_extractor.transformer.ctc import CTC from modules.wenet_extractor.transformer.decoder import ( BiTransformerDecoder, TransformerDecoder, ) from modules.wenet_extractor.transformer.encoder import ( ConformerEncoder, TransformerEncoder, ) from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder from modules.wenet_extractor.efficient_conformer.encoder import ( EfficientConformerEncoder, ) from modules.wenet_extractor.paraformer.paraformer import Paraformer from modules.wenet_extractor.cif.predictor import Predictor from modules.wenet_extractor.utils.cmvn import load_cmvn def init_model(configs): if configs["cmvn_file"] is not None: mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float() ) else: global_cmvn = None input_dim = configs["input_dim"] vocab_size = configs["output_dim"] encoder_type = configs.get("encoder", "conformer") decoder_type = configs.get("decoder", "bitransformer") if encoder_type == "conformer": encoder = ConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] ) elif encoder_type == "squeezeformer": encoder = SqueezeformerEncoder( input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] ) elif encoder_type == "efficientConformer": encoder = EfficientConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"], **configs["encoder_conf"]["efficient_conf"] if "efficient_conf" in configs["encoder_conf"] else {}, ) else: encoder = TransformerEncoder( input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] ) if decoder_type == "transformer": decoder = TransformerDecoder( vocab_size, encoder.output_size(), **configs["decoder_conf"] ) else: assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0 assert configs["decoder_conf"]["r_num_blocks"] > 0 decoder = BiTransformerDecoder( vocab_size, encoder.output_size(), **configs["decoder_conf"] ) ctc = CTC(vocab_size, encoder.output_size()) # Init joint CTC/Attention or Transducer model if "predictor" in configs: predictor_type = configs.get("predictor", "rnn") if predictor_type == "rnn": predictor = RNNPredictor(vocab_size, **configs["predictor_conf"]) elif predictor_type == "embedding": predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"]) configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ "embed_size" ] elif predictor_type == "conv": predictor = ConvPredictor(vocab_size, **configs["predictor_conf"]) configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ "embed_size" ] else: raise NotImplementedError("only rnn, embedding and conv type support now") configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][ "output_size" ] configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][ "output_size" ] joint = TransducerJoint(vocab_size, **configs["joint_conf"]) model = Transducer( vocab_size=vocab_size, blank=0, predictor=predictor, encoder=encoder, attention_decoder=decoder, joint=joint, ctc=ctc, **configs["model_conf"], ) elif "paraformer" in configs: predictor = Predictor(**configs["cif_predictor_conf"]) model = Paraformer( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, predictor=predictor, **configs["model_conf"], ) else: model = ASRModel( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, lfmmi_dir=configs.get("lfmmi_dir", ""), **configs["model_conf"], ) return model