sanchit-gandhi HF staff commited on
Commit
bec9d50
1 Parent(s): 57fbcab

Add model weights

Browse files
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.whisper filter=lfs diff=lfs merge=lfs -text
medium.en.whisper ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a667fb55dfa9d15928c4169d324d45f59e371fafcd41661c8e54da370c8a415
3
+ size 3055771163
run_speech_recognition_whisper.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning OpenAI Whisper models for speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+ # flake8: noqa: E501
21
+ import logging
22
+ import os
23
+ import re
24
+
25
+ import torchaudio
26
+ import whisper
27
+ import sys
28
+ from dataclasses import dataclass, field
29
+
30
+ from typing import Optional, Dict, Union, List
31
+
32
+ import numpy as np
33
+ import torch
34
+
35
+ import datasets
36
+ from datasets import DatasetDict, load_dataset
37
+ import transformers
38
+ from torch import nn
39
+ from transformers import (
40
+ HfArgumentParser,
41
+ Seq2SeqTrainingArguments,
42
+ set_seed,
43
+ Seq2SeqTrainer,
44
+ )
45
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
46
+ from transformers.utils import check_min_version
47
+ from transformers.utils.versions import require_version
48
+
49
+ import wandb
50
+
51
+ from whisper.normalizers.english import EnglishTextNormalizer
52
+
53
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
54
+ check_min_version("4.17.0.dev0")
55
+
56
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+
61
+ @dataclass
62
+ class ModelArguments:
63
+ """
64
+ Arguments pertaining to which model/tokenizer we are going to fine-tune from.
65
+ """
66
+ model_name_or_path: Optional[str] = field(
67
+ default=None,
68
+ metadata={"help": "Path to pretrained model or model identifier from OpenAI Whisper NGC."}
69
+ )
70
+ cache_dir: Optional[str] = field(
71
+ default=None,
72
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co or OpenAI Whisper NGC."},
73
+ )
74
+ use_auth_token: bool = field(
75
+ default=False,
76
+ metadata={
77
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
78
+ "with private models)."
79
+ },
80
+ )
81
+ manifest_path: str = field(
82
+ default="data",
83
+ metadata={
84
+ "help": "Manifest path."
85
+ },
86
+ )
87
+ tokenizer_path: str = field(
88
+ default="tokenizers",
89
+ metadata={
90
+ "help": "Tokenizer path."
91
+ },
92
+ )
93
+ freeze_encoder: bool = field(
94
+ default=False,
95
+ metadata={"help": "Freeze the acoustic encoder of the model. Recommend when fine-tuning on small datasets."}
96
+ )
97
+ num_beams: int = field(
98
+ default=1,
99
+ metadata={"help": "Number of beams for evaluation."},
100
+ )
101
+ length_penalty: float = field(
102
+ default=1.0,
103
+ metadata={"help": "Length penalty for evaluation."},
104
+ )
105
+ use_adam8bit: bool = field(
106
+ default=False,
107
+ metadata={"help": "Whether to use bitsandbytes 8bit AdamW optimiser."}
108
+ )
109
+ dropout_rate: float = field(
110
+ default=0.0,
111
+ metadata={"help": "The dropout ratio for all dropout layers (default=0)."}
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class DataTrainingArguments:
117
+ """
118
+ Arguments pertaining to what data we are going to input our model for training and eval.
119
+ """
120
+
121
+ dataset_name: str = field(
122
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
123
+ )
124
+ dataset_config_name: Optional[str] = field(
125
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
126
+ )
127
+ text_column: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
130
+ )
131
+ dataset_cache_dir: Optional[str] = field(
132
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
133
+ )
134
+ overwrite_cache: bool = field(
135
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
136
+ )
137
+ preprocessing_num_workers: Optional[int] = field(
138
+ default=None,
139
+ metadata={"help": "The number of processes to use for the preprocessing."},
140
+ )
141
+ max_train_samples: Optional[int] = field(
142
+ default=None,
143
+ metadata={
144
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
145
+ "value if set."
146
+ },
147
+ )
148
+ max_eval_samples: Optional[int] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
152
+ "value if set."
153
+ },
154
+ )
155
+ max_predict_samples: Optional[int] = field(
156
+ default=None,
157
+ metadata={
158
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
159
+ "value if set."
160
+ },
161
+ )
162
+ audio_column_name: str = field(
163
+ default="audio",
164
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
165
+ )
166
+ text_column_name: str = field(
167
+ default="text",
168
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
169
+ )
170
+ max_duration_in_seconds: float = field(
171
+ default=20.0,
172
+ metadata={
173
+ "help": "Truncate training audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
174
+ },
175
+ )
176
+ min_duration_in_seconds: float = field(
177
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
178
+ )
179
+ max_eval_duration_in_seconds: float = field(
180
+ default=None,
181
+ metadata={
182
+ "help": "Truncate eval/test audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
183
+ },
184
+ )
185
+ max_target_length: Optional[int] = field(
186
+ default=128,
187
+ metadata={
188
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
189
+ "than this will be truncated, sequences shorter will be padded."
190
+ },
191
+ )
192
+ min_target_length: Optional[int] = field(
193
+ default=0,
194
+ metadata={
195
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
196
+ "than this will be filtered."
197
+ },
198
+ )
199
+ preprocessing_only: bool = field(
200
+ default=False,
201
+ metadata={
202
+ "help": "Whether to only do data preprocessing and skip training. "
203
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
204
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
205
+ "so that the cached datasets can consequently be loaded in distributed training"
206
+ },
207
+ )
208
+ train_split_name: str = field(
209
+ default="train",
210
+ metadata={
211
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
212
+ },
213
+ )
214
+ eval_split_name: str = field(
215
+ default="validation",
216
+ metadata={
217
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
218
+ },
219
+ )
220
+ test_split_name: str = field(
221
+ default="test",
222
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
223
+ )
224
+ do_lower_case: bool = field(
225
+ default=True,
226
+ metadata={"help": "Whether the target text should be lower cased."},
227
+ )
228
+ wandb_project: str = field(
229
+ default="speech-recognition-whisper",
230
+ metadata={"help": "The name of the wandb project."},
231
+ )
232
+ ignore_verifications: bool = field(
233
+ default=False,
234
+ metadata={
235
+ "help": "Ignore the verifications of the downloaded/processed dataset information in `load_dataset` (checksums/size/splits/...)."
236
+ }
237
+ )
238
+ torchaudio_resampler: bool = field(
239
+ default=False,
240
+ metadata={
241
+ "help": "Whether to use torchaudio to resample. If `False` (default) will use the default datataset backed."
242
+ }
243
+ )
244
+
245
+
246
+ def write_wandb_pred(pred_str, label_str, prefix="eval"):
247
+ # convert str data to a wandb compatible format
248
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
249
+ # we'll log all predictions for the last epoch
250
+ wandb.log(
251
+ {
252
+ f"{prefix}/predictions": wandb.Table(
253
+ columns=["label_str", "pred_str"], data=str_data
254
+ )
255
+ },
256
+ )
257
+
258
+
259
+ def to_pad_to_mel(array):
260
+ """Static function which:
261
+ 1. Pads/trims a list of audio arrays to a max length of 30s
262
+ 2. Computes log-mel filter coefficients from padded/trimmed audio sequences
263
+ Inputs:
264
+ array: list of audio arrays
265
+ Returns:
266
+ input_ids: torch.tensor of log-mel filter bank coefficients
267
+ """
268
+ padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
269
+ input_ids = whisper.log_mel_spectrogram(padded_input)
270
+ return input_ids
271
+
272
+
273
+ def to_mel_to_pad(array):
274
+ """Static function which:
275
+ 1. Computes log-mel filter coefficients from padded/trimmed audio sequences
276
+ 2. Pads/trims a list of audio arrays to a max length of 30s
277
+ Inputs:
278
+ array: list of audio arrays
279
+ Returns:
280
+ input_ids: torch.tensor of log-mel filter bank coefficients
281
+ """
282
+ mels = whisper.log_mel_spectrogram(np.asarray(array, dtype=np.float32))
283
+ input_ids = whisper.pad_or_trim(mels, 3000)
284
+ return input_ids
285
+
286
+
287
+ @dataclass
288
+ class WhisperDataCollatorWithPadding:
289
+ """
290
+ Data collator that dynamically pads the audio inputs received. An EOS token is appended to the labels sequences.
291
+ They are then dynamically padded to max length.
292
+ Args:
293
+ eos_token_id (`int`)
294
+ The end-of-sentence token for the Whisper tokenizer. Ensure to set for sequences to terminate before
295
+ generation max length.
296
+ """
297
+
298
+ eos_token_id: int
299
+ time_stamp_token_id: int
300
+
301
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
302
+ """
303
+ Since Whisper models don't have a HF processor defined (feature extractor + tokenizer), we'll pad by hand...
304
+ """
305
+ # split inputs and labels since they have to be of different lengths
306
+ # and need different padding methods
307
+ input_ids = [feature["input_ids"] for feature in features]
308
+ labels = [feature["labels"] for feature in features]
309
+
310
+ # first, pad the audio inputs to max_len
311
+ input_ids = torch.concat([to_pad_to_mel(input_val)[None, :] for input_val in input_ids])
312
+
313
+ # next, append the eos token to our sequence of labels
314
+ labels = [lab + [self.eos_token_id] for lab in labels]
315
+ # finally, pad the target labels to max_len
316
+ label_lengths = [len(lab) for lab in labels]
317
+ max_label_len = max(label_lengths)
318
+ labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
319
+
320
+ batch = {"labels": labels}
321
+ batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
322
+
323
+ batch["input_ids"] = input_ids
324
+
325
+ return batch
326
+
327
+
328
+ def main():
329
+ # See all possible arguments in src/transformers/training_args.py
330
+ # or by passing the --help flag to this script.
331
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
332
+
333
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
334
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
335
+ # If we pass only one argument to the script and it's the path to a json file,
336
+ # let's parse it to get our arguments.
337
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
338
+ else:
339
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
340
+
341
+ # Set wandb project ID before instantiating the Trainer
342
+ os.environ["WANDB_PROJECT"] = data_args.wandb_project
343
+ report_to_wandb = "wandb" in training_args.report_to
344
+
345
+ sample_rate = 16_000
346
+
347
+ # Detecting last checkpoint.
348
+ last_checkpoint = None
349
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
350
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
351
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
352
+ raise ValueError(
353
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
354
+ "Use --overwrite_output_dir to overcome."
355
+ )
356
+ elif last_checkpoint is not None:
357
+ logger.info(
358
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
359
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ datefmt="%m/%d/%Y %H:%M:%S",
366
+ handlers=[logging.StreamHandler(sys.stdout)],
367
+ )
368
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
369
+
370
+ # Log on each process the small summary:
371
+ logger.warning(
372
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
373
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
374
+ )
375
+ # Set the verbosity to info of the Transformers logger (on main process only):
376
+ if is_main_process(training_args.local_rank):
377
+ transformers.utils.logging.set_verbosity_info()
378
+ logger.info("Training/evaluation parameters %s", training_args)
379
+
380
+ # Set seed before initializing model.
381
+ set_seed(training_args.seed)
382
+
383
+ # load the model
384
+ if os.path.isfile(model_args.model_name_or_path):
385
+ model = whisper.Whisper.load_trained(model_args.model_name_or_path)
386
+ else:
387
+ model = whisper.load_model(model_args.model_name_or_path, dropout_rate=model_args.dropout_rate)
388
+
389
+ # set the dropout for the MLP layers -> we do this here as the MLP layers are written as a 'sequential'
390
+ # so changing the modelling code gives mis-matches in the state-dict
391
+
392
+ if not model_args.freeze_encoder:
393
+ # only apply dropout when training the encoder
394
+ for block_idx in range(len(model.encoder.blocks)):
395
+ mlp_layer = model.encoder.blocks[block_idx].mlp
396
+ # going very verbose to explain what we're doing here!
397
+ fc1 = mlp_layer[0]
398
+ act_fn = mlp_layer[1]
399
+ dropout = nn.Dropout(p=model_args.dropout_rate)
400
+ fc2 = mlp_layer[2]
401
+ model.encoder.blocks[block_idx].mlp = nn.Sequential(fc1, act_fn, dropout, fc2, dropout)
402
+
403
+ """for block_idx in range(len(model.decoder.blocks)):
404
+ mlp_layer = model.decoder.blocks[block_idx].mlp
405
+ fc1 = mlp_layer[0]
406
+ act_fn = mlp_layer[1]
407
+ dropout = nn.Dropout(p=model_args.dropout_rate)
408
+ fc2 = mlp_layer[2]
409
+ model.decoder.blocks[block_idx].mlp = nn.Sequential(fc1, act_fn, dropout, fc2, dropout)"""
410
+
411
+ # load the tokenizer
412
+ whisper_tok = whisper.tokenizer.get_tokenizer(False, task="transcribe", language="en")
413
+ tokenizer = whisper_tok.tokenizer
414
+ tokenizer.pad_token = tokenizer.eos_token
415
+
416
+ # 4. Load dataset
417
+ raw_datasets = DatasetDict()
418
+
419
+ if training_args.do_train:
420
+ raw_datasets["train"] = load_dataset(
421
+ data_args.dataset_name,
422
+ data_args.dataset_config_name,
423
+ split=data_args.train_split_name,
424
+ cache_dir=data_args.dataset_cache_dir,
425
+ use_auth_token=True if model_args.use_auth_token else None,
426
+ )
427
+
428
+ if training_args.do_eval:
429
+ raw_datasets["eval"] = load_dataset(
430
+ data_args.dataset_name,
431
+ data_args.dataset_config_name,
432
+ split=data_args.eval_split_name,
433
+ cache_dir=data_args.dataset_cache_dir,
434
+ use_auth_token=True if model_args.use_auth_token else None,
435
+ )
436
+
437
+ if training_args.do_predict:
438
+ test_split = data_args.test_split_name.split("+")
439
+ for split in test_split:
440
+ raw_datasets[split] = load_dataset(
441
+ data_args.dataset_name,
442
+ data_args.dataset_config_name,
443
+ split=split,
444
+ cache_dir=data_args.dataset_cache_dir,
445
+ use_auth_token=True if model_args.use_auth_token else None,
446
+ )
447
+
448
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
449
+ raise ValueError(
450
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
451
+ "training, evaluation or prediction has to be done."
452
+ )
453
+
454
+ # if not training, there is no need to run multiple epochs
455
+ if not training_args.do_train:
456
+ training_args.num_train_epochs = 1
457
+
458
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
459
+ raise ValueError(
460
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
461
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
462
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
463
+ )
464
+
465
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
466
+ raise ValueError(
467
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
468
+ "Make sure to set `--text_column_name` to the correct text column - one of "
469
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
470
+ )
471
+
472
+ # 6. Resample speech dataset ALWAYS
473
+ if data_args.torchaudio_resampler:
474
+ # TODO: remove hardcoding of orig sr
475
+ resampler = torchaudio.transforms.Resample(8_000, sample_rate)
476
+ else:
477
+ raw_datasets = raw_datasets.cast_column(
478
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=sample_rate)
479
+ )
480
+ resampler = None
481
+
482
+ # 7. Preprocessing the datasets.
483
+ # We need to read the audio files as arrays and tokenize the targets.
484
+ max_input_length = int(data_args.max_duration_in_seconds * sample_rate)
485
+ min_input_length = min(int(data_args.min_duration_in_seconds * sample_rate), 1)
486
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * sample_rate) if data_args.max_eval_duration_in_seconds else None
487
+ max_target_length = data_args.max_target_length
488
+ min_target_length = data_args.min_target_length
489
+ audio_column_name = data_args.audio_column_name
490
+ num_workers = data_args.preprocessing_num_workers
491
+ text_column_name = data_args.text_column_name
492
+ do_lower_case = data_args.do_lower_case
493
+ dataset_name = data_args.dataset_name
494
+
495
+ # Define tokens to ignore/replace
496
+ tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
497
+ gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
498
+ gigaspeech_disfluencies = ["<other>", "<sil>"]
499
+ swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "[vocalized-noise]", "<a_aside>", "<b_aside>", "<e_aside>",
500
+ "[laughter-", "_1", "[laugh]", "[sigh]", "[cough]", "[mn]", "[breath]", "[lipsmack]",
501
+ "[sneeze]", "[skip]", "[pause]", "(%hesitation)", "(%HESITATION)"]
502
+ swb_punctuations = ["{", "}", "[", "]-", "]", "((", "))", "(", ")"]
503
+ earnings_disfluencies = ["<noise>", "<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<silence>"]
504
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
505
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", ""]
506
+ ignore_segments = ignore_segments + gigaspeech_disfluencies + swb_disfluencies + earnings_disfluencies
507
+
508
+ if training_args.do_train and data_args.max_train_samples is not None:
509
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
510
+
511
+ if training_args.do_eval and data_args.max_eval_samples is not None:
512
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
513
+
514
+ if training_args.do_predict and data_args.max_predict_samples is not None:
515
+ for split in test_split:
516
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_predict_samples))
517
+
518
+ # filter data where the targets are ignored in scoring
519
+ def is_target_labels(input_str):
520
+ return input_str.lower() not in ignore_segments
521
+
522
+ raw_datasets = raw_datasets.filter(
523
+ is_target_labels,
524
+ num_proc=num_workers,
525
+ input_columns=[text_column_name],
526
+ desc="filtering data where the targets are ignored in scoring",
527
+ )
528
+
529
+ def prepare_dataset(batch):
530
+ # pre-process audio
531
+ try:
532
+ sample = batch[audio_column_name]
533
+ except ValueError:
534
+ # E22: some samples are empty (no audio). Reading the empty audio array will trigger
535
+ # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
536
+ # They will be filtered in the subsequent filtering stage and so are
537
+ # explicitly ignored during training.
538
+ sample = {"array": np.array([0.]), "sampling_rate": sample_rate}
539
+
540
+ if resampler is not None:
541
+ speech_tensor = torch.FloatTensor(sample["array"])
542
+ speech_tensor = speech_tensor.squeeze()
543
+ speech_tensor = resampler(speech_tensor)
544
+ sample["array"] = speech_tensor.numpy()
545
+ sample["sampling_rate"] = resampler.new_freq
546
+
547
+ # For training Whisper we perform the audio preprocessing in the WhisperDataCollator
548
+ # => we only need to supply it with the raw audio values
549
+ batch["input_ids"] = sample["array"]
550
+ batch["input_lengths"] = len(batch["input_ids"])
551
+
552
+ # 'Error correction' of targets
553
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
554
+
555
+ # LibriSpeech ASR
556
+ if dataset_name == "librispeech_asr":
557
+ pass # no error correction necessary
558
+
559
+ # VoxPopuli
560
+ if dataset_name == "google/xtreme_s":
561
+ pass # no error correction necessary
562
+
563
+ # Common Voice 9
564
+ if dataset_name == "mozilla-foundation/common_voice_9_0":
565
+ if input_str.startswith('"') and input_str.endswith('"'):
566
+ # we can remove trailing quotation marks as they do not affect the transcription
567
+ input_str = input_str[1:-1]
568
+ # replace double quotation marks with single
569
+ input_str = input_str.replace('""', '"')
570
+
571
+ # TED-LIUM (Release 3)
572
+ if dataset_name == "LIUM/tedlium":
573
+ # delete the <unk> token from the text
574
+ input_str = input_str.replace("<unk>", "")
575
+ # replace spaced apostrophes with un-spaced (it 's -> it's)
576
+ for contraction in tedlium_contractions:
577
+ input_str = input_str.replace(contraction, contraction[1:])
578
+
579
+ # GigaSpeech
580
+ if dataset_name == "speechcolab/gigaspeech":
581
+ for disfluency in gigaspeech_disfluencies:
582
+ input_str = input_str.replace(disfluency, "")
583
+ # convert spelled out punctuation to symbolic form
584
+ for punctuation, replacement in gigaspeech_punctuation.items():
585
+ input_str = input_str.replace(punctuation, replacement)
586
+
587
+ # SWB: hide the path to the private HF dataset
588
+ if "switchboard" in dataset_name:
589
+ # In one conversation people speak some German phrases that are tagged as
590
+ # <german (( ja wohl )) > -- we remove these
591
+ input_str = re.sub("<[^>]*>", "", input_str)
592
+
593
+ # Remove junk tokens
594
+ for disfluency in swb_disfluencies:
595
+ input_str = input_str.replace(disfluency, "")
596
+
597
+ # normalise acronyms (Fisher: u_.c_.l_.a., SWBD: u c l a)
598
+ input_str = input_str.replace("_.", " ")
599
+
600
+ # Replace partially pronounced words (square brackets + hyphen): westmin[ster]- to westmin- or -[go]ing to -ing
601
+ # Replace anomalous words (square brackets + backslack): [lemguini/linguini] to linguini
602
+ # Replace the combo of the two: [lem[guini]-/linguini] to lem-
603
+ # Example: we [ah/are] -[go]ing to westmin[ster]- for [lem[guini]-/linguini]
604
+ # Target: we ah -ing to westmin- for lem-
605
+ # Treat anomalous words first then destroy the content of all square brackets (partially pronounced words)
606
+
607
+ # First treat partially pronounced anomalous words by removing correct word: [lem[guini]-/linguini] to [lem[guini]-
608
+ input_str = re.sub(r"\-\/.*?\]", "-", input_str)
609
+
610
+ # Now replace anomalous words with their correct transcriptions: [lemguini/linguini] to linguini
611
+ split_str = input_str.split("/")
612
+ if len(split_str) > 1:
613
+ input_str = " ".join(
614
+ [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
615
+
616
+ # Remove the trailing brackets on the start/end of words
617
+ processed_str = []
618
+ for word in input_str.split():
619
+ if word[0] == "[":
620
+ processed_str.append(word[1:])
621
+ elif word[-1] == "]":
622
+ processed_str.append(word[:-1])
623
+ else:
624
+ processed_str.append(word)
625
+
626
+ # Stick the processed words back together
627
+ input_str = " ".join(processed_str)
628
+
629
+ # Now we can remove all words in square brackets: -[go]ing to -ing
630
+ input_str = re.sub(r"\-\[(.*?)\]", "-", input_str)
631
+
632
+ # westmin[ster]- to westmin-
633
+ input_str = re.sub(r"\[(.*?)\]\-", "-", input_str)
634
+
635
+ # tech[n]ology to tech-ology
636
+ input_str = re.sub(r"\[(.*?)\]", "-", input_str)
637
+
638
+ # partially pronounced words are now done!
639
+ # remove erroneous punctuations (curly braces, trailing square brackets, etc.)
640
+ for punctuation in swb_punctuations:
641
+ input_str = input_str.replace(punctuation, "")
642
+
643
+ # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
644
+ if "earnings22" in dataset_name:
645
+ # Remove the 100ms offset at the end of the sample
646
+ sampling_rate = sample["sampling_rate"]
647
+ offset = int(100 * (10 ** -3) * sampling_rate)
648
+ batch["input_ids"] = sample["array"][:-offset]
649
+ batch["input_lengths"] = len(batch["input_ids"])
650
+ # Remove junk tokens
651
+ for disfluency in earnings_disfluencies:
652
+ input_str = input_str.replace(disfluency, "")
653
+
654
+ # SPGISpeech
655
+ if dataset_name == "kensho/spgispeech":
656
+ pass # no error correction necessary
657
+
658
+ # JIWER compliance (for WER/CER calc.)
659
+ # remove multiple spaces
660
+ input_str = re.sub(r"\s\s+", " ", input_str)
661
+ # strip trailing spaces
662
+ input_str = input_str.strip()
663
+
664
+ # Finally, we tokenize the processed text
665
+ batch["labels"] = tokenizer(input_str).input_ids
666
+ return batch
667
+
668
+ vectorized_datasets = raw_datasets.map(
669
+ prepare_dataset,
670
+ remove_columns=next(iter(raw_datasets.values())).column_names,
671
+ num_proc=num_workers,
672
+ desc="preprocess train dataset",
673
+ )
674
+
675
+ # filter training data with inputs longer than max_input_length
676
+ def is_audio_in_length_range(input_length):
677
+ return min_input_length < input_length < max_input_length
678
+
679
+ if training_args.do_train:
680
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
681
+ is_audio_in_length_range,
682
+ num_proc=num_workers,
683
+ input_columns=["input_lengths"],
684
+ )
685
+
686
+ if max_eval_input_length is not None:
687
+ # filter training data with inputs longer than max_input_length
688
+ def is_eval_audio_in_length_range(input_length):
689
+ return min_input_length < input_length < max_eval_input_length
690
+
691
+ if training_args.do_eval:
692
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
693
+ is_eval_audio_in_length_range,
694
+ num_proc=num_workers,
695
+ input_columns=["input_lengths"],
696
+ )
697
+
698
+ if training_args.do_predict:
699
+ for split in test_split:
700
+ vectorized_datasets[split] = vectorized_datasets[split].filter(
701
+ is_eval_audio_in_length_range,
702
+ num_proc=num_workers,
703
+ input_columns=["input_lengths"],
704
+ )
705
+
706
+ # filter training data with targets shorter than min_target_length or longer than max_target_length
707
+ def is_labels_in_length_range(labels):
708
+ return min_target_length < len(labels) < max_target_length
709
+
710
+ if training_args.do_train:
711
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
712
+ is_labels_in_length_range,
713
+ num_proc=num_workers,
714
+ input_columns=["labels"],
715
+ )
716
+
717
+ # filter data with targets empty sentences
718
+ def is_labels_greater_than_min(labels):
719
+ return len(labels) > 0
720
+
721
+ vectorized_datasets = vectorized_datasets.filter(
722
+ is_labels_greater_than_min,
723
+ num_proc=num_workers,
724
+ input_columns=["labels"],
725
+ )
726
+
727
+ # for large datasets it is advised to run the preprocessing on a
728
+ # single machine first with `args.preprocessing_only` since there will mostly likely
729
+ # be a timeout when running the script in distributed mode.
730
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
731
+ # cached dataset
732
+ if data_args.preprocessing_only:
733
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
734
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
735
+ return
736
+
737
+ if model_args.freeze_encoder:
738
+ model.freeze_encoder()
739
+ logging.info("Model encoder has been frozen")
740
+
741
+ # 8. Load Metric
742
+ #metric_wer = evaluate.load("wer")
743
+ #metric_cer = evaluate.load("cer")
744
+ metric_wer = datasets.load_metric("wer")
745
+ metric_cer = datasets.load_metric("cer")
746
+
747
+ normalizer = EnglishTextNormalizer()
748
+
749
+ def compute_metrics(pred):
750
+ pred_ids = pred.predictions
751
+ pred.label_ids[pred.label_ids == -100] = tokenizer.eos_token_id
752
+
753
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
754
+ pred_str = [x.lstrip().strip() for x in pred_str]
755
+
756
+ # we do not want to group tokens when computing the metrics
757
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
758
+
759
+ wer = metric_wer.compute(predictions=pred_str, references=label_str)
760
+ cer = metric_cer.compute(predictions=pred_str, references=label_str)
761
+
762
+ #normalized_pred_str = [normalizer(str(input_str)) for input_str in pred_str]
763
+ #normalized_label_str = [normalizer(str(input_str)) for input_str in label_str]
764
+
765
+ #wer_norm = metric_wer.compute(predictions=normalized_pred_str, references=normalized_label_str)
766
+ #cer_norm = metric_cer.compute(predictions=normalized_pred_str, references=normalized_label_str)
767
+
768
+ #return {"wer": wer, "cer": cer, "wer_norm": wer_norm, "cer_norm": cer_norm}
769
+ return {"wer": wer, "cer": cer}
770
+
771
+ def compute_metrics_and_predictions(pred):
772
+ pred_ids = pred.predictions
773
+ pred.label_ids[pred.label_ids == -100] = tokenizer.eos_token_id
774
+
775
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
776
+ pred_str = [x.lstrip().strip() for x in pred_str]
777
+
778
+ # we do not want to group tokens when computing the metrics
779
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
780
+
781
+ wer = metric_wer.compute(predictions=pred_str, references=label_str)
782
+ cer = metric_cer.compute(predictions=pred_str, references=label_str)
783
+
784
+ #normalized_pred_str = [normalizer(str(input_str)) for input_str in pred_str]
785
+ #normalized_label_str = [normalizer(str(input_str)) for input_str in label_str]
786
+
787
+ #wer_norm = metric_wer.compute(predictions=normalized_pred_str, references=normalized_label_str)
788
+ #cer_norm = metric_cer.compute(predictions=normalized_pred_str, references=normalized_label_str)
789
+
790
+ #return {"wer": wer, "cer": cer, "wer_norm": wer_norm, "cer_norm": cer_norm, "pred_str": pred_str, "label_str": label_str}
791
+ return {"wer": wer, "cer": cer, "pred_str": pred_str, "label_str": label_str}
792
+
793
+ class WhisperTrainer(Seq2SeqTrainer):
794
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
795
+ # If we are executing this function, we are the process zero, so we don't check for that.
796
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
797
+ os.makedirs(output_dir, exist_ok=True)
798
+ logger.info(f"Saving model checkpoint to {output_dir}")
799
+ # Save a trained model and configuration using `save_pretrained()`.
800
+ # They can then be reloaded using `from_pretrained()`
801
+ self.model.save_to(save_path=os.path.join(output_dir, model_args.model_name_or_path + ".whisper"))
802
+ # Good practice: save your training arguments together with the trained model
803
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
804
+
805
+ # Define data collator
806
+ eos = tokenizer.eos_token_id
807
+ t_stamp = tokenizer("<|notimestamps|>").input_ids[0]
808
+ whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=eos, time_stamp_token_id=t_stamp)
809
+
810
+ # make sure model uses 50257 as BOS
811
+ bos = tokenizer("<|startoftranscript|>").input_ids[0]
812
+ model.config.decoder_start_token_id = bos
813
+
814
+ # Initialize Trainer
815
+ trainer = WhisperTrainer(
816
+ model=model,
817
+ args=training_args,
818
+ compute_metrics=compute_metrics,
819
+ train_dataset=vectorized_datasets['train'] if training_args.do_train else None,
820
+ eval_dataset=vectorized_datasets['eval'] if training_args.do_eval else None,
821
+ data_collator=whisper_data_collator,
822
+ )
823
+
824
+ # 8. Finally, we can start training
825
+
826
+ # Training
827
+ if training_args.do_train:
828
+
829
+ # use last checkpoint if exist
830
+ if last_checkpoint is not None:
831
+ checkpoint = last_checkpoint
832
+ elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
833
+ checkpoint = model_args.model_name_or_path
834
+ else:
835
+ checkpoint = None
836
+
837
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
838
+ trainer.save_model()
839
+
840
+ metrics = train_result.metrics
841
+ max_train_samples = (
842
+ data_args.max_train_samples
843
+ if data_args.max_train_samples is not None
844
+ else len(vectorized_datasets["train"])
845
+ )
846
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
847
+
848
+ trainer.log_metrics("train", metrics)
849
+ trainer.save_metrics("train", metrics)
850
+ trainer.save_state()
851
+
852
+ # Change decoding strategy for final eval/predict
853
+ # if training_args.do_eval or training_args.do_predict:
854
+ # trainer.model.num_beams = 2
855
+
856
+ trainer.compute_metrics = compute_metrics_and_predictions
857
+
858
+ results = {}
859
+ if training_args.do_eval:
860
+ if not training_args.do_train and report_to_wandb:
861
+ # manually init wandb
862
+ wandb.init(project=data_args.wandb_project, name=training_args.run_name)
863
+ # Have to run this as a predict step, otherwise trainer will try to log the pred/label strings to wandb
864
+ eval_results = trainer.predict(vectorized_datasets["eval"], metric_key_prefix="eval", num_beams=model_args.num_beams, length_penalty=model_args.length_penalty)
865
+ metrics = eval_results.metrics
866
+ max_eval_samples = (
867
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
868
+ )
869
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
870
+ pred_str = metrics.pop("eval_pred_str", None)
871
+ label_str = metrics.pop("eval_label_str", None)
872
+
873
+ trainer.log_metrics("eval", metrics)
874
+ trainer.save_metrics("eval", metrics)
875
+
876
+ if report_to_wandb:
877
+ metrics = {os.path.join("eval", k[len("eval") + 1:]): v for k, v in metrics.items()}
878
+ wandb.log(metrics)
879
+ write_wandb_pred(pred_str, label_str, prefix="eval")
880
+
881
+ if training_args.do_predict:
882
+ if not training_args.do_train and not training_args.do_eval and report_to_wandb:
883
+ # manually init wandb
884
+ wandb.init(project=data_args.wandb_project, name=training_args.run_name)
885
+ for split in test_split:
886
+ predict_results = trainer.predict(
887
+ vectorized_datasets[split], metric_key_prefix=split, num_beams=model_args.num_beams, length_penalty=model_args.length_penalty)
888
+ metrics = predict_results.metrics
889
+ max_predict_samples = (
890
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(vectorized_datasets[split])
891
+ )
892
+ metrics[f"{split}_samples"] = min(max_predict_samples, len(vectorized_datasets[split]))
893
+ pred_str = metrics.pop(f"{split}_pred_str", None)
894
+ label_str = metrics.pop(f"{split}_label_str", None)
895
+
896
+ trainer.log_metrics(split, metrics)
897
+ trainer.save_metrics(split, metrics)
898
+
899
+ if report_to_wandb:
900
+ metrics = {os.path.join(split, k[len(split)+1:]): v for k, v in metrics.items()}
901
+ wandb.log(metrics)
902
+ write_wandb_pred(pred_str, label_str, prefix=split)
903
+
904
+ # Write model card and (optionally) push to hub
905
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
906
+ kwargs = {
907
+ "finetuned_from": model_args.model_name_or_path,
908
+ "tasks": "speech-recognition",
909
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
910
+ "dataset_args": (
911
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
912
+ f" {data_args.eval_split_name}"
913
+ ),
914
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
915
+ }
916
+ if "common_voice" in data_args.dataset_name:
917
+ kwargs["language"] = config_name
918
+
919
+ if training_args.push_to_hub:
920
+ trainer.push_to_hub(**kwargs)
921
+
922
+ return results
923
+
924
+
925
+ if __name__ == "__main__":
926
+ main()
run_switchboard.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_whisper.py \
3
+ --model_name_or_path="medium.en" \
4
+ --dataset_name="ldc/switchboard" \
5
+ --dataset_config_name="all" \
6
+ --train_split_name="train.switchboard" \
7
+ --eval_split_name="validation.switchboard" \
8
+ --test_split_name="test.switchboard+test.callhome" \
9
+ --text_column_name="text" \
10
+ --max_steps="5000" \
11
+ --output_dir="./" \
12
+ --run_name="whisper-switchboard-5k" \
13
+ --wandb_project="whisper" \
14
+ --per_device_train_batch_size="64" \
15
+ --per_device_eval_batch_size="16" \
16
+ --logging_steps="25" \
17
+ --learning_rate="1e-4" \
18
+ --warmup_steps="500" \
19
+ --report_to="wandb" \
20
+ --preprocessing_num_workers="16" \
21
+ --evaluation_strategy="steps" \
22
+ --eval_steps="1000" \
23
+ --save_strategy="steps" \
24
+ --save_steps="1000" \
25
+ --generation_max_length="224" \
26
+ --length_column_name="input_lengths" \
27
+ --do_lower_case="True" \
28
+ --push_to_hub="False" \
29
+ --gradient_checkpointing \
30
+ --group_by_length \
31
+ --freeze_encoder \
32
+ --fp16 \
33
+ --overwrite_output_dir \
34
+ --do_train \
35
+ --do_eval \
36
+ --do_predict \
37
+ --predict_with_generate \
38
+ --use_auth_token \
39
+ --torchaudio_resampler