supawichwac commited on
Commit
35910ca
1 Parent(s): 55e4ec6

Saving train state of step 5000

Browse files
.ipynb_checkpoints/run_distillation-checkpoint.py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training the Whisper model for sequence to sequence speech recognition via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import re
24
+ import shutil
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ import datasets
33
+ import evaluate
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ DatasetDict,
42
+ IterableDataset,
43
+ IterableDatasetDict,
44
+ concatenate_datasets,
45
+ interleave_datasets,
46
+ load_dataset,
47
+ )
48
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
49
+ from torch.utils.data import DataLoader
50
+ from tqdm import tqdm
51
+ from transformers import (
52
+ AddedToken,
53
+ HfArgumentParser,
54
+ Seq2SeqTrainingArguments,
55
+ WhisperConfig,
56
+ WhisperFeatureExtractor,
57
+ WhisperForConditionalGeneration,
58
+ WhisperProcessor,
59
+ WhisperTokenizerFast,
60
+ get_scheduler,
61
+ set_seed,
62
+ )
63
+ from transformers.modeling_outputs import BaseModelOutput
64
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ feature_extractor_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None,
103
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
104
+ )
105
+ use_fast_tokenizer: bool = field(
106
+ default=True,
107
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
+ )
109
+ model_revision: str = field(
110
+ default="main",
111
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112
+ )
113
+ subfolder: str = field(
114
+ default="",
115
+ metadata={
116
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
117
+ "specify the folder name here."
118
+ },
119
+ )
120
+ token: str = field(
121
+ default=None,
122
+ metadata={
123
+ "help": (
124
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
125
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
126
+ )
127
+ },
128
+ )
129
+ attn_implementation: Optional[str] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": (
133
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
134
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
135
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
136
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
137
+ )
138
+ },
139
+ )
140
+
141
+ def __post_init__(self):
142
+ if self.attn_implementation not in [None, "eager", "sdpa", "flash_attention_2"]:
143
+ raise ValueError(
144
+ f"Got `--attn_implementation={self.attn_implementation}`, which is an invalid attention type. Should be one of:\n"
145
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
146
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
147
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
148
+ )
149
+
150
+
151
+ @dataclass
152
+ class DataTrainingArguments:
153
+ """
154
+ Arguments pertaining to what data we are going to input our model for training and eval.
155
+ """
156
+
157
+ train_dataset_name: str = field(
158
+ default=None,
159
+ metadata={
160
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
161
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
162
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
163
+ },
164
+ )
165
+ train_dataset_config_name: Optional[str] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
169
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
170
+ "match the order of the datasets."
171
+ },
172
+ )
173
+ train_dataset_samples: str = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
177
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
178
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
179
+ "sample from every dataset is used once per epoch."
180
+ },
181
+ )
182
+ eval_dataset_name: str = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
186
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
187
+ "ids by a '+' symbol."
188
+ },
189
+ )
190
+ eval_dataset_config_name: Optional[str] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
194
+ "training dataset config name if unspecified."
195
+ },
196
+ )
197
+ dataset_cache_dir: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
200
+ )
201
+ overwrite_cache: bool = field(
202
+ default=False,
203
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
204
+ )
205
+ preprocessing_num_workers: Optional[int] = field(
206
+ default=None,
207
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
208
+ )
209
+ preprocessing_batch_size: Optional[int] = field(
210
+ default=256,
211
+ metadata={"help": "Number of examples per batch provided to the `prepare_dataset` function."},
212
+ )
213
+ max_train_samples: Optional[int] = field(
214
+ default=None,
215
+ metadata={
216
+ "help": (
217
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
218
+ )
219
+ },
220
+ )
221
+ max_eval_samples: Optional[int] = field(
222
+ default=None,
223
+ metadata={
224
+ "help": (
225
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
226
+ )
227
+ },
228
+ )
229
+ audio_column_name: str = field(
230
+ default="audio",
231
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
232
+ )
233
+ text_column_name: str = field(
234
+ default=None,
235
+ metadata={"help": "The name of the dataset column containing the text data in the training set."},
236
+ )
237
+ eval_text_column_name: str = field(
238
+ default="text",
239
+ metadata={"help": ("The name of the dataset column containing the text data in the evaluation set.")},
240
+ )
241
+ max_duration_in_seconds: float = field(
242
+ default=30.0,
243
+ metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
244
+ )
245
+ min_duration_in_seconds: float = field(
246
+ default=0.0,
247
+ metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
248
+ )
249
+ max_label_length: int = field(
250
+ default=448,
251
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
252
+ )
253
+ pad_target_to_multiple_of: Optional[int] = field(
254
+ default=None,
255
+ metadata={
256
+ "help": (
257
+ "If set will pad the target sequence to a multiple of the provided"
258
+ " value. This is important to avoid triggering recompilations on TPU."
259
+ " If unspecified, will default to padding the targets to max length."
260
+ )
261
+ },
262
+ )
263
+ preprocessing_only: bool = field(
264
+ default=False,
265
+ metadata={
266
+ "help": (
267
+ "Whether to only do data preprocessing and skip training. This is"
268
+ " especially useful when data preprocessing errors out in distributed"
269
+ " training due to timeout. In this case, one should run the"
270
+ " preprocessing in a non-distributed setup with"
271
+ " `preprocessing_only=True` so that the cached datasets can"
272
+ " consequently be loaded in distributed training"
273
+ )
274
+ },
275
+ )
276
+ train_split_name: str = field(
277
+ default="train",
278
+ metadata={
279
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
280
+ },
281
+ )
282
+ eval_split_name: str = field(
283
+ default="validation",
284
+ metadata={
285
+ "help": (
286
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
287
+ )
288
+ },
289
+ )
290
+ streaming: bool = field(
291
+ default=True,
292
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
293
+ )
294
+ wer_threshold: float = field(
295
+ default=None,
296
+ metadata={
297
+ "help": "Filter training data with Whisper transcriptions that have greater than `wer_threshold` "
298
+ "WER with the normalised transcriptions. This only takes effect if training on pseudo-labels targets."
299
+ "If `--use_pseudo_labels=False`, then no WER filtering is performed, since we train directly on the text"
300
+ "transcriptions."
301
+ },
302
+ )
303
+ use_pseudo_labels: bool = field(
304
+ default=True,
305
+ metadata={
306
+ "help": "Whether or not to use pseudo-label transcriptions as the targets. If True, the pseudo-labels "
307
+ "must be in the dataset column `whisper_transcript` from the previous pseudo-labelling step. This is "
308
+ "not currently yet configurable."
309
+ },
310
+ )
311
+ timestamp_probability: float = field(
312
+ default=0.2, metadata={"help": "Probability for training on timestamped tokens if the data contains it."}
313
+ )
314
+ condition_on_prev_probability: float = field(
315
+ default=0.2, metadata={"help": "Probability for conditioning on the previous text example."}
316
+ )
317
+ return_timestamps: bool = field(
318
+ default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
319
+ )
320
+ language: str = field(
321
+ default=None,
322
+ metadata={
323
+ "help": (
324
+ "Language for multilingual distillation. This argument should be set for multilingual distillation "
325
+ "only. For English speech recognition, it should be left as `None`."
326
+ )
327
+ },
328
+ )
329
+ task: str = field(
330
+ default="transcribe",
331
+ metadata={
332
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
333
+ "This argument should be set for multilingual distillation only. For English speech recognition, it should be left as `None`."
334
+ },
335
+ )
336
+ wandb_project: str = field(
337
+ default="distil-whisper",
338
+ metadata={"help": "The name of the wandb project."},
339
+ )
340
+
341
+
342
+ @dataclass
343
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
344
+ freeze_encoder: Optional[bool] = field(
345
+ default=False,
346
+ metadata={
347
+ "help": (
348
+ "Whether to freeze the entire encoder model. Only recommended when the entire encoder has been "
349
+ "copied from the teacher model."
350
+ )
351
+ },
352
+ )
353
+ freeze_embed_positions: Optional[bool] = field(
354
+ default=False,
355
+ metadata={"help": "Whether to freeze the decoder embedding positions."},
356
+ )
357
+ temperature: Optional[float] = field(
358
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
359
+ )
360
+ kl_weight: Optional[float] = field(
361
+ default=1.0,
362
+ metadata={
363
+ "help": (
364
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
365
+ "computed between the teacher-student hidden states and attentions."
366
+ )
367
+ },
368
+ )
369
+ dtype: Optional[str] = field(
370
+ default="float32",
371
+ metadata={
372
+ "help": (
373
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
374
+ "`float16` or `bfloat16` (both half-precision)."
375
+ )
376
+ },
377
+ )
378
+
379
+
380
+ @dataclass
381
+ class DataCollatorSpeechSeq2SeqWithPadding:
382
+ """
383
+ Data collator that will dynamically pad the inputs received.
384
+ Args:
385
+ processor ([`Wav2Vec2Processor`])
386
+ The processor used for proccessing the data.
387
+ decoder_start_token_id (:obj: `int`)
388
+ The start-of-sequence token id of the decoder.
389
+ decoder_prev_token_id (:obj: `int`)
390
+ The start-of-prompt token id of the decoder
391
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
392
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
393
+ among:
394
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
395
+ sequence if provided).
396
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
397
+ maximum acceptable input length for the model if that argument is not provided.
398
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
399
+ different lengths).
400
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
401
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
402
+ See above for details.
403
+ max_target_length (:obj:`int`, `optional`):
404
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
405
+ """
406
+
407
+ processor: Any
408
+ decoder_start_token_id: int
409
+ decoder_prev_token_id: int
410
+ input_padding: Union[bool, str] = "max_length"
411
+ target_padding: Union[bool, str] = "max_length"
412
+ max_target_length: Optional[int] = None
413
+
414
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
415
+ # split inputs and labels since they have to be of different lengths and need
416
+ # different padding methods
417
+
418
+ # dataloader returns a list of features which we convert to a dict
419
+ input_features = {"input_features": [feature["input_features"] for feature in features]}
420
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
421
+
422
+ # reformat list to dict and set to pytorch format
423
+ batch = self.processor.feature_extractor.pad(
424
+ input_features,
425
+ padding=self.input_padding,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ labels_batch = self.processor.tokenizer.pad(
430
+ label_features,
431
+ max_length=self.max_target_length,
432
+ padding=self.target_padding,
433
+ return_tensors="pt",
434
+ )
435
+
436
+ # shift labels to the right to get decoder input ids
437
+ labels = labels_batch["input_ids"]
438
+ decoder_input_ids = labels[:, :-1]
439
+ labels = labels[:, 1:]
440
+ labels_mask = labels_batch.attention_mask[:, 1:]
441
+
442
+ # replace padding with -100 to ignore correctly when computing the loss
443
+ labels = labels.masked_fill(labels_mask.ne(1), -100)
444
+
445
+ # replace initial prompt tokens with -100 to ignore correctly when computing the loss
446
+ bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
447
+ bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
448
+ prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
449
+ labels = torch.where(prompt_mask, -100, labels)
450
+
451
+ batch["labels"] = labels
452
+ batch["decoder_input_ids"] = decoder_input_ids
453
+
454
+ return batch
455
+
456
+
457
+ def log_metric(
458
+ accelerator,
459
+ metrics: Dict,
460
+ train_time: float,
461
+ step: int,
462
+ epoch: int,
463
+ learning_rate: float = None,
464
+ prefix: str = "train",
465
+ ):
466
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
467
+ log_metrics = {}
468
+ for k, v in metrics.items():
469
+ log_metrics[f"{prefix}/{k}"] = v
470
+ log_metrics[f"{prefix}/time"] = train_time
471
+ log_metrics[f"{prefix}/epoch"] = epoch
472
+ if learning_rate is not None:
473
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
474
+ accelerator.log(log_metrics, step=step)
475
+
476
+
477
+ def log_pred(
478
+ accelerator,
479
+ pred_str: List[str],
480
+ label_str: List[str],
481
+ norm_pred_str: List[str],
482
+ norm_label_str: List[str],
483
+ step: int,
484
+ prefix: str = "eval",
485
+ num_lines: int = 200000,
486
+ ):
487
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
488
+ if accelerator.is_main_process:
489
+ wandb_tracker = accelerator.get_tracker("wandb")
490
+ # pretty name for current step: step 50000 -> step 50k
491
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
492
+ prefix_pretty = prefix.replace("/", "-")
493
+
494
+ # convert str data to a wandb compatible format
495
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
496
+ # log as a table with the appropriate headers
497
+ wandb_tracker.log_table(
498
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
499
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
500
+ data=str_data[:num_lines],
501
+ step=step,
502
+ )
503
+
504
+ # log incorrect normalised predictions
505
+ str_data = np.asarray(str_data)
506
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
507
+ # log as a table with the appropriate headers
508
+ wandb_tracker.log_table(
509
+ table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
510
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
511
+ data=str_data_incorrect[:num_lines],
512
+ step=step,
513
+ )
514
+
515
+
516
+ def convert_dataset_str_to_list(
517
+ dataset_names,
518
+ dataset_config_names,
519
+ splits=None,
520
+ text_column_names=None,
521
+ dataset_samples=None,
522
+ default_split="train",
523
+ ) -> List[Dict]:
524
+ """
525
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
526
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
527
+ function returns a list of dictionaries, one for each dataset.
528
+ """
529
+ if isinstance(dataset_names, str):
530
+ dataset_names = dataset_names.split("+")
531
+ dataset_config_names = dataset_config_names.split("+") if dataset_config_names is not None else None
532
+ splits = splits.split("+") if splits is not None else None
533
+ text_column_names = text_column_names.split("+") if text_column_names is not None else None
534
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
535
+
536
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
537
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
538
+ raise ValueError(
539
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
540
+ f" {len(dataset_config_names)} configs."
541
+ )
542
+
543
+ if splits is not None and len(splits) != len(dataset_names):
544
+ raise ValueError(
545
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
546
+ )
547
+
548
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
549
+ raise ValueError(
550
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
551
+ f" {len(text_column_names)} text column names."
552
+ )
553
+
554
+ if dataset_samples is not None:
555
+ if len(dataset_samples) != len(dataset_names):
556
+ raise ValueError(
557
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
558
+ f"{len(dataset_samples)} samples."
559
+ )
560
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
561
+ else:
562
+ dataset_samples = [None] * len(dataset_names)
563
+
564
+ dataset_config_names = (
565
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
566
+ )
567
+ text_column_names = (
568
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
569
+ )
570
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
571
+
572
+ dataset_names_dict = []
573
+ for i, ds_name in enumerate(dataset_names):
574
+ dataset_names_dict.append(
575
+ {
576
+ "name": ds_name,
577
+ "config": dataset_config_names[i],
578
+ "split": splits[i],
579
+ "text_column_name": text_column_names[i],
580
+ "samples": dataset_samples[i],
581
+ }
582
+ )
583
+ return dataset_names_dict
584
+
585
+
586
+ def load_multiple_datasets(
587
+ dataset_names: Union[List, str],
588
+ dataset_config_names: Union[List, str],
589
+ splits: Optional[Union[List, str]] = None,
590
+ text_column_names: Optional[List] = None,
591
+ sampling_rate: Optional[int] = 16000,
592
+ stopping_strategy: Optional[str] = "first_exhausted",
593
+ dataset_samples: Optional[Union[List, np.array]] = None,
594
+ streaming: Optional[bool] = True,
595
+ seed: Optional[int] = None,
596
+ accelerator: Optional[Accelerator] = None,
597
+ use_pseudo_labels: float = None,
598
+ **kwargs,
599
+ ) -> IterableDataset:
600
+ dataset_names_dict = convert_dataset_str_to_list(
601
+ dataset_names, dataset_config_names, splits, text_column_names, dataset_samples
602
+ )
603
+
604
+ if dataset_samples is not None:
605
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
606
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
607
+ else:
608
+ probabilities = None
609
+
610
+ all_datasets = []
611
+ # iterate over the datasets we want to interleave
612
+ for dataset_dict in tqdm(
613
+ dataset_names_dict,
614
+ desc="Combining datasets...",
615
+ disable=not accelerator.is_local_main_process if accelerator is not None else False,
616
+ ):
617
+ dataset = load_dataset(
618
+ dataset_dict["name"],
619
+ dataset_dict["config"],
620
+ split=dataset_dict["split"],
621
+ streaming=streaming,
622
+ **kwargs,
623
+ )
624
+ # resample to specified sampling rate
625
+ dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
626
+ dataset_features = dataset.features.keys()
627
+ columns_to_keep = {"audio", "text"}
628
+
629
+ if dataset_dict["text_column_name"] not in dataset_features:
630
+ raise ValueError(
631
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
632
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
633
+ f" correct text column - one of {', '.join(dataset_features)}."
634
+ )
635
+
636
+ # blanket renaming of all transcription columns to text
637
+ if dataset_dict["text_column_name"] != "text":
638
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
639
+
640
+ if use_pseudo_labels:
641
+ if "whisper_transcript" not in dataset_features:
642
+ raise ValueError(
643
+ f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure"
644
+ "pseudo-labels are present in the dataset under this column name, or train directly on the text "
645
+ "labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`."
646
+ )
647
+ columns_to_keep.add("whisper_transcript")
648
+
649
+ if "condition_on_prev" in dataset_features:
650
+ columns_to_keep.add("condition_on_prev")
651
+
652
+ dataset_features = dataset.features.keys()
653
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
654
+ all_datasets.append(dataset)
655
+
656
+ if len(all_datasets) == 1:
657
+ # we have a single dataset so just return it as is
658
+ return all_datasets[0]
659
+
660
+ if streaming:
661
+ interleaved_dataset = interleave_datasets(
662
+ all_datasets,
663
+ stopping_strategy=stopping_strategy,
664
+ probabilities=probabilities,
665
+ seed=seed,
666
+ )
667
+ else:
668
+ interleaved_dataset = concatenate_datasets(all_datasets)
669
+
670
+ return interleaved_dataset
671
+
672
+
673
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
674
+ """Helper function to sort saved checkpoints from oldest to newest."""
675
+ ordering_and_checkpoint_path = []
676
+
677
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
678
+
679
+ for path in glob_checkpoints:
680
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
681
+ if regex_match is not None and regex_match.groups() is not None:
682
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
683
+
684
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
685
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
686
+ return checkpoints_sorted
687
+
688
+
689
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
690
+ """Helper function to delete old checkpoints."""
691
+ if save_total_limit is None or save_total_limit <= 0:
692
+ return
693
+ # Check if we should delete older checkpoint(s)
694
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
695
+ if len(checkpoints_sorted) <= save_total_limit:
696
+ return
697
+
698
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
699
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
700
+ for checkpoint in checkpoints_to_be_deleted:
701
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
702
+ shutil.rmtree(checkpoint, ignore_errors=True)
703
+
704
+
705
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
706
+
707
+
708
+ def get_last_checkpoint(folder):
709
+ content = os.listdir(folder)
710
+ checkpoints = [
711
+ path
712
+ for path in content
713
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
714
+ ]
715
+ if len(checkpoints) == 0:
716
+ return
717
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
718
+
719
+
720
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
721
+ """
722
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
723
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
724
+ (e.g. if the module is frozen).
725
+ """
726
+ result = []
727
+ for name, child in model.named_children():
728
+ result += [
729
+ f"{name}.{n}"
730
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
731
+ if not (
732
+ isinstance(child, tuple(forbidden_layer_types))
733
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
734
+ )
735
+ ]
736
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
737
+ result += list(model._parameters.keys())
738
+ return result
739
+
740
+
741
+ def main():
742
+ # 1. Parse input arguments
743
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
744
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
745
+
746
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
747
+ # If we pass only one argument to the script and it's the path to a json file,
748
+ # let's parse it to get our arguments.
749
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
750
+ else:
751
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
752
+
753
+
754
+
755
+ # 2. Initialize the accelerator
756
+ # We will let the accelerator handle device placement for us in this example
757
+ # We simply have to specify the training precision and any trackers being used
758
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
759
+ # it to accelerate format
760
+
761
+ if training_args.dtype == "float16":
762
+ mixed_precision = "fp16"
763
+ teacher_dtype = torch.float16
764
+ elif training_args.dtype == "bfloat16":
765
+ mixed_precision = "bf16"
766
+ teacher_dtype = torch.bfloat16
767
+ else:
768
+ mixed_precision = "no"
769
+ teacher_dtype = torch.float32
770
+
771
+ accelerator = Accelerator(
772
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
773
+ mixed_precision=mixed_precision,
774
+ log_with=training_args.report_to,
775
+ project_dir=training_args.output_dir,
776
+ )
777
+
778
+ accelerator.init_trackers(project_name=data_args.wandb_project)
779
+
780
+ # 3. Set-up basic logging
781
+ # Create one log on every process with the configuration for debugging
782
+ logging.basicConfig(
783
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
784
+ datefmt="%m/%d/%Y %H:%M:%S",
785
+ level=logging.INFO,
786
+ )
787
+ # Log a small summary on each proces
788
+ logger.warning(
789
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
790
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
791
+ )
792
+
793
+ # Set the verbosity to info of the Transformers logger (on main process only)
794
+ if accelerator.is_local_main_process:
795
+ datasets.utils.logging.set_verbosity_warning()
796
+ transformers.utils.logging.set_verbosity_info()
797
+ else:
798
+ datasets.utils.logging.set_verbosity_error()
799
+ transformers.utils.logging.set_verbosity_error()
800
+ logger.info("Training/evaluation parameters %s", training_args)
801
+
802
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
803
+ last_checkpoint = None
804
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
805
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
806
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
807
+ raise ValueError(
808
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
809
+ "Use --overwrite_output_dir to overcome."
810
+ )
811
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
812
+ logger.info(
813
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
814
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
815
+ )
816
+
817
+ # 5. Handle the repository creation
818
+ if accelerator.is_main_process:
819
+ if training_args.push_to_hub:
820
+ if training_args.hub_model_id is None:
821
+ repo_name = get_full_repo_name(
822
+ Path(training_args.output_dir).absolute().name,
823
+ token=training_args.hub_token,
824
+ )
825
+ else:
826
+ repo_name = training_args.hub_model_id
827
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
828
+
829
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
830
+ if "wandb" not in gitignore:
831
+ gitignore.write("wandb\n")
832
+ elif training_args.output_dir is not None:
833
+ os.makedirs(training_args.output_dir, exist_ok=True)
834
+ accelerator.wait_for_everyone()
835
+
836
+ # 6. Load dataset - either streaming or non-streaming (offline)
837
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
838
+
839
+ # set seed for determinism
840
+ set_seed(training_args.seed)
841
+
842
+ if training_args.do_train:
843
+ raw_datasets["train"] = load_multiple_datasets(
844
+ data_args.train_dataset_name,
845
+ data_args.train_dataset_config_name,
846
+ splits=data_args.train_split_name,
847
+ text_column_names=data_args.text_column_name,
848
+ use_pseudo_labels=data_args.use_pseudo_labels,
849
+ streaming=data_args.streaming,
850
+ dataset_samples=data_args.train_dataset_samples,
851
+ seed=training_args.seed,
852
+ accelerator=accelerator,
853
+ cache_dir=data_args.dataset_cache_dir,
854
+ token=model_args.token,
855
+ )
856
+ raw_datasets_train_features = list(raw_datasets["train"].features.keys())
857
+
858
+ if training_args.do_eval:
859
+ dataset_names_dict = convert_dataset_str_to_list(
860
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
861
+ (
862
+ data_args.eval_dataset_config_name
863
+ if data_args.eval_dataset_config_name
864
+ else data_args.train_dataset_config_name
865
+ ),
866
+ splits=data_args.eval_split_name,
867
+ text_column_names=data_args.eval_text_column_name,
868
+ )
869
+ all_eval_splits = []
870
+ if len(dataset_names_dict) == 1:
871
+ # load a single eval set
872
+ dataset_dict = dataset_names_dict[0]
873
+ all_eval_splits.append("eval")
874
+ raw_datasets["eval"] = load_dataset(
875
+ dataset_dict["name"],
876
+ dataset_dict["config"],
877
+ split=dataset_dict["split"],
878
+ cache_dir=data_args.dataset_cache_dir,
879
+ token=model_args.token,
880
+ streaming=data_args.streaming,
881
+ )
882
+ if data_args.eval_text_column_name != "text":
883
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
884
+ else:
885
+ # load multiple eval sets
886
+ for dataset_dict in dataset_names_dict:
887
+ if dataset_dict["name"] == "esb/diagnostic-dataset":
888
+ # for the ESB diagnostic dataset, the dataset name is effectively the config
889
+ pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
890
+ else:
891
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
892
+ all_eval_splits.append(pretty_name)
893
+ raw_datasets[pretty_name] = load_dataset(
894
+ dataset_dict["name"],
895
+ dataset_dict["config"],
896
+ split=dataset_dict["split"],
897
+ cache_dir=data_args.dataset_cache_dir,
898
+ token=model_args.token,
899
+ streaming=data_args.streaming,
900
+ )
901
+ # make column names consistent (text, audio)
902
+ if dataset_dict["text_column_name"] != "text":
903
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
904
+ dataset_dict["text_column_name"], "text"
905
+ )
906
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
907
+ set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
908
+ )
909
+
910
+ if not training_args.do_train and not training_args.do_eval:
911
+ raise ValueError(
912
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
913
+ )
914
+
915
+ # 7. Load pretrained model, tokenizer, and feature extractor
916
+ config = WhisperConfig.from_pretrained(
917
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
918
+ cache_dir=model_args.cache_dir,
919
+ revision=model_args.model_revision,
920
+ token=model_args.token,
921
+ )
922
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
923
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
924
+ cache_dir=model_args.cache_dir,
925
+ revision=model_args.model_revision,
926
+ token=model_args.token,
927
+ )
928
+ tokenizer = WhisperTokenizerFast.from_pretrained(
929
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
930
+ cache_dir=model_args.cache_dir,
931
+ use_fast=model_args.use_fast_tokenizer,
932
+ revision=model_args.model_revision,
933
+ token=model_args.token,
934
+ )
935
+
936
+ # override timestamp tokens until tokenizer issues are fixed in transformers
937
+ timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
938
+ tokenizer.add_tokens(timestamps)
939
+
940
+ # The teacher model can safely be cast to the dtype of training since we don't
941
+ # update the params
942
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
943
+ model_args.teacher_model_name_or_path,
944
+ cache_dir=model_args.cache_dir,
945
+ token=model_args.token,
946
+ low_cpu_mem_usage=True,
947
+ torch_dtype=teacher_dtype,
948
+ attn_implementation=model_args.attn_implementation,
949
+ )
950
+
951
+ student_model = WhisperForConditionalGeneration.from_pretrained(
952
+ model_args.model_name_or_path,
953
+ config=config,
954
+ cache_dir=model_args.cache_dir,
955
+ revision=model_args.model_revision,
956
+ subfolder=model_args.subfolder,
957
+ token=model_args.token,
958
+ low_cpu_mem_usage=True,
959
+ attn_implementation=model_args.attn_implementation,
960
+ )
961
+
962
+ if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
963
+ raise ValueError(
964
+ f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
965
+ f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
966
+ f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
967
+ )
968
+
969
+ # enable gradient checkpointing if necessary
970
+ if training_args.gradient_checkpointing:
971
+ student_model.gradient_checkpointing_enable()
972
+
973
+ def set_trainable_parameters(module, requires_grad=False):
974
+ for param in module.parameters():
975
+ param.requires_grad = requires_grad
976
+ module._requires_grad = requires_grad
977
+
978
+ # freeze student encoder if necessary
979
+ if training_args.freeze_encoder:
980
+ set_trainable_parameters(student_model.model.encoder, requires_grad=False)
981
+ student_model.model.encoder.gradient_checkpointing = False
982
+
983
+ if training_args.freeze_embed_positions:
984
+ # set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False)
985
+ set_trainable_parameters(student_model.model.decoder.embed_positions, requires_grad=False)
986
+ if student_model.model.decoder.gradient_checkpointing:
987
+ logger.info(
988
+ "Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`."
989
+ )
990
+
991
+ share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
992
+ if share_hidden_states:
993
+ # tie the weights for the teacher encoder if we're freezing the student and it's the same as the teacher
994
+ teacher_model.model.encoder = student_model.model.encoder
995
+
996
+ if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
997
+ # We need to set the language and task ids for previously multilingual checkpoints
998
+ is_multilingual = True
999
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task, predict_timestamps=False)
1000
+ student_model.generation_config.update(
1001
+ **{
1002
+ "language": data_args.language,
1003
+ "task": data_args.task,
1004
+ }
1005
+ )
1006
+ elif data_args.language is not None:
1007
+ raise ValueError(
1008
+ "Setting language token for an English-only checkpoint is not permitted. The language argument should "
1009
+ "only be set for multilingual checkpoints."
1010
+ )
1011
+ else:
1012
+ is_multilingual = False
1013
+
1014
+ # 8. Create a single speech processor - make sure all processes wait until data is saved
1015
+ if accelerator.is_main_process:
1016
+ feature_extractor.save_pretrained(training_args.output_dir)
1017
+ tokenizer.save_pretrained(training_args.output_dir)
1018
+ # save the config and generation config as well
1019
+ config.save_pretrained(training_args.output_dir)
1020
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1021
+
1022
+ accelerator.wait_for_everyone()
1023
+ processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1024
+
1025
+ # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1026
+ # so we just need to set the correct target sampling rate.
1027
+ sampling_rate = feature_extractor.sampling_rate
1028
+ raw_datasets = raw_datasets.cast_column(
1029
+ data_args.audio_column_name,
1030
+ datasets.features.Audio(sampling_rate=sampling_rate),
1031
+ )
1032
+
1033
+ # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1034
+ # 10.1: Define the pre-processing constants
1035
+ max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1036
+ min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1037
+ max_label_length = (
1038
+ data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1039
+ )
1040
+
1041
+ timestamp_probability = data_args.timestamp_probability
1042
+ condition_on_prev_probability = data_args.condition_on_prev_probability
1043
+ return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1044
+
1045
+ timestamp_ids = tokenizer.timestamp_ids()
1046
+ timestamp_begin = tokenizer.all_special_ids[-1]
1047
+ timestamp_position = 3 if is_multilingual else 1
1048
+
1049
+ decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1050
+ decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1051
+ prompt_cutoff_length = max_label_length // 2
1052
+
1053
+ num_workers = data_args.preprocessing_num_workers
1054
+ dataloader_num_workers = training_args.dataloader_num_workers
1055
+ prefetch_factor = training_args.dataloader_prefetch_factor
1056
+
1057
+ metric = evaluate.load("wer")
1058
+ normalizer = (
1059
+ BasicTextNormalizer()
1060
+ if data_args.language is not None
1061
+ else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1062
+ )
1063
+ wer_threshold = data_args.wer_threshold
1064
+ use_pseudo_labels = data_args.use_pseudo_labels
1065
+ train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1066
+
1067
+ # 10.2: filter based on maximum number of training/evaluation samples
1068
+ if training_args.do_train and data_args.max_train_samples is not None:
1069
+ raw_datasets["train"] = (
1070
+ raw_datasets["train"].take(data_args.max_train_samples)
1071
+ if data_args.streaming
1072
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1073
+ )
1074
+
1075
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1076
+ for eval_split in all_eval_splits:
1077
+ raw_datasets[eval_split] = (
1078
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1079
+ if data_args.streaming
1080
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1081
+ )
1082
+
1083
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1084
+ def is_wer_in_range(ground_truth, whisper_transcript):
1085
+ norm_ground_truth = normalizer(ground_truth)
1086
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1087
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1088
+ return False
1089
+ elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
1090
+ norm_whisper_transcript = normalizer(whisper_transcript)
1091
+ wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1092
+ return wer < wer_threshold
1093
+ else:
1094
+ # filter automatically since we can't know the WER
1095
+ return False
1096
+
1097
+ filter_by_wer_threshold = partial(
1098
+ raw_datasets["train"].filter,
1099
+ function=is_wer_in_range,
1100
+ input_columns=["text", "whisper_transcript"],
1101
+ )
1102
+
1103
+ if wer_threshold is not None and use_pseudo_labels:
1104
+ with accelerator.main_process_first():
1105
+ raw_datasets["train"] = (
1106
+ filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1107
+ if not data_args.streaming
1108
+ else filter_by_wer_threshold()
1109
+ )
1110
+
1111
+ # 10.4: pre-process training/evaluation datasets
1112
+ def prepare_train_dataset(batch):
1113
+ """
1114
+ Pre-process the raw dataset in a three stage process:
1115
+ 1. Convert the audio arrays to log-mel spectrogram inputs
1116
+ 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1117
+ 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1118
+ """
1119
+ # process audio input
1120
+ audio = [sample["array"] for sample in batch["audio"]]
1121
+ inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1122
+ batch["input_features"] = inputs.input_features
1123
+ batch["input_length"] = [len(sample) for sample in audio]
1124
+
1125
+ # process text targets - for training these are the Whisper-generated pseudo-labels
1126
+ input_str_batched = batch[train_text_column_name]
1127
+ condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1128
+
1129
+ all_token_ids = []
1130
+ all_token_ids_unprompted = []
1131
+ for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1132
+ token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1133
+
1134
+ # check whether we have timestamps in the PLs and filter if required
1135
+ has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1136
+ if has_timestamps:
1137
+ # sample from binomial distribution to get probability of training on timestamps
1138
+ predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1139
+ if not predict_timestamps:
1140
+ # filter timestamps and insert the <|notimestamps|> task token
1141
+ token_ids = [token for token in token_ids if token < timestamp_begin]
1142
+ token_ids.insert(timestamp_position, timestamp_begin)
1143
+
1144
+ all_token_ids_unprompted.append(token_ids)
1145
+ # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1146
+ condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1147
+ if not condition_on_prev:
1148
+ prev_ids = None
1149
+ elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1150
+ # prompt ids are the penultimate token ids in the batch
1151
+ prev_ids = all_token_ids_unprompted[-2]
1152
+
1153
+ if prev_ids is not None:
1154
+ if has_timestamps and not predict_timestamps:
1155
+ # filter timestamp ids from prompt when not predicting timestamps
1156
+ prev_ids = [token for token in prev_ids if token < timestamp_begin]
1157
+
1158
+ # check that the length of the prompt does not exceed more than half the max label length (224)
1159
+ if len(prev_ids) > prompt_cutoff_length:
1160
+ prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1161
+ prev_ids = [decoder_prev_token_id] + prev_ids
1162
+
1163
+ # and that the total length of the labels does not exceed the max label length (448)
1164
+ if len(prev_ids + token_ids) > max_label_length:
1165
+ trim_length = len(prev_ids + token_ids) - max_label_length + 1
1166
+ prev_ids = prev_ids[trim_length:]
1167
+ prev_ids = [decoder_prev_token_id] + prev_ids
1168
+
1169
+ token_ids = prev_ids + token_ids
1170
+
1171
+ all_token_ids.append(token_ids)
1172
+
1173
+ batch["labels"] = all_token_ids
1174
+ return batch
1175
+
1176
+ def prepare_eval_dataset(batch):
1177
+ # process audio input
1178
+ sample = batch["audio"]
1179
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1180
+ batch["input_features"] = inputs.input_features[0]
1181
+ batch["input_length"] = len(sample["array"])
1182
+
1183
+ # process targets - for evaluation these are the ground-truth transcriptions
1184
+ input_str = batch["text"]
1185
+ batch["labels"] = tokenizer(input_str).input_ids
1186
+ return batch
1187
+
1188
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1189
+ if training_args.do_train:
1190
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1191
+ # we can use `num_workers` (which is much faster)
1192
+ # We gate the pre-processing function accordingly
1193
+ map_fn_train = partial(
1194
+ raw_datasets["train"].map,
1195
+ function=prepare_train_dataset,
1196
+ remove_columns=raw_datasets_train_features,
1197
+ batched=True,
1198
+ batch_size=data_args.preprocessing_batch_size,
1199
+ )
1200
+ with accelerator.main_process_first():
1201
+ vectorized_datasets["train"] = (
1202
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1203
+ if not data_args.streaming
1204
+ else map_fn_train()
1205
+ )
1206
+ if training_args.do_eval:
1207
+ for eval_split in all_eval_splits:
1208
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1209
+ map_fn_eval = partial(
1210
+ raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1211
+ )
1212
+ with accelerator.main_process_first():
1213
+ vectorized_datasets[eval_split] = (
1214
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1215
+ if not data_args.streaming
1216
+ else map_fn_eval()
1217
+ )
1218
+
1219
+ # 10.5: Filter training data with inputs longer than `max_input_length`
1220
+ def is_audio_in_length_range(length):
1221
+ return min_input_length < length < max_input_length
1222
+
1223
+ filter_by_audio_fn = partial(
1224
+ vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1225
+ )
1226
+ with accelerator.main_process_first():
1227
+ vectorized_datasets = (
1228
+ filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1229
+ if not data_args.streaming
1230
+ else filter_by_audio_fn()
1231
+ )
1232
+
1233
+ # 10.6: Filter training data with labels longer than `max_label_length`
1234
+ def is_labels_in_length_range(labels):
1235
+ return 0 < len(labels) <= max_label_length
1236
+
1237
+ filter_by_labels_fn = partial(
1238
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1239
+ )
1240
+ with accelerator.main_process_first():
1241
+ vectorized_datasets = (
1242
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1243
+ if not data_args.streaming
1244
+ else filter_by_labels_fn()
1245
+ )
1246
+
1247
+ # Pre-processing complete!
1248
+ # For large datasets it is advised to run the preprocessing on a
1249
+ # single machine first with `--preprocessing_only` since there will mostly likely
1250
+ # be a timeout when running the script in distributed mode.
1251
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1252
+ # cached dataset
1253
+ if data_args.preprocessing_only:
1254
+ if data_args.streaming:
1255
+ raise ValueError(
1256
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1257
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1258
+ "on the fly with streaming mode."
1259
+ )
1260
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1261
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1262
+ return
1263
+
1264
+ # 11. Define Evaluation Metrics
1265
+ def compute_metrics(preds, labels):
1266
+ # replace padded labels by the padding token
1267
+ for idx in range(len(labels)):
1268
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1269
+
1270
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1271
+ # we do not want to group tokens when computing the metrics
1272
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1273
+ wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1274
+
1275
+ # normalize everything and re-compute the WER
1276
+ norm_pred_str = [normalizer(pred) for pred in pred_str]
1277
+ norm_label_str = [normalizer(label) for label in label_str]
1278
+ # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1279
+ pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1280
+ label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1281
+ # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1282
+ norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1283
+ norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1284
+
1285
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1286
+ return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1287
+
1288
+ # 12. Define Training Schedule
1289
+ # Store some constants
1290
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1291
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1292
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1293
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1294
+
1295
+ if not data_args.streaming and training_args.max_steps < 0:
1296
+ num_epochs = int(training_args.num_train_epochs)
1297
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1298
+ total_train_steps = steps_per_epoch * num_epochs
1299
+ elif training_args.max_steps > 0:
1300
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1301
+ total_train_steps = int(training_args.max_steps)
1302
+ if not data_args.streaming:
1303
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1304
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1305
+ else:
1306
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1307
+ num_epochs = sys.maxsize
1308
+ steps_per_epoch = total_train_steps
1309
+ else:
1310
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1311
+
1312
+ if training_args.eval_steps is None:
1313
+ logger.info(
1314
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1315
+ )
1316
+ eval_steps = steps_per_epoch
1317
+ else:
1318
+ eval_steps = training_args.eval_steps
1319
+
1320
+ # 13. Define optimizer, LR scheduler, collator
1321
+ decay_parameters = get_parameter_names(
1322
+ student_model,
1323
+ [nn.LayerNorm],
1324
+ forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1325
+ )
1326
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1327
+ optimizer_grouped_parameters = [
1328
+ {
1329
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1330
+ "weight_decay": training_args.weight_decay,
1331
+ },
1332
+ {
1333
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1334
+ "weight_decay": 0.0,
1335
+ },
1336
+ ]
1337
+ optimizer = torch.optim.AdamW(
1338
+ params=optimizer_grouped_parameters,
1339
+ lr=training_args.learning_rate,
1340
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1341
+ eps=training_args.adam_epsilon,
1342
+ )
1343
+
1344
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1345
+ lr_scheduler = get_scheduler(
1346
+ name=training_args.lr_scheduler_type,
1347
+ optimizer=optimizer,
1348
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1349
+ num_training_steps=total_train_steps * accelerator.num_processes,
1350
+ )
1351
+
1352
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1353
+ processor=processor,
1354
+ decoder_start_token_id=decoder_start_token_id,
1355
+ decoder_prev_token_id=decoder_prev_token_id,
1356
+ input_padding="longest",
1357
+ target_padding="max_length",
1358
+ max_target_length=max_label_length,
1359
+ )
1360
+
1361
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1362
+ # so that we can still access the configs
1363
+ num_beams = (
1364
+ training_args.generation_num_beams
1365
+ if training_args.generation_num_beams is not None
1366
+ else getattr(student_model.generation_config, "num_beams", 1)
1367
+ )
1368
+
1369
+ gen_kwargs = {
1370
+ "max_length": max_label_length,
1371
+ "num_beams": num_beams,
1372
+ "return_timestamps": return_timestamps,
1373
+ }
1374
+ if is_multilingual:
1375
+ # forcing the language and task tokens helps multilingual models in their generations
1376
+ gen_kwargs.update(
1377
+ {
1378
+ "language": data_args.language,
1379
+ "task": data_args.task,
1380
+ }
1381
+ )
1382
+
1383
+ # 15. Prepare everything with accelerate
1384
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1385
+ student_model, teacher_model, optimizer, lr_scheduler
1386
+ )
1387
+
1388
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1389
+ kl_loss = nn.KLDivLoss(reduction="none")
1390
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1391
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1392
+ padding_mask = labels >= 0
1393
+ padding_mask = padding_mask.unsqueeze(-1)
1394
+ divergence = divergence * padding_mask
1395
+ # take the average over the mini-batch
1396
+ divergence = divergence.sum() / padding_mask.sum()
1397
+ return divergence
1398
+
1399
+ # Define gradient update step fn
1400
+ def train_step(
1401
+ batch,
1402
+ temperature=2.0,
1403
+ ):
1404
+ student_model.train()
1405
+ teacher_model.eval()
1406
+
1407
+ student_outputs = student_model(**batch)
1408
+ with torch.no_grad():
1409
+ if share_hidden_states:
1410
+ # if the student and teacher share the same frozen encoder then we don't have to recompute the
1411
+ # encoder hidden-states for the teacher model, we can just re-use from the student
1412
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1413
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1414
+ else:
1415
+ # do the full forward pass for the teacher model (encoder + decoder)
1416
+ teacher_outputs = teacher_model(**batch)
1417
+
1418
+ # CE (data) loss
1419
+ ce_loss = student_outputs.loss
1420
+ # rescale distribution by temperature to ensure gradients scale correctly
1421
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1422
+ # log softmax of student predictions for numerical stability
1423
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1424
+ # KL-divergence loss (scaled by temperature)
1425
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1426
+
1427
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1428
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1429
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1430
+ return loss, metrics
1431
+
1432
+ # Define eval fn
1433
+ def eval_step(batch):
1434
+ student_model.eval()
1435
+ teacher_model.eval()
1436
+
1437
+ with torch.no_grad():
1438
+ student_outputs = student_model(**batch)
1439
+ if share_hidden_states:
1440
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1441
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1442
+ else:
1443
+ teacher_outputs = teacher_model(**batch)
1444
+
1445
+ # CE (data) loss
1446
+ ce_loss = student_outputs.loss
1447
+
1448
+ # log softmax / softmax for numerical stability
1449
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1450
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1451
+ # temperature is always 1 for eval
1452
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1453
+
1454
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1455
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1456
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1457
+ return metrics
1458
+
1459
+ def generate_step(batch):
1460
+ student_model.eval()
1461
+ output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1462
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1463
+ return output_ids
1464
+
1465
+ logger.info("***** Running training *****")
1466
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1467
+ if not data_args.streaming:
1468
+ logger.info(f" Num epochs = {num_epochs}")
1469
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1470
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1471
+ logger.info(
1472
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1473
+ )
1474
+ logger.info(f" Total optimization steps = {total_train_steps}")
1475
+
1476
+ # ======================== Training ================================
1477
+ train_time = 0
1478
+ train_start = time.time()
1479
+ steps_trained_progress_bar = tqdm(
1480
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1481
+ )
1482
+ continue_training = True
1483
+ epochs_trained = 0
1484
+ cur_step = 0
1485
+
1486
+ checkpoint = None
1487
+ if training_args.resume_from_checkpoint is not None:
1488
+ checkpoint = training_args.resume_from_checkpoint
1489
+ elif last_checkpoint is not None:
1490
+ checkpoint = last_checkpoint
1491
+
1492
+ if checkpoint is not None:
1493
+ accelerator.load_state(checkpoint)
1494
+ # Find num steps and epoch from saved state string pattern
1495
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1496
+ match = re.search(pattern, checkpoint)
1497
+ cur_step = int(match.group(1))
1498
+ epochs_trained = int(match.group(2))
1499
+
1500
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1501
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1502
+ logger.info(f" Continuing training from global step {cur_step}")
1503
+
1504
+ steps_trained_progress_bar.update(cur_step)
1505
+
1506
+ for epoch in range(0, epochs_trained):
1507
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1508
+
1509
+ if not data_args.streaming and training_args.max_steps < 0:
1510
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1511
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1512
+ else:
1513
+ # Currently we don't know how many steps we've taken in the current epoch
1514
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1515
+ # This is "good enough" for our purposes but not fully correct
1516
+ resume_step = None
1517
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1518
+ else:
1519
+ resume_step = None
1520
+
1521
+ for epoch in range(epochs_trained, num_epochs):
1522
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1523
+ train_dataloader = DataLoader(
1524
+ vectorized_datasets["train"],
1525
+ collate_fn=data_collator,
1526
+ batch_size=per_device_train_batch_size,
1527
+ num_workers=dataloader_num_workers,
1528
+ prefetch_factor=prefetch_factor,
1529
+ pin_memory=training_args.dataloader_pin_memory,
1530
+ )
1531
+ train_dataloader = accelerator.prepare(train_dataloader)
1532
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1533
+ train_dataloader.dataset.set_epoch(epoch)
1534
+
1535
+ if resume_step is not None:
1536
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1537
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1538
+ resume_step = None
1539
+
1540
+ for batch in train_dataloader:
1541
+ with accelerator.accumulate(student_model):
1542
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1543
+ accelerator.backward(loss)
1544
+ if accelerator.sync_gradients:
1545
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1546
+ optimizer.step()
1547
+ lr_scheduler.step()
1548
+ optimizer.zero_grad()
1549
+
1550
+ # Check if the accelerator has performed an optimization step behind the scenes
1551
+ if accelerator.sync_gradients:
1552
+ steps_trained_progress_bar.update(1)
1553
+ cur_step += 1
1554
+
1555
+ if cur_step % training_args.logging_steps == 0:
1556
+ steps_trained_progress_bar.write(
1557
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1558
+ f" {train_metric['loss']}, Learning Rate:"
1559
+ f" {lr_scheduler.get_last_lr()[0]})"
1560
+ )
1561
+ log_metric(
1562
+ accelerator,
1563
+ metrics=train_metric,
1564
+ learning_rate=lr_scheduler.get_last_lr()[0],
1565
+ train_time=train_time + time.time() - train_start,
1566
+ step=cur_step,
1567
+ epoch=epoch,
1568
+ prefix="train",
1569
+ )
1570
+
1571
+ # save checkpoint and weights after each save_steps and at the end of training
1572
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1573
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1574
+ accelerator.save_state(output_dir=intermediate_dir)
1575
+ accelerator.wait_for_everyone()
1576
+ if accelerator.is_main_process:
1577
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1578
+
1579
+ if training_args.push_to_hub:
1580
+ upload_folder(
1581
+ folder_path=training_args.output_dir,
1582
+ repo_id=repo_name,
1583
+ repo_type="model",
1584
+ commit_message=f"Saving train state of step {cur_step}",
1585
+ )
1586
+
1587
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1588
+ train_time += time.time() - train_start
1589
+ student_model.eval()
1590
+ # ======================== Evaluating ==============================
1591
+ for eval_split in all_eval_splits:
1592
+ eval_metrics = []
1593
+ eval_preds = []
1594
+ eval_labels = []
1595
+ eval_start = time.time()
1596
+
1597
+ validation_dataloader = DataLoader(
1598
+ vectorized_datasets[eval_split],
1599
+ collate_fn=data_collator,
1600
+ batch_size=per_device_eval_batch_size,
1601
+ drop_last=False,
1602
+ num_workers=dataloader_num_workers,
1603
+ prefetch_factor=prefetch_factor,
1604
+ pin_memory=training_args.dataloader_pin_memory,
1605
+ )
1606
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1607
+
1608
+ for batch in tqdm(
1609
+ validation_dataloader,
1610
+ desc=f"Evaluating {eval_split}...",
1611
+ position=2,
1612
+ disable=not accelerator.is_local_main_process,
1613
+ ):
1614
+ # Model forward
1615
+ eval_metric = eval_step(batch)
1616
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1617
+ eval_metrics.append(eval_metric)
1618
+
1619
+ # generation
1620
+ if training_args.predict_with_generate:
1621
+ generated_ids = generate_step(batch)
1622
+ # Gather all predictions and targets
1623
+ generated_ids, labels = accelerator.gather_for_metrics(
1624
+ (generated_ids, batch["labels"])
1625
+ )
1626
+ eval_preds.extend(generated_ids)
1627
+ eval_labels.extend(labels)
1628
+
1629
+ eval_time = time.time() - eval_start
1630
+ # normalize eval metrics
1631
+ eval_metrics = {
1632
+ key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1633
+ }
1634
+
1635
+ # compute WER metric
1636
+ wer_desc = ""
1637
+ if training_args.predict_with_generate:
1638
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1639
+ eval_preds, eval_labels
1640
+ )
1641
+ eval_metrics.update(wer_metric)
1642
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1643
+ log_pred(
1644
+ accelerator,
1645
+ pred_str,
1646
+ label_str,
1647
+ norm_pred_str,
1648
+ norm_label_str,
1649
+ step=cur_step,
1650
+ prefix=eval_split,
1651
+ )
1652
+
1653
+ # Print metrics and update progress bar
1654
+ steps_trained_progress_bar.write(
1655
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1656
+ f" {wer_desc})"
1657
+ )
1658
+
1659
+ log_metric(
1660
+ accelerator,
1661
+ metrics=eval_metrics,
1662
+ train_time=eval_time,
1663
+ step=cur_step,
1664
+ epoch=epoch,
1665
+ prefix=eval_split,
1666
+ )
1667
+
1668
+ # flush the train metrics
1669
+ train_start = time.time()
1670
+
1671
+ # break condition
1672
+ if cur_step == total_train_steps:
1673
+
1674
+ # un-wrap student model for save
1675
+ student_model = accelerator.unwrap_model(student_model)
1676
+ student_model.save_pretrained(training_args.output_dir)
1677
+
1678
+ if training_args.push_to_hub:
1679
+ upload_folder(
1680
+ folder_path=training_args.output_dir,
1681
+ repo_id=repo_name,
1682
+ repo_type="model",
1683
+ commit_message=f"Saving final weights of step {cur_step}",
1684
+ )
1685
+
1686
+ continue_training = False
1687
+ break
1688
+
1689
+ if not continue_training:
1690
+ break
1691
+
1692
+ accelerator.end_training()
1693
+
1694
+
1695
+ if __name__ == "__main__":
1696
+ main()
checkpoint-5000-epoch-0/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2024dd216781ceab76a526c0fb678957a3f9cdc79e7c25fd2e1fddf45a1e7ba6
3
  size 3025686376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:842f59bab397b6a8c02278413ccfd5d6dac9b7eb61db391a22a97732d4f13e55
3
  size 3025686376
checkpoint-5000-epoch-0/optimizer.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cbc528db66588def938f87aea05f116a8224afac942acac5f95f7908aae665c2
3
  size 955539578
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c49615144cf408be66ec5684644e9cbaff970f8b799a32e4e5a523adb22fa90d
3
  size 955539578
distil-whisper/events.out.tfevents.1714722015.server02.764303.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b5b6f4d8b4b95639486918cbda0b7b686283c3fd88820c23b563c0dca074dc2
3
+ size 50898
distil-whisper/events.out.tfevents.1714724453.server02.769515.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1631a259decb801e2ec2a16c11026f4d8492b25f4aa45a735c5063acd5754a0c
3
+ size 88
distil-whisper/events.out.tfevents.1714724491.server02.769647.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07d9cf426d6be544bc32ed9055598a109c9e3cf8eaf025bcd62c952286b4fce0
3
+ size 62058