bofenghuang commited on
Commit
1b3874c
1 Parent(s): 347f330

add eval script

Browse files
Files changed (2) hide show
  1. README.md +27 -16
  2. eval.py +182 -0
README.md CHANGED
@@ -11,7 +11,7 @@ tags:
11
  datasets:
12
  - mozilla-foundation/common_voice_9_0
13
  model-index:
14
- - name: wav2vec2-xls-r-1b-ft-cv9-fr
15
  results:
16
  - task:
17
  name: Automatic Speech Recognition
@@ -55,28 +55,14 @@ model-index:
55
  value: 11.09
56
  ---
57
 
58
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
59
- should probably proofread and complete it, then remove this comment. -->
60
 
61
- # wav2vec2-xls-r-1b-ft-cv9-fr
62
 
63
  This model is a fine-tuned version of [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) on the MOZILLA-FOUNDATION/COMMON_VOICE_9_0 - FR dataset.
64
  It achieves the following results on the evaluation set:
65
  - Loss: 0.1430
66
  - Wer: 0.1245
67
 
68
- ## Model description
69
-
70
- More information needed
71
-
72
- ## Intended uses & limitations
73
-
74
- More information needed
75
-
76
- ## Training and evaluation data
77
-
78
- More information needed
79
-
80
  ## Training procedure
81
 
82
  ### Training hyperparameters
@@ -171,6 +157,31 @@ The following hyperparameters were used during training:
171
  | 0.1052 | 9.84 | 35500 | 0.1428 | 0.1247 |
172
  | 0.1044 | 9.98 | 36000 | 0.1430 | 0.1245 |
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  ### Framework versions
176
 
 
11
  datasets:
12
  - mozilla-foundation/common_voice_9_0
13
  model-index:
14
+ - name: Fine-tuned Wav2Vec2 XLS-R 1B model for ASR in French
15
  results:
16
  - task:
17
  name: Automatic Speech Recognition
 
55
  value: 11.09
56
  ---
57
 
 
 
58
 
59
+ # Fine-tuned Wav2Vec2 XLS-R 1B model for ASR in French
60
 
61
  This model is a fine-tuned version of [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) on the MOZILLA-FOUNDATION/COMMON_VOICE_9_0 - FR dataset.
62
  It achieves the following results on the evaluation set:
63
  - Loss: 0.1430
64
  - Wer: 0.1245
65
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ## Training procedure
67
 
68
  ### Training hyperparameters
 
157
  | 0.1052 | 9.84 | 35500 | 0.1428 | 0.1247 |
158
  | 0.1044 | 9.98 | 36000 | 0.1430 | 0.1245 |
159
 
160
+ ## Evaluation
161
+
162
+ 1. To evaluate on `mozilla-foundation/common_voice_9_0`
163
+
164
+ ```bash
165
+ python eval.py \
166
+ --model_id "bhuang/wav2vec2-xls-r-1b-french" \
167
+ --dataset "mozilla-foundation/common_voice_9_0" \
168
+ --config "fr" \
169
+ --split "test" \
170
+ --log_outputs
171
+ ```
172
+
173
+ 2. To evaluate on `speech-recognition-community-v2/dev_data`
174
+
175
+ ```bash
176
+ python eval.py \
177
+ --model_id "bhuang/wav2vec2-xls-r-1b-french" \
178
+ --dataset "speech-recognition-community-v2/dev_data" \
179
+ --config "fr" \
180
+ --split "validation" \
181
+ --chunk_length_s 5.0 \
182
+ --stride_length_s 1.0 \
183
+ --log_outputs
184
+ ```
185
 
186
  ### Framework versions
187
 
