fxtentacle commited on
Commit
0fb6288
1 Parent(s): 6ac607d

Upload HF Eval Script.ipynb

Browse files
Files changed (1) hide show
  1. HF Eval Script.ipynb +189 -0
HF Eval Script.ipynb ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "be1c7379",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!pip install --quiet --root-user-action=ignore --upgrade pip\n",
11
+ "!pip install --quiet --root-user-action=ignore \"datasets>=1.18.3\" \"transformers==4.11.3\" librosa jiwer huggingface_hub \n",
12
+ "!pip install --quiet --root-user-action=ignore https://github.com/kpu/kenlm/archive/master.zip pyctcdecode\n",
13
+ "!pip install --quiet --root-user-action=ignore --upgrade transformers\n",
14
+ "!pip install --quiet --root-user-action=ignore torch_audiomentations audiomentations "
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "id": "8892305a",
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "Reusing dataset common_voice (/ai_data/cache/common_voice/de/6.1.0/a1dc74461f6c839bfe1e8cf1262fd4cf24297e3fbd4087a711bd090779023a5e)\n",
28
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
29
+ ]
30
+ },
31
+ {
32
+ "data": {
33
+ "application/vnd.jupyter.widget-view+json": {
34
+ "model_id": "efc316d6eedb4dcab341dfd0fe8cc926",
35
+ "version_major": 2,
36
+ "version_minor": 0
37
+ },
38
+ "text/plain": [
39
+ " 0%| | 0/15588 [00:00<?, ?ex/s]"
40
+ ]
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ }
45
+ ],
46
+ "source": [
47
+ "from datasets import load_dataset, Audio, load_metric\n",
48
+ "from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM\n",
49
+ "import torchaudio.transforms as T\n",
50
+ "import torch\n",
51
+ "import unicodedata\n",
52
+ "import numpy as np\n",
53
+ "import re\n",
54
+ "\n",
55
+ "# load testing dataset \n",
56
+ "testing_dataset = load_dataset(\"common_voice\", \"de\", split=\"test\")\n",
57
+ "\n",
58
+ "# replace invisible characters with space\n",
59
+ "allchars = list(set([c for t in testing_dataset['sentence'] for c in list(t)]))\n",
60
+ "map_to_space = [c for c in allchars if unicodedata.category(c)[0] in 'PSZ' and c not in 'ʻ-']\n",
61
+ "replacements = ''.maketrans(''.join(map_to_space), ''.join(' ' for i in range(len(map_to_space))), '\\'ʻ')\n",
62
+ "\n",
63
+ "def text_fix(text):\n",
64
+ " # change ß to ss\n",
65
+ " text = text.replace('ß','ss')\n",
66
+ " # convert dash to space and remove double-space\n",
67
+ " text = text.replace('-',' ').replace(' ',' ').replace(' ',' ')\n",
68
+ " # make lowercase\n",
69
+ " text = text.lower()\n",
70
+ " # remap all invisible characters to space\n",
71
+ " text = text.translate(replacements).strip()\n",
72
+ " # for easier comparison to Zimmermeister, replace unrepresentable characters with ?\n",
73
+ " text = re.sub(\"[âşěýňעảנźțãòàǔł̇æồאắîשðșęūāñë生בøúıśžçćńřğ]+\",\"?\",text)\n",
74
+ " # remove multiple spaces (again)\n",
75
+ " text = ' '.join([w for w in text.split(' ') if w != ''])\n",
76
+ " return text\n",
77
+ "\n",
78
+ "# load model\n",
79
+ "model = AutoModelForCTC.from_pretrained(\"fxtentacle/wav2vec2-xls-r-1b-tevr\")\n",
80
+ "model.to('cuda')\n",
81
+ "# load processor\n",
82
+ "class HajoProcessor(Wav2Vec2ProcessorWithLM):\n",
83
+ " @staticmethod\n",
84
+ " def get_missing_alphabet_tokens(decoder, tokenizer):\n",
85
+ " return []\n",
86
+ "processor = HajoProcessor.from_pretrained(\"fxtentacle/wav2vec2-xls-r-1b-tevr\")\n",
87
+ "\n",
88
+ "# this function will be called for each WAV file\n",
89
+ "def predict_single_audio(batch, image=False): \n",
90
+ " audio = batch['audio']['array']\n",
91
+ " # resample, if needed\n",
92
+ " if batch['audio']['sampling_rate'] != 16000:\n",
93
+ " audio = T.Resample(orig_freq=batch['audio']['sampling_rate'], new_freq=16000)(torch.from_numpy(audio)).numpy()\n",
94
+ " # normalize\n",
95
+ " audio = (audio - audio.mean()) / np.sqrt(audio.var() + 1e-7)\n",
96
+ " # ask HF processor to prepare audio for GPU eval\n",
97
+ " input_values = processor(audio, return_tensors=\"pt\", sampling_rate=16_000).input_values\n",
98
+ " # call model on GPU\n",
99
+ " with torch.no_grad():\n",
100
+ " logits = model(input_values.to('cuda')).logits.cpu().numpy()[0]\n",
101
+ " # ask HF processor to decode logits\n",
102
+ " decoded = processor.decode(logits, beam_width=500)\n",
103
+ " # return as dictionary\n",
104
+ " return { 'groundtruth': text_fix(batch['sentence']), 'prediction': decoded.text }\n",
105
+ "\n",
106
+ "# process all audio files\n",
107
+ "all_predictions = testing_dataset.map(predict_single_audio, remove_columns=testing_dataset.column_names)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 3,
113
+ "id": "38f5481f",
114
+ "metadata": {},
115
+ "outputs": [
116
+ {
117
+ "name": "stdout",
118
+ "output_type": "stream",
119
+ "text": [
120
+ "0 PRED: mückenstiche sollte man nicht aufkratzen\n",
121
+ "0 GT: mückenstiche sollte man nicht aufkratzen\n",
122
+ "1 PRED: ist diese leitung sicher\n",
123
+ "1 GT: ist diese leitung sicher\n",
124
+ "2 PRED: die ratten verlassen das sinkende schiff\n",
125
+ "2 GT: die ratten verlassen das sinkende schiff\n",
126
+ "3 PRED: ich habe eine neue arbeit\n",
127
+ "3 GT: ich habe eine neue arbeit\n",
128
+ "4 PRED: was sieht kamera eins gerade\n",
129
+ "4 GT: was sieht kamera eins gerade\n",
130
+ "5 PRED: was für ein angeber dachte horst im stillen\n",
131
+ "5 GT: was für ein angeber dachte horst im stillen\n",
132
+ "6 PRED: rückgängig machen\n",
133
+ "6 GT: rückgängig machen\n",
134
+ "7 PRED: war die integration erfolgreich\n",
135
+ "7 GT: war die integration erfolgreich\n"
136
+ ]
137
+ }
138
+ ],
139
+ "source": [
140
+ "# log example results\n",
141
+ "for i in range(8):\n",
142
+ " print(i,'PRED: ',all_predictions[i]['prediction'])\n",
143
+ " print(i,' GT: ',all_predictions[i]['groundtruth'])"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 5,
149
+ "id": "cbabe801",
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stdout",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "WER 3.6433399042523233 %\n",
157
+ "CER 1.5398893560981173 %\n"
158
+ ]
159
+ }
160
+ ],
161
+ "source": [
162
+ "# print results\n",
163
+ "print('WER', load_metric(\"wer\").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')\n",
164
+ "print('CER', load_metric(\"cer\").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')"
165
+ ]
166
+ }
167
+ ],
168
+ "metadata": {
169
+ "kernelspec": {
170
+ "display_name": "Python 3 (ipykernel)",
171
+ "language": "python",
172
+ "name": "python3"
173
+ },
174
+ "language_info": {
175
+ "codemirror_mode": {
176
+ "name": "ipython",
177
+ "version": 3
178
+ },
179
+ "file_extension": ".py",
180
+ "mimetype": "text/x-python",
181
+ "name": "python",
182
+ "nbconvert_exporter": "python",
183
+ "pygments_lexer": "ipython3",
184
+ "version": "3.7.5"
185
+ }
186
+ },
187
+ "nbformat": 4,
188
+ "nbformat_minor": 5
189
+ }