muzamil47 commited on
Commit
cb4dd0f
1 Parent(s): 0a7204a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +235 -0
README.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ar
3
+ datasets:
4
+ - arabic_speech_corpus
5
+ - mozilla-foundation/common_voice_6_1
6
+ metrics:
7
+ - wer
8
+ tags:
9
+ - audio
10
+ - automatic-speech-recognition
11
+ - speech
12
+ - xlsr-fine-tuning-week
13
+ license: apache-2.0
14
+ model-index:
15
+ - name: muzamil47-wav2vec2-large-xlsr-53-arabic
16
+ results:
17
+ - task:
18
+ name: Automatic Speech Recognition
19
+ type: automatic-speech-recognition
20
+ dataset:
21
+ name: Common Voice 6.1 (Arabic)
22
+ type: mozilla-foundation/common_voice_6_1
23
+ config: ar
24
+ metrics:
25
+ - name: Test WER
26
+ type: wer
27
+ value: 53.54
28
+ ---
29
+
30
+ # Wav2Vec2-Large-XLSR-53-Arabic
31
+
32
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice).
33
+ When using this model, make sure that your speech input is sampled at 16kHz.
34
+
35
+ ## Usage
36
+
37
+ The model can be used directly (without a language model) as follows:
38
+
39
+ ```python
40
+ import librosa
41
+ import torch
42
+ from lang_trans.arabic import buckwalter
43
+
44
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
45
+
46
+ asr_model = "muzamil47/wav2vec2-large-xlsr-53-arabic-demo"
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
49
+
50
+ def load_file_to_data(file, srate=16_000):
51
+ batch = {}
52
+ speech, sampling_rate = librosa.load(file, sr=srate)
53
+ batch["speech"] = speech
54
+ batch["sampling_rate"] = sampling_rate
55
+ return batch
56
+
57
+
58
+ max_length = 128000
59
+ processor = Wav2Vec2Processor.from_pretrained(asr_model)
60
+ model = Wav2Vec2ForCTC.from_pretrained(asr_model).to(device)
61
+
62
+
63
+ def predict(data):
64
+ features = processor(data["speech"], sampling_rate=data["sampling_rate"], return_tensors="pt", padding=True)
65
+ input_values = features.input_values.to(device)
66
+ try:
67
+ attention_mask = features.attention_mask.to(device)
68
+ except:
69
+ attention_mask = None
70
+ with torch.no_grad():
71
+ predicted = torch.argmax(model(input_values, attention_mask=attention_mask).logits, dim=-1)
72
+
73
+ data["predicted"] = processor.tokenizer.decode(predicted[0])
74
+ print(data["predicted"])
75
+ print("predicted:", buckwalter.untrans(data["predicted"]))
76
+ return data
77
+
78
+ predict(load_file_to_data("common_voice_ar_19058307.mp3"))
79
+ ```
80
+ **Output Result**:
81
+ ```shell
82
+ reference: هل يمكنني التحدث مع المسؤول هنا
83
+ predicted: هل يمكنني التحدث مع المسؤول هنا
84
+ ```
85
+
86
+ ## Evaluation
87
+
88
+ The model can be evaluated as follows on the Arabic test data of Common Voice.
89
+
90
+ ```python
91
+ import torch
92
+ import torchaudio
93
+ from datasets import load_dataset
94
+ from lang_trans.arabic import buckwalter
95
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
96
+
97
+ asr_model = "muzamil47/wav2vec2-large-xlsr-53-arabic-demo"
98
+
99
+ dataset = load_dataset("common_voice", "ar", split="test[:10]")
100
+
101
+ resamplers = { # all three sampling rates exist in test split
102
+ 48000: torchaudio.transforms.Resample(48000, 16000),
103
+ 44100: torchaudio.transforms.Resample(44100, 16000),
104
+ 32000: torchaudio.transforms.Resample(32000, 16000),
105
+ }
106
+
107
+ def prepare_example(example):
108
+ speech, sampling_rate = torchaudio.load(example["path"])
109
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
110
+ return example
111
+
112
+ dataset = dataset.map(prepare_example)
113
+ processor = Wav2Vec2Processor.from_pretrained(asr_model)
114
+ model = Wav2Vec2ForCTC.from_pretrained(asr_model).eval()
115
+
116
+ def predict(batch):
117
+ inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
118
+ with torch.no_grad():
119
+ predicted = torch.argmax(model(inputs.input_values).logits, dim=-1)
120
+ predicted[predicted == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
121
+ batch["predicted"] = processor.tokenizer.batch_decode(predicted)
122
+ return batch
123
+
124
+ dataset = dataset.map(predict, batched=True, batch_size=1, remove_columns=["speech"])
125
+
126
+ for reference, predicted in zip(dataset["sentence"], dataset["predicted"]):
127
+ print("reference:", reference)
128
+ print("predicted:", buckwalter.untrans(predicted))
129
+ print("--")
130
+
131
+ ```
132
+ **Output Results**:
133
+ ```shell
134
+ reference: ما أطول عودك!
135
+ predicted: ما اطول عودك
136
+
137
+ reference: ماتت عمتي منذ سنتين.
138
+ predicted: ما تتعمتي منذو سنتين
139
+
140
+ reference: الألمانية ليست لغة سهلة.
141
+ predicted: الالمانية ليست لغة سهلة
142
+
143
+ reference: طلبت منه أن يبعث الكتاب إلينا.
144
+ predicted: طلبت منه ان يبعث الكتاب الينا
145
+
146
+ reference: .السيد إيتو رجل متعلم
147
+ predicted: السيد ايتو رجل متعلم
148
+
149
+ reference: الحمد لله.
150
+ predicted: الحمذ لللا
151
+
152
+ reference: في الوقت نفسه بدأت الرماح والسهام تقع بين الغزاة
153
+ predicted: في الوقت نفسه ابدات الرماح و السهام تقع بين الغزاء
154
+
155
+ reference: لا أريد أن أكون ثقيلَ الظِّل ، أريد أن أكون رائعًا! !
156
+ predicted: لا اريد ان اكون ثقيل الظل اريد ان اكون رائع
157
+
158
+ reference: خذ مظلة معك في حال أمطرت.
159
+ predicted: خذ مظلة معك في حال امطرت
160
+
161
+ reference: .ركب توم السيارة
162
+ predicted: ركب توم السيارة
163
+ ```
164
+
165
+ The model evaluation **(WER)** on the Arabic test data of Common Voice.
166
+
167
+ ```python
168
+ import re
169
+
170
+ import torch
171
+ import torchaudio
172
+ from datasets import load_dataset, load_metric
173
+ from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
174
+
175
+ set_seed(42)
176
+
177
+ test_dataset = load_dataset("common_voice", "ar", split="test")
178
+
179
+ processor = Wav2Vec2Processor.from_pretrained("muzamil47/wav2vec2-large-xlsr-53-arabic-demo")
180
+ model = Wav2Vec2ForCTC.from_pretrained("muzamil47/wav2vec2-large-xlsr-53-arabic-demo")
181
+ model.to("cuda")
182
+
183
+ chars_to_ignore_regex = '[\,\؟\.\!\-\;\\:\'\"\☭\«\»\؛\—\ـ\_\،\“\%\‘\”\�]'
184
+
185
+ resampler = torchaudio.transforms.Resample(48_000, 16_000)
186
+
187
+
188
+ # Preprocessing the datasets. We need to read the aduio files as arrays
189
+ def speech_file_to_array_fn(batch):
190
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
191
+ batch["sentence"] = re.sub('[a-z]','',batch["sentence"])
192
+ batch["sentence"] = re.sub("[إأٱآا]", "ا", batch["sentence"])
193
+ noise = re.compile(""" ّ | # Tashdid
194
+ َ | # Fatha
195
+ ً | # Tanwin Fath
196
+ ُ | # Damma
197
+ ٌ | # Tanwin Damm
198
+ ِ | # Kasra
199
+ ٍ | # Tanwin Kasr
200
+ ْ | # Sukun
201
+ ـ # Tatwil/Kashida
202
+ """, re.VERBOSE)
203
+ batch["sentence"] = re.sub(noise, '', batch["sentence"])
204
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
205
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
206
+ return batch
207
+
208
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
209
+
210
+
211
+ def evaluate(batch):
212
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
213
+
214
+ with torch.no_grad():
215
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
216
+
217
+ pred_ids = torch.argmax(logits, dim=-1)
218
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
219
+ return batch
220
+
221
+ result = test_dataset.map(evaluate, batched=True, batch_size=8)
222
+
223
+ wer = load_metric("wer")
224
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
225
+
226
+ ```
227
+
228
+ **Test Result**: 53.54
229
+
230
+
231
+ ## Training
232
+
233
+ The Common Voice `train`, `validation` datasets were used for training.
234
+
235
+ The script used for training can be found [here](https://huggingface.co/kmfoda/wav2vec2-large-xlsr-arabic/tree/main)