metadata
language: de
datasets:
- common_voice
metrics:
- wer
- cer
tags:
- audio
- automatic-speech-recognition
- speech
license: apache-2.0
model-index:
- name: wav2vec2-xls-r-1b-5gram-german with LM by Florian Zimmermeister @A\\Ware
results:
- task:
name: Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice de
type: common_voice
args: de
metrics:
- name: Test WER
type: wer
value: 4.382541642219636
- name: Test CER
type: cer
value: 1.6235493024026488
Evaluation
The model can be evaluated as follows on the German test data of Common Voice.
import torch
from transformers import AutoModelForCTC, AutoProcessor
from unidecode import unidecode
import re
from datasets import load_dataset, load_metric
import datasets
counter = 0
wer_counter = 0
cer_counter = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
special_chars = [["Ä"," AE "], ["Ö"," OE "], ["Ü"," UE "], ["ä"," ae "], ["ö"," oe "], ["ü"," ue "]]
def clean_text(sentence):
for special in special_chars:
sentence = sentence.replace(special[0], special[1])
sentence = unidecode(sentence)
for special in special_chars:
sentence = sentence.replace(special[1], special[0])
sentence = re.sub("[^a-zA-Z0-9öäüÖÄÜ ,.!?]", " ", sentence)
return sentence
def main(model_id):
print("load model")
model = AutoModelForCTC.from_pretrained(model_id).to(device)
print("load processor")
processor = AutoProcessor.from_pretrained(processor_id)
print("load metrics")
wer = load_metric("wer")
cer = load_metric("cer")
ds = load_dataset("mozilla-foundation/common_voice_8_0","de")
ds = ds["test"]
ds = ds.cast_column(
"audio", datasets.features.Audio(sampling_rate=16_000)
)
def calculate_metrics(batch):
global counter, wer_counter, cer_counter
resampled_audio = batch["audio"]["array"]
input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
with torch.no_grad():
logits = model(input_values.to(device)).logits.cpu().numpy()[0]
decoded = processor.decode(logits)
pred = decoded.text.lower()
ref = clean_text(batch["sentence"]).lower()
wer_result = wer.compute(predictions=[pred], references=[ref])
cer_result = cer.compute(predictions=[pred], references=[ref])
counter += 1
wer_counter += wer_result
cer_counter += cer_result
if counter % 100 == True:
print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")
return batch
ds.map(calculate_metrics, remove_columns=ds.column_names)
print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")
model_id = "flozi00/wav2vec2-xls-r-1b-5gram-german"
main(model_id)