sanchit-gandhi HF staff commited on
Commit
68b83c4
1 Parent(s): dfb7e3e

Add training scripts

Browse files
Files changed (7) hide show
  1. create_student_model.py +215 -0
  2. run.sh +41 -0
  3. run_all.sh +67 -0
  4. run_distillation.py +1683 -0
  5. run_init.sh +8 -0
  6. run_labelling.sh +28 -0
  7. run_pseudo_labelling.py +1015 -0
create_student_model.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Initialise a student Whisper model from a pre-trained teacher model for
18
+ teacher-student distillation.
19
+ """
20
+
21
+ import argparse
22
+ import copy
23
+ import logging
24
+
25
+ import numpy as np
26
+ import torch
27
+ from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(
35
+ description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
36
+ )
37
+ parser.add_argument(
38
+ "--teacher_checkpoint",
39
+ type=str,
40
+ required=True,
41
+ help="The HF Hub ID of the teacher checkpoint.",
42
+ )
43
+ parser.add_argument(
44
+ "--subfolder",
45
+ type=str,
46
+ default="",
47
+ help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
48
+ "can specify the folder name here.",
49
+ )
50
+ parser.add_argument(
51
+ "--encoder_layers",
52
+ type=int,
53
+ default=None,
54
+ help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
55
+ )
56
+ parser.add_argument(
57
+ "--decoder_layers",
58
+ type=int,
59
+ default=2,
60
+ help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
61
+ )
62
+ parser.add_argument(
63
+ "--save_dir",
64
+ type=str,
65
+ required=True,
66
+ help="Where to save the student weights and processor.",
67
+ )
68
+ parser.add_argument(
69
+ "--push_to_hub",
70
+ type=bool,
71
+ required=False,
72
+ default=False,
73
+ help="Whether to push the student weights and processor to the Hub.",
74
+ )
75
+ parser.add_argument(
76
+ "--cache_dir",
77
+ type=str,
78
+ default=None,
79
+ help="Where to store the pretrained models downloaded from huggingface.co",
80
+ )
81
+
82
+ args = parser.parse_args()
83
+ return args
84
+
85
+
86
+ def init_student_model_from_teacher(
87
+ teacher_checkpoint,
88
+ encoder_layers=None,
89
+ decoder_layers=2,
90
+ save_dir=None,
91
+ push_to_hub=None,
92
+ cache_dir=None,
93
+ subfolder="",
94
+ ):
95
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
96
+ teacher_checkpoint,
97
+ cache_dir=cache_dir,
98
+ subfolder=subfolder,
99
+ low_cpu_mem_usage=True,
100
+ )
101
+ processor = WhisperProcessor.from_pretrained(teacher_checkpoint)
102
+ generation_config = GenerationConfig.from_pretrained(teacher_checkpoint)
103
+ generation_config.forced_decoder_ids = None
104
+
105
+ teacher_config = teacher_model.config
106
+ teacher_encoder_layers = teacher_config.encoder_layers
107
+ teacher_decoder_layers = teacher_config.decoder_layers
108
+
109
+ student_config = copy.deepcopy(teacher_config)
110
+ student_config.update(
111
+ {
112
+ "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers,
113
+ "decoder_layers": decoder_layers,
114
+ }
115
+ )
116
+
117
+ encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int)
118
+ encoder_mapping[-1] = teacher_encoder_layers - 1
119
+
120
+ encoder_map = {}
121
+ for student_layer, teacher_layer in enumerate(encoder_mapping):
122
+ encoder_map[teacher_layer] = student_layer
123
+
124
+ decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int)
125
+ decoder_mapping[-1] = teacher_decoder_layers - 1
126
+
127
+ decoder_map = {}
128
+ for student_layer, teacher_layer in enumerate(decoder_mapping):
129
+ decoder_map[teacher_layer] = student_layer
130
+
131
+ # init the student params from the teacher model
132
+ student_model = WhisperForConditionalGeneration(student_config)
133
+ missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
134
+ if len(missing_keys) > 0:
135
+ raise RuntimeError(
136
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
137
+ f"Missing key(s) in state_dict: {missing_keys}"
138
+ )
139
+ if decoder_layers == teacher_decoder_layers:
140
+ decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key]
141
+ if len(decoder_keys) > 0:
142
+ raise RuntimeError(
143
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
144
+ f"Unexpected key(s) in state_dict: {decoder_keys}"
145
+ )
146
+ if encoder_layers == teacher_encoder_layers:
147
+ encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key]
148
+ if len(encoder_keys) > 0:
149
+ raise RuntimeError(
150
+ "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n"
151
+ f"Unexpected key(s) in state_dict: {encoder_keys}"
152
+ )
153
+
154
+ for layer in range(teacher_decoder_layers):
155
+ if layer in decoder_map:
156
+ # re-introduce pre-defined layers from the teacher
157
+ student_model.model.decoder.layers[decoder_map[layer]].load_state_dict(
158
+ teacher_model.model.decoder.layers[layer].state_dict()
159
+ )
160
+
161
+ if encoder_layers is not None:
162
+ for layer in range(teacher_encoder_layers):
163
+ if layer in encoder_map:
164
+ # re-introduce pre-defined layers from the teacher
165
+ student_model.model.encoder.layers[encoder_map[layer]].load_state_dict(
166
+ teacher_model.model.encoder.layers[layer].state_dict()
167
+ )
168
+
169
+ # remove the teacher params and model
170
+ del teacher_model
171
+
172
+ # save the converted weights and model
173
+ if save_dir is not None:
174
+ student_model.save_pretrained(save_dir)
175
+ # we also need to correctly save the processor and generation config
176
+ processor.save_pretrained(save_dir)
177
+ generation_config.save_pretrained(save_dir)
178
+
179
+ # check we can do a forward pass with the saved model - first load the weights and processor
180
+ logger.info("Checking we can load the saved model...")
181
+ student_model = WhisperForConditionalGeneration.from_pretrained(
182
+ save_dir,
183
+ low_cpu_mem_usage=True,
184
+ )
185
+ processor = WhisperProcessor.from_pretrained(save_dir)
186
+
187
+ # define some random inputs
188
+ input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features
189
+ decoder_start_token_id = student_model.config.decoder_start_token_id
190
+ decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id
191
+
192
+ # do a forward pass - outputs will be gibberish for the initialised model so we can't check them
193
+ # but we make can sure the model runs as expected
194
+ logger.info("Checking we can run the converted model forward...")
195
+ _ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits
196
+ logger.info("Conversion successful!")
197
+
198
+ if push_to_hub:
199
+ student_model.push_to_hub(save_dir)
200
+ processor.push_to_hub(save_dir)
201
+ generation_config.push_to_hub(save_dir)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ args = parse_args()
206
+
207
+ init_student_model_from_teacher(
208
+ teacher_checkpoint=args.teacher_checkpoint,
209
+ encoder_layers=args.encoder_layers,
210
+ decoder_layers=args.decoder_layers,
211
+ save_dir=args.save_dir,
212
+ push_to_hub=args.push_to_hub,
213
+ cache_dir=args.cache_dir,
214
+ subfolder=args.subfolder,
215
+ )
run.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ accelerate launch run_distillation.py \
4
+ --model_name_or_path "./distil-large-v3-init" \
5
+ --teacher_model_name_or_path "openai/whisper-large-v3" \
6
+ --train_dataset_name "../common_voice_16_1_de_pseudo_labelled" \
7
+ --train_split_name "train" \
8
+ --text_column_name "sentence" \
9
+ --eval_dataset_name "../common_voice_16_1_de_pseudo_labelled" \
10
+ --eval_split_name "validation" \
11
+ --eval_text_column_name "sentence" \
12
+ --eval_steps 5000 \
13
+ --save_steps 5000 \
14
+ --warmup_steps 500 \
15
+ --learning_rate 0.0001 \
16
+ --timestamp_probability 0.2 \
17
+ --condition_on_prev_probability 0.2 \
18
+ --language "de" \
19
+ --task "transcribe" \
20
+ --logging_steps 25 \
21
+ --save_total_limit 1 \
22
+ --max_steps 50000 \
23
+ --wer_threshold 10 \
24
+ --per_device_train_batch_size 32 \
25
+ --per_device_eval_batch_size 32 \
26
+ --dataloader_num_workers 8 \
27
+ --preprocessing_num_workers 8 \
28
+ --ddp_timeout 7200 \
29
+ --dtype "bfloat16" \
30
+ --attn_implementation "flash_attention_2" \
31
+ --output_dir "./" \
32
+ --do_train \
33
+ --do_eval \
34
+ --gradient_checkpointing \
35
+ --overwrite_output_dir \
36
+ --predict_with_generate \
37
+ --freeze_encoder \
38
+ --freeze_embed_positions \
39
+ --streaming False \
40
+ --push_to_hub
41
+
run_all.sh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ accelerate launch run_pseudo_labelling.py \
4
+ --model_name_or_path "openai/whisper-large-v3" \
5
+ --dataset_name "mozilla-foundation/common_voice_16_1" \
6
+ --dataset_config_name "de" \
7
+ --dataset_split_name "train+validation+test" \
8
+ --text_column_name "sentence" \
9
+ --id_column_name "path" \
10
+ --output_dir "../common_voice_16_1_de_pseudo_labelled" \
11
+ --wandb_project "distil-whisper-labelling" \
12
+ --per_device_eval_batch_size 64 \
13
+ --dtype "bfloat16" \
14
+ --attn_implementation "flash_attention_2" \
15
+ --logging_steps 500 \
16
+ --max_label_length 256 \
17
+ --concatenate_audio \
18
+ --preprocessing_batch_size 500 \
19
+ --preprocessing_num_workers 8 \
20
+ --dataloader_num_workers 8 \
21
+ --report_to "wandb" \
22
+ --language "de" \
23
+ --task "transcribe" \
24
+ --return_timestamps \
25
+ --streaming False \
26
+ --generation_num_beams 1 \
27
+ --push_to_hub
28
+
29
+ accelerate launch run_distillation.py \
30
+ --model_name_or_path "./distil-large-v3-init" \
31
+ --teacher_model_name_or_path "openai/whisper-large-v3" \
32
+ --train_dataset_name "../common_voice_16_1_de_pseudo_labelled" \
33
+ --train_split_name "train" \
34
+ --text_column_name "sentence" \
35
+ --eval_dataset_name "mozilla-foundation/common_voice_16_1" \
36
+ --eval_split_name "validation" \
37
+ --eval_text_column_name "sentence" \
38
+ --eval_steps 5000 \
39
+ --save_steps 5000 \
40
+ --warmup_steps 500 \
41
+ --learning_rate 0.0001 \
42
+ --timestamp_probability 0.2 \
43
+ --condition_on_prev_probability 0.2 \
44
+ --language "de" \
45
+ --task "transcribe" \
46
+ --logging_steps 25 \
47
+ --save_total_limit 1 \
48
+ --max_steps 50000 \
49
+ --wer_threshold 10 \
50
+ --per_device_train_batch_size 32 \
51
+ --per_device_eval_batch_size 32 \
52
+ --dataloader_num_workers 8 \
53
+ --preprocessing_num_workers 8 \
54
+ --ddp_timeout 7200 \
55
+ --dtype "bfloat16" \
56
+ --attn_implementation "flash_attention_2" \
57
+ --output_dir "./" \
58
+ --do_train \
59
+ --do_eval \
60
+ --gradient_checkpointing \
61
+ --overwrite_output_dir \
62
+ --predict_with_generate \
63
+ --freeze_encoder \
64
+ --freeze_embed_positions \
65
+ --streaming False \
66
+ --push_to_hub
67
+
run_distillation.py ADDED
@@ -0,0 +1,1683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training the Whisper model for sequence to sequence speech recognition via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import re
24
+ import shutil
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ import datasets
33
+ import evaluate
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ DatasetDict,
42
+ IterableDataset,
43
+ IterableDatasetDict,
44
+ concatenate_datasets,
45
+ interleave_datasets,
46
+ load_dataset,
47
+ )
48
+ from huggingface_hub import Repository, create_repo, get_full_repo_name, upload_folder
49
+ from torch.utils.data import DataLoader
50
+ from tqdm import tqdm
51
+ from transformers import (
52
+ AddedToken,
53
+ HfArgumentParser,
54
+ Seq2SeqTrainingArguments,
55
+ WhisperConfig,
56
+ WhisperFeatureExtractor,
57
+ WhisperForConditionalGeneration,
58
+ WhisperProcessor,
59
+ WhisperTokenizerFast,
60
+ get_scheduler,
61
+ set_seed,
62
+ )
63
+ from transformers.modeling_outputs import BaseModelOutput
64
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ feature_extractor_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None,
103
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
104
+ )
105
+ use_fast_tokenizer: bool = field(
106
+ default=True,
107
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108
+ )
109
+ model_revision: str = field(
110
+ default="main",
111
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112
+ )
113
+ subfolder: str = field(
114
+ default="",
115
+ metadata={
116
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
117
+ "specify the folder name here."
118
+ },
119
+ )
120
+ token: str = field(
121
+ default=None,
122
+ metadata={
123
+ "help": (
124
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
125
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
126
+ )
127
+ },
128
+ )
129
+ attn_implementation: Optional[str] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": (
133
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
134
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
135
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
136
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
137
+ )
138
+ },
139
+ )
140
+ def __post_init__(self):
141
+ if self.attn_implementation not in [None, "eager", "sdpa", "flash_attention_2"]:
142
+ raise ValueError(
143
+ f"Got `--attn_implementation={self.attn_implementation}`, which is an invalid attention type. Should be one of:\n"
144
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
145
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
146
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
147
+ )
148
+
149
+
150
+ @dataclass
151
+ class DataTrainingArguments:
152
+ """
153
+ Arguments pertaining to what data we are going to input our model for training and eval.
154
+ """
155
+
156
+ train_dataset_name: str = field(
157
+ default=None,
158
+ metadata={
159
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
160
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
161
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
162
+ },
163
+ )
164
+ train_dataset_config_name: Optional[str] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
168
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
169
+ "match the order of the datasets."
170
+ },
171
+ )
172
+ train_dataset_samples: str = field(
173
+ default=None,
174
+ metadata={
175
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
176
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
177
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
178
+ "sample from every dataset is used once per epoch."
179
+ },
180
+ )
181
+ eval_dataset_name: str = field(
182
+ default=None,
183
+ metadata={
184
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
185
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
186
+ "ids by a '+' symbol."
187
+ },
188
+ )
189
+ eval_dataset_config_name: Optional[str] = field(
190
+ default=None,
191
+ metadata={
192
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
193
+ "training dataset config name if unspecified."
194
+ },
195
+ )
196
+ dataset_cache_dir: Optional[str] = field(
197
+ default=None,
198
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
199
+ )
200
+ overwrite_cache: bool = field(
201
+ default=False,
202
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
203
+ )
204
+ preprocessing_num_workers: Optional[int] = field(
205
+ default=None,
206
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
207
+ )
208
+ preprocessing_batch_size: Optional[int] = field(
209
+ default=256,
210
+ metadata={"help": "Number of examples per batch provided to the `prepare_dataset` function."},
211
+ )
212
+ max_train_samples: Optional[int] = field(
213
+ default=None,
214
+ metadata={
215
+ "help": (
216
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
217
+ )
218
+ },
219
+ )
220
+ max_eval_samples: Optional[int] = field(
221
+ default=None,
222
+ metadata={
223
+ "help": (
224
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
225
+ )
226
+ },
227
+ )
228
+ audio_column_name: str = field(
229
+ default="audio",
230
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
231
+ )
232
+ text_column_name: str = field(
233
+ default=None,
234
+ metadata={"help": "The name of the dataset column containing the text data in the training set."},
235
+ )
236
+ eval_text_column_name: str = field(
237
+ default="text",
238
+ metadata={"help": ("The name of the dataset column containing the text data in the evaluation set.")},
239
+ )
240
+ max_duration_in_seconds: float = field(
241
+ default=30.0,
242
+ metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
243
+ )
244
+ min_duration_in_seconds: float = field(
245
+ default=0.0,
246
+ metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
247
+ )
248
+ max_label_length: int = field(
249
+ default=448,
250
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
251
+ )
252
+ pad_target_to_multiple_of: Optional[int] = field(
253
+ default=None,
254
+ metadata={
255
+ "help": (
256
+ "If set will pad the target sequence to a multiple of the provided"
257
+ " value. This is important to avoid triggering recompilations on TPU."
258
+ " If unspecified, will default to padding the targets to max length."
259
+ )
260
+ },
261
+ )
262
+ preprocessing_only: bool = field(
263
+ default=False,
264
+ metadata={
265
+ "help": (
266
+ "Whether to only do data preprocessing and skip training. This is"
267
+ " especially useful when data preprocessing errors out in distributed"
268
+ " training due to timeout. In this case, one should run the"
269
+ " preprocessing in a non-distributed setup with"
270
+ " `preprocessing_only=True` so that the cached datasets can"
271
+ " consequently be loaded in distributed training"
272
+ )
273
+ },
274
+ )
275
+ train_split_name: str = field(
276
+ default="train",
277
+ metadata={
278
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
279
+ },
280
+ )
281
+ eval_split_name: str = field(
282
+ default="validation",
283
+ metadata={
284
+ "help": (
285
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
286
+ )
287
+ },
288
+ )
289
+ streaming: bool = field(
290
+ default=True,
291
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
292
+ )
293
+ wer_threshold: float = field(
294
+ default=None,
295
+ metadata={
296
+ "help": "Filter training data with Whisper transcriptions that have greater than `wer_threshold` "
297
+ "WER with the normalised transcriptions. This only takes effect if training on pseudo-labels targets."
298
+ "If `--use_pseudo_labels=False`, then no WER filtering is performed, since we train directly on the text"
299
+ "transcriptions."
300
+ },
301
+ )
302
+ use_pseudo_labels: bool = field(
303
+ default=True,
304
+ metadata={
305
+ "help": "Whether or not to use pseudo-label transcriptions as the targets. If True, the pseudo-labels "
306
+ "must be in the dataset column `whisper_transcript` from the previous pseudo-labelling step. This is "
307
+ "not currently yet configurable."
308
+ },
309
+ )
310
+ timestamp_probability: float = field(
311
+ default=0.2, metadata={"help": "Probability for training on timestamped tokens if the data contains it."}
312
+ )
313
+ condition_on_prev_probability: float = field(
314
+ default=0.2, metadata={"help": "Probability for conditioning on the previous text example."}
315
+ )
316
+ return_timestamps: bool = field(
317
+ default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
318
+ )
319
+ language: str = field(
320
+ default=None,
321
+ metadata={
322
+ "help": (
323
+ "Language for multilingual distillation. This argument should be set for multilingual distillation "
324
+ "only. For English speech recognition, it should be left as `None`."
325
+ )
326
+ },
327
+ )
328
+ task: str = field(
329
+ default="transcribe",
330
+ metadata={
331
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
332
+ "This argument should be set for multilingual distillation only. For English speech recognition, it should be left as `None`."
333
+ },
334
+ )
335
+ wandb_project: str = field(
336
+ default="distil-whisper",
337
+ metadata={"help": "The name of the wandb project."},
338
+ )
339
+
340
+
341
+ @dataclass
342
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
343
+ freeze_encoder: Optional[bool] = field(
344
+ default=False,
345
+ metadata={
346
+ "help": (
347
+ "Whether to freeze the entire encoder model. Only recommended when the entire encoder has been "
348
+ "copied from the teacher model."
349
+ )
350
+ },
351
+ )
352
+ freeze_embed_positions: Optional[bool] = field(
353
+ default=False,
354
+ metadata={"help": "Whether to freeze the decoder embedding positions."},
355
+ )
356
+ temperature: Optional[float] = field(
357
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
358
+ )
359
+ kl_weight: Optional[float] = field(
360
+ default=1.0,
361
+ metadata={
362
+ "help": (
363
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
364
+ "computed between the teacher-student hidden states and attentions."
365
+ )
366
+ },
367
+ )
368
+ dtype: Optional[str] = field(
369
+ default="float32",
370
+ metadata={
371
+ "help": (
372
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
373
+ "`float16` or `bfloat16` (both half-precision)."
374
+ )
375
+ },
376
+ )
377
+
378
+
379
+ @dataclass
380
+ class DataCollatorSpeechSeq2SeqWithPadding:
381
+ """
382
+ Data collator that will dynamically pad the inputs received.
383
+ Args:
384
+ processor ([`Wav2Vec2Processor`])
385
+ The processor used for proccessing the data.
386
+ decoder_start_token_id (:obj: `int`)
387
+ The start-of-sequence token id of the decoder.
388
+ decoder_prev_token_id (:obj: `int`)
389
+ The start-of-prompt token id of the decoder
390
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
391
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
392
+ among:
393
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
394
+ sequence if provided).
395
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
396
+ maximum acceptable input length for the model if that argument is not provided.
397
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
398
+ different lengths).
399
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
400
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
401
+ See above for details.
402
+ max_target_length (:obj:`int`, `optional`):
403
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
404
+ """
405
+
406
+ processor: Any
407
+ decoder_start_token_id: int
408
+ decoder_prev_token_id: int
409
+ input_padding: Union[bool, str] = "max_length"
410
+ target_padding: Union[bool, str] = "max_length"
411
+ max_target_length: Optional[int] = None
412
+
413
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
414
+ # split inputs and labels since they have to be of different lengths and need
415
+ # different padding methods
416
+
417
+ # dataloader returns a list of features which we convert to a dict
418
+ input_features = {"input_features": [feature["input_features"] for feature in features]}
419
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
420
+
421
+ # reformat list to dict and set to pytorch format
422
+ batch = self.processor.feature_extractor.pad(
423
+ input_features,
424
+ padding=self.input_padding,
425
+ return_tensors="pt",
426
+ )
427
+
428
+ labels_batch = self.processor.tokenizer.pad(
429
+ label_features,
430
+ max_length=self.max_target_length,
431
+ padding=self.target_padding,
432
+ return_tensors="pt",
433
+ )
434
+
435
+ # shift labels to the right to get decoder input ids
436
+ labels = labels_batch["input_ids"]
437
+ decoder_input_ids = labels[:, :-1]
438
+ labels = labels[:, 1:]
439
+ labels_mask = labels_batch.attention_mask[:, 1:]
440
+
441
+ # replace padding with -100 to ignore correctly when computing the loss
442
+ labels = labels.masked_fill(labels_mask.ne(1), -100)
443
+
444
+ # replace initial prompt tokens with -100 to ignore correctly when computing the loss
445
+ bos_index = torch.argmax((labels == self.decoder_start_token_id).long(), dim=1)
446
+ bos_index = torch.where(bos_index > 0, bos_index + 1, bos_index)
447
+ prompt_mask = torch.arange(labels.shape[1]) < bos_index[:, None]
448
+ labels = torch.where(prompt_mask, -100, labels)
449
+
450
+ batch["labels"] = labels
451
+ batch["decoder_input_ids"] = decoder_input_ids
452
+
453
+ return batch
454
+
455
+
456
+ def log_metric(
457
+ accelerator,
458
+ metrics: Dict,
459
+ train_time: float,
460
+ step: int,
461
+ epoch: int,
462
+ learning_rate: float = None,
463
+ prefix: str = "train",
464
+ ):
465
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
466
+ log_metrics = {}
467
+ for k, v in metrics.items():
468
+ log_metrics[f"{prefix}/{k}"] = v
469
+ log_metrics[f"{prefix}/time"] = train_time
470
+ log_metrics[f"{prefix}/epoch"] = epoch
471
+ if learning_rate is not None:
472
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
473
+ accelerator.log(log_metrics, step=step)
474
+
475
+
476
+ def log_pred(
477
+ accelerator,
478
+ pred_str: List[str],
479
+ label_str: List[str],
480
+ norm_pred_str: List[str],
481
+ norm_label_str: List[str],
482
+ step: int,
483
+ prefix: str = "eval",
484
+ num_lines: int = 200000,
485
+ ):
486
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
487
+ if accelerator.is_main_process:
488
+ wandb_tracker = accelerator.get_tracker("wandb")
489
+ # pretty name for current step: step 50000 -> step 50k
490
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
491
+ prefix_pretty = prefix.replace("/", "-")
492
+
493
+ # convert str data to a wandb compatible format
494
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
495
+ # log as a table with the appropriate headers
496
+ wandb_tracker.log_table(
497
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
498
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
499
+ data=str_data[:num_lines],
500
+ step=step,
501
+ )
502
+
503
+ # log incorrect normalised predictions
504
+ str_data = np.asarray(str_data)
505
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
506
+ # log as a table with the appropriate headers
507
+ wandb_tracker.log_table(
508
+ table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
509
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
510
+ data=str_data_incorrect[:num_lines],
511
+ step=step,
512
+ )
513
+
514
+
515
+ def convert_dataset_str_to_list(
516
+ dataset_names,
517
+ dataset_config_names,
518
+ splits=None,
519
+ text_column_names=None,
520
+ dataset_samples=None,
521
+ default_split="train",
522
+ ) -> List[Dict]:
523
+ """
524
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
525
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
526
+ function returns a list of dictionaries, one for each dataset.
527
+ """
528
+ if isinstance(dataset_names, str):
529
+ dataset_names = dataset_names.split("+")
530
+ dataset_config_names = dataset_config_names.split("+") if dataset_config_names is not None else None
531
+ splits = splits.split("+") if splits is not None else None
532
+ text_column_names = text_column_names.split("+") if text_column_names is not None else None
533
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
534
+
535
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
536
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
537
+ raise ValueError(
538
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
539
+ f" {len(dataset_config_names)} configs."
540
+ )
541
+
542
+ if splits is not None and len(splits) != len(dataset_names):
543
+ raise ValueError(
544
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
545
+ )
546
+
547
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
548
+ raise ValueError(
549
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
550
+ f" {len(text_column_names)} text column names."
551
+ )
552
+
553
+ if dataset_samples is not None:
554
+ if len(dataset_samples) != len(dataset_names):
555
+ raise ValueError(
556
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
557
+ f"{len(dataset_samples)} samples."
558
+ )
559
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
560
+ else:
561
+ dataset_samples = [None] * len(dataset_names)
562
+
563
+ dataset_config_names = (
564
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
565
+ )
566
+ text_column_names = (
567
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
568
+ )
569
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
570
+
571
+ dataset_names_dict = []
572
+ for i, ds_name in enumerate(dataset_names):
573
+ dataset_names_dict.append(
574
+ {
575
+ "name": ds_name,
576
+ "config": dataset_config_names[i],
577
+ "split": splits[i],
578
+ "text_column_name": text_column_names[i],
579
+ "samples": dataset_samples[i],
580
+ }
581
+ )
582
+ return dataset_names_dict
583
+
584
+
585
+ def load_multiple_datasets(
586
+ dataset_names: Union[List, str],
587
+ dataset_config_names: Union[List, str],
588
+ splits: Optional[Union[List, str]] = None,
589
+ text_column_names: Optional[List] = None,
590
+ sampling_rate: Optional[int] = 16000,
591
+ stopping_strategy: Optional[str] = "first_exhausted",
592
+ dataset_samples: Optional[Union[List, np.array]] = None,
593
+ streaming: Optional[bool] = True,
594
+ seed: Optional[int] = None,
595
+ accelerator: Optional[Accelerator] = None,
596
+ use_pseudo_labels: float = None,
597
+ **kwargs,
598
+ ) -> IterableDataset:
599
+ dataset_names_dict = convert_dataset_str_to_list(
600
+ dataset_names, dataset_config_names, splits, text_column_names, dataset_samples
601
+ )
602
+
603
+ if dataset_samples is not None:
604
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
605
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
606
+ else:
607
+ probabilities = None
608
+
609
+ all_datasets = []
610
+ # iterate over the datasets we want to interleave
611
+ for dataset_dict in tqdm(
612
+ dataset_names_dict,
613
+ desc="Combining datasets...",
614
+ disable=not accelerator.is_local_main_process if accelerator is not None else False,
615
+ ):
616
+ dataset = load_dataset(
617
+ dataset_dict["name"],
618
+ dataset_dict["config"],
619
+ split=dataset_dict["split"],
620
+ streaming=streaming,
621
+ **kwargs,
622
+ )
623
+ # resample to specified sampling rate
624
+ dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
625
+ dataset_features = dataset.features.keys()
626
+ columns_to_keep = {"audio", "text"}
627
+
628
+ if dataset_dict["text_column_name"] not in dataset_features:
629
+ raise ValueError(
630
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
631
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
632
+ f" correct text column - one of {', '.join(dataset_features)}."
633
+ )
634
+
635
+ # blanket renaming of all transcription columns to text
636
+ if dataset_dict["text_column_name"] != "text":
637
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
638
+
639
+ if use_pseudo_labels:
640
+ if "whisper_transcript" not in dataset_features:
641
+ raise ValueError(
642
+ f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure"
643
+ "pseudo-labels are present in the dataset under this column name, or train directly on the text "
644
+ "labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`."
645
+ )
646
+ columns_to_keep.add("whisper_transcript")
647
+
648
+ if "condition_on_prev" in dataset_features:
649
+ columns_to_keep.add("condition_on_prev")
650
+
651
+ dataset_features = dataset.features.keys()
652
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
653
+ all_datasets.append(dataset)
654
+
655
+ if len(all_datasets) == 1:
656
+ # we have a single dataset so just return it as is
657
+ return all_datasets[0]
658
+
659
+ if streaming:
660
+ interleaved_dataset = interleave_datasets(
661
+ all_datasets,
662
+ stopping_strategy=stopping_strategy,
663
+ probabilities=probabilities,
664
+ seed=seed,
665
+ )
666
+ else:
667
+ interleaved_dataset = concatenate_datasets(all_datasets)
668
+
669
+ return interleaved_dataset
670
+
671
+
672
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
673
+ """Helper function to sort saved checkpoints from oldest to newest."""
674
+ ordering_and_checkpoint_path = []
675
+
676
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
677
+
678
+ for path in glob_checkpoints:
679
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
680
+ if regex_match is not None and regex_match.groups() is not None:
681
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
682
+
683
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
684
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
685
+ return checkpoints_sorted
686
+
687
+
688
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
689
+ """Helper function to delete old checkpoints."""
690
+ if save_total_limit is None or save_total_limit <= 0:
691
+ return
692
+ # Check if we should delete older checkpoint(s)
693
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
694
+ if len(checkpoints_sorted) <= save_total_limit:
695
+ return
696
+
697
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
698
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
699
+ for checkpoint in checkpoints_to_be_deleted:
700
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
701
+ shutil.rmtree(checkpoint, ignore_errors=True)
702
+
703
+
704
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
705
+
706
+
707
+ def get_last_checkpoint(folder):
708
+ content = os.listdir(folder)
709
+ checkpoints = [
710
+ path
711
+ for path in content
712
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
713
+ ]
714
+ if len(checkpoints) == 0:
715
+ return
716
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
717
+
718
+
719
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
720
+ """
721
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
722
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
723
+ (e.g. if the module is frozen).
724
+ """
725
+ result = []
726
+ for name, child in model.named_children():
727
+ result += [
728
+ f"{name}.{n}"
729
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
730
+ if not (
731
+ isinstance(child, tuple(forbidden_layer_types))
732
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
733
+ )
734
+ ]
735
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
736
+ result += list(model._parameters.keys())
737
+ return result
738
+
739
+
740
+ def main():
741
+ # 1. Parse input arguments
742
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
743
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
744
+
745
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
746
+ # If we pass only one argument to the script and it's the path to a json file,
747
+ # let's parse it to get our arguments.
748
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
749
+ else:
750
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
751
+
752
+ # 2. Initialize the accelerator
753
+ # We will let the accelerator handle device placement for us in this example
754
+ # We simply have to specify the training precision and any trackers being used
755
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
756
+ # it to accelerate format
757
+ if training_args.dtype == "float16":
758
+ mixed_precision = "fp16"
759
+ teacher_dtype = torch.float16
760
+ elif training_args.dtype == "bfloat16":
761
+ mixed_precision = "bf16"
762
+ teacher_dtype = torch.bfloat16
763
+ else:
764
+ mixed_precision = "no"
765
+ teacher_dtype = torch.float32
766
+
767
+ accelerator = Accelerator(
768
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
769
+ mixed_precision=mixed_precision,
770
+ log_with=training_args.report_to,
771
+ project_dir=training_args.output_dir,
772
+ )
773
+
774
+ accelerator.init_trackers(project_name=data_args.wandb_project)
775
+
776
+ # 3. Set-up basic logging
777
+ # Create one log on every process with the configuration for debugging
778
+ logging.basicConfig(
779
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
780
+ datefmt="%m/%d/%Y %H:%M:%S",
781
+ level=logging.INFO,
782
+ )
783
+ # Log a small summary on each proces
784
+ logger.warning(
785
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
786
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
787
+ )
788
+
789
+ # Set the verbosity to info of the Transformers logger (on main process only)
790
+ if accelerator.is_local_main_process:
791
+ datasets.utils.logging.set_verbosity_warning()
792
+ transformers.utils.logging.set_verbosity_info()
793
+ else:
794
+ datasets.utils.logging.set_verbosity_error()
795
+ transformers.utils.logging.set_verbosity_error()
796
+ logger.info("Training/evaluation parameters %s", training_args)
797
+
798
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
799
+ last_checkpoint = None
800
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
801
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
802
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
803
+ raise ValueError(
804
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
805
+ "Use --overwrite_output_dir to overcome."
806
+ )
807
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
808
+ logger.info(
809
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
810
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
811
+ )
812
+
813
+ # 5. Handle the repository creation
814
+ if accelerator.is_main_process:
815
+ if training_args.push_to_hub:
816
+ if training_args.hub_model_id is None:
817
+ repo_name = get_full_repo_name(
818
+ Path(training_args.output_dir).absolute().name,
819
+ token=training_args.hub_token,
820
+ )
821
+ else:
822
+ repo_name = training_args.hub_model_id
823
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
824
+
825
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
826
+ if "wandb" not in gitignore:
827
+ gitignore.write("wandb\n")
828
+ elif training_args.output_dir is not None:
829
+ os.makedirs(training_args.output_dir, exist_ok=True)
830
+ accelerator.wait_for_everyone()
831
+
832
+ # 6. Load dataset - either streaming or non-streaming (offline)
833
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
834
+
835
+ # set seed for determinism
836
+ set_seed(training_args.seed)
837
+
838
+ if training_args.do_train:
839
+ raw_datasets["train"] = load_multiple_datasets(
840
+ data_args.train_dataset_name,
841
+ data_args.train_dataset_config_name,
842
+ splits=data_args.train_split_name,
843
+ text_column_names=data_args.text_column_name,
844
+ use_pseudo_labels=data_args.use_pseudo_labels,
845
+ streaming=data_args.streaming,
846
+ dataset_samples=data_args.train_dataset_samples,
847
+ seed=training_args.seed,
848
+ accelerator=accelerator,
849
+ cache_dir=data_args.dataset_cache_dir,
850
+ token=model_args.token,
851
+ )
852
+ raw_datasets_train_features = list(raw_datasets["train"].features.keys())
853
+
854
+ if training_args.do_eval:
855
+ dataset_names_dict = convert_dataset_str_to_list(
856
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
857
+ data_args.eval_dataset_config_name
858
+ if data_args.eval_dataset_config_name
859
+ else data_args.train_dataset_config_name,
860
+ splits=data_args.eval_split_name,
861
+ text_column_names=data_args.eval_text_column_name,
862
+ )
863
+ all_eval_splits = []
864
+ if len(dataset_names_dict) == 1:
865
+ # load a single eval set
866
+ dataset_dict = dataset_names_dict[0]
867
+ all_eval_splits.append("eval")
868
+ raw_datasets["eval"] = load_dataset(
869
+ dataset_dict["name"],
870
+ dataset_dict["config"],
871
+ split=dataset_dict["split"],
872
+ cache_dir=data_args.dataset_cache_dir,
873
+ token=model_args.token,
874
+ streaming=data_args.streaming,
875
+ )
876
+ if data_args.eval_text_column_name != "text":
877
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
878
+ else:
879
+ # load multiple eval sets
880
+ for dataset_dict in dataset_names_dict:
881
+ if dataset_dict["name"] == "esb/diagnostic-dataset":
882
+ # for the ESB diagnostic dataset, the dataset name is effectively the config
883
+ pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
884
+ else:
885
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
886
+ all_eval_splits.append(pretty_name)
887
+ raw_datasets[pretty_name] = load_dataset(
888
+ dataset_dict["name"],
889
+ dataset_dict["config"],
890
+ split=dataset_dict["split"],
891
+ cache_dir=data_args.dataset_cache_dir,
892
+ token=model_args.token,
893
+ streaming=data_args.streaming,
894
+ )
895
+ # make column names consistent (text, audio)
896
+ if dataset_dict["text_column_name"] != "text":
897
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
898
+ dataset_dict["text_column_name"], "text"
899
+ )
900
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
901
+ set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
902
+ )
903
+
904
+ if not training_args.do_train and not training_args.do_eval:
905
+ raise ValueError(
906
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
907
+ )
908
+
909
+ # 7. Load pretrained model, tokenizer, and feature extractor
910
+ config = WhisperConfig.from_pretrained(
911
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
912
+ cache_dir=model_args.cache_dir,
913
+ revision=model_args.model_revision,
914
+ token=model_args.token,
915
+ )
916
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
917
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
918
+ cache_dir=model_args.cache_dir,
919
+ revision=model_args.model_revision,
920
+ token=model_args.token,
921
+ )
922
+ tokenizer = WhisperTokenizerFast.from_pretrained(
923
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
924
+ cache_dir=model_args.cache_dir,
925
+ use_fast=model_args.use_fast_tokenizer,
926
+ revision=model_args.model_revision,
927
+ token=model_args.token,
928
+ )
929
+
930
+ # override timestamp tokens until tokenizer issues are fixed in transformers
931
+ timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
932
+ tokenizer.add_tokens(timestamps)
933
+
934
+ # The teacher model can safely be cast to the dtype of training since we don't
935
+ # update the params
936
+ teacher_model = WhisperForConditionalGeneration.from_pretrained(
937
+ model_args.teacher_model_name_or_path,
938
+ cache_dir=model_args.cache_dir,
939
+ token=model_args.token,
940
+ low_cpu_mem_usage=True,
941
+ torch_dtype=teacher_dtype,
942
+ attn_implementation=model_args.attn_implementation,
943
+ )
944
+
945
+ student_model = WhisperForConditionalGeneration.from_pretrained(
946
+ model_args.model_name_or_path,
947
+ config=config,
948
+ cache_dir=model_args.cache_dir,
949
+ revision=model_args.model_revision,
950
+ subfolder=model_args.subfolder,
951
+ token=model_args.token,
952
+ low_cpu_mem_usage=True,
953
+ attn_implementation=model_args.attn_implementation,
954
+ )
955
+
956
+ if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
957
+ raise ValueError(
958
+ f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
959
+ f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
960
+ f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
961
+ )
962
+
963
+ # enable gradient checkpointing if necessary
964
+ if training_args.gradient_checkpointing:
965
+ student_model.gradient_checkpointing_enable()
966
+
967
+ def set_trainable_parameters(module, requires_grad=False):
968
+ for param in module.parameters():
969
+ param.requires_grad = requires_grad
970
+ module._requires_grad = requires_grad
971
+
972
+ # freeze student encoder if necessary
973
+ if training_args.freeze_encoder:
974
+ set_trainable_parameters(student_model.model.encoder, requires_grad=False)
975
+ student_model.model.encoder.gradient_checkpointing = False
976
+
977
+ if training_args.freeze_embed_positions:
978
+ # set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False)
979
+ set_trainable_parameters(student_model.model.decoder.embed_positions, requires_grad=False)
980
+ if student_model.model.decoder.gradient_checkpointing:
981
+ logger.info(
982
+ "Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`."
983
+ )
984
+
985
+ share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
986
+ if share_hidden_states:
987
+ # tie the weights for the teacher encoder if we're freezing the student and it's the same as the teacher
988
+ teacher_model.model.encoder = student_model.model.encoder
989
+
990
+ if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
991
+ # We need to set the language and task ids for previously multilingual checkpoints
992
+ is_multilingual = True
993
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task, predict_timestamps=False)
994
+ student_model.generation_config.update(
995
+ **{
996
+ "language": data_args.language,
997
+ "task": data_args.task,
998
+ }
999
+ )
1000
+ elif data_args.language is not None:
1001
+ raise ValueError(
1002
+ "Setting language token for an English-only checkpoint is not permitted. The language argument should "
1003
+ "only be set for multilingual checkpoints."
1004
+ )
1005
+ else:
1006
+ is_multilingual = False
1007
+
1008
+ # 8. Create a single speech processor - make sure all processes wait until data is saved
1009
+ if accelerator.is_main_process:
1010
+ feature_extractor.save_pretrained(training_args.output_dir)
1011
+ tokenizer.save_pretrained(training_args.output_dir)
1012
+ # save the config and generation config as well
1013
+ config.save_pretrained(training_args.output_dir)
1014
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1015
+
1016
+ accelerator.wait_for_everyone()
1017
+ processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1018
+
1019
+ # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1020
+ # so we just need to set the correct target sampling rate.
1021
+ sampling_rate = feature_extractor.sampling_rate
1022
+ raw_datasets = raw_datasets.cast_column(
1023
+ data_args.audio_column_name,
1024
+ datasets.features.Audio(sampling_rate=sampling_rate),
1025
+ )
1026
+
1027
+ # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1028
+ # 10.1: Define the pre-processing constants
1029
+ max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1031
+ max_label_length = (
1032
+ data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1033
+ )
1034
+
1035
+ timestamp_probability = data_args.timestamp_probability
1036
+ condition_on_prev_probability = data_args.condition_on_prev_probability
1037
+ return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1038
+
1039
+ timestamp_ids = tokenizer.timestamp_ids()
1040
+ timestamp_begin = tokenizer.all_special_ids[-1]
1041
+ timestamp_position = 3 if is_multilingual else 1
1042
+
1043
+ decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1044
+ decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1045
+ prompt_cutoff_length = max_label_length // 2
1046
+
1047
+ num_workers = data_args.preprocessing_num_workers
1048
+ dataloader_num_workers = training_args.dataloader_num_workers
1049
+ prefetch_factor = training_args.dataloader_prefetch_factor
1050
+
1051
+ metric = evaluate.load("wer")
1052
+ normalizer = (
1053
+ BasicTextNormalizer() if data_args.language is not None else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1054
+ )
1055
+ wer_threshold = data_args.wer_threshold
1056
+ use_pseudo_labels = data_args.use_pseudo_labels
1057
+ train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1058
+
1059
+ # 10.2: filter based on maximum number of training/evaluation samples
1060
+ if training_args.do_train and data_args.max_train_samples is not None:
1061
+ raw_datasets["train"] = (
1062
+ raw_datasets["train"].take(data_args.max_train_samples)
1063
+ if data_args.streaming
1064
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1065
+ )
1066
+
1067
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1068
+ for eval_split in all_eval_splits:
1069
+ raw_datasets[eval_split] = (
1070
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1071
+ if data_args.streaming
1072
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1073
+ )
1074
+
1075
+ # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1076
+ def is_wer_in_range(ground_truth, whisper_transcript):
1077
+ norm_ground_truth = normalizer(ground_truth)
1078
+ if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1079
+ # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1080
+ return False
1081
+ elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
1082
+ norm_whisper_transcript = normalizer(whisper_transcript)
1083
+ wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1084
+ return wer < wer_threshold
1085
+ else:
1086
+ # filter automatically since we can't know the WER
1087
+ return False
1088
+
1089
+ filter_by_wer_threshold = partial(
1090
+ raw_datasets["train"].filter,
1091
+ function=is_wer_in_range,
1092
+ input_columns=["text", "whisper_transcript"],
1093
+ )
1094
+
1095
+ if wer_threshold is not None and use_pseudo_labels:
1096
+ raw_datasets["train"] = (
1097
+ filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1098
+ if not data_args.streaming
1099
+ else filter_by_wer_threshold()
1100
+ )
1101
+
1102
+ # 10.4: pre-process training/evaluation datasets
1103
+ def prepare_train_dataset(batch):
1104
+ """
1105
+ Pre-process the raw dataset in a three stage process:
1106
+ 1. Convert the audio arrays to log-mel spectrogram inputs
1107
+ 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1108
+ 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1109
+ """
1110
+ # process audio input
1111
+ audio = [sample["array"] for sample in batch["audio"]]
1112
+ inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1113
+ batch["input_features"] = inputs.input_features
1114
+ batch["input_length"] = [len(sample) for sample in audio]
1115
+
1116
+ # process text targets - for training these are the Whisper-generated pseudo-labels
1117
+ input_str_batched = batch[train_text_column_name]
1118
+ condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1119
+
1120
+ all_token_ids = []
1121
+ all_token_ids_unprompted = []
1122
+ for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1123
+ token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1124
+
1125
+ # check whether we have timestamps in the PLs and filter if required
1126
+ has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1127
+ if has_timestamps:
1128
+ # sample from binomial distribution to get probability of training on timestamps
1129
+ predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1130
+ if not predict_timestamps:
1131
+ # filter timestamps and insert the <|notimestamps|> task token
1132
+ token_ids = [token for token in token_ids if token < timestamp_begin]
1133
+ token_ids.insert(timestamp_position, timestamp_begin)
1134
+
1135
+ all_token_ids_unprompted.append(token_ids)
1136
+ # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1137
+ condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1138
+ if not condition_on_prev:
1139
+ prev_ids = None
1140
+ elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1141
+ # prompt ids are the penultimate token ids in the batch
1142
+ prev_ids = all_token_ids_unprompted[-2]
1143
+
1144
+ if prev_ids is not None:
1145
+ if has_timestamps and not predict_timestamps:
1146
+ # filter timestamp ids from prompt when not predicting timestamps
1147
+ prev_ids = [token for token in prev_ids if token < timestamp_begin]
1148
+
1149
+ # check that the length of the prompt does not exceed more than half the max label length (224)
1150
+ if len(prev_ids) > prompt_cutoff_length:
1151
+ prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1152
+ prev_ids = [decoder_prev_token_id] + prev_ids
1153
+
1154
+ # and that the total length of the labels does not exceed the max label length (448)
1155
+ if len(prev_ids + token_ids) > max_label_length:
1156
+ trim_length = len(prev_ids + token_ids) - max_label_length + 1
1157
+ prev_ids = prev_ids[trim_length:]
1158
+ prev_ids = [decoder_prev_token_id] + prev_ids
1159
+
1160
+ token_ids = prev_ids + token_ids
1161
+
1162
+ all_token_ids.append(token_ids)
1163
+
1164
+ batch["labels"] = all_token_ids
1165
+ return batch
1166
+
1167
+ def prepare_eval_dataset(batch):
1168
+ # process audio input
1169
+ sample = batch["audio"]
1170
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1171
+ batch["input_features"] = inputs.input_features[0]
1172
+ batch["input_length"] = len(sample["array"])
1173
+
1174
+ # process targets - for evaluation these are the ground-truth transcriptions
1175
+ input_str = batch["text"]
1176
+ batch["labels"] = tokenizer(input_str).input_ids
1177
+ return batch
1178
+
1179
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1180
+ if training_args.do_train:
1181
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1182
+ # we can use `num_workers` (which is much faster)
1183
+ # We gate the pre-processing function accordingly
1184
+ map_fn_train = partial(
1185
+ raw_datasets["train"].map,
1186
+ function=prepare_train_dataset,
1187
+ remove_columns=raw_datasets_train_features,
1188
+ batched=True,
1189
+ batch_size=data_args.preprocessing_batch_size,
1190
+ )
1191
+ if accelerator.is_main_process:
1192
+ vectorized_datasets["train"] = (
1193
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1194
+ if not data_args.streaming
1195
+ else map_fn_train()
1196
+ )
1197
+ if training_args.do_eval:
1198
+ for eval_split in all_eval_splits:
1199
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1200
+ map_fn_eval = partial(
1201
+ raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1202
+ )
1203
+ if accelerator.is_main_process:
1204
+ vectorized_datasets[eval_split] = (
1205
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1206
+ if not data_args.streaming
1207
+ else map_fn_eval()
1208
+ )
1209
+
1210
+ # 10.5: Filter training data with inputs longer than `max_input_length`
1211
+ def is_audio_in_length_range(length):
1212
+ return min_input_length < length < max_input_length
1213
+
1214
+ filter_by_audio_fn = partial(
1215
+ vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1216
+ )
1217
+ if accelerator.is_main_process:
1218
+ vectorized_datasets = (
1219
+ filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1220
+ if not data_args.streaming
1221
+ else filter_by_audio_fn()
1222
+ )
1223
+
1224
+ # 10.6: Filter training data with labels longer than `max_label_length`
1225
+ def is_labels_in_length_range(labels):
1226
+ return 0 < len(labels) <= max_label_length
1227
+
1228
+ filter_by_labels_fn = partial(
1229
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1230
+ )
1231
+ if accelerator.is_main_process:
1232
+ vectorized_datasets = (
1233
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1234
+ if not data_args.streaming
1235
+ else filter_by_labels_fn()
1236
+ )
1237
+
1238
+ # Pre-processing complete!
1239
+ # For large datasets it is advised to run the preprocessing on a
1240
+ # single machine first with `--preprocessing_only` since there will mostly likely
1241
+ # be a timeout when running the script in distributed mode.
1242
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1243
+ # cached dataset
1244
+ if data_args.preprocessing_only:
1245
+ if data_args.streaming:
1246
+ raise ValueError(
1247
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1248
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1249
+ "on the fly with streaming mode."
1250
+ )
1251
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1252
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1253
+ return
1254
+
1255
+ # 11. Define Evaluation Metrics
1256
+ def compute_metrics(preds, labels):
1257
+ # replace padded labels by the padding token
1258
+ for idx in range(len(labels)):
1259
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1260
+
1261
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1262
+ # we do not want to group tokens when computing the metrics
1263
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1264
+ wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1265
+
1266
+ # normalize everything and re-compute the WER
1267
+ norm_pred_str = [normalizer(pred) for pred in pred_str]
1268
+ norm_label_str = [normalizer(label) for label in label_str]
1269
+ # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1270
+ pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1271
+ label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1272
+ # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1273
+ norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1274
+ norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1275
+
1276
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1277
+ return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1278
+
1279
+ # 12. Define Training Schedule
1280
+ # Store some constants
1281
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1282
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1283
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1284
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1285
+
1286
+ if not data_args.streaming and training_args.max_steps < 0:
1287
+ num_epochs = int(training_args.num_train_epochs)
1288
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1289
+ total_train_steps = steps_per_epoch * num_epochs
1290
+ elif training_args.max_steps > 0:
1291
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1292
+ total_train_steps = int(training_args.max_steps)
1293
+ if not data_args.streaming:
1294
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1295
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1296
+ else:
1297
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1298
+ num_epochs = sys.maxsize
1299
+ steps_per_epoch = total_train_steps
1300
+ else:
1301
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1302
+
1303
+ if training_args.eval_steps is None:
1304
+ logger.info(
1305
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1306
+ )
1307
+ eval_steps = steps_per_epoch
1308
+ else:
1309
+ eval_steps = training_args.eval_steps
1310
+
1311
+ # 13. Define optimizer, LR scheduler, collator
1312
+ decay_parameters = get_parameter_names(
1313
+ student_model,
1314
+ [nn.LayerNorm],
1315
+ forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1316
+ )
1317
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1318
+ optimizer_grouped_parameters = [
1319
+ {
1320
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1321
+ "weight_decay": training_args.weight_decay,
1322
+ },
1323
+ {
1324
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1325
+ "weight_decay": 0.0,
1326
+ },
1327
+ ]
1328
+ optimizer = torch.optim.AdamW(
1329
+ params=optimizer_grouped_parameters,
1330
+ lr=training_args.learning_rate,
1331
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1332
+ eps=training_args.adam_epsilon,
1333
+ )
1334
+
1335
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1336
+ lr_scheduler = get_scheduler(
1337
+ name=training_args.lr_scheduler_type,
1338
+ optimizer=optimizer,
1339
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1340
+ num_training_steps=total_train_steps * accelerator.num_processes,
1341
+ )
1342
+
1343
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1344
+ processor=processor,
1345
+ decoder_start_token_id=decoder_start_token_id,
1346
+ decoder_prev_token_id=decoder_prev_token_id,
1347
+ input_padding="longest",
1348
+ target_padding="max_length",
1349
+ max_target_length=max_label_length,
1350
+ )
1351
+
1352
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1353
+ # so that we can still access the configs
1354
+ num_beams = (
1355
+ training_args.generation_num_beams
1356
+ if training_args.generation_num_beams is not None
1357
+ else getattr(student_model.generation_config, "num_beams", 1)
1358
+ )
1359
+
1360
+ gen_kwargs = {
1361
+ "max_length": max_label_length,
1362
+ "num_beams": num_beams,
1363
+ "return_timestamps": return_timestamps,
1364
+ }
1365
+ if is_multilingual:
1366
+ # forcing the language and task tokens helps multilingual models in their generations
1367
+ gen_kwargs.update(
1368
+ {
1369
+ "language": data_args.language,
1370
+ "task": data_args.task,
1371
+ }
1372
+ )
1373
+
1374
+ # 15. Prepare everything with accelerate
1375
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1376
+ student_model, teacher_model, optimizer, lr_scheduler
1377
+ )
1378
+
1379
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1380
+ kl_loss = nn.KLDivLoss(reduction="none")
1381
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1382
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1383
+ padding_mask = labels >= 0
1384
+ padding_mask = padding_mask.unsqueeze(-1)
1385
+ divergence = divergence * padding_mask
1386
+ # take the average over the mini-batch
1387
+ divergence = divergence.sum() / padding_mask.sum()
1388
+ return divergence
1389
+
1390
+ # Define gradient update step fn
1391
+ def train_step(
1392
+ batch,
1393
+ temperature=2.0,
1394
+ ):
1395
+ student_model.train()
1396
+ teacher_model.eval()
1397
+
1398
+ student_outputs = student_model(**batch)
1399
+ with torch.no_grad():
1400
+ if share_hidden_states:
1401
+ # if the student and teacher share the same frozen encoder then we don't have to recompute the
1402
+ # encoder hidden-states for the teacher model, we can just re-use from the student
1403
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1404
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1405
+ else:
1406
+ # do the full forward pass for the teacher model (encoder + decoder)
1407
+ teacher_outputs = teacher_model(**batch)
1408
+
1409
+ # CE (data) loss
1410
+ ce_loss = student_outputs.loss
1411
+ # rescale distribution by temperature to ensure gradients scale correctly
1412
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1413
+ # log softmax of student predictions for numerical stability
1414
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1415
+ # KL-divergence loss (scaled by temperature)
1416
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1417
+
1418
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1419
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1420
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1421
+ return loss, metrics
1422
+
1423
+ # Define eval fn
1424
+ def eval_step(batch):
1425
+ student_model.eval()
1426
+ teacher_model.eval()
1427
+
1428
+ with torch.no_grad():
1429
+ student_outputs = student_model(**batch)
1430
+ if share_hidden_states:
1431
+ encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1432
+ teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1433
+ else:
1434
+ teacher_outputs = teacher_model(**batch)
1435
+
1436
+ # CE (data) loss
1437
+ ce_loss = student_outputs.loss
1438
+
1439
+ # log softmax / softmax for numerical stability
1440
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1441
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1442
+ # temperature is always 1 for eval
1443
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1444
+
1445
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1446
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1447
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1448
+ return metrics
1449
+
1450
+ def generate_step(batch):
1451
+ student_model.eval()
1452
+ output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1453
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1454
+ return output_ids
1455
+
1456
+ logger.info("***** Running training *****")
1457
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1458
+ if not data_args.streaming:
1459
+ logger.info(f" Num epochs = {num_epochs}")
1460
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1461
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1462
+ logger.info(
1463
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1464
+ )
1465
+ logger.info(f" Total optimization steps = {total_train_steps}")
1466
+
1467
+ # ======================== Training ================================
1468
+ train_time = 0
1469
+ train_start = time.time()
1470
+ steps_trained_progress_bar = tqdm(
1471
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1472
+ )
1473
+ continue_training = True
1474
+ epochs_trained = 0
1475
+ cur_step = 0
1476
+
1477
+ checkpoint = None
1478
+ if training_args.resume_from_checkpoint is not None:
1479
+ checkpoint = training_args.resume_from_checkpoint
1480
+ elif last_checkpoint is not None:
1481
+ checkpoint = last_checkpoint
1482
+
1483
+ if checkpoint is not None:
1484
+ accelerator.load_state(checkpoint)
1485
+ # Find num steps and epoch from saved state string pattern
1486
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1487
+ match = re.search(pattern, checkpoint)
1488
+ cur_step = int(match.group(1))
1489
+ epochs_trained = int(match.group(2))
1490
+
1491
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1492
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1493
+ logger.info(f" Continuing training from global step {cur_step}")
1494
+
1495
+ steps_trained_progress_bar.update(cur_step)
1496
+
1497
+ for epoch in range(0, epochs_trained):
1498
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1499
+
1500
+ if not data_args.streaming and training_args.max_steps < 0:
1501
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1502
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1503
+ else:
1504
+ # Currently we don't know how many steps we've taken in the current epoch
1505
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1506
+ # This is "good enough" for our purposes but not fully correct
1507
+ resume_step = None
1508
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1509
+ else:
1510
+ resume_step = None
1511
+
1512
+ for epoch in range(epochs_trained, num_epochs):
1513
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1514
+ train_dataloader = DataLoader(
1515
+ vectorized_datasets["train"],
1516
+ collate_fn=data_collator,
1517
+ batch_size=per_device_train_batch_size,
1518
+ num_workers=dataloader_num_workers,
1519
+ prefetch_factor=prefetch_factor,
1520
+ pin_memory=training_args.dataloader_pin_memory,
1521
+
1522
+ )
1523
+ train_dataloader = accelerator.prepare(train_dataloader)
1524
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1525
+ train_dataloader.dataset.set_epoch(epoch)
1526
+
1527
+ if resume_step is not None:
1528
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1529
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1530
+ resume_step = None
1531
+
1532
+ for batch in train_dataloader:
1533
+ with accelerator.accumulate(student_model):
1534
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1535
+ accelerator.backward(loss)
1536
+ if accelerator.sync_gradients:
1537
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1538
+ optimizer.step()
1539
+ lr_scheduler.step()
1540
+ optimizer.zero_grad()
1541
+
1542
+ # Check if the accelerator has performed an optimization step behind the scenes
1543
+ if accelerator.sync_gradients:
1544
+ steps_trained_progress_bar.update(1)
1545
+ cur_step += 1
1546
+
1547
+ if cur_step % training_args.logging_steps == 0:
1548
+ steps_trained_progress_bar.write(
1549
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1550
+ f" {train_metric['loss']}, Learning Rate:"
1551
+ f" {lr_scheduler.get_last_lr()[0]})"
1552
+ )
1553
+ log_metric(
1554
+ accelerator,
1555
+ metrics=train_metric,
1556
+ learning_rate=lr_scheduler.get_last_lr()[0],
1557
+ train_time=train_time + time.time() - train_start,
1558
+ step=cur_step,
1559
+ epoch=epoch,
1560
+ prefix="train",
1561
+ )
1562
+
1563
+ # save checkpoint and weights after each save_steps and at the end of training
1564
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1565
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1566
+ accelerator.save_state(output_dir=intermediate_dir)
1567
+ accelerator.wait_for_everyone()
1568
+ if accelerator.is_main_process:
1569
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1570
+
1571
+ if cur_step == total_train_steps:
1572
+ # un-wrap student model for save
1573
+ student_model = accelerator.unwrap_model(student_model)
1574
+ student_model.save_pretrained(training_args.output_dir)
1575
+ # re-wrap student model for final eval
1576
+ student_model = accelerator.prepare(student_model)
1577
+
1578
+ if training_args.push_to_hub:
1579
+ upload_folder(
1580
+ folder_path=training_args.output_dir,
1581
+ repo_id=repo_name,
1582
+ repo_type="model",
1583
+ commit_message=f"Saving train state of step {cur_step}",
1584
+ )
1585
+
1586
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1587
+ train_time += time.time() - train_start
1588
+ student_model.eval()
1589
+ # ======================== Evaluating ==============================
1590
+ for eval_split in all_eval_splits:
1591
+ eval_metrics = []
1592
+ eval_preds = []
1593
+ eval_labels = []
1594
+ eval_start = time.time()
1595
+
1596
+ validation_dataloader = DataLoader(
1597
+ vectorized_datasets[eval_split],
1598
+ collate_fn=data_collator,
1599
+ batch_size=per_device_eval_batch_size,
1600
+ drop_last=False,
1601
+ num_workers=dataloader_num_workers,
1602
+ prefetch_factor=prefetch_factor,
1603
+ pin_memory=training_args.dataloader_pin_memory,
1604
+ )
1605
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1606
+
1607
+ for batch in tqdm(
1608
+ validation_dataloader,
1609
+ desc=f"Evaluating {eval_split}...",
1610
+ position=2,
1611
+ disable=not accelerator.is_local_main_process,
1612
+ ):
1613
+ # Model forward
1614
+ eval_metric = eval_step(batch)
1615
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1616
+ eval_metrics.append(eval_metric)
1617
+
1618
+ # generation
1619
+ if training_args.predict_with_generate:
1620
+ generated_ids = generate_step(batch)
1621
+ # Gather all predictions and targets
1622
+ generated_ids, labels = accelerator.gather_for_metrics(
1623
+ (generated_ids, batch["labels"])
1624
+ )
1625
+ eval_preds.extend(generated_ids)
1626
+ eval_labels.extend(labels)
1627
+
1628
+ eval_time = time.time() - eval_start
1629
+ # normalize eval metrics
1630
+ eval_metrics = {
1631
+ key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1632
+ }
1633
+
1634
+ # compute WER metric
1635
+ wer_desc = ""
1636
+ if training_args.predict_with_generate:
1637
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1638
+ eval_preds, eval_labels
1639
+ )
1640
+ eval_metrics.update(wer_metric)
1641
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1642
+ log_pred(
1643
+ accelerator,
1644
+ pred_str,
1645
+ label_str,
1646
+ norm_pred_str,
1647
+ norm_label_str,
1648
+ step=cur_step,
1649
+ prefix=eval_split,
1650
+ )
1651
+
1652
+ # Print metrics and update progress bar
1653
+ steps_trained_progress_bar.write(
1654
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1655
+ f" {wer_desc})"
1656
+ )
1657
+
1658
+ log_metric(
1659
+ accelerator,
1660
+ metrics=eval_metrics,
1661
+ train_time=eval_time,
1662
+ step=cur_step,
1663
+ epoch=epoch,
1664
+ prefix=eval_split,
1665
+ )
1666
+
1667
+ # flush the train metrics
1668
+ train_start = time.time()
1669
+
1670
+ # break condition
1671
+ if cur_step == total_train_steps:
1672
+ continue_training = False
1673
+ break
1674
+
1675
+ if not continue_training:
1676
+ break
1677
+
1678
+ accelerator.end_training()
1679
+
1680
+
1681
+ if __name__ == "__main__":
1682
+ main()
1683
+
run_init.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python create_student_model.py \
4
+ --teacher_checkpoint "openai/whisper-large-v3" \
5
+ --encoder_layers 32 \
6
+ --decoder_layers 2 \
7
+ --save_dir "./distil-large-v3-init"
8
+
run_labelling.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ accelerate launch run_pseudo_labelling.py \
4
+ --model_name_or_path "openai/whisper-large-v3" \
5
+ --dataset_name "mozilla-foundation/common_voice_16_1" \
6
+ --dataset_config_name "hi" \
7
+ --dataset_split_name "train+validation+test" \
8
+ --text_column_name "sentence" \
9
+ --id_column_name "path" \
10
+ --output_dir "../common_voice_16_1_de_pseudo_labelled" \
11
+ --wandb_project "distil-whisper-labelling" \
12
+ --per_device_eval_batch_size 64 \
13
+ --dtype "bfloat16" \
14
+ --attn_implementation "flash_attention_2" \
15
+ --logging_steps 500 \
16
+ --max_label_length 256 \
17
+ --concatenate_audio \
18
+ --preprocessing_batch_size 500 \
19
+ --preprocessing_num_workers 8 \
20
+ --dataloader_num_workers 8 \
21
+ --report_to "wandb" \
22
+ --language "de" \
23
+ --task "transcribe" \
24
+ --return_timestamps \
25
+ --streaming False \
26
+ --generation_num_beams 1 \
27
+ --push_to_hub
28
+
run_pseudo_labelling.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pseudo-labelling audio data using the Whisper model in preparation for distillation.
18
+ """
19
+ import csv
20
+
21
+ # You can also adapt this script for your own pseudo-labelling tasks. Pointers for this are left as comments.
22
+ import logging
23
+ import os
24
+ import string
25
+ import sys
26
+ import time
27
+ import warnings
28
+ from dataclasses import dataclass, field
29
+ from datetime import timedelta
30
+ from pathlib import Path
31
+ from typing import Any, Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import evaluate
35
+ import numpy as np
36
+ import torch
37
+ import transformers
38
+ from accelerate import Accelerator, InitProcessGroupKwargs
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ DatasetDict,
42
+ IterableDatasetDict,
43
+ load_dataset,
44
+ )
45
+ from huggingface_hub import HfFolder, Repository, create_repo, get_full_repo_name
46
+ from torch.utils.data import DataLoader
47
+ from tqdm import tqdm
48
+ from transformers import (
49
+ HfArgumentParser,
50
+ Seq2SeqTrainingArguments,
51
+ WhisperConfig,
52
+ WhisperFeatureExtractor,
53
+ WhisperForConditionalGeneration,
54
+ WhisperProcessor,
55
+ WhisperTokenizerFast,
56
+ )
57
+ from transformers.models.whisper.english_normalizer import EnglishTextNormalizer, BasicTextNormalizer
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.34.0.dev0")
64
+
65
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
66
+
67
+ logger = get_logger(__name__)
68
+
69
+
70
+ @dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None,
81
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
82
+ )
83
+ tokenizer_name: Optional[str] = field(
84
+ default=None,
85
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
86
+ )
87
+ feature_extractor_name: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "feature extractor name or path if not the same as model_name"},
90
+ )
91
+ processor_name: Optional[str] = field(
92
+ default=None,
93
+ metadata={"help": "processor name or path if not the same as model_name"},
94
+ )
95
+ cache_dir: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
98
+ )
99
+ use_fast_tokenizer: bool = field(
100
+ default=True,
101
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
102
+ )
103
+ model_revision: str = field(
104
+ default="main",
105
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
106
+ )
107
+ subfolder: str = field(
108
+ default="",
109
+ metadata={
110
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
111
+ "specify the folder name here."
112
+ },
113
+ )
114
+ token: str = field(
115
+ default=None,
116
+ metadata={
117
+ "help": (
118
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
119
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
120
+ )
121
+ },
122
+ )
123
+ dtype: Optional[str] = field(
124
+ default="float32",
125
+ metadata={
126
+ "help": (
127
+ "The data type (dtype) in which to load the model weights. One of `float32` (full-precision), "
128
+ "`float16` or `bfloat16` (both half-precision)."
129
+ )
130
+ },
131
+ )
132
+ attn_implementation: Optional[str] = field(
133
+ default=None,
134
+ metadata={
135
+ "help": (
136
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
137
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
138
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
139
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
140
+ )
141
+ },
142
+ )
143
+ attn_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "Deprecated. Use `attn_implementation` instead."
147
+ },
148
+ )
149
+ def __post_init__(self):
150
+ if self.attn_type is not None and self.attn_implementation is None:
151
+ # set attn_implementation in a backwards compatible way
152
+ if self.attn_type == "flash_attn":
153
+ self.attn_implementation = "sdpa"
154
+ elif self.attn_type == "flash_attn_2":
155
+ self.attn_implementation = "flash_attention_2"
156
+ elif self.attn_type in [None, "eager", "sdpa", "flash_attention_2"]:
157
+ self.attn_implementation = self.attn_type
158
+ else:
159
+ raise ValueError(
160
+ f"Argument `--attn_type` is deprecated, and set to an invalid option `{self.attn_type}`. You should omit the argument `--attn_type`, and instead set `-attention_implementation` to one of the following:\n"
161
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
162
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
163
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
164
+ )
165
+ warnings.warn(f"Argument `--attn_type` is deprecated. Use `--attn_implementation` instead. Inferring `--attn_implementation={self.attn_implementation} from argument `--attn_type={self.attn_type}`.")
166
+ elif self.attn_type is not None and self.attn_implementation is not None:
167
+ raise ValueError("`--attn_type` and `--attn_implementation` are both specified. Only the argument `--attn_implementation`.")
168
+
169
+
170
+ @dataclass
171
+ class DataTrainingArguments:
172
+ """
173
+ Arguments pertaining to what data we are going to input our model for training and eval.
174
+ """
175
+
176
+ dataset_name: str = field(
177
+ default=None,
178
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
179
+ )
180
+ dataset_config_name: Optional[str] = field(
181
+ default=None,
182
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
183
+ )
184
+ dataset_cache_dir: Optional[str] = field(
185
+ default=None,
186
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
187
+ )
188
+ overwrite_cache: bool = field(
189
+ default=False,
190
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
191
+ )
192
+ preprocessing_num_workers: Optional[int] = field(
193
+ default=None,
194
+ metadata={"help": "The number of processes to use for the preprocessing."},
195
+ )
196
+ preprocessing_batch_size: Optional[int] = field(
197
+ default=500,
198
+ metadata={"help": "The batch size to use for the dataset pre-processing."},
199
+ )
200
+ audio_column_name: str = field(
201
+ default="audio",
202
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
203
+ )
204
+ text_column_name: str = field(
205
+ default="text",
206
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'."},
207
+ )
208
+ id_column_name: str = field(
209
+ default="id",
210
+ metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
211
+ )
212
+ speaker_id_column_name: str = field(
213
+ default=None,
214
+ metadata={"help": "The name of the dataset column containing the speaker id data. Defaults to None."},
215
+ )
216
+ max_duration_in_seconds: float = field(
217
+ default=30.0,
218
+ metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
219
+ )
220
+ max_label_length: int = field(
221
+ default=256,
222
+ metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
223
+ )
224
+ concatenate_audio: bool = field(
225
+ default=True,
226
+ metadata={"help": "Whether or not to concatenate the audio samples to `max_duration_in_seconds`."},
227
+ )
228
+ preprocessing_only: bool = field(
229
+ default=False,
230
+ metadata={
231
+ "help": (
232
+ "Whether to only do data preprocessing and skip training. This is"
233
+ " especially useful when data preprocessing errors out in distributed"
234
+ " training due to timeout. In this case, one should run the"
235
+ " preprocessing in a non-distributed setup with"
236
+ " `preprocessing_only=True` so that the cached datasets can"
237
+ " consequently be loaded in distributed training"
238
+ )
239
+ },
240
+ )
241
+ dataset_split_name: str = field(
242
+ default="train+validation+test",
243
+ metadata={
244
+ "help": (
245
+ "The name of the data set splits to use (via the datasets library)."
246
+ " Defaults to 'train+validation+test'. Multiple splits can be passed by splitting a"
247
+ " list through the '+' character, e.g. 'train+validation' will"
248
+ " pseudo-label both the 'train' and 'validation' splits sequentially."
249
+ )
250
+ },
251
+ )
252
+ wandb_project: str = field(
253
+ default="distil-whisper",
254
+ metadata={"help": "The name of the wandb project."},
255
+ )
256
+ streaming: bool = field(
257
+ default=False,
258
+ metadata={"help": "Whether to use dataset's streaming mode to load and pre-process the data."},
259
+ )
260
+ max_samples_per_split: Optional[int] = field(
261
+ default=None,
262
+ metadata={"help": "For debugging purposes, truncate the number of examples per split to this value if set."},
263
+ )
264
+ return_timestamps: bool = field(
265
+ default=False,
266
+ metadata={
267
+ "help": "Whether to return the timestamps with the text. This enables the `FlaxWhisperTimestampsLogitsProcessor`."
268
+ },
269
+ )
270
+ language: str = field(
271
+ default=None,
272
+ metadata={
273
+ "help": (
274
+ "Language for multilingual distillation. This argument should be set for multilingual distillation "
275
+ "only. For English speech recognition, it should be left as `None`."
276
+ )
277
+ },
278
+ )
279
+ task: str = field(
280
+ default="transcribe",
281
+ metadata={
282
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
283
+ "This argument should be set for multilingual distillation only. For English speech recognition, it should be left as `None`."
284
+ },
285
+ )
286
+ decode_token_ids: bool = field(
287
+ default=True,
288
+ metadata={"help": "Deprecated. The predicted token ids should always be decoded to text transcriptions."},
289
+ )
290
+ private_dataset: bool = field(
291
+ default=False,
292
+ metadata={"help": "Whether or not to create a private dataset for the pseudo-labelled data."},
293
+ )
294
+
295
+ def __post_init__(self):
296
+ if not self.decode_token_ids:
297
+ raise ValueError(
298
+ "The argument `--decode_token_ids` is deprecated. The token ids are now always decoded to "
299
+ "their corresponding text string. This is following a fix to the merges of the Whisper tokenizer"
300
+ "on the Hugging Face Hub: https://huggingface.co/openai/whisper-large-v2/discussions/100. "
301
+ "You should either omit the argument `--decode_token_ids`, or set it to True explicitly."
302
+ )
303
+
304
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
305
+ """
306
+ Shift label ids one token to the right.
307
+ """
308
+ shifted_label_ids = np.zeros_like(label_ids)
309
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
310
+ shifted_label_ids[:, 0] = decoder_start_token_id
311
+
312
+ return shifted_label_ids
313
+
314
+
315
+ @dataclass
316
+ class DataCollatorSpeechSeq2SeqWithPadding:
317
+ """
318
+ Data collator that will dynamically pad the inputs received.
319
+ Args:
320
+ processor ([`Wav2Vec2Processor`])
321
+ The processor used for proccessing the data.
322
+ decoder_start_token_id (:obj: `int`)
323
+ The start-of-sequence token id of the decoder.
324
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
325
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
326
+ among:
327
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
328
+ sequence if provided).
329
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
330
+ maximum acceptable input length for the model if that argument is not provided.
331
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
332
+ different lengths).
333
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
334
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
335
+ See above for details.
336
+ max_target_length (:obj:`int`, `optional`):
337
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
338
+ """
339
+
340
+ processor: Any
341
+ decoder_start_token_id: int
342
+ input_padding: Union[bool, str] = "max_length"
343
+ target_padding: Union[bool, str] = "max_length"
344
+ max_target_length: Optional[int] = None
345
+
346
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
347
+ # split inputs and labels since they have to be of different lengths and need
348
+ # different padding methods
349
+ model_input_name = self.processor.model_input_names[0]
350
+
351
+ # dataloader returns a list of features which we convert to a dict
352
+ input_features = {model_input_name: [feature[model_input_name] for feature in features]}
353
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
354
+ file_ids = [feature["file_id"] for feature in features]
355
+
356
+ # reformat list to dict and set to pytorch format
357
+ batch = self.processor.feature_extractor.pad(
358
+ input_features,
359
+ padding=self.input_padding,
360
+ return_tensors="pt",
361
+ )
362
+
363
+ labels_batch = self.processor.tokenizer.pad(
364
+ label_features,
365
+ max_length=self.max_target_length,
366
+ padding=self.target_padding,
367
+ return_tensors="pt",
368
+ )
369
+
370
+ # replace padding with -100 to ignore correctly when computing the loss
371
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
372
+
373
+ # if bos token is appended in previous tokenization step,
374
+ # cut bos token here as it's append later anyways
375
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
376
+ labels = labels[:, 1:]
377
+
378
+ batch["labels"] = labels
379
+ batch["file_ids"] = file_ids
380
+
381
+ return batch
382
+
383
+
384
+ def log_metric(
385
+ accelerator,
386
+ metrics: Dict,
387
+ train_time: float,
388
+ prefix: str = "eval",
389
+ ):
390
+ """Helper function to log all evaluation metrics with the correct prefixes and styling."""
391
+ log_metrics = {}
392
+ for k, v in metrics.items():
393
+ log_metrics[f"{prefix}/{k}"] = v
394
+ log_metrics[f"{prefix}/time"] = train_time
395
+ accelerator.log(log_metrics)
396
+
397
+
398
+ def log_pred(
399
+ accelerator,
400
+ pred_str: List[str],
401
+ label_str: List[str],
402
+ norm_pred_str: List[str],
403
+ norm_label_str: List[str],
404
+ prefix: str = "eval",
405
+ num_lines: int = 200000,
406
+ ):
407
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
408
+ if accelerator.is_main_process:
409
+ wandb_tracker = accelerator.get_tracker("wandb")
410
+ # pretty name for split
411
+ prefix = prefix.replace("/", "-")
412
+
413
+ # convert str data to a wandb compatible format
414
+ str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
415
+ # log as a table with the appropriate headers
416
+ wandb_tracker.log_table(
417
+ table_name=f"{prefix}/all_predictions",
418
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
419
+ data=str_data[:num_lines],
420
+ )
421
+
422
+ # log incorrect normalised predictions
423
+ str_data = np.asarray(str_data)
424
+ str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
425
+ # log as a table with the appropriate headers
426
+ wandb_tracker.log_table(
427
+ table_name=f"{prefix}/incorrect_predictions",
428
+ columns=["Target", "Pred", "Norm Target", "Norm Pred"],
429
+ data=str_data_incorrect[:num_lines],
430
+ )
431
+
432
+
433
+ def main():
434
+ # 1. Parse input arguments
435
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
436
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
437
+
438
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
439
+ # If we pass only one argument to the script and it's the path to a json file,
440
+ # let's parse it to get our arguments.
441
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
442
+ else:
443
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
444
+
445
+ # 2. Initialize the accelerator
446
+ # We will let the accelerator handle device placement for us in this example
447
+ # We simply have to specify the training precision and any trackers being used
448
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
449
+ # it to accelerate format
450
+ if model_args.dtype == "float16":
451
+ mixed_precision = "fp16"
452
+ torch_dtype = torch.float16
453
+ elif model_args.dtype == "bfloat16":
454
+ mixed_precision = "bf16"
455
+ torch_dtype = torch.bfloat16
456
+ else:
457
+ mixed_precision = "no"
458
+ torch_dtype = torch.float32
459
+
460
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
461
+
462
+ accelerator = Accelerator(
463
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
464
+ mixed_precision=mixed_precision,
465
+ log_with=training_args.report_to,
466
+ project_dir=training_args.output_dir,
467
+ kwargs_handlers=[kwargs],
468
+ )
469
+
470
+ accelerator.init_trackers(project_name=data_args.wandb_project)
471
+
472
+ # 3. Set-up basic logging
473
+ # Create one log on every process with the configuration for debugging
474
+ logging.basicConfig(
475
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
476
+ datefmt="%m/%d/%Y %H:%M:%S",
477
+ level=logging.INFO,
478
+ )
479
+ # Log a small summary on each proces
480
+ logger.warning(
481
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
482
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
483
+ )
484
+
485
+ # Set the verbosity to info of the Transformers logger (on main process only)
486
+ if accelerator.is_local_main_process:
487
+ datasets.utils.logging.set_verbosity_warning()
488
+ transformers.utils.logging.set_verbosity_info()
489
+ else:
490
+ datasets.utils.logging.set_verbosity_error()
491
+ transformers.utils.logging.set_verbosity_error()
492
+ logger.info("Training/evaluation parameters %s", training_args)
493
+
494
+ # 3. Load dataset
495
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
496
+ token = model_args.token if model_args.token is not None else HfFolder().get_token()
497
+
498
+ data_splits = data_args.dataset_split_name.split("+")
499
+ for split in data_splits:
500
+ if data_args.streaming:
501
+ raw_datasets[split] = load_dataset(
502
+ data_args.dataset_name,
503
+ data_args.dataset_config_name,
504
+ split=split,
505
+ cache_dir=data_args.dataset_cache_dir,
506
+ token=token,
507
+ streaming=True,
508
+ )
509
+ else:
510
+ raw_datasets[split] = load_dataset(
511
+ data_args.dataset_name,
512
+ data_args.dataset_config_name,
513
+ split=split,
514
+ cache_dir=data_args.dataset_cache_dir,
515
+ token=token,
516
+ streaming=False,
517
+ num_proc=data_args.preprocessing_num_workers,
518
+ )
519
+
520
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
521
+ raise ValueError(
522
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
523
+ f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
524
+ " the correct audio column - one of"
525
+ f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
526
+ )
527
+
528
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
529
+ raise ValueError(
530
+ f"--text_column_name {data_args.text_column_name} not found in dataset"
531
+ f" '{data_args.dataset_name}'. Make sure to set `--text_column_name` to the"
532
+ " correct text column - one of"
533
+ f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
534
+ )
535
+
536
+ # 7. Load pretrained model, tokenizer, and feature extractor
537
+ config = WhisperConfig.from_pretrained(
538
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
539
+ cache_dir=model_args.cache_dir,
540
+ revision=model_args.model_revision,
541
+ token=token,
542
+ )
543
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
544
+ (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
545
+ cache_dir=model_args.cache_dir,
546
+ revision=model_args.model_revision,
547
+ token=token,
548
+ )
549
+ tokenizer = WhisperTokenizerFast.from_pretrained(
550
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
551
+ cache_dir=model_args.cache_dir,
552
+ use_fast=model_args.use_fast_tokenizer,
553
+ revision=model_args.model_revision,
554
+ token=token,
555
+ )
556
+ processor = WhisperProcessor.from_pretrained(
557
+ (model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
558
+ cache_dir=model_args.cache_dir,
559
+ revision=model_args.model_revision,
560
+ token=token,
561
+ )
562
+
563
+ model = WhisperForConditionalGeneration.from_pretrained(
564
+ model_args.model_name_or_path,
565
+ config=config,
566
+ cache_dir=model_args.cache_dir,
567
+ revision=model_args.model_revision,
568
+ subfolder=model_args.subfolder,
569
+ token=token,
570
+ low_cpu_mem_usage=True,
571
+ torch_dtype=torch_dtype,
572
+ attn_implementation=model_args.attn_implementation,
573
+ )
574
+ model.eval()
575
+
576
+ if model.config.decoder_start_token_id is None:
577
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
578
+
579
+ return_timestamps = data_args.return_timestamps
580
+ if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
581
+ is_multilingual = True
582
+ # We need to set the language and task ids for multilingual checkpoints
583
+ tokenizer.set_prefix_tokens(
584
+ language=data_args.language, task=data_args.task, predict_timestamps=return_timestamps
585
+ )
586
+ elif data_args.language is not None:
587
+ raise ValueError(
588
+ "Setting language token for an English-only checkpoint is not permitted. The language argument should "
589
+ "only be set for multilingual checkpoints."
590
+ )
591
+ else:
592
+ is_multilingual = False
593
+
594
+ # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
595
+ # so we just need to set the correct target sampling rate.
596
+ raw_datasets = raw_datasets.cast_column(
597
+ data_args.audio_column_name,
598
+ datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
599
+ )
600
+
601
+ # 7. Preprocessing the datasets.
602
+ # We need to read the audio files as arrays and tokenize the targets.
603
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
604
+ max_label_length = (
605
+ data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
606
+ )
607
+ audio_column_name = data_args.audio_column_name
608
+ sampling_rate = feature_extractor.sampling_rate
609
+
610
+ preprocessing_batch_size = data_args.preprocessing_batch_size
611
+ num_workers = data_args.preprocessing_num_workers
612
+ dataloader_num_workers = training_args.dataloader_num_workers
613
+
614
+ text_column_name = data_args.text_column_name
615
+ model_input_name = feature_extractor.model_input_names[0]
616
+ id_column_name = data_args.id_column_name
617
+ speaker_id_column_name = data_args.speaker_id_column_name
618
+ normalizer = (
619
+ BasicTextNormalizer() if data_args.language is not None else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
620
+ )
621
+
622
+ timestamp_position = 3 if is_multilingual else 1
623
+ decoder_prev_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
624
+ decoder_eot_token_id = tokenizer.eos_token_id
625
+
626
+ if data_args.max_samples_per_split is not None:
627
+ for split in data_splits:
628
+ raw_datasets[split] = (
629
+ raw_datasets[split].take(data_args.max_samples_per_split)
630
+ if data_args.streaming
631
+ else raw_datasets[split].select(range(data_args.max_samples_per_split))
632
+ )
633
+
634
+ if speaker_id_column_name is not None:
635
+ raw_datasets = raw_datasets.sort(speaker_id_column_name)
636
+
637
+ def concatenate_dataset(batch):
638
+ audio = [sample["array"] for sample in batch[audio_column_name]]
639
+ input_lengths = [len(sample) for sample in audio]
640
+
641
+ text = batch[text_column_name]
642
+ speaker_id = batch[speaker_id_column_name] if speaker_id_column_name else len(text) * [None]
643
+
644
+ concatenated_audio = []
645
+ concatenated_text = []
646
+ concatenated_speaker = []
647
+ condition_on_prev = []
648
+ audio_sample = audio[0]
649
+ text_sample = text[0]
650
+
651
+ for idx in range(1, len(audio)):
652
+ prev_speaker = speaker_id[idx - 1]
653
+ speaker = speaker_id[idx]
654
+
655
+ if len(audio_sample) + input_lengths[idx] < max_input_length:
656
+ if speaker == prev_speaker:
657
+ # we have no information about whether the segments follow on sequentially
658
+ # so we just ensure the same speaker as we concatenate across files
659
+ audio_sample = np.append(audio_sample, audio[idx])
660
+ # extra spaces in the text transcription don't matter, since we only use it for the WER computation
661
+ text_sample += " " + text[idx]
662
+ else:
663
+ # speakers do not follow sequentially, save the audio and start looping again
664
+ concatenated_audio.append(audio_sample)
665
+ concatenated_text.append(text_sample)
666
+ concatenated_speaker.append(speaker)
667
+ condition_on_prev.append(0)
668
+ audio_sample = audio[idx]
669
+ text_sample = text[idx]
670
+
671
+ else:
672
+ # concatenated audio exceeds max length, save the audio and start looping again
673
+ concatenated_audio.append(audio_sample)
674
+ concatenated_text.append(text_sample)
675
+ concatenated_speaker.append(speaker)
676
+ condition_on_prev.append(1)
677
+ audio_sample = audio[idx]
678
+ text_sample = text[idx]
679
+
680
+ batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concatenated_audio]
681
+ batch[text_column_name] = concatenated_text
682
+ batch[id_column_name] = concatenated_speaker
683
+ batch["condition_on_prev"] = condition_on_prev
684
+
685
+ return batch
686
+
687
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
688
+ if data_args.concatenate_audio and not data_args.streaming:
689
+ raw_datasets = raw_datasets.map(
690
+ concatenate_dataset,
691
+ batched=True,
692
+ batch_size=preprocessing_batch_size,
693
+ num_proc=num_workers,
694
+ remove_columns=set(raw_datasets_features) - {audio_column_name, text_column_name, id_column_name, "condition_on_prev"},
695
+ desc="Concatenating dataset...",
696
+ )
697
+
698
+ raw_datasets = raw_datasets.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
699
+ pretty_name = data_args.dataset_name.split("/")[-1]
700
+
701
+ def postprocess_ids(speaker_ids, indices):
702
+ speaker_ids_formatted = []
703
+ for speaker, idx in zip(speaker_ids, indices):
704
+ formatted_idx = f"{pretty_name}-{speaker}-{idx}" if speaker is not None else f"{pretty_name}-{idx}"
705
+ speaker_ids_formatted.append(formatted_idx)
706
+ return {id_column_name: speaker_ids_formatted}
707
+
708
+ raw_datasets = raw_datasets.map(
709
+ postprocess_ids,
710
+ input_columns=[id_column_name],
711
+ with_indices=True,
712
+ desc="Setting sample idxs...",
713
+ batched=True,
714
+ batch_size=preprocessing_batch_size,
715
+ num_proc=num_workers,
716
+ )
717
+ else:
718
+ raise ValueError(
719
+ "Streaming mode is not yet compatible with concatenating audios to `max_duration_in_seconds`."
720
+ "Either set `--streaming=False` and download the audios locally, or open an issue on the Distil-Whisper repo to request this feature."
721
+ )
722
+
723
+ def prepare_dataset(batch):
724
+ # process audio
725
+ sample = batch[audio_column_name]
726
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
727
+ # process audio length
728
+ batch[model_input_name] = inputs.get(model_input_name)[0]
729
+
730
+ # process targets
731
+ input_str = batch[text_column_name]
732
+ batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
733
+
734
+ # record the id of the sample as token ids
735
+ batch["file_id"] = batch[id_column_name]
736
+ return batch
737
+
738
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
739
+ if data_args.streaming:
740
+ vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
741
+ else:
742
+ vectorized_datasets = raw_datasets.map(
743
+ prepare_dataset,
744
+ remove_columns=raw_datasets_features,
745
+ num_proc=num_workers,
746
+ desc="preprocess dataset",
747
+ )
748
+
749
+ # for large datasets it is advised to run the preprocessing on a
750
+ # single machine first with `args.preprocessing_only` since there will mostly likely
751
+ # be a timeout when running the script in distributed mode.
752
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
753
+ # cached dataset
754
+ if data_args.preprocessing_only:
755
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
756
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
757
+ return
758
+
759
+ if data_args.streaming and dataloader_num_workers > 0:
760
+ logger.warning(
761
+ "Using multiple dataloader num workers with streaming mode will result in different shards of "
762
+ "data being transcribed in parallel. This is not advised if you want to preserve the order of the "
763
+ "audio-text data."
764
+ )
765
+
766
+ # Handle the repository creation
767
+ output_dir = training_args.output_dir
768
+ if training_args.push_to_hub:
769
+ if training_args.hub_model_id is None:
770
+ repo_name = get_full_repo_name(
771
+ Path(output_dir).absolute().name,
772
+ token=token,
773
+ )
774
+ else:
775
+ repo_name = training_args.hub_model_id
776
+ create_repo(repo_name, exist_ok=True, token=token, repo_type="dataset", private=data_args.private_dataset)
777
+ repo = Repository(
778
+ output_dir,
779
+ clone_from=repo_name,
780
+ token=token,
781
+ repo_type="dataset",
782
+ )
783
+ # Ensure large txt files can be pushed to the Hub with git-lfs
784
+ with open(os.path.join(output_dir, ".gitattributes"), "r+") as f:
785
+ git_lfs_extensions = f.read()
786
+ if "*.csv" not in git_lfs_extensions:
787
+ f.write("*.csv filter=lfs diff=lfs merge=lfs -text")
788
+ else:
789
+ # this is where we'll save our transcriptions
790
+ if not os.path.exists(output_dir):
791
+ os.makedirs(output_dir)
792
+
793
+ # 8. Load Metric
794
+ metric = evaluate.load("wer")
795
+
796
+ def compute_metrics(preds, labels, file_ids):
797
+ # replace padded labels by the padding token
798
+ for idx in range(len(labels)):
799
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
800
+
801
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
802
+ # we do not want to group tokens when computing the metrics
803
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
804
+
805
+ # normalize everything and re-compute the WER
806
+ norm_pred_str = [normalizer(pred) for pred in pred_str]
807
+ norm_label_str = [normalizer(label) for label in label_str]
808
+ # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
809
+ pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
810
+ label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
811
+ file_ids = [file_ids[i] for i in range(len(file_ids)) if len(norm_label_str[i]) > 0]
812
+ # filtering step to only evaluate the samples that correspond to non-zero normalized references:
813
+ norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
814
+ norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
815
+
816
+ wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
817
+
818
+ return {"wer": wer}, pred_str, label_str, norm_pred_str, norm_label_str, file_ids
819
+
820
+ def filter_eot_tokens(preds):
821
+ for idx in range(len(preds)):
822
+ # remove the EOT tokens to get the 'true' token length
823
+ token_ids = [token for token in preds[idx] if token != decoder_eot_token_id]
824
+ token_ids = token_ids + [decoder_eot_token_id]
825
+ preds[idx] = token_ids
826
+ return preds
827
+
828
+ # 12. Define Training Schedule
829
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
830
+
831
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
832
+ processor=processor,
833
+ decoder_start_token_id=model.config.decoder_start_token_id, # <|startoftranscript|>
834
+ input_padding="longest",
835
+ target_padding="max_length",
836
+ max_target_length=max_label_length,
837
+ )
838
+
839
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
840
+ # so that we can still access the configs
841
+ num_beams = (
842
+ training_args.generation_num_beams
843
+ if training_args.generation_num_beams is not None
844
+ else getattr(model.generation_config, "num_beams", 1)
845
+ )
846
+
847
+ gen_kwargs = {
848
+ "max_length": max_label_length,
849
+ "num_beams": num_beams,
850
+ "return_timestamps": return_timestamps,
851
+ }
852
+ if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
853
+ # forcing the language and task tokens helps multilingual models in their generations
854
+ gen_kwargs.update(
855
+ {
856
+ "language": data_args.language,
857
+ "task": data_args.task,
858
+ }
859
+ )
860
+ model.generation_config.forced_decoder_ids = None
861
+
862
+ # 15. Prepare everything with accelerate
863
+ model = accelerator.prepare(model)
864
+
865
+ def eval_step_with_save(split="eval"):
866
+ # ======================== Evaluating ==============================
867
+ eval_preds = []
868
+ eval_labels = []
869
+ eval_ids = []
870
+ pred_str = []
871
+ eval_start = time.time()
872
+
873
+ eval_loader = DataLoader(
874
+ vectorized_datasets[split],
875
+ batch_size=per_device_eval_batch_size,
876
+ collate_fn=data_collator,
877
+ num_workers=dataloader_num_workers,
878
+ pin_memory=True,
879
+ )
880
+
881
+ eval_loader = accelerator.prepare(eval_loader)
882
+ batches = tqdm(eval_loader, desc=f"Evaluating {split}...", disable=not accelerator.is_local_main_process)
883
+
884
+ # make the split name pretty for librispeech etc
885
+ split = split.replace(".", "-").split("/")[-1]
886
+ output_csv = os.path.join(output_dir, f"{split}-transcription.csv")
887
+
888
+ for step, batch in enumerate(batches):
889
+ file_ids = batch.pop("file_ids")
890
+ # Generate predictions and pad to max generated length
891
+ generate_fn = model.module.generate if accelerator.num_processes > 1 else model.generate
892
+ generated_ids = generate_fn(batch["input_features"].to(dtype=torch_dtype), **gen_kwargs)
893
+ generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id)
894
+ # Gather all predictions and targets
895
+ file_ids, generated_ids, labels = accelerator.gather_for_metrics(
896
+ (file_ids, generated_ids, batch["labels"])
897
+ )
898
+ eval_preds.extend(generated_ids.cpu().numpy())
899
+ eval_labels.extend(labels.cpu().numpy())
900
+ eval_ids.extend(file_ids)
901
+
902
+ if step % training_args.logging_steps == 0 and step > 0:
903
+ batches.write(f"Saving transcriptions for split {split} step {step}")
904
+ accelerator.wait_for_everyone()
905
+ pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)):]
906
+ pred_ids = filter_eot_tokens(pred_ids)
907
+ pred_str.extend(
908
+ tokenizer.batch_decode(pred_ids, skip_special_tokens=False,decode_with_timestamps=return_timestamps)
909
+ )
910
+ csv_data = [[eval_ids[i], pred_str[i]] for i in range(len(eval_preds))]
911
+
912
+ with open(output_csv, "w", encoding="UTF8", newline="") as f:
913
+ writer = csv.writer(f)
914
+ # write multiple rows
915
+ writer.writerow(["file_id", "whisper_transcript"])
916
+ writer.writerows(csv_data)
917
+
918
+ if training_args.push_to_hub and accelerator.is_main_process:
919
+ repo.push_to_hub(
920
+ commit_message=f"Saving transcriptions for split {split} step {step}.",
921
+ blocking=False,
922
+ )
923
+
924
+ accelerator.wait_for_everyone()
925
+ eval_time = time.time() - eval_start
926
+
927
+ # compute WER metric for eval sets
928
+ wer_desc = ""
929
+ if "validation" in split or "test" in split:
930
+ eval_preds = filter_eot_tokens(eval_preds)
931
+ wer_metric, pred_str, label_str, norm_pred_str, norm_label_str, eval_ids = compute_metrics(
932
+ eval_preds, eval_labels, eval_ids
933
+ )
934
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
935
+ # Save metrics + predictions
936
+ log_metric(
937
+ accelerator,
938
+ metrics=wer_metric,
939
+ train_time=eval_time,
940
+ prefix=split,
941
+ )
942
+ log_pred(
943
+ accelerator,
944
+ pred_str,
945
+ label_str,
946
+ norm_pred_str,
947
+ norm_label_str,
948
+ prefix=split,
949
+ )
950
+ else:
951
+ pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)):]
952
+ pred_ids = filter_eot_tokens(pred_ids)
953
+ pred_str.extend(
954
+ tokenizer.batch_decode(pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
955
+ )
956
+
957
+ batches.write(f"Saving final transcriptions for split {split}.")
958
+ csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
959
+ with open(output_csv, "w", encoding="UTF8", newline="") as f:
960
+ writer = csv.writer(f)
961
+ # write multiple rows
962
+ writer.writerow(["file_id", "whisper_transcript"])
963
+ writer.writerows(csv_data)
964
+
965
+ # Print metrics
966
+ logger.info(wer_desc)
967
+
968
+ if not data_args.streaming and accelerator.is_main_process:
969
+ raw_datasets[split] = raw_datasets[split].add_column("whisper_transcript", pred_str)
970
+ raw_datasets[split] = raw_datasets[split].add_column("eval_preds", eval_preds)
971
+
972
+ def add_concatenated_text(eval_preds, condition_on_prev):
973
+ concatenated_prev = [None]
974
+ for token_ids, condition in zip(eval_preds[:-1], condition_on_prev[1:]):
975
+ if condition is False:
976
+ concatenated_prev.append(None)
977
+ else:
978
+ prompt_ids = [token for token in token_ids if token != decoder_eot_token_id]
979
+ prompt_ids = [decoder_prev_token_id] + prompt_ids[timestamp_position:]
980
+ concatenated_prev.append(prompt_ids)
981
+ return {"condition_on_prev": concatenated_prev}
982
+
983
+ raw_datasets[split] = raw_datasets[split].map(
984
+ add_concatenated_text,
985
+ input_columns=["eval_preds", "condition_on_prev"],
986
+ remove_columns=["eval_preds"],
987
+ desc="Setting condition on prev...",
988
+ batched=True,
989
+ batch_size=preprocessing_batch_size,
990
+ num_proc=num_workers,
991
+ )
992
+
993
+ logger.info("***** Running Labelling *****")
994
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
995
+ logger.info(
996
+ f" Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size * accelerator.num_processes}"
997
+ )
998
+ logger.info(f" Predict labels with timestamps = {return_timestamps}")
999
+ for split in data_splits:
1000
+ eval_step_with_save(split=split)
1001
+ accelerator.wait_for_everyone()
1002
+ if training_args.push_to_hub and accelerator.is_main_process:
1003
+ repo.push_to_hub(
1004
+ commit_message=f"Saving final transcriptions for split {split.replace('.', '-').split('/')[-1]}",
1005
+ blocking=False,
1006
+ )
1007
+ if not data_args.streaming and accelerator.is_main_process:
1008
+ raw_datasets.save_to_disk(output_dir, num_proc=num_workers)
1009
+ if training_args.push_to_hub:
1010
+ raw_datasets.push_to_hub(repo_name, config_name=data_args.dataset_config_name)
1011
+ accelerator.end_training()
1012
+
1013
+
1014
+ if __name__ == "__main__":
1015
+ main()