dmusingu commited on
Commit
ace64ac
1 Parent(s): ad21e49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -0
README.md CHANGED
@@ -117,3 +117,60 @@ print("Reference:", test_dataset["sentence"][:2])
117
  ```
118
 
119
  ### Evaluation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  ```
118
 
119
  ### Evaluation
120
+
121
+ The model can be evaluated as follows on the Luganda test dataset.
122
+
123
+ ```python
124
+ import torch
125
+ import torchaudio
126
+ from datasets import load_dataset, load_metric
127
+ from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
128
+ import re
129
+
130
+ test_dataset = load_dataset("common_voice", "lg", split="test")
131
+ wer = load_metric("wer")
132
+
133
+ model = AutoModelForCTC.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0").to('cuda')
134
+ processor = Wav2Vec2BertProcessor.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0")
135
+
136
+ chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«]'
137
+
138
+ test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))
139
+
140
+ def remove_special_characters(batch):
141
+ # remove special characters
142
+ batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
143
+
144
+ return batch
145
+
146
+ test_dataset = test_dataset.map(remove_special_characters)
147
+
148
+ def prepare_dataset(batch):
149
+ audio = batch["audio"]
150
+ batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
151
+ batch["input_length"] = len(batch["input_features"])
152
+
153
+ batch["labels"] = processor(text=batch["sentence"]).input_ids
154
+ return batch
155
+
156
+ test_dataset = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names)
157
+
158
+ # Evaluation is carried out with a batch size of 1
159
+ def map_to_result(batch):
160
+ with torch.no_grad():
161
+ input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0)
162
+ logits = model(input_values).logits
163
+
164
+ pred_ids = torch.argmax(logits, dim=-1)
165
+ batch["pred_str"] = processor.batch_decode(pred_ids)[0]
166
+ batch["text"] = processor.decode(batch["labels"], group_tokens=False)
167
+
168
+ return batch
169
+
170
+ results = test_dataset.map(map_to_result)
171
+
172
+ print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))
173
+ ```
174
+
175
+ ### Test Result: 19.4%
176
+