lucio commited on
Commit
3dd85cc
1 Parent(s): 210b2ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -10
README.md CHANGED
@@ -23,12 +23,12 @@ model-index:
23
  metrics:
24
  - name: Test WER
25
  type: wer
26
- value: 47.99
27
  ---
28
 
29
  # Wav2Vec2-Large-XLSR-53-rw
30
 
31
- Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Kinyarwanda using the [Common Voice](https://huggingface.co/datasets/common_voice) dataset, using the validation set for training, and taking 12% of the test data for validation.
32
  When using this model, make sure that your speech input is sampled at 16kHz.
33
 
34
  ## Usage
@@ -68,6 +68,8 @@ print("Prediction:", processor.batch_decode(predicted_ids))
68
  print("Reference:", test_dataset["sentence"][:2])
69
  ```
70
 
 
 
71
 
72
  ## Evaluation
73
 
@@ -81,6 +83,7 @@ import torchaudio
81
  from datasets import load_dataset, load_metric
82
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
83
  import re
 
84
 
85
  test_dataset = load_dataset("common_voice", "rw", split="test")
86
  wer = load_metric("wer")
@@ -89,18 +92,30 @@ processor = Wav2Vec2Processor.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarw
89
  model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
90
  model.to("cuda")
91
 
92
- chars_to_ignore_regex = '[\\[\\],?.!;:%\\'"‘’“”(){}‟ˮ´ʺ″«»/…‽�–-]'
 
 
 
 
 
 
 
 
 
93
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
94
 
95
- # Preprocessing the datasets.
96
- # We need to read the audio files as arrays
97
  def speech_file_to_array_fn(batch):
98
- batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
99
  speech_array, sampling_rate = torchaudio.load(batch["path"])
100
  batch["speech"] = resampler(speech_array).squeeze().numpy()
 
101
  return batch
102
 
103
- test_dataset = test_dataset.map(speech_file_to_array_fn, remove_columns=['path'])
 
 
 
 
 
104
 
105
  # Preprocessing the datasets.
106
  # We need to read the audio files as arrays
@@ -134,10 +149,11 @@ def chunked_wer(targets, predictions, chunk_size=None):
134
  print("WER: {:2f}".format(100 * chunked_wer(result["sentence"], result["pred_strings"], chunk_size=4000)))
135
  ```
136
 
137
- **Test Result**: 47.99 %
 
138
 
139
  ## Training
140
 
141
- The Common Voice `validation` dataset was used for training, with 12% of the test dataset used for validation, trained on 1 V100 GPU for 48 hours (20 epochs).
142
 
143
- The script used for training was just the `run_finetuning.py` script provided in OVHcloud's `databuzzword/hf-wav2vec` image.
 
23
  metrics:
24
  - name: Test WER
25
  type: wer
26
+ value: 40.59
27
  ---
28
 
29
  # Wav2Vec2-Large-XLSR-53-rw
30
 
31
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Kinyarwanda using the [Common Voice](https://huggingface.co/datasets/common_voice) dataset, using about 20% of the training data (limited to utterances without downvotes and shorter with 9.5 seconds), and validated on 2048 utterances from the validation set.
32
  When using this model, make sure that your speech input is sampled at 16kHz.
33
 
34
  ## Usage
 
68
  print("Reference:", test_dataset["sentence"][:2])
69
  ```
70
 
71
+ Prediction: ['yaherukaga gukora igitaramo y iki mu jyiwa na mul mumbiliki', 'ini rero ntibizashoboka ka nibo nkunrabibzi']
72
+ Reference: ['Yaherukaga gukora igitaramo nk’iki mu Mujyi wa Namur mu Bubiligi.', 'Ibi rero, ntibizashoboka, kandi nawe arabizi.']
73
 
74
  ## Evaluation
75
 
 
83
  from datasets import load_dataset, load_metric
84
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
85
  import re
86
+ import unidecode
87
 
88
  test_dataset = load_dataset("common_voice", "rw", split="test")
89
  wer = load_metric("wer")
 
92
  model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
93
  model.to("cuda")
94
 
95
+ chars_to_ignore_regex = r'[!"#$%&()*+,./:;<=>?@\[\]\\_{}|~£¤¨©ª«¬®¯°·¸»¼½¾ðʺ˜˝ˮ‐–—―‚“”„‟•…″‽₋€™−√�]'
96
+
97
+ def remove_special_characters(batch):
98
+ batch["text"] = re.sub(r'[ʻʽʼ‘’´`]', r"'", batch["sentence"]) # normalize apostrophes
99
+ batch["text"] = re.sub(chars_to_ignore_regex, "", batch["text"]).lower().strip() # remove all other punctuation
100
+ batch["text"] = re.sub(r"(-| ?' ?| +)", " ", batch["text"]) # treat dash and apostrophe as word boundary
101
+ batch["text"] = unidecode.unidecode(batch["text"]) # strip accents
102
+ return batch
103
+
104
+ ## Audio pre-processing
105
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
106
 
 
 
107
  def speech_file_to_array_fn(batch):
 
108
  speech_array, sampling_rate = torchaudio.load(batch["path"])
109
  batch["speech"] = resampler(speech_array).squeeze().numpy()
110
+ batch["sampling_rate"] = 16_000
111
  return batch
112
 
113
+ def cv_prepare(batch):
114
+ batch = remove_special_characters(batch)
115
+ batch = speech_file_to_array_fn(batch)
116
+ return batch
117
+
118
+ test_dataset = test_dataset.map(cv_prepare)
119
 
120
  # Preprocessing the datasets.
121
  # We need to read the audio files as arrays
 
149
  print("WER: {:2f}".format(100 * chunked_wer(result["sentence"], result["pred_strings"], chunk_size=4000)))
150
  ```
151
 
152
+ **Test Result**: 40.59 %
153
+
154
 
155
  ## Training
156
 
157
+ Blocks of examples from the Common Voice training dataset (totaling about 100k examples, 20% of the available data) were used for training for 30k global steps, on 1 V100 GPU provided by OVHcloud. For validation, 2048 examples of the validation dataset were used.
158
 
159
+ The [script used for training](https://github.com/serapio/transformers/blob/feature/xlsr-finetune/examples/research_projects/wav2vec2/run_common_voice.py) is adapted from the [example script provided in the transformers repo](https://github.com/huggingface/transformers/blob/master/examples/research_projects/wav2vec2/run_common_voice.py).