eval.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import re
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from datasets import Audio, Dataset, load_dataset, load_metric
9
+
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoFeatureExtractor,
13
+ AutoModelForCTC,
14
+ AutoTokenizer,
15
+ Wav2Vec2Processor,
16
+ Wav2Vec2ProcessorWithLM,
17
+ pipeline,
18
+ )
19
+
20
+
21
+ def log_results(result: Dataset, args: Dict[str, str]):
22
+ """ DO NOT CHANGE. This function computes and logs the result metrics. """
23
+
24
+ log_outputs = args.log_outputs
25
+ dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
26
+
27
+ # load metric
28
+ wer = load_metric("wer")
29
+ cer = load_metric("cer")
30
+
31
+ # compute metrics
32
+ wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
33
+ cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
34
+
35
+ # print & log results
36
+ result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
37
+ print(result_str)
38
+
39
+ with open(f"{dataset_id}_eval_results.txt", "w") as f:
40
+ f.write(result_str)
41
+
42
+ # log all results in text file. Possibly interesting for analysis
43
+ if log_outputs is not None:
44
+ pred_file = f"log_{dataset_id}_predictions.txt"
45
+ target_file = f"log_{dataset_id}_targets.txt"
46
+
47
+ with open(pred_file, "w") as p, open(target_file, "w") as t:
48
+
49
+ # mapping function to write output
50
+ def write_to_file(batch, i):
51
+ p.write(f"{i}" + "\n")
52
+ p.write(batch["prediction"] + "\n")
53
+ t.write(f"{i}" + "\n")
54
+ t.write(batch["target"] + "\n")
55
+
56
+ result.map(write_to_file, with_indices=True)
57
+
58
+
59
+ def normalize_text(text: str, invalid_chars_regex: str) -> str:
60
+ """ DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
61
+
62
+ text = text.lower()
63
+ text = re.sub(r"’", "'", text)
64
+ text = re.sub(invalid_chars_regex, " ", text)
65
+ text = re.sub(r"\s+", " ", text).strip()
66
+
67
+ return text
68
+
69
+
70
+ def main(args):
71
+ # load dataset
72
+ dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
73
+
74
+ # for testing: only process the first two examples as a test
75
+ # dataset = dataset.select(range(10))
76
+
77
+ # load processor
78
+ if args.greedy:
79
+ processor = Wav2Vec2Processor.from_pretrained(args.model_id)
80
+ decoder = None
81
+ else:
82
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
83
+ decoder = processor.decoder
84
+
85
+ feature_extractor = processor.feature_extractor
86
+ tokenizer = processor.tokenizer
87
+ sampling_rate = feature_extractor.sampling_rate
88
+
89
+ # resample audio
90
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
91
+
92
+ # load eval pipeline
93
+ if args.device is None:
94
+ args.device = 0 if torch.cuda.is_available() else -1
95
+
96
+ config = AutoConfig.from_pretrained(args.model_id)
97
+ model = AutoModelForCTC.from_pretrained(args.model_id)
98
+
99
+ # asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
100
+ asr = pipeline(
101
+ "automatic-speech-recognition",
102
+ config=config,
103
+ model=model,
104
+ tokenizer=tokenizer,
105
+ feature_extractor=feature_extractor,
106
+ decoder=decoder,
107
+ device=args.device,
108
+ )
109
+
110
+ # build normalizer config
111
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id)
112
+ tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
113
+ special_tokens = [
114
+ tokenizer.pad_token,
115
+ tokenizer.word_delimiter_token,
116
+ tokenizer.unk_token,
117
+ tokenizer.bos_token,
118
+ tokenizer.eos_token,
119
+ ]
120
+ non_special_tokens = [x for x in tokens if x not in special_tokens]
121
+ invalid_chars_regex = f"[^\s{re.escape(''.join(set(non_special_tokens)))}]"
122
+
123
+ # normalize_to_lower = False
124
+ # for token in non_special_tokens:
125
+ # if token.isalpha() and token.islower():
126
+ # normalize_to_lower = True
127
+ # break
128
+
129
+ # map function to decode audio
130
+ def map_to_pred(batch):
131
+ prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
132
+
133
+ batch["prediction"] = prediction["text"]
134
+ batch["target"] = normalize_text(batch["sentence"], invalid_chars_regex)
135
+ return batch
136
+
137
+ # run inference on all examples
138
+ result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
139
+
140
+ # filtering out empty targets
141
+ result = result.filter(lambda example: example["target"] != "")
142
+
143
+ # compute and log_results
144
+ # do not change function below
145
+ log_results(result, args)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+
151
+ parser.add_argument("--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers")
152
+ parser.add_argument(
153
+ "--dataset",
154
+ type=str,
155
+ required=True,
156
+ help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
157
+ )
158
+ parser.add_argument("--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice")
159
+ parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
160
+ parser.add_argument(
161
+ "--chunk_length_s",
162
+ type=float,
163
+ default=None,
164
+ help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds.",
165
+ )
166
+ parser.add_argument(
167
+ "--stride_length_s",
168
+ type=float,
169
+ default=None,
170
+ help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds.",
171
+ )
172
+ parser.add_argument("--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis.")
173
+ parser.add_argument("--greedy", action="store_true", help="If defined, the LM will be ignored during inference.")
174
+ parser.add_argument(
175
+ "--device",
176
+ type=int,
177
+ default=None,
178
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
179
+ )
180
+ args = parser.parse_args()
181
+
182
+ main(args)