cahya commited on
Commit
f551c27
1 Parent(s): 133f88a

add run_evaluation.py

Browse files
Files changed (1) hide show
  1. run_evaluation.py +5 -3
run_evaluation.py CHANGED
@@ -16,13 +16,13 @@ python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o
16
 
17
 
18
  class KenLM:
19
- def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
20
  self.num_workers = num_workers
21
  self.beam_width = beam_width
22
  vocab_dict = tokenizer.get_vocab()
23
  self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
24
  self.vocabulary = self.vocabulary[:-2]
25
- self.decoder = build_ctcdecoder(self.vocabulary, model_name)
26
 
27
  @staticmethod
28
  def lm_postprocess(text):
@@ -52,6 +52,8 @@ def main():
52
  help="Batch size")
53
  parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
54
  help="Path to KenLM model")
 
 
55
  parser.add_argument("--num_workers", type=int, required=False, default=8,
56
  help="KenLM's number of workers")
57
  parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
@@ -67,7 +69,7 @@ def main():
67
  model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
68
  kenlm = None
69
  if args.kenlm:
70
- kenlm = KenLM(processor.tokenizer, args.kenlm)
71
 
72
  # Preprocessing the datasets.
73
  # We need to read the audio files as arrays
 
16
 
17
 
18
  class KenLM:
19
+ def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128):
20
  self.num_workers = num_workers
21
  self.beam_width = beam_width
22
  vocab_dict = tokenizer.get_vocab()
23
  self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
24
  self.vocabulary = self.vocabulary[:-2]
25
+ self.decoder = build_ctcdecoder(self.vocabulary, model_name, unigrams=unigrams)
26
 
27
  @staticmethod
28
  def lm_postprocess(text):
 
52
  help="Batch size")
53
  parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
54
  help="Path to KenLM model")
55
+ parser.add_argument("-u", "--unigrams", type=str, required=False, default=False,
56
+ help="Path to unigrams file")
57
  parser.add_argument("--num_workers", type=int, required=False, default=8,
58
  help="KenLM's number of workers")
59
  parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
 
69
  model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
70
  kenlm = None
71
  if args.kenlm:
72
+ kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams)
73
 
74
  # Preprocessing the datasets.
75
  # We need to read the audio files as arrays