emilios commited on
Commit
52869cd
1 Parent(s): e318e7d

Upload 2 files

Browse files
Files changed (2) hide show
  1. run_inter_1gpu.sh +38 -0
  2. run_interleave.py +696 -0
run_inter_1gpu.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_interleave.py --model_name_or_path="emilios/whisper-medium-el" \
2
+ --dataset_name="mozilla-foundation/common_voice_11_0" \
3
+ --dataset_config_name="el" \
4
+ --language="greek" \
5
+ --train_split_name="train+validation" \
6
+ --eval_split_name="test" \
7
+ --model_index_name="Whisper Medium El - Greek One" \
8
+ --max_steps="5000" \
9
+ --output_dir="./" \
10
+ --per_device_train_batch_size="12" \
11
+ --gradient_accumulation_steps="2" \
12
+ --per_device_eval_batch_size="8" \
13
+ --logging_steps="25" \
14
+ --learning_rate="1e-5" \
15
+ --warmup_steps="500" \
16
+ --evaluation_strategy="steps" \
17
+ --eval_steps="1000" \
18
+ --save_strategy="steps" \
19
+ --save_steps="1000" \
20
+ --generation_max_length="225" \
21
+ --length_column_name="input_length" \
22
+ --max_duration_in_seconds="30" \
23
+ --text_column_name="sentence" \
24
+ --freeze_feature_encoder="False" \
25
+ --report_to="tensorboard" \
26
+ --metric_for_best_model="wer" \
27
+ --greater_is_better="False" \
28
+ --load_best_model_at_end \
29
+ --gradient_checkpointing \
30
+ --fp16 \
31
+ --overwrite_output_dir \
32
+ --do_train \
33
+ --do_eval \
34
+ --predict_with_generate \
35
+ --do_normalize_eval \
36
+ --streaming \
37
+ --use_auth_token \
38
+ --push_to_hub
run_interleave.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # You can also adapt this script for your own sequence to sequence speech
21
+ # recognition task. Pointers for this are left as comments.
22
+
23
+ import logging
24
+ import os
25
+ import sys
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import torch
31
+
32
+ #are these really needed?
33
+ import torch.distributed as dist
34
+ import torch.multiprocessing as mp
35
+ import torch.nn.functional as F
36
+ from torch.nn.parallel import DistributedDataParallel
37
+ from tqdm import tqdm
38
+
39
+
40
+ from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
41
+ from datasets import Audio, interleave_datasets, IterableDataset, load_dataset
42
+
43
+ dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"]
44
+ dataset_config_names = ["el", "el_gr"]
45
+ text_column_names = ["sentence", "transcription"]
46
+
47
+
48
+ from torch.utils.data import IterableDataset
49
+
50
+ import evaluate
51
+ import transformers
52
+ from transformers import (
53
+ AutoConfig,
54
+ AutoFeatureExtractor,
55
+ AutoModelForSpeechSeq2Seq,
56
+ AutoProcessor,
57
+ AutoTokenizer,
58
+ HfArgumentParser,
59
+ Seq2SeqTrainer,
60
+ Seq2SeqTrainingArguments,
61
+ TrainerCallback,
62
+ set_seed,
63
+ )
64
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
65
+ from transformers.trainer_pt_utils import IterableDatasetShard
66
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
67
+ from transformers.utils import check_min_version, send_example_telemetry
68
+ from transformers.utils.versions import require_version
69
+
70
+
71
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
72
+ check_min_version("4.25.0.dev0")
73
+
74
+ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+
79
+ @dataclass
80
+ class ModelArguments:
81
+ """
82
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
83
+ """
84
+
85
+ model_name_or_path: str = field(
86
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
87
+ )
88
+ config_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
93
+ )
94
+ feature_extractor_name: Optional[str] = field(
95
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ model_revision: str = field(
106
+ default="main",
107
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
+ )
109
+ use_auth_token: bool = field(
110
+ default=False,
111
+ metadata={
112
+ "help": (
113
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
114
+ "with private models)."
115
+ )
116
+ },
117
+ )
118
+ freeze_feature_encoder: bool = field(
119
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
120
+ )
121
+ freeze_encoder: bool = field(
122
+ default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
123
+ )
124
+ forced_decoder_ids: List[List[int]] = field(
125
+ default=None,
126
+ metadata={
127
+ "help": (
128
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
129
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
130
+ "will always be a token of index 123."
131
+ )
132
+ },
133
+ )
134
+ suppress_tokens: List[int] = field(
135
+ default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
136
+ )
137
+ model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
138
+
139
+
140
+ @dataclass
141
+ class DataTrainingArguments:
142
+ """
143
+ Arguments pertaining to what data we are going to input our model for training and eval.
144
+ """
145
+
146
+ dataset_name: str = field(
147
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
148
+ )
149
+ dataset_config_name: Optional[str] = field(
150
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
151
+ )
152
+ text_column: Optional[str] = field(
153
+ default=None,
154
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
155
+ )
156
+ max_train_samples: Optional[int] = field(
157
+ default=None,
158
+ metadata={
159
+ "help": (
160
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
161
+ "value if set."
162
+ )
163
+ },
164
+ )
165
+ max_eval_samples: Optional[int] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": (
169
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
170
+ "value if set."
171
+ )
172
+ },
173
+ )
174
+ audio_column_name: str = field(
175
+ default="audio",
176
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
177
+ )
178
+ text_column_name: str = field(
179
+ default="text",
180
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
181
+ )
182
+ max_duration_in_seconds: float = field(
183
+ default=20.0,
184
+ metadata={
185
+ "help": (
186
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
187
+ " 'max_duration_in_seconds`"
188
+ )
189
+ },
190
+ )
191
+ min_duration_in_seconds: float = field(
192
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
193
+ )
194
+ train_split_name: str = field(
195
+ default="train",
196
+ metadata={
197
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
198
+ },
199
+ )
200
+ eval_split_name: str = field(
201
+ default="test",
202
+ metadata={
203
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
204
+ },
205
+ )
206
+ do_lower_case: bool = field(
207
+ default=False,
208
+ metadata={"help": "Whether the target text should be lower cased."},
209
+ )
210
+ do_remove_punctuation: bool = field(
211
+ default=False,
212
+ metadata={"help": "Whether the target text should be striped of punctuation."},
213
+ )
214
+ do_normalize_eval: bool = field(
215
+ default=True,
216
+ metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
217
+ )
218
+ language: str = field(
219
+ default=None,
220
+ metadata={
221
+ "help": (
222
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
223
+ "only. For English speech recognition, it should be set to `None`."
224
+ )
225
+ },
226
+ )
227
+ task: str = field(
228
+ default="transcribe",
229
+ metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
230
+ )
231
+ shuffle_buffer_size: Optional[int] = field(
232
+ default=500,
233
+ metadata={
234
+ "help": (
235
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
236
+ "the closer it is to real offline shuffling."
237
+ )
238
+ },
239
+ )
240
+ streaming: bool = field(
241
+ default=True,
242
+ metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
243
+ )
244
+
245
+
246
+ @dataclass
247
+ class DataCollatorSpeechSeq2SeqWithPadding:
248
+ """
249
+ Data collator that will dynamically pad the inputs received.
250
+ Args:
251
+ processor ([`WhisperProcessor`])
252
+ The processor used for processing the data.
253
+ decoder_start_token_id (`int`)
254
+ The begin-of-sentence of the decoder.
255
+ """
256
+
257
+ processor: Any
258
+ decoder_start_token_id: int
259
+
260
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
261
+ # split inputs and labels since they have to be of different lengths and need
262
+ # different padding methods
263
+ model_input_name = self.processor.model_input_names[0]
264
+ input_features = [{model_input_name: feature[model_input_name]} for feature in features]
265
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
266
+
267
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
268
+
269
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
270
+
271
+ # replace padding with -100 to ignore loss correctly
272
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
273
+
274
+ # if bos token is appended in previous tokenization step,
275
+ # cut bos token here as it's append later anyways
276
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
277
+ labels = labels[:, 1:]
278
+
279
+ batch["labels"] = labels
280
+
281
+ return batch
282
+
283
+
284
+ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
285
+ """
286
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
287
+ each split is loaded individually and then splits combined by taking alternating examples from
288
+ each (interleaving).
289
+ """
290
+ if "+" in split:
291
+ # load multiple splits separated by the `+` symbol with streaming mode
292
+ dataset_splits = [
293
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
294
+ for split_name in split.split("+")
295
+ ]
296
+ # interleave multiple splits to form one dataset
297
+ interleaved_dataset = interleave_datasets(dataset_splits)
298
+ return interleaved_dataset
299
+ else:
300
+ # load a single split *with* streaming mode
301
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
302
+ return dataset
303
+
304
+ def load_multiple_streaming_datasets(
305
+ dataset_names: List,
306
+ dataset_config_names: List,
307
+ splits: Optional[List] = None,
308
+ text_column_names: Optional[List] = None,
309
+ sampling_rate: Optional[int] = 16000,
310
+ stopping_strategy: Optional[str] = "all_exhausted",
311
+ **kwargs
312
+ ) -> IterableDataset:
313
+
314
+ if len(dataset_names) != len(dataset_config_names):
315
+ raise ValueError(
316
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
317
+ f" {len(dataset_config_names)} configs."
318
+ )
319
+
320
+ if splits is not None and len(splits) != len(dataset_names):
321
+ raise ValueError(
322
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
323
+ )
324
+
325
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
326
+ raise ValueError(
327
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
328
+ f" {len(text_column_names)} text column names."
329
+ )
330
+
331
+ splits = splits if splits is not None else ["train" for i in range(len(dataset_names))]
332
+ text_column_names = (
333
+ text_column_names if text_column_names is not None else ["text" for i in range(len(dataset_names))]
334
+ )
335
+
336
+ all_datasets = []
337
+ # iterate over the datasets we want to interleave
338
+ for i, dataset_name in enumerate(dataset_names):
339
+ dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)
340
+ # resample to specified sampling rate
341
+ dataset = dataset.cast_column("audio", Audio(sampling_rate))
342
+ # normalise columns to ["audio", "sentence"]
343
+ if text_column_names[i] != "sentence":
344
+ dataset = dataset.rename_column(text_column_names[i], "sentence")
345
+ dataset = dataset.remove_columns(set(dataset.features.keys()) - set(["audio", "sentence"]))
346
+ all_datasets.append(dataset)
347
+
348
+ interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
349
+ return interleaved_dataset
350
+
351
+ ds = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)
352
+
353
+ def main():
354
+ # 1. Parse input arguments
355
+ # See all possible arguments in src/transformers/training_args.py
356
+ # or by passing the --help flag to this script.
357
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
358
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
359
+
360
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
361
+ # If we pass only one argument to the script and it's the path to a json file,
362
+ # let's parse it to get our arguments.
363
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
364
+ else:
365
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
366
+
367
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
368
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
369
+ send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
370
+
371
+ # 2. Setup logging
372
+ logging.basicConfig(
373
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
374
+ datefmt="%m/%d/%Y %H:%M:%S",
375
+ handlers=[logging.StreamHandler(sys.stdout)],
376
+ )
377
+ log_level = training_args.get_process_log_level()
378
+ logger.setLevel(log_level)
379
+ datasets.utils.logging.set_verbosity(log_level)
380
+ transformers.utils.logging.set_verbosity(log_level)
381
+ transformers.utils.logging.enable_default_handler()
382
+ transformers.utils.logging.enable_explicit_format()
383
+
384
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
385
+
386
+ # Log on each process the small summary:
387
+ logger.warning(
388
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
389
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
390
+ )
391
+ logger.info(f"Training/evaluation parameters {training_args}")
392
+
393
+ # Set the verbosity to info of the Transformers logger (on main process only):
394
+ if is_main_process(training_args.local_rank):
395
+ transformers.utils.logging.set_verbosity_info()
396
+ logger.info("Training/evaluation parameters %s", training_args)
397
+
398
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
399
+ last_checkpoint = None
400
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
401
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
402
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
403
+ raise ValueError(
404
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
405
+ "Use --overwrite_output_dir to overcome."
406
+ )
407
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
408
+ logger.info(
409
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
410
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
411
+ )
412
+
413
+ # Set seed before initializing model.
414
+ set_seed(training_args.seed)
415
+
416
+ # 4. Load dataset
417
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
418
+ # """
419
+ # if training_args.do_train:
420
+ # raw_datasets["train"] = load_maybe_streaming_dataset(
421
+ # data_args.dataset_name,
422
+ # data_args.dataset_config_name,
423
+ # split=data_args.train_split_name,
424
+ # use_auth_token=True if model_args.use_auth_token else None,
425
+ # streaming=data_args.streaming,
426
+ # )
427
+ #"""
428
+ if training_args.do_train:
429
+ raw_datasets["train"] = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)
430
+
431
+ if training_args.do_eval:
432
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
433
+ data_args.dataset_name,
434
+ data_args.dataset_config_name,
435
+ split=data_args.eval_split_name,
436
+ use_auth_token=True if model_args.use_auth_token else None,
437
+ streaming=data_args.streaming,
438
+ )
439
+
440
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
441
+
442
+ if data_args.audio_column_name not in raw_datasets_features:
443
+ raise ValueError(
444
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
445
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
446
+ f"{', '.join(raw_datasets_features)}."
447
+ )
448
+
449
+ if data_args.text_column_name not in raw_datasets_features:
450
+ raise ValueError(
451
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
452
+ "Make sure to set `--text_column_name` to the correct text column - one of "
453
+ f"{', '.join(raw_datasets_features)}."
454
+ )
455
+
456
+ # 5. Load pretrained model, tokenizer, and feature extractor
457
+ #
458
+ # Distributed training:
459
+ # The .from_pretrained methods guarantee that only one local process can concurrently
460
+ config = AutoConfig.from_pretrained(
461
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
462
+ cache_dir=model_args.cache_dir,
463
+ revision=model_args.model_revision,
464
+ use_auth_token=True if model_args.use_auth_token else None,
465
+ )
466
+
467
+ config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
468
+
469
+ if training_args.gradient_checkpointing:
470
+ config.update({"use_cache": False})
471
+
472
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
473
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
474
+ cache_dir=model_args.cache_dir,
475
+ revision=model_args.model_revision,
476
+ use_auth_token=True if model_args.use_auth_token else None,
477
+ )
478
+ tokenizer = AutoTokenizer.from_pretrained(
479
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
480
+ cache_dir=model_args.cache_dir,
481
+ use_fast=model_args.use_fast_tokenizer,
482
+ revision=model_args.model_revision,
483
+ use_auth_token=True if model_args.use_auth_token else None,
484
+ )
485
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
486
+ model_args.model_name_or_path,
487
+ config=config,
488
+ cache_dir=model_args.cache_dir,
489
+ revision=model_args.model_revision,
490
+ use_auth_token=True if model_args.use_auth_token else None,
491
+ )
492
+
493
+ if model.config.decoder_start_token_id is None:
494
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
495
+
496
+ if model_args.freeze_feature_encoder:
497
+ model.freeze_feature_encoder()
498
+
499
+ if model_args.freeze_encoder:
500
+ model.freeze_encoder()
501
+
502
+ if data_args.language is not None:
503
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
504
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
505
+
506
+ # 6. Resample speech dataset if necessary
507
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
508
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
509
+ raw_datasets = raw_datasets.cast_column(
510
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
511
+ )
512
+
513
+ # 7. Preprocessing the datasets.
514
+ # We need to read the audio files as arrays and tokenize the targets.
515
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
516
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
517
+ audio_column_name = data_args.audio_column_name
518
+ text_column_name = data_args.text_column_name
519
+ model_input_name = feature_extractor.model_input_names[0]
520
+ do_lower_case = data_args.do_lower_case
521
+ do_remove_punctuation = data_args.do_remove_punctuation
522
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
523
+
524
+ if data_args.max_train_samples is not None:
525
+ raw_datasets["train"] = (
526
+ raw_datasets["train"].take(data_args.max_train_samples)
527
+ if data_args.streaming
528
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
529
+ )
530
+
531
+ if data_args.max_eval_samples is not None:
532
+ raw_datasets["eval"] = (
533
+ raw_datasets["eval"].take(data_args.max_eval_samples)
534
+ if data_args.streaming
535
+ else raw_datasets["eval"].select(range(data_args.max_eval_samples))
536
+ )
537
+
538
+ def prepare_dataset(batch):
539
+ # process audio
540
+ sample = batch[audio_column_name]
541
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
542
+ # process audio length
543
+ batch[model_input_name] = inputs.get(model_input_name)[0]
544
+ batch["input_length"] = len(sample["array"])
545
+
546
+ # process targets
547
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
548
+ if do_remove_punctuation:
549
+ input_str = normalizer(input_str).strip()
550
+ batch["labels"] = tokenizer(input_str).input_ids
551
+ return batch
552
+
553
+ with training_args.main_process_first(desc="dataset map pre-processing"):
554
+ vectorized_datasets = raw_datasets.map(
555
+ prepare_dataset,
556
+ remove_columns=raw_datasets_features,
557
+ ).with_format("torch")
558
+
559
+ if training_args.do_train and data_args.streaming:
560
+ # manually shuffle if streaming (done by the trainer for non-streaming)
561
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
562
+ buffer_size=data_args.shuffle_buffer_size,
563
+ seed=training_args.seed,
564
+ )
565
+
566
+ # filter training data that is shorter than min_input_length or longer than
567
+ # max_input_length
568
+ def is_audio_in_length_range(length):
569
+ return min_input_length < length < max_input_length
570
+
571
+ if training_args.do_train:
572
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
573
+ is_audio_in_length_range,
574
+ input_columns=["input_length"],
575
+ )
576
+
577
+ # 8. Load Metric
578
+ metric = evaluate.load("wer")
579
+ do_normalize_eval = data_args.do_normalize_eval
580
+
581
+ def compute_metrics(pred):
582
+ pred_ids = pred.predictions
583
+
584
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
585
+
586
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
587
+ # we do not want to group tokens when computing the metrics
588
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
589
+
590
+ if do_normalize_eval:
591
+ pred_str = [normalizer(pred) for pred in pred_str]
592
+ label_str = [normalizer(label) for label in label_str]
593
+ # filtering step to only evaluate the samples that correspond to non-zero references:
594
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
595
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
596
+
597
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
598
+
599
+ return {"wer": wer}
600
+
601
+ # 9. Create a single speech processor
602
+ if is_main_process(training_args.local_rank):
603
+ # save feature extractor, tokenizer and config
604
+ feature_extractor.save_pretrained(training_args.output_dir)
605
+ tokenizer.save_pretrained(training_args.output_dir)
606
+ config.save_pretrained(training_args.output_dir)
607
+
608
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
609
+
610
+ # 10. Define data collator
611
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
612
+ processor=processor,
613
+ decoder_start_token_id=model.config.decoder_start_token_id,
614
+ )
615
+
616
+ # 11. Configure Trainer
617
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
618
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
619
+ class ShuffleCallback(TrainerCallback):
620
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
621
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
622
+ pass # set_epoch() is handled by the Trainer
623
+ elif isinstance(train_dataloader.dataset, IterableDataset):
624
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
625
+
626
+ # Initialize Trainer
627
+ trainer = Seq2SeqTrainer(
628
+ model=model,
629
+ args=training_args,
630
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
631
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
632
+ tokenizer=feature_extractor,
633
+ data_collator=data_collator,
634
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
635
+ callbacks=[ShuffleCallback()] if data_args.streaming else None,
636
+ )
637
+
638
+ # 12. Training
639
+ if training_args.do_train:
640
+ checkpoint = None
641
+ if training_args.resume_from_checkpoint is not None:
642
+ checkpoint = training_args.resume_from_checkpoint
643
+ elif last_checkpoint is not None:
644
+ checkpoint = last_checkpoint
645
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
646
+ trainer.save_model() # Saves the feature extractor too for easy upload
647
+
648
+ metrics = train_result.metrics
649
+ if data_args.max_train_samples:
650
+ metrics["train_samples"] = data_args.max_train_samples
651
+ trainer.log_metrics("train", metrics)
652
+ trainer.save_metrics("train", metrics)
653
+ trainer.save_state()
654
+
655
+ # 13. Evaluation
656
+ results = {}
657
+ if training_args.do_eval:
658
+ logger.info("*** Evaluate ***")
659
+ metrics = trainer.evaluate(
660
+ metric_key_prefix="eval",
661
+ max_length=training_args.generation_max_length,
662
+ num_beams=training_args.generation_num_beams,
663
+ )
664
+ if data_args.max_eval_samples:
665
+ metrics["eval_samples"] = data_args.max_eval_samples
666
+
667
+ trainer.log_metrics("eval", metrics)
668
+ trainer.save_metrics("eval", metrics)
669
+
670
+ # 14. Write Training Stats
671
+ kwargs = {
672
+ "finetuned_from": model_args.model_name_or_path,
673
+ "tasks": "automatic-speech-recognition",
674
+ "tags": "whisper-event",
675
+ }
676
+ if data_args.dataset_name is not None:
677
+ kwargs["dataset_tags"] = data_args.dataset_name
678
+ if data_args.dataset_config_name is not None:
679
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
680
+ else:
681
+ kwargs["dataset"] = data_args.dataset_name
682
+ if "common_voice" in data_args.dataset_name:
683
+ kwargs["language"] = data_args.dataset_config_name[:2]
684
+ if model_args.model_index_name is not None:
685
+ kwargs["model_name"] = model_args.model_index_name
686
+
687
+ if training_args.push_to_hub:
688
+ trainer.push_to_hub(**kwargs)
689
+ else:
690
+ trainer.create_model_card(**kwargs)
691
+
692
+ return results
693
+
694
+
695
+ if __name__ == "__main__":
696
+ main()