m3hrdadfi commited on
Commit
93accd5
1 Parent(s): 614aca0

Initial model

Browse files
README.md ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: fa
3
+ datasets:
4
+ - common_voice
5
+ tags:
6
+ - audio
7
+ - automatic-speech-recognition
8
+ - speech
9
+ - xlsr-fine-tuning-week
10
+ license: apache-2.0
11
+ widget:
12
+ - label: Common Voice sample 687
13
+ src: https://huggingface.co/m3hrdadfi/wav2vec2-large-xlsr-persian/resolve/main/sample687.flac
14
+ - label: Common Voice sample 1671
15
+ src: https://huggingface.co/m3hrdadfi/wav2vec2-large-xlsr-persian/resolve/main/sample1671.flac
16
+ model-index:
17
+ - name: XLSR Wav2Vec2 Persian (Farsi) by Mehrdad Farahani
18
+ results:
19
+ - task:
20
+ name: Speech Recognition
21
+ type: automatic-speech-recognition
22
+ dataset:
23
+ name: Common Voice fa
24
+ type: common_voice
25
+ args: fa
26
+ metrics:
27
+ - name: Test WER
28
+ type: wer
29
+ value: 32.09
30
+ - name: Test CER
31
+ type: cer
32
+ value: 8.23
33
+
34
+ ---
35
+
36
+ # Wav2Vec2-Large-XLSR-53-tw-gpt
37
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) in Persian (Farsi) using [Common Voice](https://huggingface.co/datasets/common_voice). When using this model, make sure that your speech input is sampled at 16kHz.
38
+
39
+ ## Usage
40
+ The model can be used directly (without a language model) as follows:
41
+
42
+ ```bash
43
+ !pip install git+https://github.com/huggingface/datasets.git
44
+ !pip install git+https://github.com/huggingface/transformers.git
45
+ !pip install torchaudio
46
+ !pip install librosa
47
+ !pip install jiwer
48
+ !pip install hazm
49
+ ```
50
+
51
+ ```python
52
+ import torch
53
+ import torchaudio
54
+ from datasets import load_dataset, load_metric
55
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
56
+
57
+ import librosa
58
+
59
+ import pandas as pd
60
+ import numpy as np
61
+
62
+ import hazm
63
+
64
+ import random
65
+ import os
66
+ import string
67
+ import six
68
+ import re
69
+
70
+ import IPython.display as ipd
71
+
72
+ # Loading the datasets
73
+ dataset = load_dataset("common_voice", "fa", split="test[:2%]")
74
+
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ processor = Wav2Vec2Processor.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian")
77
+ model = Wav2Vec2ForCTC.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian").to(device)
78
+
79
+
80
+ # Preprocessing the datasets.
81
+ # Normalizing the texts
82
+ _normalizer = hazm.Normalizer()
83
+ def multiple_replace(mapping, text):
84
+ pattern = "|".join(map(re.escape, mapping.keys()))
85
+ return re.sub(pattern, lambda m: mapping[m.group()], str(text))
86
+
87
+ def convert_weirdos(input_str):
88
+ # character
89
+ mapping = {
90
+ 'ك': 'ک',
91
+ 'دِ': 'د',
92
+ 'بِ': 'ب',
93
+ 'زِ': 'ز',
94
+ 'ذِ': 'ذ',
95
+ 'شِ': 'ش',
96
+ 'سِ': 'س',
97
+ 'ى': 'ی',
98
+ 'ي': 'ی',
99
+ 'أ': 'ا',
100
+ 'ؤ': 'و',
101
+ "ے": "ی",
102
+ "ۀ": "ه",
103
+ "ﭘ": "پ",
104
+ "ﮐ": "ک",
105
+ "ﯽ": "ی",
106
+ "ﺎ": "ا",
107
+ "ﺑ": "ب",
108
+ "ﺘ": "ت",
109
+ "ﺧ": "خ",
110
+ "ﺩ": "د",
111
+ "ﺱ": "س",
112
+ "ﻀ": "ض",
113
+ "ﻌ": "ع",
114
+ "ﻟ": "ل",
115
+ "ﻡ": "م",
116
+ "ﻢ": "م",
117
+ "ﻪ": "ه",
118
+ "ﻮ": "و",
119
+ "ئ": "ی",
120
+ 'ﺍ': "ا",
121
+ 'ة': "ه",
122
+ 'ﯾ': "ی",
123
+ 'ﯿ': "ی",
124
+ 'ﺒ': "ب",
125
+ 'ﺖ': "ت",
126
+ 'ﺪ': "د",
127
+ 'ﺮ': "ر",
128
+ 'ﺴ': "س",
129
+ 'ﺷ': "ش",
130
+ 'ﺸ': "ش",
131
+ 'ﻋ': "ع",
132
+ 'ﻤ': "م",
133
+ 'ﻥ': "ن",
134
+ 'ﻧ': "ن",
135
+ 'ﻭ': "و",
136
+ 'ﺭ': "ر",
137
+ "ﮔ": "گ",
138
+ }
139
+
140
+ # notation
141
+ mapping.update(**{
142
+ "#": " ",
143
+ "!": " ",
144
+ "؟": " ",
145
+ "?": " ",
146
+ "«": " ",
147
+ "»": " ",
148
+ "ء": " ",
149
+ "،": " ",
150
+ "(": " ",
151
+ ")": " ",
152
+ "؛": " ",
153
+ "'ٔ": " ",
154
+ "٬": " ",
155
+ 'ٔ': " ",
156
+ ",": " ",
157
+ "?": " ",
158
+ ".": " ",
159
+ "!": " ",
160
+ "-": " ",
161
+ ";": " ",
162
+ ":": " ",
163
+ '"': " ",
164
+ "“": " ",
165
+ "%": " ",
166
+ "‘": " ",
167
+ "”": " ",
168
+ "�": " ",
169
+ "–": " ",
170
+ "…": " ",
171
+ "_": " ",
172
+ })
173
+
174
+ return multiple_replace(mapping, input_str)
175
+
176
+
177
+ PERSIAN_ALPHA = "\u0621-\u0628\u062A-\u063A\u0641-\u0642\u0644-\u0648\u064E-\u0651\u0655\u067E\u0686\u0698\u06A9\u06AF\u06BE\u06CC" # noqa: E501
178
+ PERSIAN_DIGIT = "\u06F0-\u06F9"
179
+
180
+ COMMON_ARABIC_ALPHA = "\u0629\u0643\u0649-\u064B\u064D\u06D5"
181
+ COMMON_ARABIC_DIGIT = "\u0660-\u0669"
182
+
183
+ ZWNJ = "\u200c"
184
+
185
+ ENGLISH = "a-z0-9\&"
186
+ PERSIAN = PERSIAN_ALPHA + PERSIAN_DIGIT + COMMON_ARABIC_ALPHA + COMMON_ARABIC_DIGIT + ZWNJ
187
+
188
+
189
+ def normalizer(text, min_ratio=1.1):
190
+ text = text.lower()
191
+ text = _normalizer.normalize(text)
192
+ text = text.replace("\u200c", " ")
193
+ text = text.replace("\u200d", " ")
194
+ text = text.replace("\u200e", " ")
195
+ text = text.replace("\u200f", " ")
196
+ text = text.replace("\ufeff", " ")
197
+ text = convert_weirdos(text)
198
+
199
+ words = [word.replace("آ", "ا") if "آ" in word and not word.startswith("آ") else word for word in text.split()]
200
+ text = " ".join(words)
201
+
202
+ if not text or not len(text) > 2:
203
+ return None
204
+
205
+ en_text = re.sub(r"[^" + ENGLISH + "+]", " ", six.ensure_str(text))
206
+ en_text = re.sub(r"\s+", " ", en_text)
207
+ if len(en_text) > 1:
208
+ return None
209
+
210
+ return text
211
+
212
+
213
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
214
+ def remove_special_characters(batch):
215
+ text = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
216
+ text = normalizer(text)
217
+ batch["sentence"] = text
218
+ return batch
219
+
220
+ # We need to read the aduio files as arrays
221
+ def speech_file_to_array_fn(batch):
222
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
223
+ speech_array = speech_array.squeeze().numpy()
224
+ speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, 16_000)
225
+
226
+ batch["speech"] = speech_array
227
+ return batch
228
+
229
+ def predict(batch):
230
+ features = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
231
+
232
+ input_values = features.input_values.to(device)
233
+ attention_mask = features.attention_mask.to(device)
234
+
235
+ with torch.no_grad():
236
+ logits = model(input_values, attention_mask=attention_mask).logits
237
+
238
+ pred_ids = torch.argmax(logits, dim=-1)
239
+
240
+ batch["predicted"] = processor.batch_decode(pred_ids)[0]
241
+ return batch
242
+
243
+ dataset = dataset.map(remove_special_characters)
244
+ dataset = dataset.map(speech_file_to_array_fn, remove_columns=list(set(dataset.column_names) - set(['sentence', 'path'])))
245
+ result = dataset.map(predict)
246
+ ```
247
+
248
+ ## Prediction
249
+
250
+ ```python
251
+ max_items = np.random.randint(0, len(result), 20).tolist()
252
+ for i in max_items:
253
+ reference, predicted = result["sentence"][i], result["predicted"][i]
254
+ print("reference:", reference)
255
+ print("predicted:", predicted)
256
+ print('---')
257
+ ```
258
+
259
+ ```text
260
+ reference: اطلاعات مسری است
261
+ predicted: اطلاعات مسری است
262
+ ---
263
+ reference: نه منظورم اینه که وقتی که ساکته چه کاریه خودمونه بندازیم زحمت
264
+ predicted: نه منظورم اینه که وقتی که ساکت چی کاریه خودمونو بندازیم زحمت
265
+ ---
266
+ reference: من آب پرتقال می خورم لطفا
267
+ predicted: من آپ ارتغال می خورم لطفا
268
+ ---
269
+ reference: وقت آن رسیده آنها را که قدم پیش میگذارند بزرگ بداریم
270
+ predicted: وقت آ رسیده آنها را که قدم پیش میگذارند بزرگ بداریم
271
+ ---
272
+ reference: سیم باتری دارید
273
+ predicted: سیم باتری دارید
274
+ ---
275
+ reference: این بهتره تا اینکه به بهونه درس و مشق هر روز بره خونه شون
276
+ predicted: این بهتره تا اینکه به بهمونه درسومش خرروز بره خونه اشون
277
+ ---
278
+ reference: ژاکت تنگ است
279
+ predicted: ژاکت تنگ است
280
+ ---
281
+ reference: آت و اشغال های خیابان
282
+ predicted: آت و اشغال های خیابان
283
+ ---
284
+ reference: من به این روند اعتراض دارم
285
+ predicted: من به این لوند تراج دارم
286
+ ---
287
+ reference: کرایه این مکان چند است
288
+ predicted: کرایه این مکان چند است
289
+ ---
290
+ reference: ولی این فرصت این سهم جوانی اعطا نشده است
291
+ predicted: ولی این فرصت این سحم جوانی اتان نشده است
292
+ ---
293
+ reference: متوجه فاجعهای محیطی میشوم
294
+ predicted: متوجه فاجایهای محیطی میشوم
295
+ ---
296
+ reference: ترافیک شدیدیم بود و دیدن نور ماشینا و چراغا و لامپهای مراکز تجاری حس خوبی بهم میدادن
297
+ predicted: ترافیک شدید ی هم بودا دیدن نور ماشینا و چراغ لامپهای مراکز تجاری حس خولی بهم میدادن
298
+ ---
299
+ reference: این مورد عمل ها مربوط به تخصص شما می شود
300
+ predicted: این مورد عملها مربوط به تخصص شما میشود
301
+ ---
302
+ reference: انرژی خیلی کمی دارم
303
+ predicted: انرژی خیلی کمی دارم
304
+ ---
305
+ reference: زیادی خوبی کردنم تهش داستانه
306
+ predicted: زیادی خوبی کردنم ترش داستانه
307
+ ---
308
+ reference: بردهای که پادشاه شود
309
+ predicted: برده ای که پاده شاه شود
310
+ ---
311
+ reference: یونسکو
312
+ predicted: یونسکو
313
+ ---
314
+ reference: شما اخراج هستید
315
+ predicted: شما اخراج هستید
316
+ ---
317
+ reference: من سفر کردن را دوست دارم
318
+ predicted: من سفر کردم را دوست دارم
319
+ ```
320
+
321
+ ## Evaluation
322
+
323
+ ```python
324
+ !mkdir cer
325
+ !wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
326
+
327
+ wer = load_metric("wer")
328
+ cer = load_metric("./cer")
329
+
330
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["predicted"], references=result["sentence"])))
331
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["sentence"])))
332
+ ```
333
+
334
+ **Test Result**:
335
+ - WER: 32.09%
336
+ - CER: 8.23%
337
+
338
+
339
+ ## Training
340
+ The Common Voice `train`, `validation` datasets were used for training.
341
+ The script used for training can be found [here](https://colab.research.google.com/github/m3hrdadfi/notebooks/blob/main/Fine_Tune_XLSR_Wav2Vec2_on_Persian_ASR_with_%F0%9F%A4%97_Transformers_ipynb.ipynb)
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "Wav2Vec2ForCTC"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "conv_bias": true,
10
+ "conv_dim": [
11
+ 512,
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512
18
+ ],
19
+ "conv_kernel": [
20
+ 10,
21
+ 3,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 2,
26
+ 2
27
+ ],
28
+ "conv_stride": [
29
+ 5,
30
+ 2,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2
36
+ ],
37
+ "ctc_loss_reduction": "mean",
38
+ "ctc_zero_infinity": false,
39
+ "do_stable_layer_norm": true,
40
+ "eos_token_id": 2,
41
+ "feat_extract_activation": "gelu",
42
+ "feat_extract_dropout": 0.0,
43
+ "feat_extract_norm": "layer",
44
+ "feat_proj_dropout": 0.0,
45
+ "final_dropout": 0.0,
46
+ "gradient_checkpointing": true,
47
+ "hidden_act": "gelu",
48
+ "hidden_dropout": 0.1,
49
+ "hidden_size": 1024,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 4096,
52
+ "layer_norm_eps": 1e-05,
53
+ "layerdrop": 0.1,
54
+ "mask_channel_length": 10,
55
+ "mask_channel_min_space": 1,
56
+ "mask_channel_other": 0.0,
57
+ "mask_channel_prob": 0.0,
58
+ "mask_channel_selection": "static",
59
+ "mask_feature_length": 10,
60
+ "mask_feature_prob": 0.0,
61
+ "mask_time_length": 10,
62
+ "mask_time_min_space": 1,
63
+ "mask_time_other": 0.0,
64
+ "mask_time_prob": 0.05,
65
+ "mask_time_selection": "static",
66
+ "model_type": "wav2vec2",
67
+ "num_attention_heads": 16,
68
+ "num_conv_pos_embedding_groups": 16,
69
+ "num_conv_pos_embeddings": 128,
70
+ "num_feat_extract_layers": 7,
71
+ "num_hidden_layers": 24,
72
+ "pad_token_id": 35,
73
+ "transformers_version": "4.5.0.dev0",
74
+ "vocab_size": 36
75
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_size": 1,
4
+ "padding_side": "right",
5
+ "padding_value": 0.0,
6
+ "return_attention_mask": true,
7
+ "sampling_rate": 16000
8
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c9497f2383df9550e1f3310265224e7bc7e994c0ec844aac51fca5d8e9483b4
3
+ size 1262081431
sample1671.flac ADDED
Binary file (169 kB). View file
sample687.flac ADDED
Binary file (103 kB). View file
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}
test_predicted.csv ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|"}
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbcb369f506a36d7e8b81d8831323d570a47695ce0b9cb55afbf5dff6f84f5ff
3
+ size 2351
vocab.json ADDED
@@ -0,0 +1 @@
 
1
+ {"ت": 0, "گ": 1, "ب": 2, "ژ": 3, "ع": 4, "ذ": 5, "چ": 6, "ج": 7, "خ": 8, "ا": 9, "د": 10, "ن": 11, "ح": 12, "آ": 13, "غ": 14, "م": 15, "ص": 16, "ر": 17, "پ": 18, "ظ": 19, "ض": 20, "ه": 21, "ق": 23, "ک": 24, "ش": 25, "ط": 26, "ف": 27, "ی": 28, "ز": 29, "و": 30, "ل": 31, "س": 32, "ث": 33, "|": 22, "[UNK]": 34, "[PAD]": 35}