bakrianoo commited on
Commit
65b4bd2
1 Parent(s): 657acba

Update WER score + Publish Benchmark Results

Browse files
Files changed (1) hide show
  1. README.md +144 -11
README.md CHANGED
@@ -23,7 +23,7 @@ model-index:
23
  metrics:
24
  - name: Test WER
25
  type: wer
26
- value: 23.70
27
  ---
28
 
29
  # Sinai Voice Arabic Speech Recognition Model
@@ -31,12 +31,136 @@ model-index:
31
  Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
32
  on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice)
33
 
34
-
35
- ## Usage
36
 
37
  Please install:
38
  - [PyTorch](https://pytorch.org/)
39
- - `$ pip3 install jiwer lang_trans torchaudio datasets transformers`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  The model can be used directly (without a language model) as follows:
42
  ```python
@@ -51,10 +175,15 @@ resamplers = { # all three sampling rates exist in test split
51
  44100: torchaudio.transforms.Resample(44100, 16000),
52
  32000: torchaudio.transforms.Resample(32000, 16000),
53
  }
 
54
  def prepare_example(example):
55
  speech, sampling_rate = torchaudio.load(example["path"])
56
- example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
 
 
 
57
  return example
 
58
  dataset = dataset.map(prepare_example)
59
  processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
60
  model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
@@ -103,9 +232,8 @@ predicted: أين المشكل
103
  reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ
104
  predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون
105
  ```
106
- ## Evaluation
107
 
108
- CLONED from [elgeish/wav2vec2-large-xlsr-53-arabic](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic)
109
 
110
  The model can be evaluated as follows on the Arabic test data of Common Voice:
111
  ```python
@@ -122,10 +250,15 @@ resamplers = { # all three sampling rates exist in test split
122
  44100: torchaudio.transforms.Resample(44100, 16000),
123
  32000: torchaudio.transforms.Resample(32000, 16000),
124
  }
 
125
  def prepare_example(example):
126
  speech, sampling_rate = torchaudio.load(example["path"])
127
- example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
 
 
 
128
  return example
 
129
  test_split = test_split.map(prepare_example)
130
  processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
131
  model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
@@ -141,8 +274,8 @@ test_split = test_split.map(predict, batched=True, batch_size=16, remove_columns
141
  transformation = jiwer.Compose([
142
  # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
143
  jiwer.SubstituteRegexes({
144
- r'[auiFNKo\~_،؟»\?;:\-,\.؛«!"]': "", "\u06D6": "",
145
- r"[\|\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
146
  # default transformation below
147
  jiwer.RemoveMultipleSpaces(),
148
  jiwer.Strip(),
@@ -158,7 +291,7 @@ metrics = jiwer.compute_measures(
158
  )
159
  print(f"WER: {metrics['wer']:.2%}")
160
  ```
161
- **Test Result**: 23.70%
162
 
163
 
164
  ## Other Arabic Voice recognition Models
 
23
  metrics:
24
  - name: Test WER
25
  type: wer
26
+ value: 23.80
27
  ---
28
 
29
  # Sinai Voice Arabic Speech Recognition Model
 
31
  Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
32
  on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice)
33
 
34
+ Most of evaluation codes in this documentation are INSPIRED by [elgeish/wav2vec2-large-xlsr-53-arabic](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic)
 
35
 
36
  Please install:
