marinone94 commited on
Commit
841dcc3
1 Parent(s): 12706a1

update model card

Browse files
Files changed (2) hide show
  1. README.md +19 -9
  2. run_eval_whisper_streaming.py +164 -0
README.md CHANGED
@@ -8,11 +8,15 @@ tags:
8
  - whisper-event
9
  - generated_from_trainer
10
  datasets:
11
- - mozilla-foundation/common_voice_11_0
 
 
12
  - babelbox/babelbox_voice
13
  - NbAiLab/NST
14
  - NbAiLab/NPSC
15
- - google/fleurs
 
 
16
  metrics:
17
  - wer
18
  model-index:
@@ -53,15 +57,9 @@ model-index:
53
  metrics:
54
  - name: Wer
55
  type: wer
56
- value: 37.02
57
-
58
-
59
-
60
  ---
61
 
62
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
63
- should probably proofread and complete it, then remove this comment. -->
64
-
65
  # Whisper Medium Nordic
66
 
67
  This model is a fine-tuned version of [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) on the [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) (sv-SE, da, nn-NO), the [babelbox/babelbox_voice](https://huggingface.co/datasets/babelbox/babelbox_voice) (Swedish radio), the [NbAiLab/NST](https://huggingface.co/datasets/NbAiLab/NST) (Norwegian radio), the [NbAiLab/NPSC](https://huggingface.co/datasets/NbAiLab/NPSC) (Norwegian parliament) and the [google/fleurs](https://huggingface.co/datasets/google/fleurs) (sv_se, da_dk, nb_no) datasets. The goal is to leverage transfer learning across Nordic languages, which have strong similarities.
@@ -122,3 +120,15 @@ The following hyperparameters were used during training:
122
  - Pytorch 1.13.1+cu117
123
  - Datasets 2.7.1.dev0
124
  - Tokenizers 0.13.2
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - whisper-event
9
  - generated_from_trainer
10
  datasets:
11
+ - mozilla-foundation/common_voice_11_0 (sv-SE)
12
+ - mozilla-foundation/common_voice_11_0 (da)
13
+ - mozilla-foundation/common_voice_11_0 (nn-NO)
14
  - babelbox/babelbox_voice
15
  - NbAiLab/NST
16
  - NbAiLab/NPSC
17
+ - google/fleurs (sv_se)
18
+ - google/fleurs (da_dk)
19
+ - google/fleurs (nb_no)
20
  metrics:
21
  - wer
22
  model-index:
 
57
  metrics:
58
  - name: Wer
59
  type: wer
60
+ value: 37.02
 
 
 
61
  ---
62
 
 
 
 
63
  # Whisper Medium Nordic
64
 
65
  This model is a fine-tuned version of [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) on the [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) (sv-SE, da, nn-NO), the [babelbox/babelbox_voice](https://huggingface.co/datasets/babelbox/babelbox_voice) (Swedish radio), the [NbAiLab/NST](https://huggingface.co/datasets/NbAiLab/NST) (Norwegian radio), the [NbAiLab/NPSC](https://huggingface.co/datasets/NbAiLab/NPSC) (Norwegian parliament) and the [google/fleurs](https://huggingface.co/datasets/google/fleurs) (sv_se, da_dk, nb_no) datasets. The goal is to leverage transfer learning across Nordic languages, which have strong similarities.
 
120
  - Pytorch 1.13.1+cu117
121
  - Datasets 2.7.1.dev0
122
  - Tokenizers 0.13.2
123
+
124
+ ### WandB run
125
+ https://wandb.ai/pn-aa/whisper/runs/xc70fbwv?workspace=user-emilio_marinone
126
+
127
+ ### Baseline model
128
+ This model finetuned whisper-medium, and here we can observe imrpovements when evaluated on CommonVoice 11 Swedish(sv-SE), Danish(da), and Norwegian (nn-NO) test splits.
129
+
130
+ | Language | Whisper Medium (WER) | Whisper Medium Nordic (WER) |
131
+ |:--------:|:--------------------:|:---------------------------:|
132
+ | sv-SE | 14.93 | 11.31 |
133
+ | da | 20.85 | 14.86 |
134
+ | nn-NO | 50.82 | 37.02
run_eval_whisper_streaming.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from transformers import pipeline
4
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
5
+ from datasets import load_dataset, Audio
6
+ import evaluate
7
+
8
+ wer_metric = evaluate.load("wer")
9
+
10
+
11
+ def is_target_text_in_range(ref):
12
+ if ref.strip() == "ignore time segment in scoring":
13
+ return False
14
+ else:
15
+ return ref.strip() != ""
16
+
17
+
18
+ def get_text(sample):
19
+ if "text" in sample:
20
+ return sample["text"]
21
+ elif "sentence" in sample:
22
+ return sample["sentence"]
23
+ elif "normalized_text" in sample:
24
+ return sample["normalized_text"]
25
+ elif "transcript" in sample:
26
+ return sample["transcript"]
27
+ elif "transcription" in sample:
28
+ return sample["transcription"]
29
+ else:
30
+ raise ValueError(
31
+ f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
32
+ ".join{sample.keys()}. Ensure a text column name is present in the dataset."
33
+ )
34
+
35
+
36
+ whisper_norm = BasicTextNormalizer()
37
+
38
+
39
+ def normalise(batch):
40
+ batch["norm_text"] = whisper_norm(get_text(batch))
41
+ return batch
42
+
43
+
44
+ def data(dataset):
45
+ for i, item in enumerate(dataset):
46
+ yield {**item["audio"], "reference": item["norm_text"]}
47
+
48
+
49
+ def main(args):
50
+ batch_size = args.batch_size
51
+ whisper_asr = pipeline(
52
+ "automatic-speech-recognition", model=args.model_id, device=args.device
53
+ )
54
+ print("pipe loaded")
55
+ whisper_asr.model.config.forced_decoder_ids = (
56
+ whisper_asr.tokenizer.get_decoder_prompt_ids(
57
+ language=args.language, task="transcribe"
58
+ )
59
+ )
60
+
61
+ dataset = load_dataset(
62
+ args.dataset,
63
+ args.config,
64
+ split=args.split,
65
+ streaming=args.streaming,
66
+ use_auth_token=True,
67
+ )
68
+ print("ds init")
69
+
70
+ # Only uncomment for debugging
71
+ dataset = dataset.take(args.max_eval_samples)
72
+
73
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
74
+ dataset = dataset.map(normalise)
75
+ dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
76
+
77
+ predictions = []
78
+ references = []
79
+ print("starting eval...")
80
+ # run streamed inference
81
+ for out in whisper_asr(data(dataset), batch_size=batch_size):
82
+ predictions.append(whisper_norm(out["text"]))
83
+ references.append(out["reference"][0])
84
+ print("computing wer")
85
+ wer = wer_metric.compute(references=references, predictions=predictions)
86
+ wer = round(100 * wer, 2)
87
+
88
+ print("WER:", wer)
89
+ print("pushing metric to hub")
90
+ evaluate.push_to_hub(
91
+ model_id=args.model_id,
92
+ metric_value=wer,
93
+ metric_type="wer",
94
+ metric_name="WER",
95
+ dataset_name=args.dataset,
96
+ dataset_type=args.dataset,
97
+ dataset_split=args.split,
98
+ dataset_config=args.config,
99
+ task_type="automatic-speech-recognition",
100
+ task_name="Automatic Speech Recognition"
101
+ )
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+
107
+ parser.add_argument(
108
+ "--model_id",
109
+ type=str,
110
+ required=True,
111
+ help="Model identifier. Should be loadable with 🤗 Transformers",
112
+ )
113
+ parser.add_argument(
114
+ "--dataset",
115
+ type=str,
116
+ default="mozilla-foundation/common_voice_11_0",
117
+ help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
118
+ )
119
+ parser.add_argument(
120
+ "--config",
121
+ type=str,
122
+ required=True,
123
+ help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
124
+ )
125
+ parser.add_argument(
126
+ "--split",
127
+ type=str,
128
+ default="test",
129
+ help="Split of the dataset. *E.g.* `'test'`",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--device",
134
+ type=int,
135
+ default=-1,
136
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
137
+ )
138
+ parser.add_argument(
139
+ "--batch_size",
140
+ type=int,
141
+ default=16,
142
+ help="Number of samples to go through each streamed batch.",
143
+ )
144
+ parser.add_argument(
145
+ "--max_eval_samples",
146
+ type=int,
147
+ default=None,
148
+ help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
149
+ )
150
+ parser.add_argument(
151
+ "--streaming",
152
+ type=bool,
153
+ default=True,
154
+ help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
155
+ )
156
+ parser.add_argument(
157
+ "--language",
158
+ type=str,
159
+ required=True,
160
+ help="Two letter language code for the transcription language, e.g. use 'en' for English.",
161
+ )
162
+ args = parser.parse_args()
163
+
164
+ main(args)