pere commited on
Commit
e9502fe
1 Parent(s): c531109

updated with XLA hook

Browse files
Files changed (1) hide show
  1. run_whisper.py +139 -130
run_whisper.py CHANGED
@@ -50,138 +50,147 @@ class DataCollatorSpeechSeq2SeqWithPadding:
50
 
51
  return batch
52
 
 
 
 
 
 
53
 
54
- # Metrics
55
- def compute_metrics(pred):
56
- pred_ids = pred.predictions
57
- label_ids = pred.label_ids
58
 
59
- # replace -100 with the pad_token_id
60
- label_ids[label_ids == -100] = tokenizer.pad_token_id
 
61
 
62
- # we do not want to group tokens when computing the metrics
63
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
64
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
65
 
66
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
67
 
68
- return {"wer": wer}
69
-
70
- # Prepare dataset
71
-
72
-
73
- def prepare_dataset(batch):
74
- # load and resample audio data from 48 to 16kHz
75
- audio = batch["audio"]
76
-
77
- # compute log-Mel input features from input audio array
78
- batch["input_features"] = feature_extractor(
79
- audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
80
-
81
- # encode target text to label ids
82
- batch["labels"] = tokenizer(batch["sentence"]).input_ids
83
- return batch
84
-
85
-
86
- # Whisper Trainin Script
87
-
88
- # Map the source and target columns
89
- # Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
90
- source = "audio"
91
- target = "sentence"
92
-
93
-
94
- # Load a sample dataset
95
- speech_data = DatasetDict()
96
-
97
- # Examples
98
- # speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True)
99
- # speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True)
100
- # speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
101
- #speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
102
-
103
- # The smallest dataset I found
104
- speech_data["train"] = load_dataset(
105
- "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
106
- speech_data["test"] = load_dataset(
107
- "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
108
-
109
-
110
- # Rename columns
111
- if "audio" not in speech_data.column_names["train"]:
112
- speech_data = speech_data.rename_column(source, "audio")
113
-
114
- if "sentence" not in speech_data.column_names["train"]:
115
- speech_data = speech_data.rename_column(target, "sentence")
116
-
117
- # Remove not needed columns - Not really sure if this is necessary
118
- remove_list = [i for i in speech_data.column_names["train"]
119
- if i not in ["audio", "sentence"]]
120
-
121
- speech_data = speech_data.remove_columns(remove_list)
122
-
123
- # Initialise
124
- feature_extractor = WhisperFeatureExtractor.from_pretrained(
125
- "openai/whisper-small")
126
- tokenizer = WhisperTokenizer.from_pretrained(
127
- "openai/whisper-small", language="Norwegian", task="transcribe")
128
- processor = WhisperProcessor.from_pretrained(
129
- "openai/whisper-small", language="Norwegian", task="transcribe")
130
- data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
131
-
132
- # Prepare data
133
- speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000))
134
- speech_data = speech_data.map(
135
- prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
136
-
137
- # Metrics
138
- metric = evaluate.load("wer")
139
-
140
- # Initialise a Pretrained model
141
- # We need to set use_cache=False here if we want to use gradient accumulation
142
- model = WhisperForConditionalGeneration.from_pretrained(
143
- "openai/whisper-small", use_cache=False)
144
-
145
- # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
146
- model.config.forced_decoder_ids = None
147
- model.config.suppress_tokens = []
148
-
149
- # Training arguments
150
- training_args = Seq2SeqTrainingArguments(
151
- output_dir="./whisper-small-no-test", # change to a repo name of your choice
152
- # Use at least 16 is reasonable. This is just for the test on Ficino
153
- per_device_train_batch_size=4,
154
- gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
155
- learning_rate=1e-5,
156
- warmup_steps=500,
157
- max_steps=1000, # Changed from 4000
158
- gradient_checkpointing=True,
159
- fp16=True,
160
- group_by_length=True,
161
- evaluation_strategy="steps",
162
- per_device_eval_batch_size=8,
163
- predict_with_generate=True,
164
- generation_max_length=225,
165
- save_steps=500,
166
- eval_steps=500,
167
- logging_steps=25,
168
- report_to=["tensorboard"],
169
- load_best_model_at_end=True,
170
- metric_for_best_model="wer",
171
- greater_is_better=False,
172
- push_to_hub=True,
173
- )
174
-
175
- trainer = Seq2SeqTrainer(
176
- args=training_args,
177
- model=model,
178
- train_dataset=speech_data["train"],
179
- eval_dataset=speech_data["test"],
180
- data_collator=data_collator,
181
- compute_metrics=compute_metrics,
182
- tokenizer=processor.feature_extractor,
183
- )
184
-
185
-
186
- # Start training
187
- trainer.train()
 
 
 
 
 
 
 
 
50
 
51
  return batch
52
 
53
+ def main():
54
+ # Metrics
55
+ def compute_metrics(pred):
56
+ pred_ids = pred.predictions
57
+ label_ids = pred.label_ids
58
 
59
+ # replace -100 with the pad_token_id
60
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
 
 
61
 
62
+ # we do not want to group tokens when computing the metrics
63
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
64
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
65
 
66
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
 
 
67
 
68
+ return {"wer": wer}
69
 
70
+ # Prepare dataset
71
+
72
+
73
+ def prepare_dataset(batch):
74
+ # load and resample audio data from 48 to 16kHz
75
+ audio = batch["audio"]
76
+
77
+ # compute log-Mel input features from input audio array
78
+ batch["input_features"] = feature_extractor(
79
+ audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
80
+
81
+ # encode target text to label ids
82
+ batch["labels"] = tokenizer(batch["sentence"]).input_ids
83
+ return batch
84
+
85
+
86
+ # Whisper Trainin Script
87
+
88
+ # Map the source and target columns
89
+ # Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
90
+ source = "audio"
91
+ target = "sentence"
92
+
93
+
94
+ # Load a sample dataset
95
+ speech_data = DatasetDict()
96
+
97
+ # Examples
98
+ # speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True)
99
+ # speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True)
100
+ # speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
101
+ #speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
102
+
103
+ # The smallest dataset I found
104
+ speech_data["train"] = load_dataset(
105
+ "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
106
+ speech_data["test"] = load_dataset(
107
+ "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
108
+
109
+
110
+ # Rename columns
111
+ if "audio" not in speech_data.column_names["train"]:
112
+ speech_data = speech_data.rename_column(source, "audio")
113
+
114
+ if "sentence" not in speech_data.column_names["train"]:
115
+ speech_data = speech_data.rename_column(target, "sentence")
116
+
117
+ # Remove not needed columns - Not really sure if this is necessary
118
+ remove_list = [i for i in speech_data.column_names["train"]
119
+ if i not in ["audio", "sentence"]]
120
+
121
+ speech_data = speech_data.remove_columns(remove_list)
122
+
123
+ # Initialise
124
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
125
+ "openai/whisper-small")
126
+ tokenizer = WhisperTokenizer.from_pretrained(
127
+ "openai/whisper-small", language="Norwegian", task="transcribe")
128
+ processor = WhisperProcessor.from_pretrained(
129
+ "openai/whisper-small", language="Norwegian", task="transcribe")
130
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
131
+
132
+ # Prepare data
133
+ speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000))
134
+ speech_data = speech_data.map(
135
+ prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
136
+
137
+ # Metrics
138
+ metric = evaluate.load("wer")
139
+
140
+ # Initialise a Pretrained model
141
+ # We need to set use_cache=False here if we want to use gradient accumulation
142
+ model = WhisperForConditionalGeneration.from_pretrained(
143
+ "openai/whisper-small", use_cache=False)
144
+
145
+ # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
146
+ model.config.forced_decoder_ids = None
147
+ model.config.suppress_tokens = []
148
+
149
+ # Training arguments
150
+ training_args = Seq2SeqTrainingArguments(
151
+ output_dir="../whisper-test", # change to a repo name of your choice
152
+ # Use at least 16 is reasonable. This is just for the test on Ficino
153
+ per_device_train_batch_size=4,
154
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
155
+ learning_rate=1e-5,
156
+ warmup_steps=500,
157
+ max_steps=1000, # Changed from 4000
158
+ gradient_checkpointing=True,
159
+ fp16=True,
160
+ group_by_length=True,
161
+ evaluation_strategy="steps",
162
+ per_device_eval_batch_size=8,
163
+ predict_with_generate=True,
164
+ generation_max_length=225,
165
+ save_steps=500,
166
+ eval_steps=500,
167
+ logging_steps=25,
168
+ report_to=["tensorboard"],
169
+ load_best_model_at_end=True,
170
+ metric_for_best_model="wer",
171
+ greater_is_better=False,
172
+ push_to_hub=True,
173
+ )
174
+
175
+ trainer = Seq2SeqTrainer(
176
+ args=training_args,
177
+ model=model,
178
+ train_dataset=speech_data["train"],
179
+ eval_dataset=speech_data["test"],
180
+ data_collator=data_collator,
181
+ compute_metrics=compute_metrics,
182
+ tokenizer=processor.feature_extractor,
183
+ )
184
+
185
+
186
+ # Start training
187
+ trainer.train()
188
+
189
+
190
+ def _mp_fn(index):
191
+ # For xla_spawn (TPUs)
192
+ main()
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()