patrickvonplaten commited on
Commit
9e495b5
1 Parent(s): a228ddf
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +633 -0
run_speech_recognition_ctc.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
+
18
+ import functools
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ import bitsandbytes as bnb
25
+ from dataclasses import dataclass, field
26
+ from typing import Dict, List, Optional, Union
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+
33
+ import transformers
34
+ from transformers import (
35
+ AutoConfig,
36
+ AutoFeatureExtractor,
37
+ AutoModelForCTC,
38
+ AutoTokenizer,
39
+ HfArgumentParser,
40
+ Trainer,
41
+ TrainingArguments,
42
+ Wav2Vec2Processor,
43
+ set_seed,
44
+ )
45
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
46
+ from transformers.utils import check_min_version
47
+ from transformers.utils.versions import require_version
48
+
49
+
50
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
+ check_min_version("4.13.0.dev0")
52
+
53
+ require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ def list_field(default=None, metadata=None):
60
+ return field(default_factory=lambda: default, metadata=metadata)
61
+
62
+
63
+ @dataclass
64
+ class ModelArguments:
65
+ """
66
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
67
+ """
68
+
69
+ model_name_or_path: str = field(
70
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
71
+ )
72
+ cache_dir: Optional[str] = field(
73
+ default=None,
74
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
75
+ )
76
+ freeze_feature_extractor: Optional[bool] = field(
77
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
78
+ )
79
+ attention_dropout: Optional[float] = field(
80
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
81
+ )
82
+ activation_dropout: Optional[float] = field(
83
+ default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
84
+ )
85
+ feat_proj_dropout: Optional[float] = field(
86
+ default=0.0, metadata={"help": "The dropout ratio for the projected features."}
87
+ )
88
+ hidden_dropout: Optional[float] = field(
89
+ default=0.0,
90
+ metadata={
91
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
92
+ },
93
+ )
94
+ final_dropout: Optional[float] = field(
95
+ default=0.0,
96
+ metadata={"help": "The dropout probability for the final projection layer."},
97
+ )
98
+ mask_time_prob: Optional[float] = field(
99
+ default=0.05,
100
+ metadata={
101
+ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
102
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
103
+ "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
104
+ },
105
+ )
106
+ layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
107
+ ctc_loss_reduction: Optional[str] = field(
108
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
109
+ )
110
+
111
+
112
+ @dataclass
113
+ class DataTrainingArguments:
114
+ """
115
+ Arguments pertaining to what data we are going to input our model for training and eval.
116
+
117
+ Using `HfArgumentParser` we can turn this class
118
+ into argparse arguments to be able to specify them on
119
+ the command line.
120
+ """
121
+
122
+ dataset_name: str = field(
123
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
+ )
125
+ dataset_config_name: Optional[str] = field(
126
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
127
+ )
128
+ train_split_name: Optional[str] = field(
129
+ default="train+validation",
130
+ metadata={
131
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
132
+ },
133
+ )
134
+ eval_split_name: Optional[str] = field(
135
+ default="test",
136
+ metadata={
137
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
138
+ },
139
+ )
140
+ audio_column_name: Optional[str] = field(
141
+ default="audio",
142
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
143
+ )
144
+ text_column_name: Optional[str] = field(
145
+ default="text",
146
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
147
+ )
148
+ overwrite_cache: bool = field(
149
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
150
+ )
151
+ preprocessing_num_workers: Optional[int] = field(
152
+ default=None,
153
+ metadata={"help": "The number of processes to use for the preprocessing."},
154
+ )
155
+ max_train_samples: Optional[int] = field(
156
+ default=None,
157
+ metadata={
158
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
159
+ "value if set."
160
+ },
161
+ )
162
+ max_eval_samples: Optional[int] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
166
+ "value if set."
167
+ },
168
+ )
169
+ chars_to_ignore: Optional[List[str]] = list_field(
170
+ default=None,
171
+ metadata={"help": "A list of characters to remove from the transcripts."},
172
+ )
173
+ max_duration_in_seconds: Optional[float] = field(
174
+ default=20.0,
175
+ metadata={
176
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
177
+ },
178
+ )
179
+ min_duration_in_seconds: Optional[float] = field(
180
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
181
+ )
182
+ preprocessing_only: Optional[bool] = field(
183
+ default=False,
184
+ metadata={
185
+ "help": "Whether to only do data preprocessing and skip training. "
186
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
187
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
188
+ "so that the cached datasets can consequently be loaded in distributed training"
189
+ },
190
+ )
191
+ use_auth_token: Optional[bool] = field(
192
+ default=False,
193
+ metadata={
194
+ "help": "If :obj:`True`, will use the token generated when running"
195
+ ":obj:`transformers-cli logiin as HTTP bearer authorization for remote files."
196
+ },
197
+ )
198
+
199
+
200
+ @dataclass
201
+ class DataCollatorCTCWithPadding:
202
+ """
203
+ Data collator that will dynamically pad the inputs received.
204
+ Args:
205
+ processor (:class:`~transformers.Wav2Vec2Processor`)
206
+ The processor used for proccessing the data.
207
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
208
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
209
+ among:
210
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
211
+ sequence if provided).
212
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
213
+ maximum acceptable input length for the model if that argument is not provided.
214
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
215
+ different lengths).
216
+ max_length (:obj:`int`, `optional`):
217
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
218
+ max_length_labels (:obj:`int`, `optional`):
219
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
220
+ pad_to_multiple_of (:obj:`int`, `optional`):
221
+ If set will pad the sequence to a multiple of the provided value.
222
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
223
+ 7.5 (Volta).
224
+ """
225
+
226
+ processor: Wav2Vec2Processor
227
+ padding: Union[bool, str] = "longest"
228
+ pad_to_multiple_of: Optional[int] = None
229
+ pad_to_multiple_of_labels: Optional[int] = None
230
+
231
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
232
+ # split inputs and labels since they have to be of different lenghts and need
233
+ # different padding methods
234
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
235
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
236
+
237
+ batch = self.processor.pad(
238
+ input_features,
239
+ padding=self.padding,
240
+ pad_to_multiple_of=self.pad_to_multiple_of,
241
+ return_tensors="pt",
242
+ )
243
+
244
+ with self.processor.as_target_processor():
245
+ labels_batch = self.processor.pad(
246
+ label_features,
247
+ padding=self.padding,
248
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
249
+ return_tensors="pt",
250
+ )
251
+
252
+ # replace padding with -100 to ignore loss correctly
253
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
254
+
255
+ batch["labels"] = labels
256
+
257
+ return batch
258
+
259
+
260
+ def create_vocabulary_from_data(datasets: DatasetDict):
261
+ # Given training and test labels create vocabulary
262
+ def extract_all_chars(batch):
263
+ all_text = " ".join(batch["target_text"])
264
+ vocab = list(set(all_text))
265
+ return {"vocab": [vocab], "all_text": [all_text]}
266
+
267
+ vocabs = datasets.map(
268
+ extract_all_chars,
269
+ batched=True,
270
+ batch_size=-1,
271
+ keep_in_memory=True,
272
+ remove_columns=datasets["train"].column_names,
273
+ )
274
+
275
+ # take union of all unique characters in each dataset
276
+ vocab_set = functools.reduce(
277
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
278
+ )
279
+
280
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
281
+
282
+ # replace white space with delimiter token
283
+ vocab_dict["|"] = vocab_dict[" "]
284
+ del vocab_dict[" "]
285
+
286
+ # add unk and pad token
287
+ vocab_dict["[UNK]"] = len(vocab_dict)
288
+ vocab_dict["[PAD]"] = len(vocab_dict)
289
+
290
+ return vocab_dict
291
+
292
+
293
+ def main():
294
+ # See all possible arguments in src/transformers/training_args.py
295
+ # or by passing the --help flag to this script.
296
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
297
+
298
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
299
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
300
+ # If we pass only one argument to the script and it's the path to a json file,
301
+ # let's parse it to get our arguments.
302
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
303
+ else:
304
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
305
+
306
+ # Detecting last checkpoint.
307
+ last_checkpoint = None
308
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
309
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
310
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
311
+ raise ValueError(
312
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
313
+ "Use --overwrite_output_dir to overcome."
314
+ )
315
+ elif last_checkpoint is not None:
316
+ logger.info(
317
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
318
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
319
+ )
320
+
321
+ # Setup logging
322
+ logging.basicConfig(
323
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
+ datefmt="%m/%d/%Y %H:%M:%S",
325
+ handlers=[logging.StreamHandler(sys.stdout)],
326
+ )
327
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
328
+
329
+ # Log on each process the small summary:
330
+ logger.warning(
331
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
332
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
333
+ )
334
+ # Set the verbosity to info of the Transformers logger (on main process only):
335
+ if is_main_process(training_args.local_rank):
336
+ transformers.utils.logging.set_verbosity_info()
337
+ logger.info("Training/evaluation parameters %s", training_args)
338
+
339
+ # Set seed before initializing model.
340
+ set_seed(training_args.seed)
341
+
342
+ # 1. First, let's load the dataset
343
+ raw_datasets = DatasetDict()
344
+ raw_datasets["train"] = load_dataset(
345
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
346
+ )
347
+ raw_datasets["eval"] = load_dataset(
348
+ data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name
349
+ )
350
+
351
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
352
+ raise ValueError(
353
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
354
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
355
+ f"{', '.join(raw_datasets['train'].column_names)}."
356
+ )
357
+
358
+ if data_args.text_column_name not in raw_datasets["train"].column_names:
359
+ raise ValueError(
360
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
361
+ "Make sure to set `--text_column_name` to the correct text column - one of "
362
+ f"{', '.join(raw_datasets['train'].column_names)}."
363
+ )
364
+
365
+ # prepare dataset
366
+ if data_args.max_train_samples is not None:
367
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
368
+
369
+ if data_args.max_eval_samples is not None:
370
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
371
+
372
+ # 2. We remove some special characters from the datasets
373
+ # that make training complicated and do not help in transcribing the speech
374
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
375
+ # that could be easily picked up by the model
376
+
377
+ chars_to_ignore_regex = (
378
+ f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
379
+ )
380
+
381
+ def remove_special_characters(batch):
382
+ if chars_to_ignore_regex is not None:
383
+ batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[data_args.text_column_name]).lower() + " "
384
+ else:
385
+ batch["target_text"] = batch[data_args.text_column_name].lower() + " "
386
+ return batch
387
+
388
+ with training_args.main_process_first(desc="dataset map special characters removal"):
389
+ raw_datasets = raw_datasets.map(
390
+ remove_special_characters,
391
+ remove_columns=[data_args.text_column_name],
392
+ desc="remove special characters from datasets",
393
+ )
394
+
395
+ # 3. Next, we create the vocabulary of the model by extracting all unique characters from
396
+ # the training and evaluation datasets
397
+ # We need to make sure that only first rank saves vocabulary
398
+ # make sure all processes wait until vocab is created
399
+ vocab_file = os.path.join(training_args.output_dir, "vocab.json")
400
+
401
+ with training_args.main_process_first():
402
+ if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
403
+ os.remove(vocab_file)
404
+
405
+ with training_args.main_process_first(desc="dataset map vocabulary creation"):
406
+ if not os.path.isfile(vocab_file):
407
+ os.makedirs(training_args.output_dir, exist_ok=True)
408
+ vocab_dict = create_vocabulary_from_data(raw_datasets)
409
+
410
+ # save vocab dict to be loaded into tokenizer
411
+ with open(vocab_file, "w") as file:
412
+ json.dump(vocab_dict, file)
413
+
414
+ # 4. Now we can instantiate the configuration, feature extractor, tokenizer and model
415
+ # Note for distributed training, the .from_pretrained methods guarantee that only
416
+ # one local process can concurrently download model & vocab.
417
+
418
+ # load config
419
+ config = AutoConfig.from_pretrained(
420
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
421
+ )
422
+
423
+ # tokenizer is defined by `tokenizer_class` if present in config else by `model_type`
424
+ config_for_tokenizer = config if config.tokenizer_class is not None else None
425
+ tokenizer_type = config.model_type if config.tokenizer_class is None else None
426
+
427
+ # load feature_extractor, tokenizer and create processor
428
+ tokenizer = AutoTokenizer.from_pretrained(
429
+ training_args.output_dir,
430
+ config=config_for_tokenizer,
431
+ tokenizer_type=tokenizer_type,
432
+ unk_token="[UNK]",
433
+ pad_token="[PAD]",
434
+ word_delimiter_token="|",
435
+ use_auth_token=data_args.use_auth_token,
436
+ )
437
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
438
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
439
+ )
440
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
441
+
442
+ # adapt config
443
+ config.update(
444
+ {
445
+ "feat_proj_dropout": model_args.feat_proj_dropout,
446
+ "attention_dropout": model_args.attention_dropout,
447
+ "hidden_dropout": model_args.hidden_dropout,
448
+ "final_dropout": model_args.final_dropout,
449
+ "mask_time_prob": model_args.mask_time_prob,
450
+ "gradient_checkpointing": training_args.gradient_checkpointing,
451
+ "layerdrop": model_args.layerdrop,
452
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
453
+ "pad_token_id": processor.tokenizer.pad_token_id,
454
+ "vocab_size": len(processor.tokenizer),
455
+ "activation_dropout": model_args.activation_dropout,
456
+ }
457
+ )
458
+
459
+ # create model
460
+ model = AutoModelForCTC.from_pretrained(
461
+ model_args.model_name_or_path,
462
+ cache_dir=model_args.cache_dir,
463
+ config=config,
464
+ use_auth_token=data_args.use_auth_token,
465
+ )
466
+
467
+ # freeze encoder
468
+ if model_args.freeze_feature_extractor:
469
+ model.freeze_feature_extractor()
470
+
471
+ # 5. Now we preprocess the datasets including loading the audio, resampling and normalization
472
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
473
+ # so that we just need to set the correct target sampling rate and normalize the input
474
+ # via the `feature_extractor`
475
+
476
+ # make sure that dataset decodes audio with correct sampling rate
477
+ raw_datasets = raw_datasets.cast_column(
478
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
479
+ )
480
+
481
+ # derive max & min input length for sample rate & max duration
482
+ max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate
483
+ min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate
484
+
485
+ # Preprocessing the datasets.
486
+ # We need to read the audio files as arrays and tokenize the targets.
487
+ def prepare_dataset(batch):
488
+ # load audio
489
+ sample = batch[data_args.audio_column_name]
490
+
491
+ batch["input_values"] = processor(
492
+ sample["array"], sampling_rate=sample["sampling_rate"], truncate=True, max_length=max_input_length
493
+ ).input_values[0]
494
+ batch["input_length"] = len(batch["input_values"])
495
+
496
+ # Setup the processor for targets
497
+ with processor.as_target_processor():
498
+ batch["labels"] = processor(batch["target_text"]).input_ids
499
+ return batch
500
+
501
+ with training_args.main_process_first(desc="dataset map preprocessing"):
502
+ vectorized_datasets = raw_datasets.map(
503
+ prepare_dataset,
504
+ remove_columns=raw_datasets["train"].column_names,
505
+ num_proc=data_args.preprocessing_num_workers,
506
+ desc="preprocess datasets",
507
+ )
508
+
509
+ if min_input_length > 0.0:
510
+ # filter data that is shorter than min_input_length
511
+ vectorized_datasets = vectorized_datasets.filter(
512
+ lambda x: x > min_input_length,
513
+ num_proc=data_args.preprocessing_num_workers,
514
+ input_columns=["input_length"],
515
+ )
516
+
517
+ vectorized_datasets = vectorized_datasets.remove_columns("input_length")
518
+
519
+ # 6. Next, we can prepare the training.
520
+ # Let's use word error rate (WER) as our evaluation metric,
521
+ # instantiate a data collator and the trainer
522
+
523
+ # Define Metric during training
524
+ wer_metric = load_metric("wer")
525
+
526
+ # for large datasets it is advised to run the preprocessing on a
527
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
528
+ # be a timeout when running the script in distributed mode.
529
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
530
+ # cached dataset
531
+ if data_args.preprocessing_only:
532
+ logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
533
+ return
534
+
535
+ def compute_metrics(pred):
536
+ pred_logits = pred.predictions
537
+ pred_ids = np.argmax(pred_logits, axis=-1)
538
+
539
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
540
+
541
+ pred_str = processor.batch_decode(pred_ids)
542
+ # we do not want to group tokens when computing the metrics
543
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
544
+
545
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
546
+
547
+ return {"wer": wer}
548
+
549
+ # Instantiate custom data collator
550
+ data_collator = DataCollatorCTCWithPadding(processor=processor)
551
+
552
+ # create Adam8bit optimizer
553
+ optimizer = bnb.optim.Adam8bit(model.parameters(), lr=training_args.learning_rate, betas=(training_args.adam_beta1, training_args.adam_beta2))
554
+
555
+ # Initialize Trainer
556
+ trainer = Trainer(
557
+ model=model,
558
+ data_collator=data_collator,
559
+ args=training_args,
560
+ compute_metrics=compute_metrics,
561
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
562
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
563
+ tokenizer=processor.feature_extractor,
564
+ optimizers=(optimizer, None), # None is replaced by default learning rate schedule
565
+ )
566
+
567
+ # 7. Finally, we can start training
568
+
569
+ # Training
570
+ if training_args.do_train:
571
+
572
+ # use last checkpoint if exist
573
+ if last_checkpoint is not None:
574
+ checkpoint = last_checkpoint
575
+ elif os.path.isdir(model_args.model_name_or_path):
576
+ checkpoint = model_args.model_name_or_path
577
+ else:
578
+ checkpoint = None
579
+
580
+ # Save the feature_extractor and the tokenizer
581
+ if is_main_process(training_args.local_rank):
582
+ processor.save_pretrained(training_args.output_dir)
583
+
584
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
585
+ trainer.save_model()
586
+
587
+ metrics = train_result.metrics
588
+ max_train_samples = (
589
+ data_args.max_train_samples
590
+ if data_args.max_train_samples is not None
591
+ else len(vectorized_datasets["train"])
592
+ )
593
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
594
+
595
+ trainer.log_metrics("train", metrics)
596
+ trainer.save_metrics("train", metrics)
597
+ trainer.save_state()
598
+
599
+ # Evaluation
600
+ results = {}
601
+ if training_args.do_eval:
602
+ logger.info("*** Evaluate ***")
603
+ metrics = trainer.evaluate()
604
+ max_eval_samples = (
605
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
606
+ )
607
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
608
+
609
+ trainer.log_metrics("eval", metrics)
610
+ trainer.save_metrics("eval", metrics)
611
+
612
+ # Write model card and (optionally) push to hub
613
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
614
+ kwargs = {
615
+ "finetuned_from": model_args.model_name_or_path,
616
+ "tasks": "speech-recognition",
617
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
618
+ "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
619
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
620
+ }
621
+ if "common_voice" in data_args.dataset_name:
622
+ kwargs["language"] = config_name
623
+
624
+ if training_args.push_to_hub:
625
+ trainer.push_to_hub(**kwargs)
626
+ else:
627
+ trainer.create_model_card(**kwargs)
628
+
629
+ return results
630
+
631
+
632
+ if __name__ == "__main__":
633
+ main()