In [1]:
!pip install --quiet --root-user-action=ignore --upgrade pip
!pip install --quiet --root-user-action=ignore "datasets>=1.18.3" "transformers==4.11.3" librosa jiwer huggingface_hub  
!pip install --quiet --root-user-action=ignore https://github.com/kpu/kenlm/archive/master.zip pyctcdecode
!pip install --quiet --root-user-action=ignore --upgrade transformers
!pip install --quiet --root-user-action=ignore torch_audiomentations audiomentations  

In [2]:
from datasets import load_dataset, Audio, load_metric
from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
import torchaudio.transforms as T
import torch
import unicodedata
import numpy as np
import re

# load testing dataset 
testing_dataset = load_dataset("common_voice", "de", split="test")

# replace invisible characters with space
allchars = list(set([c for t in testing_dataset['sentence'] for c in list(t)]))
map_to_space = [c for c in allchars if unicodedata.category(c)[0] in 'PSZ' and c not in 'ʻ-']
replacements = ''.maketrans(''.join(map_to_space), ''.join(' ' for i in range(len(map_to_space))), '\'ʻ')

def text_fix(text):
    # change ß to ss
    text = text.replace('ß','ss')
    # convert dash to space and remove double-space
    text = text.replace('-',' ').replace('  ',' ').replace('  ',' ')
    # make lowercase
    text = text.lower()
    # remap all invisible characters to space
    text = text.translate(replacements).strip()
    # for easier comparison to Zimmermeister, replace unrepresentable characters with ?
    text = re.sub("[âşěýňעảנźțãòàǔł̇æồאắîשðșęūāñë生בøúıśžçćńřğ]+","?",text)
    # remove multiple spaces (again)
    text = ' '.join([w for w in text.split(' ') if w != ''])
    return text

# load model
model = AutoModelForCTC.from_pretrained("fxtentacle/wav2vec2-xls-r-1b-tevr")
model.to('cuda')
# load processor
class HajoProcessor(Wav2Vec2ProcessorWithLM):
    @staticmethod
    def get_missing_alphabet_tokens(decoder, tokenizer):
        return []
processor = HajoProcessor.from_pretrained("fxtentacle/wav2vec2-xls-r-1b-tevr")

# this function will be called for each WAV file
def predict_single_audio(batch, image=False):    
    audio = batch['audio']['array']
    # resample, if needed
    if batch['audio']['sampling_rate'] != 16000:
        audio = T.Resample(orig_freq=batch['audio']['sampling_rate'], new_freq=16000)(torch.from_numpy(audio)).numpy()
    # normalize
    audio = (audio - audio.mean()) / np.sqrt(audio.var() + 1e-7)
    # ask HF processor to prepare audio for GPU eval
    input_values = processor(audio, return_tensors="pt", sampling_rate=16_000).input_values
    # call model on GPU
    with torch.no_grad():
        logits = model(input_values.to('cuda')).logits.cpu().numpy()[0]
    # ask HF processor to decode logits
    decoded = processor.decode(logits, beam_width=500)
    # return as dictionary
    return { 'groundtruth': text_fix(batch['sentence']), 'prediction': decoded.text }

# process all audio files
all_predictions = testing_dataset.map(predict_single_audio, remove_columns=testing_dataset.column_names)

Reusing dataset common_voice (/ai_data/cache/common_voice/de/6.1.0/a1dc74461f6c839bfe1e8cf1262fd4cf24297e3fbd4087a711bd090779023a5e)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/15588 [00:00<?, ?ex/s]

In [3]:
# log example results
for i in range(8):
    print(i,'PRED: ',all_predictions[i]['prediction'])
    print(i,'  GT: ',all_predictions[i]['groundtruth'])

0 PRED:  mückenstiche sollte man nicht aufkratzen
0   GT:  mückenstiche sollte man nicht aufkratzen
1 PRED:  ist diese leitung sicher
1   GT:  ist diese leitung sicher
2 PRED:  die ratten verlassen das sinkende schiff
2   GT:  die ratten verlassen das sinkende schiff
3 PRED:  ich habe eine neue arbeit
3   GT:  ich habe eine neue arbeit
4 PRED:  was sieht kamera eins gerade
4   GT:  was sieht kamera eins gerade
5 PRED:  was für ein angeber dachte horst im stillen
5   GT:  was für ein angeber dachte horst im stillen
6 PRED:  rückgängig machen
6   GT:  rückgängig machen
7 PRED:  war die integration erfolgreich
7   GT:  war die integration erfolgreich


In [5]:
# print results
print('WER', load_metric("wer").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')
print('CER', load_metric("cer").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')

WER 3.6433399042523233 %
CER 1.5398893560981173 %
