new eval script
Browse files
README.md
CHANGED
@@ -140,6 +140,87 @@ result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
|
140 |
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
141 |
```
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
**Test Result**: 15.80 %
|
144 |
|
145 |
|
|
|
140 |
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
141 |
```
|
142 |
|
143 |
+
The model can also be evaluated with in 10% chunks which needs less ressources (to be tested).
|
144 |
+
|
145 |
+
```
|
146 |
+
import torch
|
147 |
+
import torchaudio
|
148 |
+
from datasets import load_dataset, load_metric
|
149 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
150 |
+
import re
|
151 |
+
import jiwer
|
152 |
+
lang_id = "de"
|
153 |
+
|
154 |
+
processor = Wav2Vec2Processor.from_pretrained("marcel/wav2vec2-large-xlsr-53-german")
|
155 |
+
model = Wav2Vec2ForCTC.from_pretrained("marcel/wav2vec2-large-xlsr-53-german")
|
156 |
+
model.to("cuda")
|
157 |
+
|
158 |
+
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\”\�\カ\æ\無\ན\カ\臣\ѹ\…\«\»\ð\ı\„\幺\א\ב\比\ш\ע\)\ứ\в\œ\ч\+\—\ш\‚\נ\м\ń\乡\$\=\ש\ф\支\(\°\и\к\̇]'
|
159 |
+
substitutions = {
|
160 |
+
'e' : '[\ə\é\ě\ę\ê\ế\ế\ë\ė\е]',
|
161 |
+
'o' : '[\ō\ô\ô\ó\ò\ø\ọ\ŏ\õ\ő\о]',
|
162 |
+
'a' : '[\á\ā\ā\ă\ã\å\â\à\ą\а]',
|
163 |
+
'c' : '[\č\ć\ç\с]',
|
164 |
+
'l' : '[\ł]',
|
165 |
+
'u' : '[\ú\ū\ứ\ů]',
|
166 |
+
'und' : '[\&]',
|
167 |
+
'r' : '[\ř]',
|
168 |
+
'y' : '[\ý]',
|
169 |
+
's' : '[\ś\š\ș\ş]',
|
170 |
+
'i' : '[\ī\ǐ\í\ï\î\ï]',
|
171 |
+
'z' : '[\ź\ž\ź\ż]',
|
172 |
+
'n' : '[\ñ\ń\ņ]',
|
173 |
+
'g' : '[\ğ]',
|
174 |
+
'ss' : '[\ß]',
|
175 |
+
't' : '[\ț\ť]',
|
176 |
+
'd' : '[\ď\đ]',
|
177 |
+
"'": '[\ʿ\་\’\`\´\ʻ\`\‘]',
|
178 |
+
'p': '\р'
|
179 |
+
}
|
180 |
+
resampler = torchaudio.transforms.Resample(48_000, 16_000)
|
181 |
+
|
182 |
+
# Preprocessing the datasets.
|
183 |
+
# We need to read the aduio files as arrays
|
184 |
+
def speech_file_to_array_fn(batch):
|
185 |
+
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
|
186 |
+
for x in substitutions:
|
187 |
+
batch["sentence"] = re.sub(substitutions[x], x, batch["sentence"])
|
188 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
189 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
190 |
+
batch["speech"] = resampler(speech_array).squeeze().numpy()
|
191 |
+
return batch
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
# Preprocessing the datasets.
|
196 |
+
# We need to read the aduio files as arrays
|
197 |
+
def evaluate(batch):
|
198 |
+
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
199 |
+
|
200 |
+
with torch.no_grad():
|
201 |
+
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
202 |
+
|
203 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
204 |
+
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
205 |
+
return batch
|
206 |
+
|
207 |
+
H, S, D, I = 0, 0, 0, 0
|
208 |
+
for i in range(10):
|
209 |
+
print("test["+str(10*i)+"%:"+str(10*(i+1))+"%]")
|
210 |
+
test_dataset = load_dataset("common_voice", "de", split="test["+str(10*i)+"%:"+str(10*(i+1))+"%]")
|
211 |
+
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
212 |
+
result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
213 |
+
predictions = result["pred_strings"]
|
214 |
+
targets = result["sentence"]
|
215 |
+
chunk_metrics = jiwer.compute_measures(targets, predictions)
|
216 |
+
H = H + chunk_metrics["hits"]
|
217 |
+
S = S + chunk_metrics["substitutions"]
|
218 |
+
D = D + chunk_metrics["deletions"]
|
219 |
+
I = I + chunk_metrics["insertions"]
|
220 |
+
WER = float(S + D + I) / float(H + S + D)
|
221 |
+
print("WER: {:2f}".format(WER*100).mean())
|
222 |
+
```
|
223 |
+
|
224 |
**Test Result**: 15.80 %
|
225 |
|
226 |
|