maxidl commited on
Commit
45a4f3a
1 Parent(s): 4246f93

add chunked wer to eval script

Browse files
Files changed (1) hide show
  1. README.md +24 -3
README.md CHANGED
@@ -114,8 +114,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
114
  """
115
  Evaluation on the full test set:
116
  - takes ~20mins (RTX 3090).
117
- - requires ~170GB RAM to compute the WER. A potential solution to this is computing it in chunks.
118
- See https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5 on how to implement this.
119
  """
120
  test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample
121
  wer = load_metric("wer")
@@ -151,8 +150,30 @@ def evaluate(batch):
151
 
152
  result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
153
 
154
- print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
 
155
  # WER: 12.615308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  ```
157
 
158
  **Test Result**: 12.62 %
 
114
  """
115
  Evaluation on the full test set:
116
  - takes ~20mins (RTX 3090).
117
+ - requires ~170GB RAM to compute the WER. Below, we use a chunked implementation of WER to avoid large RAM consumption.
 
118
  """
119
  test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample
120
  wer = load_metric("wer")
 
150
 
151
  result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
152
 
153
+ # non-chunked version:
154
+ # print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
155
  # WER: 12.615308
156
+
157
+ # Chunked version, see https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5:
158
+ import jiwer
159
+
160
+ def chunked_wer(targets, predictions, chunk_size=None):
161
+ if chunk_size is None: return jiwer.wer(targets, predictions)
162
+ start = 0
163
+ end = chunk_size
164
+ H, S, D, I = 0, 0, 0, 0
165
+ while start < len(targets):
166
+ chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
167
+ H = H + chunk_metrics["hits"]
168
+ S = S + chunk_metrics["substitutions"]
169
+ D = D + chunk_metrics["deletions"]
170
+ I = I + chunk_metrics["insertions"]
171
+ start += chunk_size
172
+ end += chunk_size
173
+ return float(S + D + I) / float(H + S + D)
174
+
175
+ print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pred_strings"], result["sentence"], chunk_size=1000)))
176
+ # Total (chunk=1000), WER: 12.768981
177
  ```
178
 
179
  **Test Result**: 12.62 %