jonatasgrosman commited on
Commit
1b7d8af
1 Parent(s): 58f7139

add evaluation

Browse files
README.md CHANGED
@@ -22,16 +22,36 @@ model-index:
22
  metrics:
23
  - name: Test WER
24
  type: wer
25
- value: 9.88
26
  - name: Test CER
27
  type: cer
28
- value: 2.32
29
  - name: Test WER (+LM)
30
  type: wer
31
- value: 7.07
32
  - name: Test CER (+LM)
33
  type: cer
34
- value: 1.88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ---
36
 
37
  # XLS-R-1B-RUSSIAN
22
  metrics:
23
  - name: Test WER
24
  type: wer
25
+ value: 16.52
26
  - name: Test CER
27
  type: cer
28
+ value: 4.62
29
  - name: Test WER (+LM)
30
  type: wer
31
+ value: 12.46
32
  - name: Test CER (+LM)
33
  type: cer
34
+ value: 3.98
35
+ - task:
36
+ name: Automatic Speech Recognition
37
+ type: automatic-speech-recognition
38
+ dataset:
39
+ name: Robust Speech Event - Dev Data
40
+ type: speech-recognition-community-v2/dev_data
41
+ args: sv
42
+ metrics:
43
+ - name: Test WER
44
+ type: wer
45
+ value: 23.96
46
+ - name: Test CER
47
+ type: cer
48
+ value: 8.88
49
+ - name: Test WER (+LM)
50
+ type: wer
51
+ value: 15.88
52
+ - name: Test CER (+LM)
53
+ type: cer
54
+ value: 7.42
55
  ---
56
 
57
  # XLS-R-1B-RUSSIAN
eval.py CHANGED
@@ -1,12 +1,11 @@
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset, load_metric, Audio, Dataset
3
- from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer
4
  import re
5
  import torch
6
  import argparse
7
  from typing import Dict
8
 
9
-
10
  def log_results(result: Dataset, args: Dict[str, str]):
11
  """ DO NOT CHANGE. This function computes and logs the result metrics. """
12
 
@@ -68,17 +67,30 @@ def main(args):
68
  # dataset = dataset.select(range(10))
69
 
70
  # load processor
71
- feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
72
- sampling_rate = feature_extractor.sampling_rate
 
 
 
 
 
 
 
73
 
74
  # resample audio
75
- dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
76
 
77
  # load eval pipeline
78
  if args.device is None:
79
  args.device = 0 if torch.cuda.is_available() else -1
80
- asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
81
-
 
 
 
 
 
 
82
  # build normalizer config
83
  tokenizer = AutoTokenizer.from_pretrained(args.model_id)
84
  tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
@@ -106,6 +118,9 @@ def main(args):
106
  # run inference on all examples
107
  result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
108
 
 
 
 
109
  # compute and log_results
110
  # do not change function below
111
  log_results(result, args)
@@ -135,6 +150,9 @@ if __name__ == "__main__":
135
  parser.add_argument(
136
  "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
137
  )
 
 
 
