jstoone commited on
Commit
e640b8d
1 Parent(s): f383cae

✨ Add evalutaion scripts.

Browse files
Files changed (2) hide show
  1. run_eval.sh +7 -0
  2. run_eval_whisper_streaming.py +150 -0
run_eval.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python run_eval_whisper_streaming.py \
2
+ --model_id="jstoone/whisper-medium-da-cv11" \
3
+ --dataset="google/fleurs" \
4
+ --config="da_dk" \
5
+ --language="da" \
6
+ --device=0
7
+
run_eval_whisper_streaming.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from transformers import pipeline
4
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
5
+ from datasets import load_dataset, Audio
6
+ import evaluate
7
+
8
+ wer_metric = evaluate.load("wer")
9
+
10
+
11
+ def is_target_text_in_range(ref):
12
+ if ref.strip() == "ignore time segment in scoring":
13
+ return False
14
+ else:
15
+ return ref.strip() != ""
16
+
17
+
18
+ def get_text(sample):
19
+ if "text" in sample:
20
+ return sample["text"]
21
+ elif "sentence" in sample:
22
+ return sample["sentence"]
23
+ elif "normalized_text" in sample:
24
+ return sample["normalized_text"]
25
+ elif "transcript" in sample:
26
+ return sample["transcript"]
27
+ elif "transcription" in sample:
28
+ return sample["transcription"]
29
+ else:
30
+ raise ValueError(
31
+ f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
32
+ ".join{sample.keys()}. Ensure a text column name is present in the dataset."
33
+ )
34
+
35
+
36
+ whisper_norm = BasicTextNormalizer()
37
+
38
+
39
+ def normalise(batch):
40
+ batch["norm_text"] = whisper_norm(get_text(batch))
41
+ return batch
42
+
43
+
44
+ def data(dataset):
45
+ for i, item in enumerate(dataset):
46
+ yield {**item["audio"], "reference": item["norm_text"]}
47
+
48
+
49
+ def main(args):
50
+ batch_size = args.batch_size
51
+ whisper_asr = pipeline(
52
+ "automatic-speech-recognition", model=args.model_id, device=args.device
53
+ )
54
+
55
+ whisper_asr.model.config.forced_decoder_ids = (
56
+ whisper_asr.tokenizer.get_decoder_prompt_ids(
57
+ language=args.language, task="transcribe"
58
+ )
59
+ )
60
+
61
+ dataset = load_dataset(
62
+ args.dataset,
63
+ args.config,
64
+ split=args.split,
65
+ streaming=args.streaming,
66
+ use_auth_token=True,
67
+ )
68
+
69
+ # Only uncomment for debugging
70
+ dataset = dataset.take(args.max_eval_samples)
71
+
72
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
73
+ dataset = dataset.map(normalise)
74
+ dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
75
+
76
+ predictions = []
77
+ references = []
78
+
79
+ # run streamed inference
80
+ for out in whisper_asr(data(dataset), batch_size=batch_size):
81
+ predictions.append(whisper_norm(out["text"]))
82
+ references.append(out["reference"][0])
83
+
84
+ wer = wer_metric.compute(references=references, predictions=predictions)
85
+ wer = round(100 * wer, 2)
86
+
87
+ print("WER:", wer)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser()
92
+
93
+ parser.add_argument(
94
+ "--model_id",
95
+ type=str,
96
+ required=True,
97
+ help="Model identifier. Should be loadable with 🤗 Transformers",
98
+ )
99
+ parser.add_argument(
100
+ "--dataset",
101
+ type=str,
102
+ default="mozilla-foundation/common_voice_11_0",
103
+ help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
104
+ )
105
+ parser.add_argument(
106
+ "--config",
107
+ type=str,
108
+ required=True,
109
+ help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
110
+ )
111
+ parser.add_argument(
112
+ "--split",
113
+ type=str,
114
+ default="test",
115
+ help="Split of the dataset. *E.g.* `'test'`",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--device",
120
+ type=int,
121
+ default=-1,
122
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
123
+ )
124
+ parser.add_argument(
125
+ "--batch_size",
126
+ type=int,
127
+ default=16,
128
+ help="Number of samples to go through each streamed batch.",
129
+ )
130
+ parser.add_argument(
131
+ "--max_eval_samples",
132
+ type=int,
133
+ default=None,
134
+ help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
135
+ )
136
+ parser.add_argument(
137
+ "--streaming",
138
+ type=bool,
139
+ default=True,
140
+ help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
141
+ )
142
+ parser.add_argument(
143
+ "--language",
144
+ type=str,
145
+ required=True,
146
+ help="Two letter language code for the transcription language, e.g. use 'en' for English.",
147
+ )
148
+ args = parser.parse_args()
149
+
150
+ main(args)