Update readme for more info
Browse files
README.md
CHANGED
@@ -37,10 +37,12 @@ model-index:
|
|
37 |
|
38 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) in Persian (Farsi) using [Common Voice](https://huggingface.co/datasets/common_voice). When using this model, make sure that your speech input is sampled at 16kHz.
|
39 |
|
40 |
-
##
|
41 |
The model can be used directly (without a language model) as follows:
|
42 |
|
|
|
43 |
```bash
|
|
|
44 |
!pip install git+https://github.com/huggingface/datasets.git
|
45 |
!pip install git+https://github.com/huggingface/transformers.git
|
46 |
!pip install torchaudio
|
@@ -49,176 +51,88 @@ The model can be used directly (without a language model) as follows:
|
|
49 |
!pip install hazm
|
50 |
```
|
51 |
|
52 |
-
|
53 |
-
import torch
|
54 |
-
import torchaudio
|
55 |
-
from datasets import load_dataset, load_metric
|
56 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
import numpy as np
|
62 |
|
63 |
import hazm
|
64 |
-
|
65 |
-
import random
|
66 |
-
import os
|
67 |
-
import string
|
68 |
-
import six
|
69 |
import re
|
|
|
70 |
|
71 |
-
import IPython.display as ipd
|
72 |
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
# Normalizing the texts
|
83 |
-
_normalizer = hazm.Normalizer()
|
84 |
-
def multiple_replace(mapping, text):
|
85 |
-
pattern = "|".join(map(re.escape, mapping.keys()))
|
86 |
-
return re.sub(pattern, lambda m: mapping[m.group()], str(text))
|
87 |
-
|
88 |
-
def convert_weirdos(input_str):
|
89 |
-
# character
|
90 |
-
mapping = {
|
91 |
-
'ك': 'ک',
|
92 |
-
'دِ': 'د',
|
93 |
-
'بِ': 'ب',
|
94 |
-
'زِ': 'ز',
|
95 |
-
'ذِ': 'ذ',
|
96 |
-
'شِ': 'ش',
|
97 |
-
'سِ': 'س',
|
98 |
-
'ى': 'ی',
|
99 |
-
'ي': 'ی',
|
100 |
-
'أ': 'ا',
|
101 |
-
'ؤ': 'و',
|
102 |
-
"ے": "ی",
|
103 |
-
"ۀ": "ه",
|
104 |
-
"ﭘ": "پ",
|
105 |
-
"ﮐ": "ک",
|
106 |
-
"ﯽ": "ی",
|
107 |
-
"ﺎ": "ا",
|
108 |
-
"ﺑ": "ب",
|
109 |
-
"ﺘ": "ت",
|
110 |
-
"ﺧ": "خ",
|
111 |
-
"ﺩ": "د",
|
112 |
-
"ﺱ": "س",
|
113 |
-
"ﻀ": "ض",
|
114 |
-
"ﻌ": "ع",
|
115 |
-
"ﻟ": "ل",
|
116 |
-
"ﻡ": "م",
|
117 |
-
"ﻢ": "م",
|
118 |
-
"ﻪ": "ه",
|
119 |
-
"ﻮ": "و",
|
120 |
-
"ئ": "ی",
|
121 |
-
'ﺍ': "ا",
|
122 |
-
'ة': "ه",
|
123 |
-
'ﯾ': "ی",
|
124 |
-
'ﯿ': "ی",
|
125 |
-
'ﺒ': "ب",
|
126 |
-
'ﺖ': "ت",
|
127 |
-
'ﺪ': "د",
|
128 |
-
'ﺮ': "ر",
|
129 |
-
'ﺴ': "س",
|
130 |
-
'ﺷ': "ش",
|
131 |
-
'ﺸ': "ش",
|
132 |
-
'ﻋ': "ع",
|
133 |
-
'ﻤ': "م",
|
134 |
-
'ﻥ': "ن",
|
135 |
-
'ﻧ': "ن",
|
136 |
-
'ﻭ': "و",
|
137 |
-
'ﺭ': "ر",
|
138 |
-
"ﮔ": "گ",
|
139 |
-
}
|
140 |
-
|
141 |
-
# notation
|
142 |
-
mapping.update(**{
|
143 |
-
"#": " ",
|
144 |
-
"!": " ",
|
145 |
-
"؟": " ",
|
146 |
-
"?": " ",
|
147 |
-
"«": " ",
|
148 |
-
"»": " ",
|
149 |
-
"ء": " ",
|
150 |
-
"،": " ",
|
151 |
-
"(": " ",
|
152 |
-
")": " ",
|
153 |
-
"؛": " ",
|
154 |
-
"'ٔ": " ",
|
155 |
-
"٬": " ",
|
156 |
-
'ٔ': " ",
|
157 |
-
",": " ",
|
158 |
-
"?": " ",
|
159 |
-
".": " ",
|
160 |
-
"!": " ",
|
161 |
-
"-": " ",
|
162 |
-
";": " ",
|
163 |
-
":": " ",
|
164 |
-
'"': " ",
|
165 |
-
"“": " ",
|
166 |
-
"%": " ",
|
167 |
-
"‘": " ",
|
168 |
-
"”": " ",
|
169 |
-
"�": " ",
|
170 |
-
"–": " ",
|
171 |
-
"…": " ",
|
172 |
-
"_": " ",
|
173 |
-
})
|
174 |
-
|
175 |
-
return multiple_replace(mapping, input_str)
|
176 |
-
|
177 |
-
|
178 |
-
PERSIAN_ALPHA = "\u0621-\u0628\u062A-\u063A\u0641-\u0642\u0644-\u0648\u064E-\u0651\u0655\u067E\u0686\u0698\u06A9\u06AF\u06BE\u06CC" # noqa: E501
|
179 |
-
PERSIAN_DIGIT = "\u06F0-\u06F9"
|
180 |
-
|
181 |
-
COMMON_ARABIC_ALPHA = "\u0629\u0643\u0649-\u064B\u064D\u06D5"
|
182 |
-
COMMON_ARABIC_DIGIT = "\u0660-\u0669"
|
183 |
-
|
184 |
-
ZWNJ = "\u200c"
|
185 |
-
|
186 |
-
ENGLISH = "a-z0-9\&"
|
187 |
-
PERSIAN = PERSIAN_ALPHA + PERSIAN_DIGIT + COMMON_ARABIC_ALPHA + COMMON_ARABIC_DIGIT + ZWNJ
|
188 |
-
|
189 |
-
|
190 |
-
def normalizer(text, min_ratio=1.1):
|
191 |
-
text = text.lower()
|
192 |
-
text = _normalizer.normalize(text)
|
193 |
-
text = text.replace("\u200c", " ")
|
194 |
-
text = text.replace("\u200d", " ")
|
195 |
-
text = text.replace("\u200e", " ")
|
196 |
-
text = text.replace("\u200f", " ")
|
197 |
-
text = text.replace("\ufeff", " ")
|
198 |
-
text = convert_weirdos(text)
|
199 |
|
200 |
-
|
201 |
-
|
202 |
|
203 |
-
|
204 |
-
|
|
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
-
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
|
215 |
-
def remove_special_characters(batch):
|
216 |
-
text = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
|
217 |
-
text = normalizer(text)
|
218 |
-
batch["sentence"] = text
|
219 |
-
return batch
|
220 |
|
221 |
-
# We need to read the aduio files as arrays
|
222 |
def speech_file_to_array_fn(batch):
|
223 |
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
224 |
speech_array = speech_array.squeeze().numpy()
|
@@ -227,6 +141,7 @@ def speech_file_to_array_fn(batch):
|
|
227 |
batch["speech"] = speech_array
|
228 |
return batch
|
229 |
|
|
|
230 |
def predict(batch):
|
231 |
features = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
232 |
|
@@ -240,15 +155,24 @@ def predict(batch):
|
|
240 |
|
241 |
batch["predicted"] = processor.batch_decode(pred_ids)[0]
|
242 |
return batch
|
243 |
-
|
244 |
-
dataset = dataset.map(remove_special_characters)
|
245 |
-
dataset = dataset.map(speech_file_to_array_fn, remove_columns=list(set(dataset.column_names) - set(['sentence', 'path'])))
|
246 |
-
result = dataset.map(predict)
|
247 |
```
|
248 |
|
249 |
## Prediction
|
250 |
|
251 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
max_items = np.random.randint(0, len(result), 20).tolist()
|
253 |
for i in max_items:
|
254 |
reference, predicted = result["sentence"][i], result["predicted"][i]
|
@@ -257,6 +181,7 @@ for i in max_items:
|
|
257 |
print('---')
|
258 |
```
|
259 |
|
|
|
260 |
```text
|
261 |
reference: اطلاعات مسری است
|
262 |
predicted: اطلاعات مسری است
|
@@ -321,9 +246,22 @@ predicted: من سفر کردم را دوست دارم
|
|
321 |
|
322 |
## Evaluation
|
323 |
|
|
|
|
|
|
|
|
|
|
|
324 |
```python
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
wer = load_metric("wer")
|
329 |
cer = load_metric("./cer")
|
@@ -332,9 +270,11 @@ print("WER: {:.2f}".format(100 * wer.compute(predictions=result["predicted"], re
|
|
332 |
print("CER: {:.2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["sentence"])))
|
333 |
```
|
334 |
|
335 |
-
**
|
336 |
-
|
337 |
-
|
|
|
|
|
338 |
|
339 |
|
340 |
## Training
|
|
|
37 |
|
38 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) in Persian (Farsi) using [Common Voice](https://huggingface.co/datasets/common_voice). When using this model, make sure that your speech input is sampled at 16kHz.
|
39 |
|
40 |
+
## How To Use
|
41 |
The model can be used directly (without a language model) as follows:
|
42 |
|
43 |
+
### Requirements
|
44 |
```bash
|
45 |
+
# requirement packages
|
46 |
!pip install git+https://github.com/huggingface/datasets.git
|
47 |
!pip install git+https://github.com/huggingface/transformers.git
|
48 |
!pip install torchaudio
|
|
|
51 |
!pip install hazm
|
52 |
```
|
53 |
|
54 |
+
### Preprocessing
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
```python
|
57 |
+
# preprocessing the datasets.
|
58 |
+
# normalizing the texts
|
|
|
59 |
|
60 |
import hazm
|
|
|
|
|
|
|
|
|
|
|
61 |
import re
|
62 |
+
import string
|
63 |
|
|
|
64 |
|
65 |
+
_normalizer = hazm.Normalizer()
|
66 |
+
chars_to_ignore = [
|
67 |
+
",", "?", ".", "!", "-", ";", ":", '""', "%", "'", '"', "�",
|
68 |
+
"#", "!", "؟", "?", "«", "»", "ء", "،", "(", ")", "؛", "'ٔ", "٬",'ٔ', ",", "?",
|
69 |
+
".", "!", "-", ";", ":",'"',"“", "%", "‘", "”", "�", "–", "…", "_", "”", '“', '„'
|
70 |
+
]
|
71 |
+
|
72 |
+
# In case of farsi
|
73 |
+
# chars_to_ignore = chars_to_ignore + list(string.ascii_lowercase + string.digits)
|
74 |
+
|
75 |
+
chars_to_mapping = {
|
76 |
+
'ك': 'ک', 'دِ': 'د', 'بِ': 'ب', 'زِ': 'ز', 'ذِ': 'ذ', 'شِ': 'ش', 'سِ': 'س', 'ى': 'ی',
|
77 |
+
'ي': 'ی', 'أ': 'ا', 'ؤ': 'و', "ے": "ی", "ۀ": "ه", "ﭘ": "پ", "ﮐ": "ک", "ﯽ": "ی",
|
78 |
+
"ﺎ": "ا", "ﺑ": "ب", "ﺘ": "ت", "ﺧ": "خ", "ﺩ": "د", "ﺱ": "س", "ﻀ": "ض", "ﻌ": "ع",
|
79 |
+
"ﻟ": "ل", "ﻡ": "م", "ﻢ": "م", "ﻪ": "ه", "ﻮ": "و", "ئ": "ی", 'ﺍ': "ا", 'ة': "ه",
|
80 |
+
'ﯾ': "ی", 'ﯿ': "ی", 'ﺒ': "ب", 'ﺖ': "ت", 'ﺪ': "د", 'ﺮ': "ر", 'ﺴ': "س", 'ﺷ': "ش",
|
81 |
+
'ﺸ': "ش", 'ﻋ': "ع", 'ﻤ': "م", 'ﻥ': "ن", 'ﻧ': "ن", 'ﻭ': "و", 'ﺭ': "ر", "ﮔ": "گ",
|
82 |
+
"\u200c": " ", "\u200d": " ", "\u200e": " ", "\u200f": " ", "\ufeff": " ",
|
83 |
+
}
|
84 |
+
|
85 |
+
def multiple_replace(text, chars_to_mapping):
|
86 |
+
pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
|
87 |
+
return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
|
88 |
+
|
89 |
+
def remove_special_characters(text, chars_to_ignore_regex):
|
90 |
+
text = re.sub(chars_to_ignore_regex, '', text).lower() + " "
|
91 |
+
return text
|
92 |
|
93 |
+
def normalizer(batch, chars_to_ignore, chars_to_mapping):
|
94 |
+
chars_to_ignore_regex = f"""[{"".join(chars_to_ignore)}]"""
|
95 |
+
text = batch["sentence"].lower().strip()
|
96 |
+
|
97 |
+
text = _normalizer.normalize(text)
|
98 |
+
text = multiple_replace(text, chars_to_mapping)
|
99 |
+
text = remove_special_characters(text, chars_to_ignore_regex)
|
100 |
|
101 |
+
batch["sentence"] = text
|
102 |
+
return batch
|
103 |
+
```
|
104 |
|
105 |
+
### Loading The Data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
```python
|
108 |
+
from datasets import load_dataset
|
109 |
|
110 |
+
dataset = load_dataset("common_voice", "fa", split="test[:1%]")
|
111 |
+
print(dataset)
|
112 |
+
```
|
113 |
|
114 |
+
**Output:**
|
115 |
+
```text
|
116 |
+
>>>
|
117 |
+
Dataset({
|
118 |
+
features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
|
119 |
+
num_rows: 52
|
120 |
+
})
|
121 |
+
```
|
122 |
|
123 |
+
### Model
|
124 |
|
125 |
+
```python
|
126 |
+
import librosa
|
127 |
+
import torch
|
128 |
+
import torchaudio
|
129 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
130 |
+
|
131 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
132 |
+
processor = Wav2Vec2Processor.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian")
|
133 |
+
model = Wav2Vec2ForCTC.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian").to(device)
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
|
|
136 |
def speech_file_to_array_fn(batch):
|
137 |
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
138 |
speech_array = speech_array.squeeze().numpy()
|
|
|
141 |
batch["speech"] = speech_array
|
142 |
return batch
|
143 |
|
144 |
+
|
145 |
def predict(batch):
|
146 |
features = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
147 |
|
|
|
155 |
|
156 |
batch["predicted"] = processor.batch_decode(pred_ids)[0]
|
157 |
return batch
|
|
|
|
|
|
|
|
|
158 |
```
|
159 |
|
160 |
## Prediction
|
161 |
|
162 |
```python
|
163 |
+
import IPython.display as ipd
|
164 |
+
|
165 |
+
|
166 |
+
dataset = load_dataset("common_voice", "fa", split="test[:1%]")
|
167 |
+
dataset = dataset.map(
|
168 |
+
normalizer,
|
169 |
+
fn_kwargs={"chars_to_ignore": chars_to_ignore, "chars_to_mapping": chars_to_mapping},
|
170 |
+
remove_columns=list(set(dataset.column_names) - set(['sentence', 'path']))
|
171 |
+
)
|
172 |
+
|
173 |
+
dataset = dataset.map(speech_file_to_array_fn)
|
174 |
+
result = dataset.map(predict)
|
175 |
+
|
176 |
max_items = np.random.randint(0, len(result), 20).tolist()
|
177 |
for i in max_items:
|
178 |
reference, predicted = result["sentence"][i], result["predicted"][i]
|
|
|
181 |
print('---')
|
182 |
```
|
183 |
|
184 |
+
**Output:**
|
185 |
```text
|
186 |
reference: اطلاعات مسری است
|
187 |
predicted: اطلاعات مسری است
|
|
|
246 |
|
247 |
## Evaluation
|
248 |
|
249 |
+
```bash
|
250 |
+
mkdir cer
|
251 |
+
wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
|
252 |
+
```
|
253 |
+
|
254 |
```python
|
255 |
+
from datasets import load_metric
|
256 |
+
|
257 |
+
dataset = load_dataset("common_voice", "fa", split="test")
|
258 |
+
dataset = dataset.map(
|
259 |
+
normalizer,
|
260 |
+
fn_kwargs={"chars_to_ignore": chars_to_ignore, "chars_to_mapping": chars_to_mapping},
|
261 |
+
remove_columns=list(set(dataset.column_names) - set(['sentence', 'path']))
|
262 |
+
)
|
263 |
+
dataset = dataset.map(speech_file_to_array_fn)
|
264 |
+
result = dataset.map(predict)
|
265 |
|
266 |
wer = load_metric("wer")
|
267 |
cer = load_metric("./cer")
|
|
|
270 |
print("CER: {:.2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["sentence"])))
|
271 |
```
|
272 |
|
273 |
+
**Output:**
|
274 |
+
```text
|
275 |
+
WER: 32.09%
|
276 |
+
CER: 8.23%
|
277 |
+
```
|
278 |
|
279 |
|
280 |
## Training
|