flozi00's picture
Update README.md
1aa754b
|
raw
history blame
2.95 kB
metadata
language: de
datasets:
  - common_voice
metrics:
  - wer
  - cer
tags:
  - audio
  - automatic-speech-recognition
  - speech
license: apache-2.0
model-index:
  - name: wav2vec2-xls-r-1b-5gram-german with LM by Florian Zimmermeister @A\\Ware
    results:
      - task:
          name: Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Common Voice de
          type: common_voice
          args: de
        metrics:
          - name: Test WER
            type: wer
            value: 4.382541642219636
          - name: Test CER
            type: cer
            value: 1.6235493024026488

Evaluation

The model can be evaluated as follows on the German test data of Common Voice.

import torch
from transformers import AutoModelForCTC, AutoProcessor
from unidecode import unidecode
import re
from datasets import load_dataset, load_metric
import datasets

counter = 0
wer_counter = 0
cer_counter = 0
device = "cuda" if torch.cuda.is_available() else "cpu"


special_chars = [["Ä"," AE "], ["Ö"," OE "], ["Ü"," UE "], ["ä"," ae "], ["ö"," oe "], ["ü"," ue "]]
def clean_text(sentence):
    for special in special_chars:
        sentence = sentence.replace(special[0], special[1])

    sentence = unidecode(sentence)

    for special in special_chars:
        sentence = sentence.replace(special[1], special[0])

    sentence = re.sub("[^a-zA-Z0-9öäüÖÄÜ ,.!?]", " ", sentence)

    return sentence

def main(model_id):
    print("load model")
    model = AutoModelForCTC.from_pretrained(model_id).to(device)
    print("load processor")
    processor = AutoProcessor.from_pretrained(processor_id)

    print("load metrics")
    wer = load_metric("wer")
    cer = load_metric("cer")

    ds = load_dataset("mozilla-foundation/common_voice_8_0","de")
    ds = ds["test"]

    ds = ds.cast_column(
        "audio", datasets.features.Audio(sampling_rate=16_000)
    )

    def calculate_metrics(batch):
        global counter, wer_counter, cer_counter
        resampled_audio = batch["audio"]["array"]

        input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values

        with torch.no_grad():
            logits = model(input_values.to(device)).logits.cpu().numpy()[0]


        decoded = processor.decode(logits)
        pred = decoded.text.lower()

        ref = clean_text(batch["sentence"]).lower()

        wer_result = wer.compute(predictions=[pred], references=[ref])
        cer_result = cer.compute(predictions=[pred], references=[ref])

        counter += 1
        wer_counter += wer_result
        cer_counter += cer_result

        if counter % 100 == True:
            print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")

        return batch


    ds.map(calculate_metrics, remove_columns=ds.column_names)
    print(f"WER: {(wer_counter/counter)*100} | CER: {(cer_counter/counter)*100}")

model_id = "flozi00/wav2vec2-xls-r-1b-5gram-german"
main(model_id)