anantoj commited on
Commit
8baddd5
β€’
1 Parent(s): 077ffa0

Evaluation results

Browse files
eval.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+ from typing import Dict
4
+
5
+ import torch
6
+ from datasets import Audio, Dataset, load_dataset, load_metric
7
+
8
+ from transformers import AutoFeatureExtractor, pipeline
9
+
10
+
11
+ def log_results(result: Dataset, args: Dict[str, str]):
12
+ """DO NOT CHANGE. This function computes and logs the result metrics."""
13
+
14
+ log_outputs = args.log_outputs
15
+ dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
16
+
17
+ # load metric
18
+ wer = load_metric("wer")
19
+ cer = load_metric("cer")
20
+
21
+ # compute metrics
22
+ wer_result = wer.compute(
23
+ references=result["target"], predictions=result["prediction"]
24
+ )
25
+ cer_result = cer.compute(
26
+ references=result["target"], predictions=result["prediction"]
27
+ )
28
+
29
+ # print & log results
30
+ result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
31
+ print(result_str)
32
+
33
+ with open(f"{dataset_id}_eval_results.txt", "w") as f:
34
+ f.write(result_str)
35
+
36
+ # log all results in text file. Possibly interesting for analysis
37
+ if log_outputs is not None:
38
+ pred_file = f"log_{dataset_id}_predictions.txt"
39
+ target_file = f"log_{dataset_id}_targets.txt"
40
+
41
+ with open(pred_file, "w") as p, open(target_file, "w") as t:
42
+
43
+ # mapping function to write output
44
+ def write_to_file(batch, i):
45
+ p.write(f"{i}" + "\n")
46
+ p.write(batch["prediction"] + "\n")
47
+ t.write(f"{i}" + "\n")
48
+ t.write(batch["target"] + "\n")
49
+
50
+ result.map(write_to_file, with_indices=True)
51
+
52
+
53
+ def normalize_text(text: str) -> str:
54
+ """DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
55
+
56
+ chars_to_ignore_regex = '[,?.!\-\;\:"β€œ%β€˜β€οΏ½β€”β€™β€¦β€“]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
57
+
58
+ text = re.sub(chars_to_ignore_regex, "", text.lower())
59
+
60
+ # In addition, we can normalize the target text, e.g. removing new lines characters etc...
61
+ # note that order is important here!
62
+ token_sequences_to_ignore = ["\n\n", "\n", " ", " "]
63
+
64
+ for t in token_sequences_to_ignore:
65
+ text = " ".join(text.split(t))
66
+
67
+ return text
68
+
69
+
70
+ def main(args):
71
+ # load dataset
72
+ dataset = load_dataset(
73
+ args.dataset, args.config, split=args.split, use_auth_token=True
74
+ )
75
+
76
+ # for testing: only process the first two examples as a test
77
+ # dataset = dataset.select(range(10))
78
+
79
+ # load processor
80
+ feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
81
+ sampling_rate = feature_extractor.sampling_rate
82
+
83
+ # resample audio
84
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
85
+
86
+ # load eval pipeline
87
+ if args.device is None:
88
+ args.device = 0 if torch.cuda.is_available() else -1
89
+ asr = pipeline(
90
+ "automatic-speech-recognition", model=args.model_id, device=args.device
91
+ )
92
+
93
+ # map function to decode audio
94
+ def map_to_pred(batch):
95
+ prediction = asr(
96
+ batch["audio"]["array"],
97
+ chunk_length_s=args.chunk_length_s,
98
+ stride_length_s=args.stride_length_s,
99
+ )
100
+
101
+ batch["prediction"] = prediction["text"]
102
+ batch["target"] = normalize_text(batch[args.text_column_name])
103
+ return batch
104
+
105
+ # run inference on all examples
106
+ result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
107
+
108
+ # compute and log_results
109
+ # do not change function below
110
+ log_results(result, args)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+
116
+ parser.add_argument(
117
+ "--model_id",
118
+ type=str,
119
+ required=True,
120
+ help="Model identifier. Should be loadable with πŸ€— Transformers",
121
+ )
122
+ parser.add_argument(
123
+ "--dataset",
124
+ type=str,
125
+ required=True,
126
+ help="Dataset name to evaluate the `model_id`. Should be loadable with πŸ€— Datasets",
127
+ )
128
+ parser.add_argument(
129
+ "--config",
130
+ type=str,
131
+ required=True,
132
+ help="Config of the dataset. *E.g.* `'en'` for Common Voice",
133
+ )
134
+ parser.add_argument(
135
+ "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
136
+ )
137
+ parser.add_argument(
138
+ "--text_column_name",
139
+ type=str,
140
+ default="text",
141
+ help="The name of the dataset column containing the text data. Defaults to 'text'",
142
+ )
143
+ parser.add_argument(
144
+ "--chunk_length_s",
145
+ type=float,
146
+ default=None,
147
+ help="Chunk length in seconds. Defaults to 5 seconds.",
148
+ )
149
+ parser.add_argument(
150
+ "--stride_length_s",
151
+ type=float,
152
+ default=None,
153
+ help="Stride of the audio chunks. Defaults to 1 second.",
154
+ )
155
+ parser.add_argument(
156
+ "--log_outputs",
157
+ action="store_true",
158
+ help="If defined, write outputs to log file for analysis.",
159
+ )
160
+ parser.add_argument(
161
+ "--device",
162
+ type=int,
163
+ default=None,
164
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
165
+ )
166
+ args = parser.parse_args()
167
+
168
+ main(args)
log_speech-recognition-community-v2_dev_data_ko_validation_predictions.txt ADDED
The diff for this file is too large to render. See raw diff
 
log_speech-recognition-community-v2_dev_data_ko_validation_targets.txt ADDED
The diff for this file is too large to render. See raw diff
 
nohup.out CHANGED
The diff for this file is too large to render. See raw diff
 
run_eval_dev.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ python eval.py \
2
+ --model_id="anantoj/wav2vec2-xls-r-1b-korean" \
3
+ --dataset="speech-recognition-community-v2/dev_data" \
4
+ --config="ko" \
5
+ --split="validation" \
6
+ --text_column_name="sentence" \
7
+ --chunk_length_s="10" \
8
+ --stride_length_s="2" \
9
+ --log_outputs \
10
+ --device="0"
speech-recognition-community-v2_dev_data_ko_validation_eval_results.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ WER: 0.820667072799269
2
+ CER: 0.4212443152869247