sanchit-gandhi HF staff commited on
Commit
b804ac9
1 Parent(s): 29a6ac7

Saving train state of step 500

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
checkpoint-500-epoch-6/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dfcbb8d9658d1c23fcc4dd9edb3822ac5914e7ac1dbc60ed30e3cbd3b46a41f
3
+ size 3652763351
checkpoint-500-epoch-6/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c0bda92d762ceac4074c24935347a016d4b3eae0db01a4859aa671ec66edc49
3
+ size 2588462170
checkpoint-500-epoch-6/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08efa1fa4270ed24b2ddf0858b252a2544edc723bdada91c58f7ff0c17eb5406
3
+ size 14344
checkpoint-500-epoch-6/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fc1d5d5011b8f5d8cf0fccf1801a3bbade375e254a51919d338cccdd54e3ed8
3
+ size 1000
config.json ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/fsx/yoach/tmp/artefacts/training-400M-punctuated-v2/",
3
+ "architectures": [
4
+ "ParlerTTSForConditionalGeneration"
5
+ ],
6
+ "audio_encoder": {
7
+ "_name_or_path": "ylacombe/dac_44khZ_8kbps",
8
+ "add_cross_attention": false,
9
+ "architectures": [
10
+ "DACModel"
11
+ ],
12
+ "bad_words_ids": null,
13
+ "begin_suppress_tokens": null,
14
+ "bos_token_id": null,
15
+ "chunk_size_feed_forward": 0,
16
+ "codebook_size": 1024,
17
+ "cross_attention_hidden_size": null,
18
+ "decoder_start_token_id": null,
19
+ "diversity_penalty": 0.0,
20
+ "do_sample": false,
21
+ "early_stopping": false,
22
+ "encoder_no_repeat_ngram_size": 0,
23
+ "eos_token_id": null,
24
+ "exponential_decay_length_penalty": null,
25
+ "finetuning_task": null,
26
+ "forced_bos_token_id": null,
27
+ "forced_eos_token_id": null,
28
+ "frame_rate": 86,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1"
32
+ },
33
+ "is_decoder": false,
34
+ "is_encoder_decoder": false,
35
+ "label2id": {
36
+ "LABEL_0": 0,
37
+ "LABEL_1": 1
38
+ },
39
+ "latent_dim": 1024,
40
+ "length_penalty": 1.0,
41
+ "max_length": 20,
42
+ "min_length": 0,
43
+ "model_bitrate": 8,
44
+ "model_type": "dac",
45
+ "no_repeat_ngram_size": 0,
46
+ "num_beam_groups": 1,
47
+ "num_beams": 1,
48
+ "num_codebooks": 9,
49
+ "num_return_sequences": 1,
50
+ "output_attentions": false,
51
+ "output_hidden_states": false,
52
+ "output_scores": false,
53
+ "pad_token_id": null,
54
+ "prefix": null,
55
+ "problem_type": null,
56
+ "pruned_heads": {},
57
+ "remove_invalid_values": false,
58
+ "repetition_penalty": 1.0,
59
+ "return_dict": true,
60
+ "return_dict_in_generate": false,
61
+ "sampling_rate": 44100,
62
+ "sep_token_id": null,
63
+ "suppress_tokens": null,
64
+ "task_specific_params": null,
65
+ "temperature": 1.0,
66
+ "tf_legacy_loss": false,
67
+ "tie_encoder_decoder": false,
68
+ "tie_word_embeddings": true,
69
+ "tokenizer_class": null,
70
+ "top_k": 50,
71
+ "top_p": 1.0,
72
+ "torch_dtype": "float32",
73
+ "torchscript": false,
74
+ "typical_p": 1.0,
75
+ "use_bfloat16": false
76
+ },
77
+ "decoder": {
78
+ "_name_or_path": "/fsx/yoach/tmp/artefacts/decoder_400M/",
79
+ "activation_dropout": 0.0,
80
+ "activation_function": "gelu",
81
+ "add_cross_attention": true,
82
+ "architectures": [
83
+ "ParlerTTSForCausalLM"
84
+ ],
85
+ "attention_dropout": 0.0,
86
+ "bad_words_ids": null,
87
+ "begin_suppress_tokens": null,
88
+ "bos_token_id": 1025,
89
+ "chunk_size_feed_forward": 0,
90
+ "cross_attention_hidden_size": null,
91
+ "decoder_start_token_id": null,
92
+ "diversity_penalty": 0.0,
93
+ "do_sample": false,
94
+ "dropout": 0.1,
95
+ "early_stopping": false,
96
+ "encoder_no_repeat_ngram_size": 0,
97
+ "eos_token_id": 1024,
98
+ "exponential_decay_length_penalty": null,
99
+ "ffn_dim": 4096,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "hidden_size": 1024,
104
+ "id2label": {
105
+ "0": "LABEL_0",
106
+ "1": "LABEL_1"
107
+ },
108
+ "initializer_factor": 0.02,
109
+ "is_decoder": true,
110
+ "is_encoder_decoder": false,
111
+ "label2id": {
112
+ "LABEL_0": 0,
113
+ "LABEL_1": 1
114
+ },
115
+ "layerdrop": 0.0,
116
+ "length_penalty": 1.0,
117
+ "max_length": 20,
118
+ "max_position_embeddings": 4096,
119
+ "min_length": 0,
120
+ "model_type": "parler_tts_decoder",
121
+ "no_repeat_ngram_size": 0,
122
+ "num_attention_heads": 16,
123
+ "num_beam_groups": 1,
124
+ "num_beams": 1,
125
+ "num_codebooks": 9,
126
+ "num_hidden_layers": 24,
127
+ "num_return_sequences": 1,
128
+ "output_attentions": false,
129
+ "output_hidden_states": false,
130
+ "output_scores": false,
131
+ "pad_token_id": 1024,
132
+ "prefix": null,
133
+ "problem_type": null,
134
+ "pruned_heads": {},
135
+ "remove_invalid_values": false,
136
+ "repetition_penalty": 1.0,
137
+ "return_dict": true,
138
+ "return_dict_in_generate": false,
139
+ "scale_embedding": false,
140
+ "sep_token_id": null,
141
+ "suppress_tokens": null,
142
+ "task_specific_params": null,
143
+ "temperature": 1.0,
144
+ "tf_legacy_loss": false,
145
+ "tie_encoder_decoder": false,
146
+ "tie_word_embeddings": false,
147
+ "tokenizer_class": null,
148
+ "top_k": 50,
149
+ "top_p": 1.0,
150
+ "torch_dtype": "float32",
151
+ "torchscript": false,
152
+ "typical_p": 1.0,
153
+ "use_bfloat16": false,
154
+ "use_cache": true,
155
+ "vocab_size": 1088
156
+ },
157
+ "decoder_start_token_id": 1025,
158
+ "is_encoder_decoder": true,
159
+ "model_type": "parler_tts",
160
+ "pad_token_id": 1024,
161
+ "text_encoder": {
162
+ "_name_or_path": "google/flan-t5-base",
163
+ "add_cross_attention": false,
164
+ "architectures": [
165
+ "T5ForConditionalGeneration"
166
+ ],
167
+ "bad_words_ids": null,
168
+ "begin_suppress_tokens": null,
169
+ "bos_token_id": null,
170
+ "chunk_size_feed_forward": 0,
171
+ "classifier_dropout": 0.0,
172
+ "cross_attention_hidden_size": null,
173
+ "d_ff": 2048,
174
+ "d_kv": 64,
175
+ "d_model": 768,
176
+ "decoder_start_token_id": 0,
177
+ "dense_act_fn": "gelu_new",
178
+ "diversity_penalty": 0.0,
179
+ "do_sample": false,
180
+ "dropout_rate": 0.1,
181
+ "early_stopping": false,
182
+ "encoder_no_repeat_ngram_size": 0,
183
+ "eos_token_id": 1,
184
+ "exponential_decay_length_penalty": null,
185
+ "feed_forward_proj": "gated-gelu",
186
+ "finetuning_task": null,
187
+ "forced_bos_token_id": null,
188
+ "forced_eos_token_id": null,
189
+ "id2label": {
190
+ "0": "LABEL_0",
191
+ "1": "LABEL_1"
192
+ },
193
+ "initializer_factor": 1.0,
194
+ "is_decoder": false,
195
+ "is_encoder_decoder": true,
196
+ "is_gated_act": true,
197
+ "label2id": {
198
+ "LABEL_0": 0,
199
+ "LABEL_1": 1
200
+ },
201
+ "layer_norm_epsilon": 1e-06,
202
+ "length_penalty": 1.0,
203
+ "max_length": 20,
204
+ "min_length": 0,
205
+ "model_type": "t5",
206
+ "n_positions": 512,
207
+ "no_repeat_ngram_size": 0,
208
+ "num_beam_groups": 1,
209
+ "num_beams": 1,
210
+ "num_decoder_layers": 12,
211
+ "num_heads": 12,
212
+ "num_layers": 12,
213
+ "num_return_sequences": 1,
214
+ "output_attentions": false,
215
+ "output_hidden_states": false,
216
+ "output_past": true,
217
+ "output_scores": false,
218
+ "pad_token_id": 0,
219
+ "prefix": null,
220
+ "problem_type": null,
221
+ "pruned_heads": {},
222
+ "relative_attention_max_distance": 128,
223
+ "relative_attention_num_buckets": 32,
224
+ "remove_invalid_values": false,
225
+ "repetition_penalty": 1.0,
226
+ "return_dict": true,
227
+ "return_dict_in_generate": false,
228
+ "sep_token_id": null,
229
+ "suppress_tokens": null,
230
+ "task_specific_params": {
231
+ "summarization": {
232
+ "early_stopping": true,
233
+ "length_penalty": 2.0,
234
+ "max_length": 200,
235
+ "min_length": 30,
236
+ "no_repeat_ngram_size": 3,
237
+ "num_beams": 4,
238
+ "prefix": "summarize: "
239
+ },
240
+ "translation_en_to_de": {
241
+ "early_stopping": true,
242
+ "max_length": 300,
243
+ "num_beams": 4,
244
+ "prefix": "translate English to German: "
245
+ },
246
+ "translation_en_to_fr": {
247
+ "early_stopping": true,
248
+ "max_length": 300,
249
+ "num_beams": 4,
250
+ "prefix": "translate English to French: "
251
+ },
252
+ "translation_en_to_ro": {
253
+ "early_stopping": true,
254
+ "max_length": 300,
255
+ "num_beams": 4,
256
+ "prefix": "translate English to Romanian: "
257
+ }
258
+ },
259
+ "temperature": 1.0,
260
+ "tf_legacy_loss": false,
261
+ "tie_encoder_decoder": false,
262
+ "tie_word_embeddings": false,
263
+ "tokenizer_class": null,
264
+ "top_k": 50,
265
+ "top_p": 1.0,
266
+ "torch_dtype": null,
267
+ "torchscript": false,
268
+ "typical_p": 1.0,
269
+ "use_bfloat16": false,
270
+ "use_cache": true,
271
+ "vocab_size": 32128
272
+ },
273
+ "torch_dtype": "float32",
274
+ "transformers_version": "4.41.0.dev0",
275
+ "vocab_size": 32128
276
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length_s": null,
3
+ "feature_extractor_type": "EncodecFeatureExtractor",
4
+ "feature_size": 1,
5
+ "overlap": null,
6
+ "padding_side": "right",
7
+ "padding_value": 0.0,
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 44100
10
+ }
run.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python run_parler_tts_training.py \
4
+ --model_name_or_path parler-tts/parler_tts_mini_v0.1 \
5
+ --feature_extractor_name parler-tts/dac_44khZ_8kbps \
6
+ --description_tokenizer_name parler-tts/parler_tts_mini_v0.1 \
7
+ --prompt_tokenizer_name parler-tts/parler_tts_mini_v0.1 \
8
+ --report_to wandb \
9
+ --overwrite_output_dir true \
10
+ --train_dataset_name reach-vb/expresso-tagged-mistral-7b-instruct-v0.2 \
11
+ --train_metadata_dataset_name reach-vb/expresso-tagged-mistral-7b-instruct-v0.2 \
12
+ --train_dataset_config_name read \
13
+ --train_split_name train \
14
+ --eval_dataset_name reach-vb/expresso-tagged-mistral-7b-instruct-v0.2 \
15
+ --eval_metadata_dataset_name reach-vb/expresso-tagged-mistral-7b-instruct-v0.2 \
16
+ --eval_dataset_config_name read \
17
+ --eval_split_name train \
18
+ --max_eval_samples 8 \
19
+ --per_device_eval_batch_size 16 \
20
+ --target_audio_column_name audio \
21
+ --description_column_name text_description \
22
+ --prompt_column_name text \
23
+ --max_duration_in_seconds 20 \
24
+ --min_duration_in_seconds 2.0 \
25
+ --max_text_length 400 \
26
+ --preprocessing_num_workers 2 \
27
+ --do_train true \
28
+ --num_train_epochs 10 \
29
+ --gradient_accumulation_steps 4 \
30
+ --gradient_checkpointing true \
31
+ --per_device_train_batch_size 32 \
32
+ --learning_rate 3e-5 \
33
+ --adam_beta1 0.9 \
34
+ --adam_beta2 0.99 \
35
+ --weight_decay 0.01 \
36
+ --warmup_steps 100 \
37
+ --logging_steps 2 \
38
+ --freeze_text_encoder true \
39
+ --audio_encoder_per_device_batch_size 4 \
40
+ --dtype bfloat16 \
41
+ --seed 456 \
42
+ --output_dir ./ \
43
+ --temporary_save_to_disk ../audio_code_tmp/ \
44
+ --save_to_disk ../tmp_dataset_audio/ \
45
+ --dataloader_num_workers 4 \
46
+ --do_eval \
47
+ --predict_with_generate \
48
+ --include_inputs_for_metrics \
49
+ --group_by_length true \
50
+ --push_to_hub
51
+
run_parler_tts_training.py ADDED
@@ -0,0 +1,1760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ Train Parler-TTS using 🤗 Accelerate"""
18
+
19
+ import logging
20
+ import os
21
+ import re
22
+ import sys
23
+ import shutil
24
+ import time
25
+ from multiprocess import set_start_method
26
+ from datetime import timedelta
27
+
28
+
29
+ import evaluate
30
+ from tqdm import tqdm
31
+ from pathlib import Path
32
+ from dataclasses import dataclass, field
33
+ from typing import Dict, List, Optional, Union, Set
34
+
35
+ import datasets
36
+ import numpy as np
37
+ import torch
38
+ from torch.utils.data import DataLoader
39
+
40
+ from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
41
+
42
+ from huggingface_hub import Repository, create_repo
43
+ import transformers
44
+ from transformers import (
45
+ AutoFeatureExtractor,
46
+ AutoModel,
47
+ AutoProcessor,
48
+ AutoTokenizer,
49
+ HfArgumentParser,
50
+ Seq2SeqTrainingArguments,
51
+ )
52
+ from transformers.trainer_pt_utils import LengthGroupedSampler
53
+ from transformers import pipeline
54
+ from transformers.optimization import get_scheduler
55
+ from transformers.utils import send_example_telemetry
56
+ from transformers import AutoModel
57
+
58
+
59
+ from accelerate import Accelerator
60
+ from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
61
+ from accelerate.utils.memory import release_memory
62
+
63
+ from parler_tts import (
64
+ ParlerTTSForConditionalGeneration,
65
+ ParlerTTSConfig,
66
+ build_delay_pattern_mask,
67
+ )
68
+
69
+ from wandb import Audio
70
+
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+
75
+ def list_field(default=None, metadata=None):
76
+ return field(default_factory=lambda: default, metadata=metadata)
77
+
78
+
79
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
80
+
81
+
82
+ def get_last_checkpoint(folder):
83
+ content = os.listdir(folder)
84
+ checkpoints = [
85
+ path
86
+ for path in content
87
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
88
+ ]
89
+ if len(checkpoints) == 0:
90
+ return
91
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
92
+
93
+
94
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
95
+ """Helper function to sort saved checkpoints from oldest to newest."""
96
+ ordering_and_checkpoint_path = []
97
+
98
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
99
+
100
+ for path in glob_checkpoints:
101
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
102
+ if regex_match is not None and regex_match.groups() is not None:
103
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
104
+
105
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
106
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
107
+ return checkpoints_sorted
108
+
109
+
110
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
111
+ """Helper function to delete old checkpoints."""
112
+ if save_total_limit is None or save_total_limit <= 0:
113
+ return
114
+ # Check if we should delete older checkpoint(s)
115
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
116
+ if len(checkpoints_sorted) <= save_total_limit:
117
+ return
118
+
119
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
120
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
121
+ for checkpoint in checkpoints_to_be_deleted:
122
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
123
+ shutil.rmtree(checkpoint, ignore_errors=True)
124
+
125
+
126
+ def log_metric(
127
+ accelerator,
128
+ metrics: Dict,
129
+ train_time: float,
130
+ step: int,
131
+ epoch: int,
132
+ learning_rate: float = None,
133
+ prefix: str = "train",
134
+ ):
135
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
136
+ log_metrics = {}
137
+ for k, v in metrics.items():
138
+ log_metrics[f"{prefix}/{k}"] = v
139
+ log_metrics[f"{prefix}/time"] = train_time
140
+ log_metrics[f"{prefix}/epoch"] = epoch
141
+ if learning_rate is not None:
142
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
143
+ accelerator.log(log_metrics, step=step)
144
+
145
+
146
+ def log_pred(
147
+ accelerator,
148
+ pred_descriptions: List[str],
149
+ pred_prompts: List[str],
150
+ transcriptions: List[str],
151
+ audios: List[torch.Tensor],
152
+ sampling_rate: int,
153
+ step: int,
154
+ prefix: str = "eval",
155
+ num_lines: int = 200000,
156
+ ):
157
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
158
+ if accelerator.is_main_process:
159
+ wandb_tracker = accelerator.get_tracker("wandb")
160
+ # pretty name for current step: step 50000 -> step 50k
161
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
162
+ prefix_pretty = prefix.replace("/", "-")
163
+
164
+ # convert str data to a wandb compatible format
165
+ str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))]
166
+ # log as a table with the appropriate headers
167
+ wandb_tracker.log_table(
168
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
169
+ columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
170
+ data=str_data[:num_lines],
171
+ step=step,
172
+ commit=False,
173
+ )
174
+
175
+ # wandb can only loads 100 audios per step
176
+ wandb_tracker.log(
177
+ {
178
+ "Speech samples": [
179
+ Audio(
180
+ audio,
181
+ caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
182
+ sample_rate=sampling_rate,
183
+ )
184
+ for (i, audio) in enumerate(audios[: min(len(audios), 100)])
185
+ ]
186
+ },
187
+ step=step,
188
+ )
189
+
190
+
191
+ @dataclass
192
+ class ModelArguments:
193
+ """
194
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
195
+ """
196
+
197
+ model_name_or_path: str = field(
198
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
199
+ )
200
+ config_name: Optional[str] = field(
201
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
202
+ )
203
+ feature_extractor_name: Optional[str] = field(
204
+ default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
205
+ )
206
+ description_tokenizer_name: Optional[str] = field(
207
+ default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
208
+ )
209
+ prompt_tokenizer_name: Optional[str] = field(
210
+ default=None,
211
+ metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
212
+ )
213
+ cache_dir: Optional[str] = field(
214
+ default=None,
215
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
216
+ )
217
+ use_fast_tokenizer: bool = field(
218
+ default=True,
219
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
220
+ )
221
+ model_revision: str = field(
222
+ default="main",
223
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
224
+ )
225
+ pad_token_id: int = field(
226
+ default=None,
227
+ metadata={"help": "If specified, change the model pad token id."},
228
+ )
229
+ decoder_start_token_id: int = field(
230
+ default=None,
231
+ metadata={"help": "If specified, change the model decoder start token id."},
232
+ )
233
+ freeze_text_encoder: bool = field(
234
+ default=False,
235
+ metadata={"help": "Whether to freeze the text encoder."},
236
+ )
237
+ do_sample: bool = field(
238
+ default=True,
239
+ metadata={"help": "Whether to do sampling or greedy decoding."},
240
+ )
241
+ temperature: float = field(
242
+ default=1.0,
243
+ metadata={"help": "Temperature if sampling."},
244
+ )
245
+ max_length: int = field(
246
+ default=2580,
247
+ metadata={"help": "Generation max length."},
248
+ )
249
+ bandwidth: float = field(
250
+ default=6,
251
+ metadata={"help": "Audio encoder bandwidth."},
252
+ )
253
+ asr_model_name_or_path: str = field(
254
+ default="distil-whisper/distil-large-v2",
255
+ metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
256
+ )
257
+ clap_model_name_or_path: str = field(
258
+ default="laion/larger_clap_music_and_speech",
259
+ metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
260
+ )
261
+
262
+
263
+
264
+ @dataclass
265
+ class DataTrainingArguments:
266
+ """
267
+ Arguments pertaining to what data we are going to input our model for training and eval.
268
+
269
+ Using `HfArgumentParser` we can turn this class
270
+ into argparse arguments to be able to specify them on
271
+ the command line.
272
+ """
273
+
274
+ train_dataset_name: str = field(
275
+ default=None,
276
+ metadata={
277
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
278
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
279
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
280
+ },
281
+ )
282
+ train_dataset_config_name: Optional[str] = field(
283
+ default=None,
284
+ metadata={
285
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
286
+ "multiple datasets by separating dataset configs by a '+' symbol."
287
+ },
288
+ )
289
+ train_split_name: str = field(
290
+ default="train",
291
+ metadata={
292
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
293
+ },
294
+ )
295
+ train_dataset_samples: str = field(
296
+ default=None,
297
+ metadata={
298
+ "help": "Number of samples in the training data. Load and combine "
299
+ "multiple datasets by separating dataset samples by a '+' symbol."
300
+ },
301
+ )
302
+ train_metadata_dataset_name: str = field(
303
+ default=None,
304
+ metadata={
305
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
306
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
307
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
308
+ },
309
+ )
310
+ eval_dataset_name: str = field(
311
+ default=None,
312
+ metadata={
313
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
314
+ },
315
+ )
316
+ eval_dataset_config_name: Optional[str] = field(
317
+ default=None,
318
+ metadata={
319
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
320
+ },
321
+ )
322
+ eval_split_name: str = field(
323
+ default="test",
324
+ metadata={
325
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
326
+ },
327
+ )
328
+ eval_metadata_dataset_name: str = field(
329
+ default=None,
330
+ metadata={
331
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
332
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
333
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
334
+ },
335
+ )
336
+ target_audio_column_name: str = field(
337
+ default="audio",
338
+ metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
339
+ )
340
+ description_column_name: str = field(
341
+ default=None,
342
+ metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
343
+ )
344
+ prompt_column_name: str = field(
345
+ default=None,
346
+ metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
347
+ )
348
+ overwrite_cache: bool = field(
349
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
350
+ )
351
+ preprocessing_num_workers: Optional[int] = field(
352
+ default=None,
353
+ metadata={"help": "The number of processes to use for the preprocessing."},
354
+ )
355
+ max_train_samples: Optional[int] = field(
356
+ default=None,
357
+ metadata={
358
+ "help": (
359
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
360
+ "value if set."
361
+ )
362
+ },
363
+ )
364
+ max_eval_samples: Optional[int] = field(
365
+ default=None,
366
+ metadata={
367
+ "help": (
368
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
369
+ "value if set."
370
+ )
371
+ },
372
+ )
373
+ max_duration_in_seconds: float = field(
374
+ default=35.0,
375
+ metadata={
376
+ "help": (
377
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
378
+ "Also, used to set maximum audio length if `pad_to_max_length=True`."
379
+ )
380
+ },
381
+ )
382
+ min_duration_in_seconds: float = field(
383
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
384
+ )
385
+ max_text_length: int = field(
386
+ default=500, metadata={"help": "If set, max description lengths in number of characters."}
387
+ )
388
+ max_prompt_token_length: int = field(
389
+ default=None,
390
+ metadata={
391
+ "help": (
392
+ "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
393
+ "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
394
+ )
395
+ },
396
+ )
397
+ max_description_token_length: int = field(
398
+ default=None,
399
+ metadata={
400
+ "help": (
401
+ "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
402
+ "Also, used to set maximum desription token length if `pad_to_max_length=True`."
403
+ )
404
+ },
405
+ )
406
+ pad_to_max_length: bool = field(
407
+ default=False,
408
+ metadata={
409
+ "help": (
410
+ "If `True`, pad audio, prompt and description to a maximum length set with respectively "
411
+ "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
412
+ )
413
+ },
414
+ )
415
+ preprocessing_only: bool = field(
416
+ default=False,
417
+ metadata={
418
+ "help": (
419
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
420
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
421
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
422
+ " can consequently be loaded in distributed training."
423
+ " In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
424
+ )
425
+ },
426
+ )
427
+ token: str = field(
428
+ default=None,
429
+ metadata={
430
+ "help": (
431
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
432
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
433
+ )
434
+ },
435
+ )
436
+ use_auth_token: bool = field(
437
+ default=None,
438
+ metadata={
439
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
440
+ },
441
+ )
442
+ trust_remote_code: bool = field(
443
+ default=False,
444
+ metadata={
445
+ "help": (
446
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
447
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
448
+ "execute code present on the Hub on your local machine."
449
+ )
450
+ },
451
+ )
452
+ add_audio_samples_to_wandb: bool = field(
453
+ default=False,
454
+ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
455
+ )
456
+ id_column_name: str = field(default=None, metadata={"help": "id column name."})
457
+ wandb_project: str = field(
458
+ default="parler-speech",
459
+ metadata={"help": "The name of the wandb project."},
460
+ )
461
+ save_to_disk: str = field(
462
+ default=None,
463
+ metadata={
464
+ "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
465
+ },
466
+ )
467
+ temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
468
+ pad_to_multiple_of: Optional[int] = field(
469
+ default=2,
470
+ metadata={"help": ("Pad to multiple of for tokenizers.")},
471
+ )
472
+
473
+
474
+ @dataclass
475
+ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
476
+ dtype: Optional[str] = field(
477
+ default="float32",
478
+ metadata={
479
+ "help": (
480
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
481
+ "`float16` or `bfloat16` (both half-precision)."
482
+ )
483
+ },
484
+ )
485
+ audio_encoder_per_device_batch_size: int = field(
486
+ default=8,
487
+ metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
488
+ )
489
+
490
+
491
+ @dataclass
492
+ class DataCollatorEncodecWithPadding:
493
+ """
494
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
495
+ to `max_length` if `max_length` is set and `padding=max_length`.
496
+ """
497
+
498
+ feature_extractor: AutoFeatureExtractor
499
+ audio_column_name: str
500
+ feature_extractor_input_name: Optional[str] = "input_values"
501
+ max_length: Optional[int] = None
502
+ padding: Optional[str] = "longest"
503
+
504
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
505
+ # split inputs and labels since they have to be of different lengths and need
506
+ # different padding methods
507
+ audios = [feature[self.audio_column_name]["array"] for feature in features]
508
+ len_audio = [len(audio) for audio in audios]
509
+
510
+ batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length)
511
+ batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
512
+ return batch
513
+
514
+
515
+ @dataclass
516
+ class DataCollatorParlerTTSWithPadding:
517
+ """
518
+ Data collator that will dynamically pad the inputs received.
519
+ Args:
520
+ prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
521
+ The prompt_tokenizer used for proccessing the data.
522
+ description_tokenizer (:class:`~transformers.AutoTokenizer`)
523
+ The description_tokenizer used for proccessing the data.
524
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
525
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
526
+ among:
527
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
528
+ sequence if provided).
529
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
530
+ maximum acceptable input length for the model if that argument is not provided.
531
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
532
+ different lengths).
533
+ pad_to_multiple_of (:obj:`int`, `optional`):
534
+ If set will pad the sequence to a multiple of the provided value.
535
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
536
+ 7.5 (Volta).
537
+ """
538
+
539
+ prompt_tokenizer: AutoTokenizer
540
+ description_tokenizer: AutoTokenizer
541
+ padding: Union[bool, str] = "longest"
542
+ pad_to_multiple_of: Optional[int] = None
543
+ prompt_max_length: Optional[int] = None
544
+ description_max_length: Optional[int] = None
545
+ audio_max_length: Optional[int] = None
546
+
547
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
548
+ # split inputs and labels since they have to be of different lengths and need
549
+ # different padding methods
550
+
551
+ labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
552
+ # (bsz, seq_len, num_codebooks)
553
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
554
+ if self.audio_max_length is not None and self.padding == "max_length":
555
+ labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
556
+
557
+ input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
558
+
559
+ input_ids = self.description_tokenizer.pad(
560
+ input_ids,
561
+ return_tensors="pt",
562
+ padding=self.padding,
563
+ pad_to_multiple_of=self.pad_to_multiple_of,
564
+ max_length=self.description_max_length,
565
+ )
566
+
567
+ batch = {"labels": labels, **input_ids}
568
+
569
+ if self.audio_max_length is not None and self.padding == "max_length":
570
+ # if we do torch.compile, we need to also specify the attention_mask
571
+ decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
572
+ batch["decoder_attention_mask"] = decoder_attention_mask
573
+
574
+ prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
575
+ prompt_input_ids = self.prompt_tokenizer.pad(
576
+ prompt_input_ids,
577
+ return_tensors="pt",
578
+ padding=self.padding,
579
+ pad_to_multiple_of=self.pad_to_multiple_of,
580
+ max_length=self.prompt_max_length,
581
+ )
582
+
583
+ batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
584
+ if "attention_mask" in prompt_input_ids:
585
+ batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
586
+
587
+ return batch
588
+
589
+
590
+ def convert_dataset_str_to_list(
591
+ dataset_names,
592
+ dataset_config_names,
593
+ metadata_dataset_names=None,
594
+ splits=None,
595
+ dataset_samples=None,
596
+ default_split="train",
597
+ ):
598
+ if isinstance(dataset_names, str):
599
+ dataset_names = dataset_names.split("+")
600
+ dataset_config_names = dataset_config_names.split("+")
601
+ splits = splits.split("+") if splits is not None else None
602
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
603
+ metadata_dataset_names = metadata_dataset_names.split("+") if metadata_dataset_names is not None else None
604
+
605
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
606
+ if len(dataset_names) != len(dataset_config_names):
607
+ raise ValueError(
608
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
609
+ f" {len(dataset_config_names)} configs."
610
+ )
611
+
612
+ if splits is not None and len(splits) != len(dataset_names):
613
+ raise ValueError(
614
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
615
+ )
616
+
617
+ if metadata_dataset_names is not None and len(metadata_dataset_names) != len(dataset_names):
618
+ raise ValueError(
619
+ f"Ensure one metadata dataset is passed for each dataset, got {len(dataset_names)} datasets and {len(metadata_dataset_names)} metadata datasets."
620
+ )
621
+
622
+ if dataset_samples is not None:
623
+ if len(dataset_samples) != len(dataset_names):
624
+ raise ValueError(
625
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
626
+ f"{len(dataset_samples)} samples."
627
+ )
628
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
629
+ else:
630
+ dataset_samples = [None] * len(dataset_names)
631
+
632
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
633
+
634
+ dataset_names_dict = []
635
+ for i, ds_name in enumerate(dataset_names):
636
+ dataset_names_dict.append(
637
+ {
638
+ "name": ds_name,
639
+ "config": dataset_config_names[i],
640
+ "split": splits[i],
641
+ "metadata_dataset_name": metadata_dataset_names[i],
642
+ "samples": dataset_samples[i],
643
+ }
644
+ )
645
+ return dataset_names_dict
646
+
647
+
648
+ def load_multiple_datasets(
649
+ accelerator: Accelerator,
650
+ dataset_names: Union[List, str],
651
+ dataset_config_names: Union[List, str],
652
+ metadata_dataset_names: Optional[str] = None,
653
+ splits: Optional[Union[List, str]] = None,
654
+ label_column_names: Optional[List] = None,
655
+ stopping_strategy: Optional[str] = "first_exhausted",
656
+ dataset_samples: Optional[Union[List, np.array]] = None,
657
+ streaming: Optional[bool] = False,
658
+ seed: Optional[int] = None,
659
+ id_column_name: Optional[str] = None,
660
+ columns_to_keep: Optional[Set[str]] = None,
661
+ prompt_column_name: Optional[str] = None,
662
+ sampling_rate: Optional[int] = None,
663
+ audio_column_name: Optional[str] = None,
664
+ **kwargs,
665
+ ) -> Union[Dataset, IterableDataset]:
666
+ dataset_names_dict = convert_dataset_str_to_list(
667
+ dataset_names, dataset_config_names, metadata_dataset_names, splits, label_column_names, dataset_samples
668
+ )
669
+
670
+ if dataset_samples is not None:
671
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
672
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
673
+ else:
674
+ probabilities = None
675
+
676
+ all_datasets = []
677
+ # iterate over the datasets we want to interleave
678
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
679
+ with accelerator.main_process_first():
680
+ dataset = load_dataset(
681
+ dataset_dict["name"],
682
+ dataset_dict["config"],
683
+ split=dataset_dict["split"],
684
+ streaming=streaming,
685
+ **kwargs,
686
+ )
687
+ dataset_features = dataset.features.keys()
688
+
689
+ if sampling_rate is not None and audio_column_name is not None:
690
+ # resample target audio
691
+ dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
692
+
693
+ metadata_dataset_name = dataset_dict["metadata_dataset_name"]
694
+ if metadata_dataset_name is not None:
695
+ logger.info(
696
+ f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}'
697
+ )
698
+ metadata_dataset = load_dataset(
699
+ metadata_dataset_name,
700
+ dataset_dict["config"],
701
+ split=dataset_dict["split"],
702
+ streaming=streaming,
703
+ **kwargs,
704
+ )
705
+
706
+ # TODO(YL): I forgot to create unique ids for MLS english.
707
+ # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
708
+ # if dataset_dict["name"] == "parler-tts/mls_eng_10k":
709
+ # def concat_ids(book_id, speaker_id, begin_time):
710
+ # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
711
+ # dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
712
+ # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
713
+ # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
714
+
715
+ if dataset_dict["name"] != "parler-tts/mls_eng_10k":
716
+ if id_column_name is not None and id_column_name not in dataset.column_names:
717
+ raise ValueError(
718
+ f"id_column_name={id_column_name} but has not been found in the dataset columns"
719
+ f"- one of {', '.join(list(dataset.column_names))}."
720
+ )
721
+ if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
722
+ raise ValueError(
723
+ f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
724
+ f"- one of {', '.join(list(metadata_dataset.column_names))}."
725
+ )
726
+ elif id_column_name is not None:
727
+ metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
728
+
729
+ metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
730
+
731
+ if prompt_column_name is not None:
732
+ # We might have applied some transformations to the prompts (e.g punctuation restoration)
733
+ # so we make sure to remove it from the original dataset
734
+ if prompt_column_name in dataset.column_names:
735
+ logger.info(
736
+ f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']"
737
+ )
738
+ dataset.remove_columns(prompt_column_name)
739
+
740
+ metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
741
+ metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
742
+
743
+ dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
744
+
745
+ if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
746
+ if (
747
+ len(
748
+ dataset.filter(
749
+ lambda id1, id2: id1 != id2,
750
+ input_columns=[id_column_name, f"metadata_{id_column_name}"],
751
+ )
752
+ )
753
+ != 0
754
+ ):
755
+ raise ValueError(
756
+ f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}"
757
+ )
758
+
759
+ dataset_features = dataset.features.keys()
760
+
761
+ if columns_to_keep is not None:
762
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
763
+ all_datasets.append(dataset)
764
+
765
+ if len(all_datasets) == 1:
766
+ # we have a single dataset so just return it as is
767
+ return all_datasets[0]
768
+
769
+ if streaming:
770
+ interleaved_dataset = interleave_datasets(
771
+ all_datasets,
772
+ stopping_strategy=stopping_strategy,
773
+ probabilities=probabilities,
774
+ seed=seed,
775
+ )
776
+ else:
777
+ with accelerator.main_process_first():
778
+ interleaved_dataset = concatenate_datasets(all_datasets)
779
+
780
+ return interleaved_dataset
781
+
782
+
783
+ def main():
784
+ # See all possible arguments in src/transformers/training_args.py
785
+ # or by passing the --help flag to this script.
786
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
787
+
788
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
789
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
790
+ # If we pass only one argument to the script and it's the path to a json file,
791
+ # let's parse it to get our arguments.
792
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
793
+ else:
794
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
795
+
796
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
797
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
798
+ send_example_telemetry("run_parler_tts", model_args, data_args)
799
+
800
+ if training_args.dtype == "float16":
801
+ mixed_precision = "fp16"
802
+ elif training_args.dtype == "bfloat16":
803
+ mixed_precision = "bf16"
804
+ else:
805
+ mixed_precision = "no"
806
+
807
+ if data_args.pad_to_max_length and (
808
+ data_args.max_duration_in_seconds is None
809
+ or data_args.max_prompt_token_length is None
810
+ or data_args.max_description_token_length is None
811
+ ):
812
+ raise ValueError(
813
+ "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
814
+ )
815
+
816
+ padding = "max_length" if data_args.pad_to_max_length else "longest"
817
+
818
+ ####### A. Preparation
819
+ kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
820
+ if training_args.torch_compile:
821
+ # TODO(YL): add more compile modes?
822
+ kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
823
+
824
+ accelerator = Accelerator(
825
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
826
+ mixed_precision=mixed_precision,
827
+ log_with=training_args.report_to,
828
+ project_dir=training_args.output_dir,
829
+ kwargs_handlers=kwargs_handlers,
830
+ )
831
+
832
+ accelerator.init_trackers(
833
+ project_name=data_args.wandb_project,
834
+ config={
835
+ "learning_rate": training_args.learning_rate,
836
+ "model_name_or_path": model_args.model_name_or_path,
837
+ "num_train_epochs": training_args.num_train_epochs,
838
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
839
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
840
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
841
+ "mixed_precision": mixed_precision,
842
+ "lr_scheduler_type": training_args.lr_scheduler_type,
843
+ "warmup_steps": training_args.warmup_steps,
844
+ "freeze_text_encoder": model_args.freeze_text_encoder,
845
+ "max_duration_in_seconds": data_args.max_duration_in_seconds,
846
+ "weight_decay": training_args.weight_decay,
847
+ "adam_beta1": training_args.adam_beta1,
848
+ "adam_beta2": training_args.adam_beta2,
849
+ "temperature": model_args.temperature,
850
+ },
851
+ )
852
+
853
+ # Detecting last checkpoint and eventually continue from last checkpoint
854
+ last_checkpoint = None
855
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
856
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
857
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
858
+ raise ValueError(
859
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
860
+ "Use --overwrite_output_dir to overcome."
861
+ )
862
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
863
+ logger.info(
864
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
865
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
866
+ )
867
+
868
+ # Setup logging
869
+ logging.basicConfig(
870
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
871
+ datefmt="%m/%d/%Y %H:%M:%S",
872
+ handlers=[logging.StreamHandler(sys.stdout)],
873
+ )
874
+ logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
875
+
876
+ # Log a small summary on each proces
877
+ logger.warning(
878
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
879
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
880
+ )
881
+
882
+ # Set the verbosity to info of the Transformers logger (on main process only)
883
+ if accelerator.is_local_main_process:
884
+ datasets.utils.logging.set_verbosity_warning()
885
+ transformers.utils.logging.set_verbosity_info()
886
+ else:
887
+ datasets.utils.logging.set_verbosity_error()
888
+ transformers.utils.logging.set_verbosity_error()
889
+
890
+ logger.info("Training/evaluation parameters %s", training_args)
891
+
892
+ # Set seed before initializing model.
893
+ set_seed(training_args.seed)
894
+ num_workers = data_args.preprocessing_num_workers
895
+
896
+ # 1. First, lett's instantiate the feature extractor, tokenizers and model
897
+ # Note for distributed training, the .from_pretrained methods guarantee that only
898
+ # one local process can concurrently download model & vocab.
899
+
900
+ # load feature extractor
901
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
902
+ model_args.feature_extractor_name or model_args.model_name_or_path,
903
+ cache_dir=model_args.cache_dir,
904
+ token=data_args.token,
905
+ trust_remote_code=data_args.trust_remote_code,
906
+ )
907
+ sampling_rate = feature_extractor.sampling_rate
908
+
909
+ # load prompt tokenizer
910
+ prompt_tokenizer = AutoTokenizer.from_pretrained(
911
+ model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
912
+ cache_dir=model_args.cache_dir,
913
+ token=data_args.token,
914
+ trust_remote_code=data_args.trust_remote_code,
915
+ use_fast=model_args.use_fast_tokenizer,
916
+ padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
917
+ )
918
+
919
+ # load description tokenizer
920
+ description_tokenizer = AutoTokenizer.from_pretrained(
921
+ model_args.description_tokenizer_name or model_args.model_name_or_path,
922
+ cache_dir=model_args.cache_dir,
923
+ token=data_args.token,
924
+ trust_remote_code=data_args.trust_remote_code,
925
+ use_fast=model_args.use_fast_tokenizer,
926
+ )
927
+
928
+ if model_args.use_fast_tokenizer:
929
+ logger.warning(
930
+ "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
931
+ )
932
+ prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
933
+ description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
934
+
935
+ # 2. Now, let's load the dataset
936
+
937
+ if data_args.save_to_disk is not None:
938
+ os.makedirs(data_args.save_to_disk, exist_ok=True)
939
+
940
+ # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
941
+ dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
942
+ if dataset_was_precomputed:
943
+ vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
944
+ else:
945
+ raw_datasets = DatasetDict()
946
+
947
+ columns_to_keep = {
948
+ "target_audio_column_name": data_args.target_audio_column_name,
949
+ "prompt_column_name": data_args.prompt_column_name,
950
+ }
951
+ if data_args.description_column_name is not None:
952
+ columns_to_keep["description_column_name"] = data_args.description_column_name
953
+
954
+ if training_args.do_train:
955
+ raw_datasets["train"] = load_multiple_datasets(
956
+ accelerator,
957
+ data_args.train_dataset_name,
958
+ data_args.train_dataset_config_name,
959
+ metadata_dataset_names=data_args.train_metadata_dataset_name,
960
+ splits=data_args.train_split_name,
961
+ dataset_samples=data_args.train_dataset_samples,
962
+ seed=training_args.seed,
963
+ cache_dir=model_args.cache_dir,
964
+ num_proc=data_args.preprocessing_num_workers,
965
+ id_column_name=data_args.id_column_name,
966
+ columns_to_keep=columns_to_keep.values(),
967
+ prompt_column_name=data_args.prompt_column_name,
968
+ audio_column_name=data_args.target_audio_column_name,
969
+ sampling_rate=sampling_rate,
970
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
971
+ )
972
+
973
+ for key in columns_to_keep:
974
+ if columns_to_keep[key] not in raw_datasets["train"].column_names:
975
+ raise ValueError(
976
+ f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
977
+ f" Make sure to set `--{key}` to the correct audio column - one of"
978
+ f" {', '.join(raw_datasets['train'].column_names)}."
979
+ )
980
+
981
+ if data_args.max_train_samples is not None:
982
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
983
+
984
+ if training_args.do_eval:
985
+ raw_datasets["eval"] = load_multiple_datasets(
986
+ accelerator,
987
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
988
+ data_args.eval_dataset_config_name
989
+ if data_args.eval_dataset_config_name
990
+ else data_args.train_dataset_config_name,
991
+ metadata_dataset_names=data_args.eval_metadata_dataset_name,
992
+ splits=data_args.eval_split_name,
993
+ cache_dir=model_args.cache_dir,
994
+ num_proc=data_args.preprocessing_num_workers,
995
+ id_column_name=data_args.id_column_name,
996
+ columns_to_keep=columns_to_keep.values(),
997
+ prompt_column_name=data_args.prompt_column_name,
998
+ audio_column_name=data_args.target_audio_column_name,
999
+ sampling_rate=sampling_rate,
1000
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
1001
+ )
1002
+
1003
+ if data_args.max_eval_samples is not None:
1004
+ raw_datasets["eval"] = (
1005
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
1006
+ )
1007
+
1008
+ # 3. Next, let's load the config.
1009
+ config = ParlerTTSConfig.from_pretrained(
1010
+ model_args.model_name_or_path,
1011
+ cache_dir=model_args.cache_dir,
1012
+ token=data_args.token,
1013
+ trust_remote_code=data_args.trust_remote_code,
1014
+ )
1015
+
1016
+ # update pad token id and decoder_start_token_id
1017
+ config.update(
1018
+ {
1019
+ "pad_token_id": model_args.pad_token_id
1020
+ if model_args.pad_token_id is not None
1021
+ else config.pad_token_id,
1022
+ "decoder_start_token_id": model_args.decoder_start_token_id
1023
+ if model_args.decoder_start_token_id is not None
1024
+ else config.decoder_start_token_id,
1025
+ }
1026
+ )
1027
+
1028
+ # create model
1029
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
1030
+ model_args.model_name_or_path,
1031
+ cache_dir=model_args.cache_dir,
1032
+ config=config,
1033
+ token=data_args.token,
1034
+ trust_remote_code=data_args.trust_remote_code,
1035
+ )
1036
+
1037
+ # enable gradient checkpointing if necessary
1038
+ if training_args.gradient_checkpointing:
1039
+ model.gradient_checkpointing_enable()
1040
+
1041
+ # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
1042
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
1043
+ # so that we just need to set the correct target sampling rate and normalize the input
1044
+ # via the `feature_extractor`
1045
+
1046
+ # derive max & min input length for sample rate & max duration
1047
+ sampling_rate = feature_extractor.sampling_rate
1048
+ max_target_length = data_args.max_duration_in_seconds * sampling_rate
1049
+ min_target_length = data_args.min_duration_in_seconds * sampling_rate
1050
+ target_audio_column_name = data_args.target_audio_column_name
1051
+ description_column_name = data_args.description_column_name
1052
+ prompt_column_name = data_args.prompt_column_name
1053
+ feature_extractor_input_name = feature_extractor.model_input_names[0]
1054
+ audio_encoder_pad_token_id = config.decoder.pad_token_id
1055
+ audio_encoder_eos_token_id = config.decoder.eos_token_id
1056
+ audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
1057
+ max_length = model.generation_config.max_length
1058
+ num_codebooks = model.decoder.config.num_codebooks
1059
+ bandwidth = model_args.bandwidth
1060
+
1061
+ # Freeze Encoders
1062
+ model.freeze_encoders(model_args.freeze_text_encoder)
1063
+
1064
+ # Test all gather - used for warmout and avoiding timeout
1065
+ test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
1066
+ gathered_tensor = accelerator.gather(test_tensor)
1067
+ print("gathered_tensor", gathered_tensor)
1068
+ accelerator.wait_for_everyone()
1069
+
1070
+ if not dataset_was_precomputed:
1071
+ # Filter on text length
1072
+ if description_column_name is not None and data_args.max_text_length is not None:
1073
+ with accelerator.main_process_first():
1074
+ # filter description that is shorter than max_text_length
1075
+ raw_datasets = raw_datasets.filter(
1076
+ lambda x: len(x) < data_args.max_text_length,
1077
+ num_proc=num_workers,
1078
+ input_columns=[description_column_name],
1079
+ )
1080
+
1081
+ # Preprocessing the dataset.
1082
+ # We need to tokenize the texts.
1083
+ def pass_through_processors(description, prompt):
1084
+ batch = {}
1085
+
1086
+ batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
1087
+ batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
1088
+
1089
+ return batch
1090
+
1091
+ with accelerator.main_process_first():
1092
+ # this is a trick to avoid to rewrite the entire audio column which takes ages
1093
+ vectorized_datasets = raw_datasets.map(
1094
+ pass_through_processors,
1095
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1096
+ input_columns=[description_column_name, prompt_column_name],
1097
+ num_proc=num_workers,
1098
+ desc="preprocess datasets",
1099
+ )
1100
+
1101
+ # We use Accelerate to perform distributed inference
1102
+ # T5 doesn't support fp16
1103
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
1104
+
1105
+ # Now we encode the audio labels with encodec.
1106
+ ####### B. Encode audio
1107
+
1108
+ logger.info("*** Encode target audio with encodec ***")
1109
+
1110
+ # no need to prepare audio_decoder because used for inference without mixed precision
1111
+ # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
1112
+ if training_args.torch_compile:
1113
+ audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
1114
+ else:
1115
+ audio_decoder = model.audio_encoder
1116
+
1117
+ encoder_data_collator = DataCollatorEncodecWithPadding(
1118
+ feature_extractor,
1119
+ audio_column_name=target_audio_column_name,
1120
+ feature_extractor_input_name=feature_extractor_input_name,
1121
+ max_length=max_target_length,
1122
+ padding=padding,
1123
+ )
1124
+
1125
+ def apply_audio_decoder(batch):
1126
+ len_audio = batch.pop("len_audio")
1127
+ audio_decoder.to(batch["input_values"].device).eval()
1128
+ with torch.no_grad():
1129
+ labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
1130
+ output = {}
1131
+ output["len_audio"] = len_audio
1132
+ # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
1133
+ output["labels"] = labels.squeeze(0).transpose(1, 2)
1134
+ output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
1135
+ return output
1136
+
1137
+ for split in vectorized_datasets:
1138
+ data_loader = DataLoader(
1139
+ raw_datasets[split],
1140
+ batch_size=training_args.audio_encoder_per_device_batch_size,
1141
+ collate_fn=encoder_data_collator,
1142
+ num_workers=training_args.dataloader_num_workers,
1143
+ pin_memory=True,
1144
+ )
1145
+ data_loader = accelerator.prepare(data_loader)
1146
+
1147
+ all_generated_labels = []
1148
+ all_lens = []
1149
+ for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
1150
+ generate_labels = apply_audio_decoder(batch)
1151
+ generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
1152
+ generate_labels = accelerator.gather_for_metrics(generate_labels)
1153
+
1154
+ if accelerator.is_main_process:
1155
+ lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
1156
+ rat = generate_labels["ratio"].cpu().squeeze()
1157
+ lens = generate_labels["len_audio"].cpu().squeeze()
1158
+ lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
1159
+
1160
+ all_generated_labels.extend(lab)
1161
+ all_lens.extend(lens)
1162
+
1163
+ # (1, codebooks, seq_len) where seq_len=1
1164
+ bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
1165
+
1166
+ if accelerator.is_main_process:
1167
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
1168
+ tmp_labels.save_to_disk(
1169
+ os.path.join(data_args.temporary_save_to_disk, split),
1170
+ num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
1171
+ )
1172
+ accelerator.wait_for_everyone()
1173
+ del all_generated_labels
1174
+
1175
+ tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
1176
+ with accelerator.main_process_first():
1177
+ vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
1178
+
1179
+ def postprocess_dataset(labels):
1180
+ # (1, codebooks, seq_len)
1181
+ labels = torch.tensor(labels).unsqueeze(0)
1182
+ # add bos
1183
+ labels = torch.cat([bos_labels, labels], dim=-1)
1184
+
1185
+ labels, delay_pattern_mask = build_delay_pattern_mask(
1186
+ labels,
1187
+ bos_token_id=audio_encoder_bos_token_id,
1188
+ pad_token_id=audio_encoder_eos_token_id,
1189
+ max_length=labels.shape[-1] + num_codebooks,
1190
+ num_codebooks=num_codebooks,
1191
+ )
1192
+
1193
+ # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
1194
+ # to take care of EOS
1195
+ # we want labels to look like this:
1196
+ # - [B, a, b, E, E, E, E]
1197
+ # - [B, B, c, d, E, E, E]
1198
+ # - [B, B, B, e, f, E, E]
1199
+ # - [B, B, B, B, g, h, E]
1200
+ labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
1201
+
1202
+ # the first timestamp is associated to a row full of BOS, let's get rid of it
1203
+ # we also remove the last timestampts (full of PAD)
1204
+ output = {"labels": labels[:, 1:]}
1205
+ return output
1206
+
1207
+ with accelerator.main_process_first():
1208
+ vectorized_datasets[split] = vectorized_datasets[split].map(
1209
+ postprocess_dataset,
1210
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
1211
+ input_columns=["labels"],
1212
+ desc="Postprocessing labeling",
1213
+ )
1214
+
1215
+ accelerator.free_memory()
1216
+ del generate_labels, all_lens
1217
+
1218
+ with accelerator.main_process_first():
1219
+ # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
1220
+ # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
1221
+ # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
1222
+
1223
+ def is_audio_in_length_range(length):
1224
+ return length > min_target_length and length < max_target_length
1225
+
1226
+ # filter data that is shorter than min_target_length
1227
+ vectorized_datasets = vectorized_datasets.filter(
1228
+ is_audio_in_length_range,
1229
+ num_proc=num_workers,
1230
+ input_columns=["target_length"],
1231
+ )
1232
+
1233
+ if description_column_name is not None and data_args.max_description_token_length is not None:
1234
+ with accelerator.main_process_first():
1235
+ # filter description that is shorter than max_text_length
1236
+ vectorized_datasets = vectorized_datasets.filter(
1237
+ lambda x: len(x) < data_args.max_description_token_length,
1238
+ num_proc=num_workers,
1239
+ input_columns=["input_ids"],
1240
+ )
1241
+
1242
+ if data_args.max_prompt_token_length is not None:
1243
+ with accelerator.main_process_first():
1244
+ # filter description that is shorter than max_text_length
1245
+ vectorized_datasets = vectorized_datasets.filter(
1246
+ lambda x: len(x) < data_args.max_prompt_token_length,
1247
+ num_proc=num_workers,
1248
+ input_columns=["prompt_input_ids"],
1249
+ )
1250
+
1251
+ if data_args.save_to_disk is not None and not dataset_was_precomputed:
1252
+ if accelerator.is_main_process:
1253
+ vectorized_datasets.save_to_disk(
1254
+ data_args.save_to_disk,
1255
+ num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
1256
+ )
1257
+ logger.info(f"Dataset saved at {data_args.save_to_disk}")
1258
+
1259
+ audio_max_length = None
1260
+ if training_args.torch_compile:
1261
+ audio_max_length = max(vectorized_datasets["train"]["target_length"])
1262
+ with accelerator.main_process_first():
1263
+ max_sample = vectorized_datasets["train"].filter(
1264
+ lambda x: x == audio_max_length,
1265
+ num_proc=num_workers,
1266
+ input_columns=["target_length"],
1267
+ )
1268
+ audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
1269
+
1270
+ # for large datasets it is advised to run the preprocessing on a
1271
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
1272
+ # be a timeout when running the script in distributed mode.
1273
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
1274
+ # cached dataset
1275
+ if data_args.preprocessing_only and data_args.save_to_disk is None:
1276
+ raise ValueError(
1277
+ "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
1278
+ )
1279
+ elif data_args.preprocessing_only:
1280
+ logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
1281
+ return
1282
+
1283
+ # 6. Next, we can prepare the training.
1284
+
1285
+ # Let's use word CLAP similary and WER metrics as our evaluation metrics,
1286
+
1287
+ # Define evaluation metrics during training, *i.e.* CLAP similarity
1288
+ clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
1289
+ clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
1290
+ metric = evaluate.load("wer")
1291
+
1292
+ def clap_similarity(texts, audios, device):
1293
+ clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
1294
+ clap.to(device)
1295
+ with torch.no_grad():
1296
+ text_features = clap.get_text_features(
1297
+ clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
1298
+ )
1299
+ audio_features = clap.get_audio_features(clap_inputs["input_features"])
1300
+
1301
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
1302
+
1303
+ clap.to("cpu")
1304
+ clap_inputs.to("cpu")
1305
+ return cosine_sim.mean().to("cpu")
1306
+
1307
+ def wer(prompts, audios, device):
1308
+ asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
1309
+ transcriptions = asr_pipeline(
1310
+ [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
1311
+ batch_size=int(training_args.per_device_eval_batch_size),
1312
+ )
1313
+
1314
+ word_error = 100 * metric.compute(
1315
+ predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
1316
+ )
1317
+
1318
+ return word_error, [t["text"] for t in transcriptions]
1319
+
1320
+ eval_methods = {"clap": clap_similarity, "wer": wer}
1321
+
1322
+ def compute_metrics(audios, descriptions, prompts, device="cpu"):
1323
+ input_ids = descriptions
1324
+ texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
1325
+ prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
1326
+ audios = [a.cpu().numpy() for a in audios]
1327
+ results = {"clap": eval_methods["clap"](texts, audios, device)}
1328
+ word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
1329
+ results["wer"] = word_error
1330
+
1331
+ return results, texts, prompts, audios, transcriptions
1332
+
1333
+ # Define Training Schedule
1334
+ # Store some constants
1335
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1336
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1337
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1338
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1339
+
1340
+ if training_args.max_steps < 0:
1341
+ num_epochs = int(training_args.num_train_epochs)
1342
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1343
+ total_train_steps = steps_per_epoch * num_epochs
1344
+ elif training_args.max_steps > 0:
1345
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1346
+ total_train_steps = int(training_args.max_steps)
1347
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1348
+ num_epochs = sys.maxsize
1349
+ steps_per_epoch = total_train_steps
1350
+
1351
+ if training_args.eval_steps is None:
1352
+ logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
1353
+ eval_steps = steps_per_epoch
1354
+ else:
1355
+ eval_steps = training_args.eval_steps
1356
+
1357
+ # T5 doesn't support fp16
1358
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
1359
+
1360
+ # Define optimizer, LR scheduler, collator
1361
+ optimizer = torch.optim.AdamW(
1362
+ params=model.parameters(),
1363
+ lr=training_args.learning_rate,
1364
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1365
+ eps=training_args.adam_epsilon,
1366
+ weight_decay=training_args.weight_decay,
1367
+ )
1368
+
1369
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1370
+ lr_scheduler = get_scheduler(
1371
+ name=training_args.lr_scheduler_type,
1372
+ optimizer=optimizer,
1373
+ num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
1374
+ num_training_steps=total_train_steps * accelerator.num_processes,
1375
+ )
1376
+
1377
+ # Instantiate custom data collator
1378
+ data_collator = DataCollatorParlerTTSWithPadding(
1379
+ prompt_tokenizer=prompt_tokenizer,
1380
+ description_tokenizer=description_tokenizer,
1381
+ pad_to_multiple_of=data_args.pad_to_multiple_of,
1382
+ padding=padding,
1383
+ prompt_max_length=data_args.max_prompt_token_length,
1384
+ description_max_length=data_args.max_description_token_length,
1385
+ audio_max_length=audio_max_length,
1386
+ )
1387
+
1388
+ # Prepare everything with accelerate
1389
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
1390
+
1391
+ logger.info("***** Running training *****")
1392
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1393
+ logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
1394
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1395
+ logger.info(
1396
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1397
+ )
1398
+ logger.info(f" Total optimization steps = {total_train_steps}")
1399
+
1400
+ # ======================== Training ================================
1401
+ train_time = 0
1402
+ train_start = time.time()
1403
+ steps_trained_progress_bar = tqdm(
1404
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1405
+ )
1406
+ continue_training = True
1407
+ epochs_trained = 0
1408
+ cur_step = 0
1409
+
1410
+ checkpoint = None
1411
+ if training_args.resume_from_checkpoint is not None:
1412
+ checkpoint = training_args.resume_from_checkpoint
1413
+ elif last_checkpoint is not None:
1414
+ checkpoint = last_checkpoint
1415
+
1416
+ if accelerator.is_main_process:
1417
+ if training_args.push_to_hub:
1418
+ # Retrieve of infer repo_name
1419
+ repo_name = training_args.hub_model_id
1420
+ if repo_name is None:
1421
+ repo_name = Path(training_args.output_dir).absolute().name
1422
+ # Create repo and retrieve repo_id
1423
+ repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
1424
+ # Clone repo locally
1425
+ repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
1426
+
1427
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
1428
+ if "wandb" not in gitignore:
1429
+ gitignore.write("wandb\n")
1430
+ elif training_args.output_dir is not None:
1431
+ os.makedirs(training_args.output_dir, exist_ok=True)
1432
+ accelerator.wait_for_everyone()
1433
+
1434
+ # Now save everything to be able to create a single processor later
1435
+ # make sure all processes wait until data is saved
1436
+ with accelerator.main_process_first():
1437
+ # only the main process saves them
1438
+ if accelerator.is_main_process:
1439
+ # save feature extractor, tokenizer and config
1440
+ if (
1441
+ model_args.prompt_tokenizer_name is None
1442
+ and model_args.description_tokenizer_name
1443
+ or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
1444
+ ):
1445
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
1446
+ else:
1447
+ logger.warning(
1448
+ "Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
1449
+ )
1450
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
1451
+
1452
+ feature_extractor.save_pretrained(training_args.output_dir)
1453
+ config.save_pretrained(training_args.output_dir)
1454
+
1455
+ if checkpoint is not None:
1456
+ accelerator.load_state(checkpoint)
1457
+ # Find num steps and epoch from saved state string pattern
1458
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1459
+ match = re.search(pattern, checkpoint)
1460
+ cur_step = int(match.group(1))
1461
+ epochs_trained = int(match.group(2))
1462
+
1463
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1464
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1465
+ logger.info(f" Continuing training from global step {cur_step}")
1466
+
1467
+ steps_trained_progress_bar.update(cur_step)
1468
+
1469
+ for epoch in range(0, epochs_trained):
1470
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1471
+
1472
+ if training_args.max_steps < 0:
1473
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1474
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1475
+ else:
1476
+ # Currently we don't know how many steps we've taken in the current epoch
1477
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1478
+ # This is "good enough" for our purposes but not fully correct
1479
+ resume_step = None
1480
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1481
+ else:
1482
+ resume_step = None
1483
+
1484
+ gen_kwargs = {
1485
+ "do_sample": model_args.do_sample,
1486
+ "temperature": model_args.temperature,
1487
+ "max_length": model_args.max_length,
1488
+ }
1489
+
1490
+ # Define gradient update step fn
1491
+ def train_step(
1492
+ batch,
1493
+ accelerator,
1494
+ autocast_kwargs,
1495
+ ):
1496
+ model.train()
1497
+
1498
+ if mixed_precision == "fp16":
1499
+ # fp16 doesn't work with T5-like models
1500
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
1501
+ if training_args.parallel_mode.value != "distributed":
1502
+ encoder_outputs = model.text_encoder(
1503
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1504
+ )
1505
+ else:
1506
+ encoder_outputs = model.module.text_encoder(
1507
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1508
+ )
1509
+ batch["encoder_outputs"] = encoder_outputs
1510
+
1511
+ outputs = model(**batch)
1512
+ # CE (data) loss
1513
+ ce_loss = outputs.loss
1514
+
1515
+ metrics = {"loss": ce_loss}
1516
+ return ce_loss, metrics
1517
+
1518
+ # Define eval fn
1519
+ def eval_step(
1520
+ batch,
1521
+ accelerator,
1522
+ autocast_kwargs,
1523
+ ):
1524
+ eval_model = model if not training_args.torch_compile else model._orig_mod
1525
+ eval_model.eval()
1526
+
1527
+ if mixed_precision == "fp16":
1528
+ # fp16 doesn't work with T5-like models
1529
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
1530
+ with torch.no_grad():
1531
+ if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
1532
+ encoder_outputs = eval_model.text_encoder(
1533
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1534
+ )
1535
+ else:
1536
+ encoder_outputs = eval_model.module.text_encoder(
1537
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1538
+ )
1539
+ batch["encoder_outputs"] = encoder_outputs
1540
+
1541
+ with torch.no_grad():
1542
+ outputs = eval_model(**batch)
1543
+ # CE (data) loss
1544
+ ce_loss = outputs.loss
1545
+ metrics = {"loss": ce_loss}
1546
+ return metrics
1547
+
1548
+ def generate_step(batch):
1549
+ batch.pop("decoder_attention_mask", None)
1550
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
1551
+ if training_args.torch_compile:
1552
+ eval_model = model._orig_mod
1553
+
1554
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
1555
+ output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
1556
+ return output_audios
1557
+
1558
+ for epoch in range(epochs_trained, num_epochs):
1559
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1560
+ sampler = None
1561
+ if training_args.group_by_length:
1562
+ sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
1563
+ train_dataloader = DataLoader(
1564
+ vectorized_datasets["train"],
1565
+ collate_fn=data_collator,
1566
+ batch_size=per_device_train_batch_size,
1567
+ sampler=sampler,
1568
+ num_workers=training_args.dataloader_num_workers,
1569
+ pin_memory=training_args.dataloader_pin_memory,
1570
+ )
1571
+ train_dataloader = accelerator.prepare(train_dataloader)
1572
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1573
+ train_dataloader.dataset.set_epoch(epoch)
1574
+
1575
+ if resume_step is not None:
1576
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1577
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1578
+ resume_step = None
1579
+
1580
+ for batch in train_dataloader:
1581
+ with accelerator.accumulate(model):
1582
+ loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
1583
+ accelerator.backward(loss)
1584
+ if accelerator.sync_gradients:
1585
+ accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
1586
+ optimizer.step()
1587
+ lr_scheduler.step()
1588
+ optimizer.zero_grad()
1589
+
1590
+ # Check if the accelerator has performed an optimization step behind the scenes
1591
+ if accelerator.sync_gradients:
1592
+ steps_trained_progress_bar.update(1)
1593
+ cur_step += 1
1594
+
1595
+ if cur_step % training_args.logging_steps == 0:
1596
+ steps_trained_progress_bar.write(
1597
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1598
+ f" {train_metric['loss']}, Learning Rate:"
1599
+ f" {lr_scheduler.get_last_lr()[0]})"
1600
+ )
1601
+ log_metric(
1602
+ accelerator,
1603
+ metrics=train_metric,
1604
+ learning_rate=lr_scheduler.get_last_lr()[0],
1605
+ train_time=train_time + time.time() - train_start,
1606
+ step=cur_step,
1607
+ epoch=epoch,
1608
+ prefix="train",
1609
+ )
1610
+
1611
+ # save checkpoint and weights after each save_steps and at the end of training
1612
+ if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1613
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1614
+ # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
1615
+ # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
1616
+ accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
1617
+ accelerator.wait_for_everyone()
1618
+ if accelerator.is_main_process:
1619
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1620
+
1621
+ if cur_step == total_train_steps:
1622
+ # un-wrap student model for save
1623
+ unwrapped_model = accelerator.unwrap_model(model)
1624
+ unwrapped_model.save_pretrained(training_args.output_dir)
1625
+
1626
+ if training_args.push_to_hub:
1627
+ repo.push_to_hub(
1628
+ commit_message=f"Saving train state of step {cur_step}",
1629
+ blocking=False,
1630
+ )
1631
+
1632
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1633
+ train_time += time.time() - train_start
1634
+ # ======================== Evaluating ==============================
1635
+ eval_metrics = []
1636
+ eval_preds = []
1637
+ eval_descriptions = []
1638
+ eval_prompts = []
1639
+ eval_start = time.time()
1640
+
1641
+ # release training input batch
1642
+ batch = release_memory(batch)
1643
+
1644
+ validation_dataloader = DataLoader(
1645
+ vectorized_datasets["eval"],
1646
+ collate_fn=data_collator,
1647
+ batch_size=per_device_eval_batch_size,
1648
+ drop_last=False,
1649
+ num_workers=training_args.dataloader_pin_memory,
1650
+ pin_memory=training_args.dataloader_pin_memory,
1651
+ )
1652
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1653
+
1654
+ for batch in tqdm(
1655
+ validation_dataloader,
1656
+ desc=f"Evaluating - Inference ...",
1657
+ position=2,
1658
+ disable=not accelerator.is_local_main_process,
1659
+ ):
1660
+ # Model forward
1661
+ eval_metric = eval_step(batch, accelerator, autocast_kwargs)
1662
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1663
+ eval_metrics.append(eval_metric)
1664
+
1665
+ if training_args.predict_with_generate:
1666
+ validation_dataloader = DataLoader(
1667
+ vectorized_datasets["eval"],
1668
+ collate_fn=data_collator,
1669
+ batch_size=per_device_eval_batch_size,
1670
+ drop_last=False,
1671
+ num_workers=training_args.dataloader_pin_memory,
1672
+ pin_memory=training_args.dataloader_pin_memory,
1673
+ )
1674
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1675
+ # generation
1676
+ for batch in tqdm(
1677
+ validation_dataloader,
1678
+ desc=f"Evaluating - Generation ...",
1679
+ position=2,
1680
+ disable=not accelerator.is_local_main_process,
1681
+ ):
1682
+ generated_audios = generate_step(batch)
1683
+ # Gather all predictions and targets
1684
+ generated_audios, input_ids, prompts = accelerator.pad_across_processes(
1685
+ (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
1686
+ )
1687
+ generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
1688
+ (generated_audios, input_ids, prompts)
1689
+ )
1690
+ eval_preds.extend(generated_audios.to("cpu"))
1691
+ eval_descriptions.extend(input_ids.to("cpu"))
1692
+ eval_prompts.extend(prompts.to("cpu"))
1693
+
1694
+ eval_time = time.time() - eval_start
1695
+ # normalize eval metrics
1696
+ eval_metrics = {
1697
+ key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
1698
+ for key in eval_metrics[0]
1699
+ }
1700
+
1701
+ # compute metrics
1702
+ metrics_desc = ""
1703
+ if training_args.predict_with_generate:
1704
+ metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics(
1705
+ eval_preds, eval_descriptions, eval_prompts, accelerator.device
1706
+ )
1707
+ eval_metrics.update(metric_values)
1708
+ metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
1709
+ if "wandb" in training_args.report_to:
1710
+ log_pred(
1711
+ accelerator,
1712
+ pred_descriptions,
1713
+ pred_prompts,
1714
+ transcriptions,
1715
+ audios,
1716
+ sampling_rate=sampling_rate,
1717
+ step=cur_step,
1718
+ prefix="eval",
1719
+ )
1720
+
1721
+ # Print metrics and update progress bar
1722
+ steps_trained_progress_bar.write(
1723
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1724
+ f" {metrics_desc})"
1725
+ )
1726
+
1727
+ log_metric(
1728
+ accelerator,
1729
+ metrics=eval_metrics,
1730
+ train_time=eval_time,
1731
+ step=cur_step,
1732
+ epoch=epoch,
1733
+ prefix="eval",
1734
+ )
1735
+
1736
+ # release eval batch and relax metrics
1737
+ eval_metrics = []
1738
+ eval_preds = []
1739
+ eval_descriptions = []
1740
+ eval_prompts = []
1741
+ batch = release_memory(batch)
1742
+
1743
+ # flush the train metrics
1744
+ train_start = time.time()
1745
+
1746
+ # break condition
1747
+ if cur_step == total_train_steps:
1748
+ continue_training = False
1749
+ break
1750
+
1751
+ if not continue_training:
1752
+ break
1753
+
1754
+ accelerator.end_training()
1755
+
1756
+
1757
+ if __name__ == "__main__":
1758
+ set_start_method("spawn")
1759
+ main()
1760
+
special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "padding_side": "left",
938
+ "sp_model_kwargs": {},
939
+ "tokenizer_class": "T5Tokenizer",
940
+ "unk_token": "<unk>"
941
+ }