flozi00's picture
Update README.md
4bfed40
---
language: de
datasets:
- common_voice
metrics:
- wer
- cer
tags:
- audio
- automatic-speech-recognition
- speech
- hf-asr-leaderboard
license: apache-2.0
model-index:
- name: wav2vec2-xls-r-1b-5gram-german with LM by Florian Zimmermeister @A\\Ware
results:
- task:
name: Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice de
type: common_voice
args: de
metrics:
- name: Test WER
type: wer
value: 4.382541642219636
- name: Test CER
type: cer
value: 1.6235493024026488
- task:
name: Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice 8 de
type: mozilla-foundation/common_voice_8_0
args: de
metrics:
- name: Test WER
type: wer
value: 4.382541642219636
- name: Test CER
type: cer
value: 1.6235493024026488
---
## Evaluation
The model can be evaluated as follows on the German test data of Common Voice.
```python
import torch
from transformers import AutoModelForCTC, AutoProcessor
from unidecode import unidecode
import re
from datasets import load_dataset, load_metric
import datasets
counter = 0
wer_counter = 0
cer_counter = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
special_chars = [["Ä"," AE "], ["Ö"," OE "], ["Ü"," UE "], ["ä"," ae "], ["ö"," oe "], ["ü"," ue "]]
def clean_text(sentence):
for special in special_chars:
sentence = sentence.replace(special[0], special[1])
sentence = unidecode(sentence)
for special in special_chars:
sentence = sentence.replace(special[1], special[0])
sentence = re.sub("[^a-zA-Z0-9öäüÖÄÜ ,.!?]", " ", sentence)
return sentence
def main(model_id):
print("load model")
model = AutoModelForCTC.from_pretrained(model_id).to(device)
print("load processor")
processor = AutoProcessor.from_pretrained(processor_id)
print("load metrics")
wer = load_metric("wer")
cer = load_metric("cer")
ds = load_dataset("mozilla-foundation/common_voice_8_0","de")
ds = ds["test"]
ds = ds.cast_column(
"audio", datasets.features.Audio(sampling_rate=16_000)
)
def calculate_metrics(batch):
global counter, wer_counter, cer_counter
resampled_audio = batch["audio"]["array"]
input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
with torch.no_grad():
logits = model(input_values.to(device)).logits.cpu().numpy()[0]
decoded = processor.decode(logits)
pred = decoded.text.lower()
ref = clean_text(batch["sentence"]).lower()
wer_result = wer.compute(predictions=[pred], references=[ref])
cer_result = cer.compute(predictions=[pred], references=[ref])
counter += 1
wer_counter += wer_result
cer_counter += cer_result
if counter % 100 == True:
print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")
return batch
ds.map(calculate_metrics, remove_columns=ds.column_names)
print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")
model_id = "flozi00/wav2vec2-xls-r-1b-5gram-german"
main(model_id)
```