from transformers import AutoProcessor from transformers import Wav2Vec2ProcessorWithLM from pyctcdecode import build_ctcdecoder from huggingface_hub import Repository import logging import fire logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def exec( kenlm_model_path: str, model_name: str, lm_model_name: str = "", ): if not lm_model_name: lm_model_name = model_name + "_lm" logger.info(f'writing on {lm_model_name}') logger.info(f'loading processor of `{model_name}`') processor = AutoProcessor.from_pretrained(model_name) logger.info(f'done loading `{model_name}`') vocab_dict = processor.tokenizer.get_vocab() sorted_vocab_dict = { k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1]) } logger.info(f'building ctc decoder from {kenlm_model_path}') decoder = build_ctcdecoder( labels=list(sorted_vocab_dict.keys()), kenlm_model_path=kenlm_model_path, ) logger.info('done') processor_with_lm = Wav2Vec2ProcessorWithLM( feature_extractor=processor.feature_extractor, tokenizer=processor.tokenizer, decoder=decoder, ) # repo = Repository( # local_dir=lm_model_name, clone_from=model_name # ) # model_name # repo.push_to_hub() processor_with_lm.save_pretrained(lm_model_name) if __name__ == "__main__": fire.Fire(exec)