cahya commited on
Commit
133f88a
1 Parent(s): d2fdc8c

add run_evaluation.py

Browse files
Files changed (2) hide show
  1. eval_kenlm.py +1 -0
  2. run_evaluation.py +120 -0
eval_kenlm.py CHANGED
@@ -106,6 +106,7 @@ def main(args):
106
  set_seed(42) # set the random seed to have reproducible result.
107
  processor = Wav2Vec2Processor.from_pretrained(args.model_id)
108
  model = Wav2Vec2ForCTC.from_pretrained(args.model_id)
 
109
  kenlm = KenLM(processor.tokenizer, "language_model/5gram.bin", unigrams="language_model/unigrams.txt")
110
 
111
  # map function to decode audio
106
  set_seed(42) # set the random seed to have reproducible result.
107
  processor = Wav2Vec2Processor.from_pretrained(args.model_id)
108
  model = Wav2Vec2ForCTC.from_pretrained(args.model_id)
109
+ model.to(args.device)
110
  kenlm = KenLM(processor.tokenizer, "language_model/5gram.bin", unigrams="language_model/unigrams.txt")
111
 
112
  # map function to decode audio
run_evaluation.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from datasets import load_dataset, load_metric
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, set_seed
5
+ import argparse
6
+ from pyctcdecode import build_ctcdecoder
7
+ from multiprocessing import Pool
8
+ import re
9
+
10
+ """
11
+ This is the script to run the prediction on test set of Turkish speech dataset.
12
+ Usage:
13
+ python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o <output file name> \
14
+ -b <optional batch size, default=16>
15
+ """
16
+
17
+
18
+ class KenLM:
19
+ def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
20
+ self.num_workers = num_workers
21
+ self.beam_width = beam_width
22
+ vocab_dict = tokenizer.get_vocab()
23
+ self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
24
+ self.vocabulary = self.vocabulary[:-2]
25
+ self.decoder = build_ctcdecoder(self.vocabulary, model_name)
26
+
27
+ @staticmethod
28
+ def lm_postprocess(text):
29
+ return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()
30
+
31
+ def decode(self, logits):
32
+ probs = logits.cpu().numpy()
33
+ # probs = logits.numpy()
34
+ with Pool(self.num_workers) as pool:
35
+ text = self.decoder.decode_batch(pool, probs)
36
+ text = [KenLM.lm_postprocess(x) for x in text]
37
+ return text
38
+
39
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\'\`…\’»«]'
40
+
41
+ def main():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("-m", "--model_name", type=str, required=True,
44
+ help="The wav2vec2 model name")
45
+ parser.add_argument("-n", "--name", type=str, required=True,
46
+ help="The name of dataset")
47
+ parser.add_argument("-c", "--config_name", type=str, required=True,
48
+ help="The config name of the dataset")
49
+ parser.add_argument("-d", "--data_dir", type=str, required=False, default=None,
50
+ help="The directory contains the dataset")
51
+ parser.add_argument("-b", "--batch_size", type=int, required=False, default=16,
52
+ help="Batch size")
53
+ parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
54
+ help="Path to KenLM model")
55
+ parser.add_argument("--num_workers", type=int, required=False, default=8,
56
+ help="KenLM's number of workers")
57
+ parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
58
+ help="KenLM's beam width")
59
+ parser.add_argument("--test_pct", type=int, required=False, default=100,
60
+ help="Percentage of the test set")
61
+ parser.add_argument("--cpu", required=False, action='store_true',
62
+ help="Force to use CPU")
63
+ args = parser.parse_args()
64
+
65
+ set_seed(42) # set the random seed to have reproducible result.
66
+ processor = Wav2Vec2Processor.from_pretrained(args.model_name)
67
+ model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
68
+ kenlm = None
69
+ if args.kenlm:
70
+ kenlm = KenLM(processor.tokenizer, args.kenlm)
71
+
72
+ # Preprocessing the datasets.
73
+ # We need to read the audio files as arrays
74
+
75
+ def speech_file_to_array_fn(batch):
76
+ if "audio" in batch:
77
+ speech_array = torch.tensor(batch["audio"]["array"])
78
+ resampler = torchaudio.transforms.Resample(48_000, 16_000)
79
+ else:
80
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
81
+ resampler = torchaudio.transforms.Resample(sampling_rate, 16_000)
82
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
83
+ return batch
84
+
85
+ def remove_special_characters(batch):
86
+ batch["norm_text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower().strip()
87
+ return batch
88
+
89
+ lg_test = load_dataset(args.name, args.config_name, data_dir=args.data_dir,
90
+ split=f"test[:{args.test_pct}%]", use_auth_token=True)
91
+ if args.cpu:
92
+ device = "cpu"
93
+ else:
94
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95
+ lg_test = lg_test.map(speech_file_to_array_fn)
96
+ lg_test = lg_test.map(remove_special_characters)
97
+ model = model.to(device)
98
+ wer = load_metric("wer")
99
+
100
+ batch_size = args.batch_size
101
+ def evaluate(batch):
102
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
103
+
104
+ with torch.no_grad():
105
+ logits = model(inputs.input_values.to(device)).logits
106
+
107
+ if args.kenlm:
108
+ batch["pred_strings"] = kenlm.decode(logits)
109
+ else:
110
+ predicted_ids = torch.argmax(logits, dim=-1)
111
+ batch["pred_strings"] = processor.batch_decode(predicted_ids)
112
+ return batch
113
+
114
+ result = lg_test.map(evaluate, batched=True, batch_size=batch_size)
115
+ WER = 100 * wer.compute(predictions=result["pred_strings"], references=result["norm_text"])
116
+ print(f"WER: {WER:.2f}%")
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()