37
  - [PyTorch](https://pytorch.org/)
38
+ - `$ pip3 install jiwer lang_trans torchaudio datasets transformers pandas tqdm`
39
+
40
+ ## Benchmark
41
+
42
+ We evaluated the model against different Arabic-STT Wav2Vec models.
43
+
44
+ | | model | using_transliation | WER |
45
+ |---:|:--------------------------------------|:---------------------|---------:|
46
+ | 0 | bakrianoo/sinai-voice-ar-stt | True | 0.238001 |
47
+ | 1 | elgeish/wav2vec2-large-xlsr-53-arabic | True | 0.266527 |
48
+ | 2 | othrif/wav2vec2-large-xlsr-arabic | True | 0.298122 |
49
+ | 3 | bakrianoo/sinai-voice-ar-stt | False | 0.448987 |
50
+ | 4 | othrif/wav2vec2-large-xlsr-arabic | False | 0.464004 |
51
+ | 5 | anas/wav2vec2-large-xlsr-arabic | True | 0.506191 |
52
+ | 6 | anas/wav2vec2-large-xlsr-arabic | False | 0.622288 |
53
+
54
+
55
+ <details>
56
+ <summary>We used the following <b>CODE</b> to generate the above results</summary>
57
+
58
+ ```python
59
+ import jiwer
60
+ import torch
61
+ from tqdm.auto import tqdm
62
+ import torchaudio
63
+ from datasets import load_dataset
64
+ from lang_trans.arabic import buckwalter
65
+ from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
66
+ import pandas as pd
67
+
68
+ # load test dataset
69
+ set_seed(42)
70
+ test_split = load_dataset("common_voice", "ar", split="test")
71
+
72
+ # init sample rate resamplers
73
+ resamplers = { # all three sampling rates exist in test split
74
+ 48000: torchaudio.transforms.Resample(48000, 16000),
75
+ 44100: torchaudio.transforms.Resample(44100, 16000),
76
+ 32000: torchaudio.transforms.Resample(32000, 16000),
77
+ }
78
+
79
+ # WER composer
80
+ transformation = jiwer.Compose([
81
+ # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
82
+ jiwer.SubstituteRegexes({
83
+ r'[auiFNKo\~_،؟»\?;:\-,\.؛«!"]': "", "\u06D6": "",
84
+ r"[\|\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
85
+ # default transformation below
86
+ jiwer.RemoveMultipleSpaces(),
87
+ jiwer.Strip(),
88
+ jiwer.SentencesToListOfWords(),
89
+ jiwer.RemoveEmptyStrings(),
90
+ ])
91
+
92
+ def prepare_example(example):
93
+ speech, sampling_rate = torchaudio.load(example["path"])
94
+ if sampling_rate in resamplers:
95
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
96
+ else:
97
+ example["speech"] = resamplers[4800](speech).squeeze().numpy()
98
+ return example
99
+
100
+ def predict(batch):
101
+ inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
102
+ with torch.no_grad():
103
+ predicted = torch.argmax(model(inputs.input_values.to("cuda")).logits, dim=-1)
104
+ predicted[predicted == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
105
+ batch["predicted"] = processor.batch_decode(predicted)
106
+ return batch
107
+
108
+ # prepare the test dataset
109
+ test_split = test_split.map(prepare_example)
110
+
111
+ stt_models = {
112
+ "elgeish/wav2vec2-large-xlsr-53-arabic",
113
+ "othrif/wav2vec2-large-xlsr-arabic",
114
+ "anas/wav2vec2-large-xlsr-arabic",
115
+ "bakrianoo/sinai-voice-ar-stt"
116
+ }
117
+
118
+ stt_results = []
119
+
120
+ for model_path in tqdm(stt_models):
121
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
122
+ model = Wav2Vec2ForCTC.from_pretrained(model_path).to("cuda").eval()
123
+
124
+ test_split_preds = test_split.map(predict, batched=True, batch_size=56, remove_columns=["speech"])
125
+
126
+ orig_metrics = jiwer.compute_measures(
127
+ truth=[s for s in test_split_preds["sentence"]],
128
+ hypothesis=[s for s in test_split_preds["predicted"]],
129
+ truth_transform=transformation,
130
+ hypothesis_transform=transformation,
131
+ )
132
+
133
+ trans_metrics = jiwer.compute_measures(
134
+ truth=[buckwalter.trans(s) for s in test_split_preds["sentence"]], # Buckwalter transliteration
135
+ hypothesis=[buckwalter.trans(s) for s in test_split_preds["predicted"]], # Buckwalter transliteration
136
+ truth_transform=transformation,
137
+ hypothesis_transform=transformation,
138
+ )
139
+
140
+ stt_results.append({
141
+ "model": model_path,
142
+ "using_transliation": True,
143
+ "WER": trans_metrics["wer"]
144
+ })
145
+
146
+ stt_results.append({
147
+ "model": model_path,
148
+ "using_transliation": False,
149
+ "WER": orig_metrics["wer"]
150
+ })
151
+
152
+ del model
153
+ del processor
154
+
155
+ stt_results_df = pd.DataFrame(stt_results)
156
+ stt_results_df = stt_results_df.sort_values('WER', axis=0, ascending=True)
157
+ stt_results_df.head(n=50)
158
+
159
+ ```
160
+ </details>
161
+
162
+
163
+ ## Usage
164
 
165
  The model can be used directly (without a language model) as follows:
166
  ```python
 
175
  44100: torchaudio.transforms.Resample(44100, 16000),
176
  32000: torchaudio.transforms.Resample(32000, 16000),
177
  }
178
+
179
  def prepare_example(example):
180
  speech, sampling_rate = torchaudio.load(example["path"])
181
+ if sampling_rate in resamplers:
182
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
183
+ else:
184
+ example["speech"] = resamplers[4800](speech).squeeze().numpy()
185
  return example
186
+
187
  dataset = dataset.map(prepare_example)
188
  processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
189
  model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
 
232
  reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ
233
  predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون
234
  ```
 
235
 
236
+ ## Evaluation
237
 
238
  The model can be evaluated as follows on the Arabic test data of Common Voice:
239
  ```python
 
250
  44100: torchaudio.transforms.Resample(44100, 16000),
251
  32000: torchaudio.transforms.Resample(32000, 16000),
252
  }
253
+
254
  def prepare_example(example):
255
  speech, sampling_rate = torchaudio.load(example["path"])
256
+ if sampling_rate in resamplers:
257
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
258
+ else:
259
+ example["speech"] = resamplers[4800](speech).squeeze().numpy()
260
  return example
261
+
262
  test_split = test_split.map(prepare_example)
263
  processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
264
  model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
 
274
  transformation = jiwer.Compose([
275
  # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
276
  jiwer.SubstituteRegexes({
277
+ r'[auiFNKo\\~_،؟»\\?;:\\-,\\.؛«!"]': "", "\\u06D6": "",
278
+ r"[\\|\\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
279
  # default transformation below
280
  jiwer.RemoveMultipleSpaces(),
281
  jiwer.Strip(),
 
291
  )
292
  print(f"WER: {metrics['wer']:.2%}")
293
  ```
294
+ **Test Result**: 23.80%
295
 
296
 
297
  ## Other Arabic Voice recognition Models