138
  parser.add_argument(
139
  "--device",
140
  type=int,
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset, load_metric, Audio, Dataset
3
+ from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, AutoConfig, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
4
  import re
5
  import torch
6
  import argparse
7
  from typing import Dict
8
 
 
9
  def log_results(result: Dataset, args: Dict[str, str]):
10
  """ DO NOT CHANGE. This function computes and logs the result metrics. """
11
 
67
  # dataset = dataset.select(range(10))
68
 
69
  # load processor
70
+ if args.greedy:
71
+ processor = Wav2Vec2Processor.from_pretrained(args.model_id)
72
+ decoder = None
73
+ else:
74
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
75
+ decoder = processor.decoder
76
+
77
+ feature_extractor = processor.feature_extractor
78
+ tokenizer = processor.tokenizer
79
 
80
  # resample audio
81
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
82
 
83
  # load eval pipeline
84
  if args.device is None:
85
  args.device = 0 if torch.cuda.is_available() else -1
86
+
87
+ config = AutoConfig.from_pretrained(args.model_id)
88
+ model = AutoModelForCTC.from_pretrained(args.model_id)
89
+
90
+ #asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
91
+ asr = pipeline("automatic-speech-recognition", config=config, model=model, tokenizer=tokenizer,
92
+ feature_extractor=feature_extractor, decoder=decoder, device=args.device)
93
+
94
  # build normalizer config
95
  tokenizer = AutoTokenizer.from_pretrained(args.model_id)
96
  tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
118
  # run inference on all examples
119
  result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
120
 
121
+ # filtering out empty targets
122
+ result = result.filter(lambda example: example["target"] != "")
123
+
124
  # compute and log_results
125
  # do not change function below
126
  log_results(result, args)
150
  parser.add_argument(
151
  "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
152
  )
153
+ parser.add_argument(
154
+ "--greedy", action='store_true', help="If defined, the LM will be ignored during inference."
155
+ )
156
  parser.add_argument(
157
  "--device",
158
  type=int,
full_eval.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CV 8 - TEST
2
+
3
+ python eval.py --model_id jonatasgrosman/wav2vec2-xls-r-1b-russian --dataset mozilla-foundation/common_voice_8_0 --config ru --split test --log_outputs --greedy
4
+ mv log_mozilla-foundation_common_voice_8_0_ru_test_predictions.txt log_mozilla-foundation_common_voice_8_0_ru_test_predictions_greedy.txt
5
+ mv mozilla-foundation_common_voice_8_0_ru_test_eval_results.txt mozilla-foundation_common_voice_8_0_ru_test_eval_results_greedy.txt
6
+
7
+ python eval.py --model_id jonatasgrosman/wav2vec2-xls-r-1b-russian --dataset mozilla-foundation/common_voice_8_0 --config ru --split test --log_outputs
8
+
9
+ # HF EVENT - DEV
10
+
11
+ python eval.py --model_id jonatasgrosman/wav2vec2-xls-r-1b-russian --dataset speech-recognition-community-v2/dev_data --config ru --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs --greedy
12
+ mv log_speech-recognition-community-v2_dev_data_ru_validation_predictions.txt log_speech-recognition-community-v2_dev_data_ru_validation_predictions_greedy.txt
13
+ mv speech-recognition-community-v2_dev_data_ru_validation_eval_results.txt speech-recognition-community-v2_dev_data_ru_validation_eval_results_greedy.txt
14
+
15
+ python eval.py --model_id jonatasgrosman/wav2vec2-xls-r-1b-russian --dataset speech-recognition-community-v2/dev_data --config ru --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs
log_mozilla-foundation_common_voice_8_0_ru_test_predictions.txt CHANGED
The diff for this file is too large to render. See raw diff
log_mozilla-foundation_common_voice_8_0_ru_test_predictions_greedy.txt CHANGED
The diff for this file is too large to render. See raw diff
log_mozilla-foundation_common_voice_8_0_ru_test_targets.txt CHANGED
The diff for this file is too large to render. See raw diff
log_speech-recognition-community-v2_dev_data_ru_validation_predictions.txt ADDED
The diff for this file is too large to render. See raw diff
log_speech-recognition-community-v2_dev_data_ru_validation_predictions_greedy.txt ADDED
The diff for this file is too large to render. See raw diff
log_speech-recognition-community-v2_dev_data_ru_validation_targets.txt ADDED
The diff for this file is too large to render. See raw diff
mozilla-foundation_common_voice_8_0_ru_test_eval_results.txt CHANGED
@@ -1,2 +1,2 @@
1
- WER: 0.07074739140985198
2
- CER: 0.01881101267576011
1
+ WER: 0.12460567823343849
2
+ CER: 0.039853492631028165
mozilla-foundation_common_voice_8_0_ru_test_eval_results_greedy.txt CHANGED
@@ -1,2 +1,2 @@
1
- WER: 0.09881096821159913
2
- CER: 0.02324829818558626
1
+ WER: 0.16526328561028877
2
+ CER: 0.04621616676669796
speech-recognition-community-v2_dev_data_ru_validation_eval_results.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ WER: 0.15887177213786988
2
+ CER: 0.07426400420951243
speech-recognition-community-v2_dev_data_ru_validation_eval_results_greedy.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ WER: 0.23960988584727919
2
+ CER: 0.08884568391199266