versae commited on
Commit
6162805
1 Parent(s): 3af58b3

Adding KenLM script

Browse files
Files changed (1) hide show
  1. add_kenlm.py +37 -0
add_kenlm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoProcessor
3
+ from transformers import Wav2Vec2ProcessorWithLM
4
+ from pyctcdecode import build_ctcdecoder
5
+
6
+
7
+ def main(args):
8
+ processor = AutoProcessor.from_pretrained(args.model_name_or_path)
9
+ vocab_dict = processor.tokenizer.get_vocab()
10
+ sorted_vocab_dict = {
11
+ k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])
12
+ }
13
+ decoder = build_ctcdecoder(
14
+ labels=list(sorted_vocab_dict.keys()),
15
+ kenlm_model_path=args.kenlm_model_path,
16
+ )
17
+ processor_with_lm = Wav2Vec2ProcessorWithLM(
18
+ feature_extractor=processor.feature_extractor,
19
+ tokenizer=processor.tokenizer,
20
+ decoder=decoder,
21
+ )
22
+ processor_with_lm.save_pretrained(args.model_name_or_path)
23
+ print(
24
+ f"Run: ~/bin/build_binary language_model/*.arpa language_model/5gram.bin -T $(pwd) && rm language_model/*.arpa")
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument('--model_name_or_path', default="./", help='Model name or path. Defaults to ./')
30
+ parser.add_argument('--kenlm_model_path', required=True, help='Path to KenLM arpa file.')
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ if __name__ == "__main__":
36
+ main(parse_args())
37
+