|
from transformers import AutoProcessor |
|
from transformers import Wav2Vec2ProcessorWithLM |
|
|
|
from pyctcdecode import build_ctcdecoder |
|
|
|
from huggingface_hub import Repository |
|
|
|
import logging |
|
|
|
import fire |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def exec( |
|
kenlm_model_path: str, |
|
model_name: str, |
|
lm_model_name: str = "", |
|
): |
|
if not lm_model_name: |
|
lm_model_name = model_name + "_lm" |
|
logger.info(f'writing on {lm_model_name}') |
|
logger.info(f'loading processor of `{model_name}`') |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
logger.info(f'done loading `{model_name}`') |
|
|
|
vocab_dict = processor.tokenizer.get_vocab() |
|
sorted_vocab_dict = { |
|
k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1]) |
|
} |
|
|
|
logger.info(f'building ctc decoder from {kenlm_model_path}') |
|
decoder = build_ctcdecoder( |
|
labels=list(sorted_vocab_dict.keys()), |
|
kenlm_model_path=kenlm_model_path, |
|
) |
|
logger.info('done') |
|
|
|
processor_with_lm = Wav2Vec2ProcessorWithLM( |
|
feature_extractor=processor.feature_extractor, |
|
tokenizer=processor.tokenizer, |
|
decoder=decoder, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
processor_with_lm.save_pretrained(lm_model_name) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(exec) |
|
|