sanchit-gandhi HF staff commited on
Commit
37a08da
1 Parent(s): 9f3a8be

Upload folder using huggingface_hub

Browse files
checkpoint-1120-epoch-6/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72090bb2fa144ea827c28f55280c9b5d4e50d3f5a89359fcc850079f35309b1e
3
+ size 3652763351
checkpoint-1120-epoch-6/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bce47e0f3b778b4ad7ddf6164868262be9682558bb083a0d53841a87279b3537
3
+ size 2588462170
checkpoint-1120-epoch-6/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ca49fbae7eb9705c8b805b1ab5f1049040cacf26d75207ae76b7e964251141d
3
+ size 14344
checkpoint-1120-epoch-6/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e9dda5edaf0996237b47609e5ead50419f750cf63cc30b02d73ce7ca51ed5b0
3
+ size 1000
checkpoint-1280-epoch-7/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a22ddbb565705b39d5bb86e2494d49f115f58b308e1207e78b98dc9c8456295
3
+ size 3652763351
checkpoint-1280-epoch-7/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee2430b336c101ba8a0bf42fb405360191213a0085e9e16d048026292650a0a0
3
+ size 2588462170
checkpoint-1280-epoch-7/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2858b9303b7791ffdd884017cf45488bdde5e7e57da3915a2f8642af64676749
3
+ size 14408
checkpoint-1280-epoch-7/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e667786cde5557ba042df8b54052d7a5d1fc0343a39be07ae4e658be38c0fd5b
3
+ size 1000
checkpoint-640-epoch-3/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7676577b9a00401b141b802905c27c0e36fbbfb39be9fe4ab5c416440d084b97
3
+ size 3652763351
checkpoint-640-epoch-3/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7dc0f7aee183a683eaf160c4cb24b324321cd016f5d56da009862f2436b65f3
3
+ size 2588462170
checkpoint-640-epoch-3/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:256df18c89f256f371cd05fd9c75d1352980db04c485621dcf7c689ad1f421d6
3
+ size 14408
checkpoint-640-epoch-3/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c6e13ca1e843d20bdad4e8b9e4c4b9a971b6c11e81ca88ea991c25a1a7bb4b6
3
+ size 1000
checkpoint-800-epoch-4/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1be446c0dbd4b6c8512cfca027c114f7ad54b35c3d4076c76436c03f8c7e8b1
3
+ size 3652763351
checkpoint-800-epoch-4/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:241896d2452407ad54eeff936427883c44a2e230e23b765744eacfd779d02213
3
+ size 2588462170
checkpoint-800-epoch-4/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62803c1cff666cd37d714e1f111d0780398deb4f04ea085c8ed15a8efecea7f5
3
+ size 14344
checkpoint-800-epoch-4/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b9eb049f1562b6873fae29ad37c91ec8873f5a9f839a2af7ddc4c6a15b79b18
3
+ size 1000
checkpoint-960-epoch-5/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd7754aeea66398a842117378c23e420c4b9bfd2ee4e1c15ef940d894a0b6e35
3
+ size 3652763351
checkpoint-960-epoch-5/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52b3e76b7fde9e53c910cfcb44db88274cc60c24569f3f51786d34fd6b06e9f1
3
+ size 2588462170
checkpoint-960-epoch-5/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43e8de9bf1af8ea2926a2a17dba99400420a9bbdbc76642417f1086a83a5227e
3
+ size 14344
checkpoint-960-epoch-5/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efa6daa533a1a0406a7fb0ba0638b044472c0f8f17f62573f1959073e81617fe
3
+ size 1000
config.json ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "parler-tts/parler_tts_mini_v0.1",
3
+ "architectures": [
4
+ "ParlerTTSForConditionalGeneration"
5
+ ],
6
+ "audio_encoder": {
7
+ "_name_or_path": "ylacombe/dac_44khZ_8kbps",
8
+ "add_cross_attention": false,
9
+ "architectures": [
10
+ "DACModel"
11
+ ],
12
+ "bad_words_ids": null,
13
+ "begin_suppress_tokens": null,
14
+ "bos_token_id": null,
15
+ "chunk_size_feed_forward": 0,
16
+ "codebook_size": 1024,
17
+ "cross_attention_hidden_size": null,
18
+ "decoder_start_token_id": null,
19
+ "diversity_penalty": 0.0,
20
+ "do_sample": false,
21
+ "early_stopping": false,
22
+ "encoder_no_repeat_ngram_size": 0,
23
+ "eos_token_id": null,
24
+ "exponential_decay_length_penalty": null,
25
+ "finetuning_task": null,
26
+ "forced_bos_token_id": null,
27
+ "forced_eos_token_id": null,
28
+ "frame_rate": 86,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1"
32
+ },
33
+ "is_decoder": false,
34
+ "is_encoder_decoder": false,
35
+ "label2id": {
36
+ "LABEL_0": 0,
37
+ "LABEL_1": 1
38
+ },
39
+ "latent_dim": 1024,
40
+ "length_penalty": 1.0,
41
+ "max_length": 20,
42
+ "min_length": 0,
43
+ "model_bitrate": 8,
44
+ "model_type": "dac",
45
+ "no_repeat_ngram_size": 0,
46
+ "num_beam_groups": 1,
47
+ "num_beams": 1,
48
+ "num_codebooks": 9,
49
+ "num_return_sequences": 1,
50
+ "output_attentions": false,
51
+ "output_hidden_states": false,
52
+ "output_scores": false,
53
+ "pad_token_id": null,
54
+ "prefix": null,
55
+ "problem_type": null,
56
+ "pruned_heads": {},
57
+ "remove_invalid_values": false,
58
+ "repetition_penalty": 1.0,
59
+ "return_dict": true,
60
+ "return_dict_in_generate": false,
61
+ "sampling_rate": 44100,
62
+ "sep_token_id": null,
63
+ "suppress_tokens": null,
64
+ "task_specific_params": null,
65
+ "temperature": 1.0,
66
+ "tf_legacy_loss": false,
67
+ "tie_encoder_decoder": false,
68
+ "tie_word_embeddings": true,
69
+ "tokenizer_class": null,
70
+ "top_k": 50,
71
+ "top_p": 1.0,
72
+ "torch_dtype": "float32",
73
+ "torchscript": false,
74
+ "typical_p": 1.0,
75
+ "use_bfloat16": false
76
+ },
77
+ "decoder": {
78
+ "_name_or_path": "/fsx/yoach/tmp/artefacts/decoder_400M/",
79
+ "activation_dropout": 0.0,
80
+ "activation_function": "gelu",
81
+ "add_cross_attention": true,
82
+ "architectures": [
83
+ "ParlerTTSForCausalLM"
84
+ ],
85
+ "attention_dropout": 0.0,
86
+ "bad_words_ids": null,
87
+ "begin_suppress_tokens": null,
88
+ "bos_token_id": 1025,
89
+ "chunk_size_feed_forward": 0,
90
+ "cross_attention_hidden_size": null,
91
+ "decoder_start_token_id": null,
92
+ "diversity_penalty": 0.0,
93
+ "do_sample": false,
94
+ "dropout": 0.1,
95
+ "early_stopping": false,
96
+ "encoder_no_repeat_ngram_size": 0,
97
+ "eos_token_id": 1024,
98
+ "exponential_decay_length_penalty": null,
99
+ "ffn_dim": 4096,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "hidden_size": 1024,
104
+ "id2label": {
105
+ "0": "LABEL_0",
106
+ "1": "LABEL_1"
107
+ },
108
+ "initializer_factor": 0.02,
109
+ "is_decoder": true,
110
+ "is_encoder_decoder": false,
111
+ "label2id": {
112
+ "LABEL_0": 0,
113
+ "LABEL_1": 1
114
+ },
115
+ "layerdrop": 0.0,
116
+ "length_penalty": 1.0,
117
+ "max_length": 20,
118
+ "max_position_embeddings": 4096,
119
+ "min_length": 0,
120
+ "model_type": "parler_tts_decoder",
121
+ "no_repeat_ngram_size": 0,
122
+ "num_attention_heads": 16,
123
+ "num_beam_groups": 1,
124
+ "num_beams": 1,
125
+ "num_codebooks": 9,
126
+ "num_hidden_layers": 24,
127
+ "num_return_sequences": 1,
128
+ "output_attentions": false,
129
+ "output_hidden_states": false,
130
+ "output_scores": false,
131
+ "pad_token_id": 1024,
132
+ "prefix": null,
133
+ "problem_type": null,
134
+ "pruned_heads": {},
135
+ "remove_invalid_values": false,
136
+ "repetition_penalty": 1.0,
137
+ "return_dict": true,
138
+ "return_dict_in_generate": false,
139
+ "scale_embedding": false,
140
+ "sep_token_id": null,
141
+ "suppress_tokens": null,
142
+ "task_specific_params": null,
143
+ "temperature": 1.0,
144
+ "tf_legacy_loss": false,
145
+ "tie_encoder_decoder": false,
146
+ "tie_word_embeddings": false,
147
+ "tokenizer_class": null,
148
+ "top_k": 50,
149
+ "top_p": 1.0,
150
+ "torch_dtype": "float32",
151
+ "torchscript": false,
152
+ "typical_p": 1.0,
153
+ "use_bfloat16": false,
154
+ "use_cache": true,
155
+ "vocab_size": 1088
156
+ },
157
+ "decoder_start_token_id": 1025,
158
+ "is_encoder_decoder": true,
159
+ "model_type": "parler_tts",
160
+ "pad_token_id": 1024,
161
+ "text_encoder": {
162
+ "_name_or_path": "google/flan-t5-base",
163
+ "add_cross_attention": false,
164
+ "architectures": [
165
+ "T5ForConditionalGeneration"
166
+ ],
167
+ "bad_words_ids": null,
168
+ "begin_suppress_tokens": null,
169
+ "bos_token_id": null,
170
+ "chunk_size_feed_forward": 0,
171
+ "classifier_dropout": 0.0,
172
+ "cross_attention_hidden_size": null,
173
+ "d_ff": 2048,
174
+ "d_kv": 64,
175
+ "d_model": 768,
176
+ "decoder_start_token_id": 0,
177
+ "dense_act_fn": "gelu_new",
178
+ "diversity_penalty": 0.0,
179
+ "do_sample": false,
180
+ "dropout_rate": 0.1,
181
+ "early_stopping": false,
182
+ "encoder_no_repeat_ngram_size": 0,
183
+ "eos_token_id": 1,
184
+ "exponential_decay_length_penalty": null,
185
+ "feed_forward_proj": "gated-gelu",
186
+ "finetuning_task": null,
187
+ "forced_bos_token_id": null,
188
+ "forced_eos_token_id": null,
189
+ "id2label": {
190
+ "0": "LABEL_0",
191
+ "1": "LABEL_1"
192
+ },
193
+ "initializer_factor": 1.0,
194
+ "is_decoder": false,
195
+ "is_encoder_decoder": true,
196
+ "is_gated_act": true,
197
+ "label2id": {
198
+ "LABEL_0": 0,
199
+ "LABEL_1": 1
200
+ },
201
+ "layer_norm_epsilon": 1e-06,
202
+ "length_penalty": 1.0,
203
+ "max_length": 20,
204
+ "min_length": 0,
205
+ "model_type": "t5",
206
+ "n_positions": 512,
207
+ "no_repeat_ngram_size": 0,
208
+ "num_beam_groups": 1,
209
+ "num_beams": 1,
210
+ "num_decoder_layers": 12,
211
+ "num_heads": 12,
212
+ "num_layers": 12,
213
+ "num_return_sequences": 1,
214
+ "output_attentions": false,
215
+ "output_hidden_states": false,
216
+ "output_past": true,
217
+ "output_scores": false,
218
+ "pad_token_id": 0,
219
+ "prefix": null,
220
+ "problem_type": null,
221
+ "pruned_heads": {},
222
+ "relative_attention_max_distance": 128,
223
+ "relative_attention_num_buckets": 32,
224
+ "remove_invalid_values": false,
225
+ "repetition_penalty": 1.0,
226
+ "return_dict": true,
227
+ "return_dict_in_generate": false,
228
+ "sep_token_id": null,
229
+ "suppress_tokens": null,
230
+ "task_specific_params": {
231
+ "summarization": {
232
+ "early_stopping": true,
233
+ "length_penalty": 2.0,
234
+ "max_length": 200,
235
+ "min_length": 30,
236
+ "no_repeat_ngram_size": 3,
237
+ "num_beams": 4,
238
+ "prefix": "summarize: "
239
+ },
240
+ "translation_en_to_de": {
241
+ "early_stopping": true,
242
+ "max_length": 300,
243
+ "num_beams": 4,
244
+ "prefix": "translate English to German: "
245
+ },
246
+ "translation_en_to_fr": {
247
+ "early_stopping": true,
248
+ "max_length": 300,
249
+ "num_beams": 4,
250
+ "prefix": "translate English to French: "
251
+ },
252
+ "translation_en_to_ro": {
253
+ "early_stopping": true,
254
+ "max_length": 300,
255
+ "num_beams": 4,
256
+ "prefix": "translate English to Romanian: "
257
+ }
258
+ },
259
+ "temperature": 1.0,
260
+ "tf_legacy_loss": false,
261
+ "tie_encoder_decoder": false,
262
+ "tie_word_embeddings": false,
263
+ "tokenizer_class": null,
264
+ "top_k": 50,
265
+ "top_p": 1.0,
266
+ "torch_dtype": null,
267
+ "torchscript": false,
268
+ "typical_p": 1.0,
269
+ "use_bfloat16": false,
270
+ "use_cache": true,
271
+ "vocab_size": 32128
272
+ },
273
+ "torch_dtype": "float32",
274
+ "transformers_version": "4.41.0.dev0",
275
+ "vocab_size": 32128
276
+ }
finetuning_config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "parler-tts/parler_tts_mini_v0.1",
3
+ "feature_extractor_name": "parler-tts/dac_44khZ_8kbps",
4
+ "description_tokenizer_name": "parler-tts/parler_tts_mini_v0.1",
5
+ "prompt_tokenizer_name": "parler-tts/parler_tts_mini_v0.1",
6
+ "report_to": ["wandb"],
7
+ "overwrite_output_dir": true,
8
+ "train_dataset_name": "ylacombe/expresso+reach-vb/jenny_tts_dataset+sanchit-gandhi/libritts_r_test+sanchit-gandhi/libritts_r_test",
9
+ "train_metadata_dataset_name": "reach-vb/expresso-tagged-w-speech-mistral-v3+ylacombe/jenny-tts-10k-tagged+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated",
10
+ "train_dataset_config_name": "read+default+clean+other",
11
+ "train_split_name": "train[:-25]+train[:20%]+test.clean+test.other[:-31]",
12
+ "eval_dataset_name": "ylacombe/expresso+reach-vb/jenny_tts_dataset+sanchit-gandhi/libritts_r_test+sanchit-gandhi/libritts_r_test",
13
+ "eval_metadata_dataset_name": "reach-vb/expresso-tagged-w-speech-mistral-v3+ylacombe/jenny-tts-10k-tagged+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated",
14
+ "eval_dataset_config_name": "read+default+clean+other",
15
+ "eval_split_name": "train+train[:20%]+test.clean+test.other",
16
+ "max_eval_samples": 8,
17
+ "per_device_eval_batch_size": 16,
18
+ "target_audio_column_name": "audio",
19
+ "description_column_name": "text_description",
20
+ "prompt_column_name": "text",
21
+ "max_duration_in_seconds": 30.0,
22
+ "min_duration_in_seconds": 2.0,
23
+ "max_text_length": 400,
24
+ "preprocessing_num_workers": 2,
25
+ "do_train": true,
26
+ "num_train_epochs": 8,
27
+ "max_steps": -1,
28
+ "gradient_accumulation_steps": 8,
29
+ "gradient_checkpointing": true,
30
+ "per_device_train_batch_size": 16,
31
+ "learning_rate": 0.00008,
32
+ "adam_beta1": 0.9,
33
+ "adam_beta2": 0.99,
34
+ "weight_decay": 0.01,
35
+ "lr_scheduler_type": "cosine",
36
+ "warmup_steps": 250,
37
+ "logging_steps": 5,
38
+ "freeze_text_encoder": true,
39
+ "audio_encoder_per_device_batch_size": 4,
40
+ "dtype": "bfloat16",
41
+ "seed": 456,
42
+ "output_dir": "./",
43
+ "temporary_save_to_disk": "../audio_code_tmp_constant/",
44
+ "save_to_disk": "../tmp_dataset_audio_constant/",
45
+ "dataloader_num_workers": 4,
46
+ "do_eval": true,
47
+ "predict_with_generate": true,
48
+ "include_inputs_for_metrics": true,
49
+ "save_strategy": "epoch",
50
+ "evaluation_strategy": "epoch",
51
+ "save_total_limit": 5,
52
+ "group_by_length": true
53
+ }
54
+
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1025,
4
+ "decoder_start_token_id": 1025,
5
+ "do_sample": true,
6
+ "eos_token_id": 1024,
7
+ "guidance_scale": 1.0,
8
+ "max_length": 2580,
9
+ "min_new_tokens": 50,
10
+ "pad_token_id": 1024,
11
+ "transformers_version": "4.41.0.dev0"
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45937d8dfa9f232e79f76f536602301dde6f2b5be7200414d69ff07bff1228c3
3
+ size 2588215392
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length_s": null,
3
+ "feature_extractor_type": "EncodecFeatureExtractor",
4
+ "feature_size": 1,
5
+ "overlap": null,
6
+ "padding_side": "right",
7
+ "padding_value": 0.0,
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 44100
10
+ }
run_parler_tts_training.py ADDED
@@ -0,0 +1,1763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ Train Parler-TTS using 🤗 Accelerate"""
18
+
19
+ import logging
20
+ import os
21
+ import re
22
+ import shutil
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from datetime import timedelta
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional, Set, Union
29
+
30
+ import datasets
31
+ import evaluate
32
+ import numpy as np
33
+ import torch
34
+ import transformers
35
+ from accelerate import Accelerator
36
+ from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, set_seed
37
+ from accelerate.utils.memory import release_memory
38
+ from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
39
+ from huggingface_hub import Repository, create_repo
40
+ from multiprocess import set_start_method
41
+ from torch.utils.data import DataLoader
42
+ from tqdm import tqdm
43
+ from transformers import (
44
+ AutoFeatureExtractor,
45
+ AutoModel,
46
+ AutoProcessor,
47
+ AutoTokenizer,
48
+ HfArgumentParser,
49
+ Seq2SeqTrainingArguments,
50
+ pipeline,
51
+ )
52
+ from transformers.optimization import get_scheduler
53
+ from transformers.trainer_pt_utils import LengthGroupedSampler
54
+ from transformers.utils import send_example_telemetry
55
+ from wandb import Audio
56
+
57
+ from parler_tts import (
58
+ ParlerTTSConfig,
59
+ ParlerTTSForConditionalGeneration,
60
+ build_delay_pattern_mask,
61
+ )
62
+
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ def list_field(default=None, metadata=None):
68
+ return field(default_factory=lambda: default, metadata=metadata)
69
+
70
+
71
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
72
+
73
+
74
+ def get_last_checkpoint(folder):
75
+ content = os.listdir(folder)
76
+ checkpoints = [
77
+ path
78
+ for path in content
79
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
80
+ ]
81
+ if len(checkpoints) == 0:
82
+ return
83
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
84
+
85
+
86
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
87
+ """Helper function to sort saved checkpoints from oldest to newest."""
88
+ ordering_and_checkpoint_path = []
89
+
90
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
91
+
92
+ for path in glob_checkpoints:
93
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
94
+ if regex_match is not None and regex_match.groups() is not None:
95
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
96
+
97
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
98
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
99
+ return checkpoints_sorted
100
+
101
+
102
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
103
+ """Helper function to delete old checkpoints."""
104
+ if save_total_limit is None or save_total_limit <= 0:
105
+ return
106
+ # Check if we should delete older checkpoint(s)
107
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
108
+ if len(checkpoints_sorted) <= save_total_limit:
109
+ return
110
+
111
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
112
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
113
+ for checkpoint in checkpoints_to_be_deleted:
114
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
115
+ shutil.rmtree(checkpoint, ignore_errors=True)
116
+
117
+
118
+ def log_metric(
119
+ accelerator,
120
+ metrics: Dict,
121
+ train_time: float,
122
+ step: int,
123
+ epoch: int,
124
+ learning_rate: float = None,
125
+ prefix: str = "train",
126
+ ):
127
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
128
+ log_metrics = {}
129
+ for k, v in metrics.items():
130
+ log_metrics[f"{prefix}/{k}"] = v
131
+ log_metrics[f"{prefix}/time"] = train_time
132
+ log_metrics[f"{prefix}/epoch"] = epoch
133
+ if learning_rate is not None:
134
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
135
+ accelerator.log(log_metrics, step=step)
136
+
137
+
138
+ def log_pred(
139
+ accelerator,
140
+ pred_descriptions: List[str],
141
+ pred_prompts: List[str],
142
+ transcriptions: List[str],
143
+ audios: List[torch.Tensor],
144
+ sampling_rate: int,
145
+ step: int,
146
+ prefix: str = "eval",
147
+ num_lines: int = 200000,
148
+ ):
149
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
150
+ if accelerator.is_main_process:
151
+ wandb_tracker = accelerator.get_tracker("wandb")
152
+ # pretty name for current step: step 50000 -> step 50k
153
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
154
+ prefix_pretty = prefix.replace("/", "-")
155
+
156
+ # convert str data to a wandb compatible format
157
+ str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))]
158
+ # log as a table with the appropriate headers
159
+ wandb_tracker.log_table(
160
+ table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
161
+ columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
162
+ data=str_data[:num_lines],
163
+ step=step,
164
+ commit=False,
165
+ )
166
+
167
+ # wandb can only loads 100 audios per step
168
+ wandb_tracker.log(
169
+ {
170
+ f"Speech samples/{prefix}": [
171
+ Audio(
172
+ audio,
173
+ caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
174
+ sample_rate=sampling_rate,
175
+ )
176
+ for (i, audio) in enumerate(audios[: min(len(audios), 100)])
177
+ ]
178
+ },
179
+ step=step,
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class ModelArguments:
185
+ """
186
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
187
+ """
188
+
189
+ model_name_or_path: str = field(
190
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
191
+ )
192
+ config_name: Optional[str] = field(
193
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
194
+ )
195
+ feature_extractor_name: Optional[str] = field(
196
+ default=None, metadata={"help": "Pretrained feature extractor name or path if not the same as model_name"}
197
+ )
198
+ description_tokenizer_name: Optional[str] = field(
199
+ default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
200
+ )
201
+ prompt_tokenizer_name: Optional[str] = field(
202
+ default=None,
203
+ metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
204
+ )
205
+ cache_dir: Optional[str] = field(
206
+ default=None,
207
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
208
+ )
209
+ use_fast_tokenizer: bool = field(
210
+ default=True,
211
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
212
+ )
213
+ model_revision: str = field(
214
+ default="main",
215
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
216
+ )
217
+ pad_token_id: int = field(
218
+ default=None,
219
+ metadata={"help": "If specified, change the model pad token id."},
220
+ )
221
+ decoder_start_token_id: int = field(
222
+ default=None,
223
+ metadata={"help": "If specified, change the model decoder start token id."},
224
+ )
225
+ freeze_text_encoder: bool = field(
226
+ default=False,
227
+ metadata={"help": "Whether to freeze the text encoder."},
228
+ )
229
+ do_sample: bool = field(
230
+ default=True,
231
+ metadata={"help": "Whether to do sampling or greedy decoding."},
232
+ )
233
+ temperature: float = field(
234
+ default=1.0,
235
+ metadata={"help": "Temperature if sampling."},
236
+ )
237
+ max_length: int = field(
238
+ default=2580,
239
+ metadata={"help": "Generation max length."},
240
+ )
241
+ bandwidth: float = field(
242
+ default=6,
243
+ metadata={"help": "Audio encoder bandwidth."},
244
+ )
245
+ asr_model_name_or_path: str = field(
246
+ default="distil-whisper/distil-large-v2",
247
+ metadata={
248
+ "help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
249
+ },
250
+ )
251
+ clap_model_name_or_path: str = field(
252
+ default="laion/larger_clap_music_and_speech",
253
+ metadata={
254
+ "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
255
+ },
256
+ )
257
+
258
+
259
+ @dataclass
260
+ class DataTrainingArguments:
261
+ """
262
+ Arguments pertaining to what data we are going to input our model for training and eval.
263
+
264
+ Using `HfArgumentParser` we can turn this class
265
+ into argparse arguments to be able to specify them on
266
+ the command line.
267
+ """
268
+
269
+ train_dataset_name: str = field(
270
+ default=None,
271
+ metadata={
272
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
273
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
274
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
275
+ },
276
+ )
277
+ train_dataset_config_name: Optional[str] = field(
278
+ default=None,
279
+ metadata={
280
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
281
+ "multiple datasets by separating dataset configs by a '+' symbol."
282
+ },
283
+ )
284
+ train_split_name: str = field(
285
+ default="train",
286
+ metadata={
287
+ "help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
288
+ },
289
+ )
290
+ train_dataset_samples: str = field(
291
+ default=None,
292
+ metadata={
293
+ "help": "Number of samples in the training data. Load and combine "
294
+ "multiple datasets by separating dataset samples by a '+' symbol."
295
+ },
296
+ )
297
+ train_metadata_dataset_name: str = field(
298
+ default=None,
299
+ metadata={
300
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
301
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
302
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
303
+ },
304
+ )
305
+ eval_dataset_name: str = field(
306
+ default=None,
307
+ metadata={
308
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
309
+ },
310
+ )
311
+ eval_dataset_config_name: Optional[str] = field(
312
+ default=None,
313
+ metadata={
314
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
315
+ },
316
+ )
317
+ eval_split_name: str = field(
318
+ default="test",
319
+ metadata={
320
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
321
+ },
322
+ )
323
+ eval_metadata_dataset_name: str = field(
324
+ default=None,
325
+ metadata={
326
+ "help": "The name of the metadata training dataset to use (via the datasets library). Load and combine "
327
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
328
+ " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
329
+ },
330
+ )
331
+ target_audio_column_name: str = field(
332
+ default="audio",
333
+ metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
334
+ )
335
+ description_column_name: str = field(
336
+ default=None,
337
+ metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
338
+ )
339
+ prompt_column_name: str = field(
340
+ default=None,
341
+ metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
342
+ )
343
+ overwrite_cache: bool = field(
344
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
345
+ )
346
+ preprocessing_num_workers: Optional[int] = field(
347
+ default=None,
348
+ metadata={"help": "The number of processes to use for the preprocessing."},
349
+ )
350
+ max_train_samples: Optional[int] = field(
351
+ default=None,
352
+ metadata={
353
+ "help": (
354
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
355
+ "value if set."
356
+ )
357
+ },
358
+ )
359
+ max_eval_samples: Optional[int] = field(
360
+ default=None,
361
+ metadata={
362
+ "help": (
363
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
364
+ "value if set."
365
+ )
366
+ },
367
+ )
368
+ max_duration_in_seconds: float = field(
369
+ default=35.0,
370
+ metadata={
371
+ "help": (
372
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
373
+ "Also, used to set maximum audio length if `pad_to_max_length=True`."
374
+ )
375
+ },
376
+ )
377
+ min_duration_in_seconds: float = field(
378
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
379
+ )
380
+ max_text_length: int = field(
381
+ default=500, metadata={"help": "If set, max description lengths in number of characters."}
382
+ )
383
+ max_prompt_token_length: int = field(
384
+ default=None,
385
+ metadata={
386
+ "help": (
387
+ "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
388
+ "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
389
+ )
390
+ },
391
+ )
392
+ max_description_token_length: int = field(
393
+ default=None,
394
+ metadata={
395
+ "help": (
396
+ "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
397
+ "Also, used to set maximum desription token length if `pad_to_max_length=True`."
398
+ )
399
+ },
400
+ )
401
+ pad_to_max_length: bool = field(
402
+ default=False,
403
+ metadata={
404
+ "help": (
405
+ "If `True`, pad audio, prompt and description to a maximum length set with respectively "
406
+ "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
407
+ )
408
+ },
409
+ )
410
+ preprocessing_only: bool = field(
411
+ default=False,
412
+ metadata={
413
+ "help": (
414
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
415
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
416
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
417
+ " can consequently be loaded in distributed training."
418
+ " In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
419
+ )
420
+ },
421
+ )
422
+ token: str = field(
423
+ default=None,
424
+ metadata={
425
+ "help": (
426
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
427
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
428
+ )
429
+ },
430
+ )
431
+ use_auth_token: bool = field(
432
+ default=None,
433
+ metadata={
434
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
435
+ },
436
+ )
437
+ trust_remote_code: bool = field(
438
+ default=False,
439
+ metadata={
440
+ "help": (
441
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
442
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
443
+ "execute code present on the Hub on your local machine."
444
+ )
445
+ },
446
+ )
447
+ add_audio_samples_to_wandb: bool = field(
448
+ default=False,
449
+ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
450
+ )
451
+ id_column_name: str = field(default=None, metadata={"help": "id column name."})
452
+ wandb_project: str = field(
453
+ default="parler-speech",
454
+ metadata={"help": "The name of the wandb project."},
455
+ )
456
+ save_to_disk: str = field(
457
+ default=None,
458
+ metadata={
459
+ "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
460
+ },
461
+ )
462
+ temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
463
+ pad_to_multiple_of: Optional[int] = field(
464
+ default=2,
465
+ metadata={"help": ("Pad to multiple of for tokenizers.")},
466
+ )
467
+
468
+
469
+ @dataclass
470
+ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
471
+ dtype: Optional[str] = field(
472
+ default="float32",
473
+ metadata={
474
+ "help": (
475
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
476
+ "`float16` or `bfloat16` (both half-precision)."
477
+ )
478
+ },
479
+ )
480
+ audio_encoder_per_device_batch_size: int = field(
481
+ default=8,
482
+ metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
483
+ )
484
+
485
+
486
+ @dataclass
487
+ class DataCollatorEncodecWithPadding:
488
+ """
489
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
490
+ to `max_length` if `max_length` is set and `padding=max_length`.
491
+ """
492
+
493
+ feature_extractor: AutoFeatureExtractor
494
+ audio_column_name: str
495
+ feature_extractor_input_name: Optional[str] = "input_values"
496
+ max_length: Optional[int] = None
497
+ padding: Optional[str] = "longest"
498
+
499
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
500
+ # split inputs and labels since they have to be of different lengths and need
501
+ # different padding methods
502
+ audios = [feature[self.audio_column_name]["array"] for feature in features]
503
+ len_audio = [len(audio) for audio in audios]
504
+
505
+ batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length, sampling_rate=self.feature_extractor.sampling_rate)
506
+ batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
507
+ return batch
508
+
509
+
510
+ @dataclass
511
+ class DataCollatorParlerTTSWithPadding:
512
+ """
513
+ Data collator that will dynamically pad the inputs received.
514
+ Args:
515
+ prompt_tokenizer (:class:`~transformers.AutoTokenizer`)
516
+ The prompt_tokenizer used for proccessing the data.
517
+ description_tokenizer (:class:`~transformers.AutoTokenizer`)
518
+ The description_tokenizer used for proccessing the data.
519
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
520
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
521
+ among:
522
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
523
+ sequence if provided).
524
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
525
+ maximum acceptable input length for the model if that argument is not provided.
526
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
527
+ different lengths).
528
+ pad_to_multiple_of (:obj:`int`, `optional`):
529
+ If set will pad the sequence to a multiple of the provided value.
530
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
531
+ 7.5 (Volta).
532
+ """
533
+
534
+ prompt_tokenizer: AutoTokenizer
535
+ description_tokenizer: AutoTokenizer
536
+ padding: Union[bool, str] = "longest"
537
+ pad_to_multiple_of: Optional[int] = None
538
+ prompt_max_length: Optional[int] = None
539
+ description_max_length: Optional[int] = None
540
+ audio_max_length: Optional[int] = None
541
+
542
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
543
+ # split inputs and labels since they have to be of different lengths and need
544
+ # different padding methods
545
+
546
+ labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
547
+ # (bsz, seq_len, num_codebooks)
548
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
549
+ if self.audio_max_length is not None and self.padding == "max_length":
550
+ labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
551
+
552
+ input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
553
+
554
+ input_ids = self.description_tokenizer.pad(
555
+ input_ids,
556
+ return_tensors="pt",
557
+ padding=self.padding,
558
+ pad_to_multiple_of=self.pad_to_multiple_of,
559
+ max_length=self.description_max_length,
560
+ )
561
+
562
+ batch = {"labels": labels, **input_ids}
563
+
564
+ if self.audio_max_length is not None and self.padding == "max_length":
565
+ # if we do torch.compile, we need to also specify the attention_mask
566
+ decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
567
+ batch["decoder_attention_mask"] = decoder_attention_mask
568
+
569
+ prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
570
+ prompt_input_ids = self.prompt_tokenizer.pad(
571
+ prompt_input_ids,
572
+ return_tensors="pt",
573
+ padding=self.padding,
574
+ pad_to_multiple_of=self.pad_to_multiple_of,
575
+ max_length=self.prompt_max_length,
576
+ )
577
+
578
+ batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
579
+ if "attention_mask" in prompt_input_ids:
580
+ batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
581
+
582
+ return batch
583
+
584
+
585
+ def convert_dataset_str_to_list(
586
+ dataset_names,
587
+ dataset_config_names,
588
+ metadata_dataset_names=None,
589
+ splits=None,
590
+ dataset_samples=None,
591
+ default_split="train",
592
+ ):
593
+ if isinstance(dataset_names, str):
594
+ dataset_names = dataset_names.split("+")
595
+ dataset_config_names = dataset_config_names.split("+")
596
+ splits = splits.split("+") if splits is not None else None
597
+ dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
598
+ metadata_dataset_names = metadata_dataset_names.split("+") if metadata_dataset_names is not None else None
599
+
600
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
601
+ if len(dataset_names) != len(dataset_config_names):
602
+ raise ValueError(
603
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
604
+ f" {len(dataset_config_names)} configs."
605
+ )
606
+
607
+ if splits is not None and len(splits) != len(dataset_names):
608
+ raise ValueError(
609
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
610
+ )
611
+
612
+ if metadata_dataset_names is not None and len(metadata_dataset_names) != len(dataset_names):
613
+ raise ValueError(
614
+ f"Ensure one metadata dataset is passed for each dataset, got {len(dataset_names)} datasets and {len(metadata_dataset_names)} metadata datasets."
615
+ )
616
+
617
+ if dataset_samples is not None:
618
+ if len(dataset_samples) != len(dataset_names):
619
+ raise ValueError(
620
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
621
+ f"{len(dataset_samples)} samples."
622
+ )
623
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
624
+ else:
625
+ dataset_samples = [None] * len(dataset_names)
626
+
627
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
628
+
629
+ dataset_names_dict = []
630
+ for i, ds_name in enumerate(dataset_names):
631
+ dataset_names_dict.append(
632
+ {
633
+ "name": ds_name,
634
+ "config": dataset_config_names[i],
635
+ "split": splits[i],
636
+ "metadata_dataset_name": metadata_dataset_names[i],
637
+ "samples": dataset_samples[i],
638
+ }
639
+ )
640
+ return dataset_names_dict
641
+
642
+
643
+ def load_multiple_datasets(
644
+ accelerator: Accelerator,
645
+ dataset_names: Union[List, str],
646
+ dataset_config_names: Union[List, str],
647
+ metadata_dataset_names: Optional[str] = None,
648
+ splits: Optional[Union[List, str]] = None,
649
+ label_column_names: Optional[List] = None,
650
+ stopping_strategy: Optional[str] = "first_exhausted",
651
+ dataset_samples: Optional[Union[List, np.array]] = None,
652
+ streaming: Optional[bool] = False,
653
+ seed: Optional[int] = None,
654
+ id_column_name: Optional[str] = None,
655
+ columns_to_keep: Optional[Set[str]] = None,
656
+ prompt_column_name: Optional[str] = None,
657
+ sampling_rate: Optional[int] = None,
658
+ audio_column_name: Optional[str] = None,
659
+ **kwargs,
660
+ ) -> Union[Dataset, IterableDataset]:
661
+ dataset_names_dict = convert_dataset_str_to_list(
662
+ dataset_names, dataset_config_names, metadata_dataset_names, splits, label_column_names, dataset_samples
663
+ )
664
+
665
+ if dataset_samples is not None:
666
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
667
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
668
+ else:
669
+ probabilities = None
670
+
671
+ all_datasets = []
672
+ # iterate over the datasets we want to interleave
673
+ for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
674
+ with accelerator.main_process_first():
675
+ dataset = load_dataset(
676
+ dataset_dict["name"],
677
+ dataset_dict["config"],
678
+ split=dataset_dict["split"],
679
+ streaming=streaming,
680
+ **kwargs,
681
+ )
682
+ dataset_features = dataset.features.keys()
683
+
684
+ if sampling_rate is not None and audio_column_name is not None:
685
+ # resample target audio
686
+ dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
687
+
688
+ metadata_dataset_name = dataset_dict["metadata_dataset_name"]
689
+ if metadata_dataset_name is not None:
690
+ logger.info(
691
+ f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}'
692
+ )
693
+ metadata_dataset = load_dataset(
694
+ metadata_dataset_name,
695
+ dataset_dict["config"],
696
+ split=dataset_dict["split"],
697
+ streaming=streaming,
698
+ **kwargs,
699
+ )
700
+
701
+ # TODO(YL): I forgot to create unique ids for MLS english.
702
+ # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
703
+ # if dataset_dict["name"] == "parler-tts/mls_eng_10k":
704
+ # def concat_ids(book_id, speaker_id, begin_time):
705
+ # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
706
+ # dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
707
+ # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
708
+ # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
709
+
710
+ if dataset_dict["name"] != "parler-tts/mls_eng_10k":
711
+ if id_column_name is not None and id_column_name not in dataset.column_names:
712
+ raise ValueError(
713
+ f"id_column_name={id_column_name} but has not been found in the dataset columns"
714
+ f"- one of {', '.join(list(dataset.column_names))}."
715
+ )
716
+ if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
717
+ raise ValueError(
718
+ f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
719
+ f"- one of {', '.join(list(metadata_dataset.column_names))}."
720
+ )
721
+ elif id_column_name is not None:
722
+ metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
723
+
724
+ metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
725
+
726
+ if prompt_column_name is not None:
727
+ # We might have applied some transformations to the prompts (e.g punctuation restoration)
728
+ # so we make sure to remove it from the original dataset
729
+ if prompt_column_name in dataset.column_names:
730
+ logger.info(
731
+ f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']"
732
+ )
733
+ dataset.remove_columns(prompt_column_name)
734
+
735
+ metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
736
+ metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
737
+
738
+ dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
739
+
740
+ if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
741
+ if (
742
+ len(
743
+ dataset.filter(
744
+ lambda id1, id2: id1 != id2,
745
+ input_columns=[id_column_name, f"metadata_{id_column_name}"],
746
+ )
747
+ )
748
+ != 0
749
+ ):
750
+ raise ValueError(
751
+ f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}"
752
+ )
753
+
754
+ dataset_features = dataset.features.keys()
755
+
756
+ if columns_to_keep is not None:
757
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
758
+ all_datasets.append(dataset)
759
+
760
+ if len(all_datasets) == 1:
761
+ # we have a single dataset so just return it as is
762
+ return all_datasets[0]
763
+
764
+ if streaming:
765
+ interleaved_dataset = interleave_datasets(
766
+ all_datasets,
767
+ stopping_strategy=stopping_strategy,
768
+ probabilities=probabilities,
769
+ seed=seed,
770
+ )
771
+ else:
772
+ with accelerator.main_process_first():
773
+ interleaved_dataset = concatenate_datasets(all_datasets)
774
+
775
+ return interleaved_dataset
776
+
777
+
778
+ def main():
779
+ # See all possible arguments in src/transformers/training_args.py
780
+ # or by passing the --help flag to this script.
781
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
782
+
783
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
784
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
785
+ # If we pass only one argument to the script and it's the path to a json file,
786
+ # let's parse it to get our arguments.
787
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
788
+ else:
789
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
790
+
791
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
792
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
793
+ send_example_telemetry("run_parler_tts", model_args, data_args)
794
+
795
+ if training_args.dtype == "float16":
796
+ mixed_precision = "fp16"
797
+ elif training_args.dtype == "bfloat16":
798
+ mixed_precision = "bf16"
799
+ else:
800
+ mixed_precision = "no"
801
+
802
+ if data_args.pad_to_max_length and (
803
+ data_args.max_duration_in_seconds is None
804
+ or data_args.max_prompt_token_length is None
805
+ or data_args.max_description_token_length is None
806
+ ):
807
+ raise ValueError(
808
+ "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
809
+ )
810
+
811
+ padding = "max_length" if data_args.pad_to_max_length else "longest"
812
+
813
+ ####### A. Preparation
814
+ kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
815
+ if training_args.torch_compile:
816
+ # TODO(YL): add more compile modes?
817
+ kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
818
+
819
+ accelerator = Accelerator(
820
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
821
+ mixed_precision=mixed_precision,
822
+ log_with=training_args.report_to,
823
+ project_dir=training_args.output_dir,
824
+ kwargs_handlers=kwargs_handlers,
825
+ )
826
+
827
+ accelerator.init_trackers(
828
+ project_name=data_args.wandb_project,
829
+ config={
830
+ "learning_rate": training_args.learning_rate,
831
+ "model_name_or_path": model_args.model_name_or_path,
832
+ "num_train_epochs": training_args.num_train_epochs,
833
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
834
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
835
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
836
+ "mixed_precision": mixed_precision,
837
+ "lr_scheduler_type": training_args.lr_scheduler_type,
838
+ "warmup_steps": training_args.warmup_steps,
839
+ "freeze_text_encoder": model_args.freeze_text_encoder,
840
+ "max_duration_in_seconds": data_args.max_duration_in_seconds,
841
+ "weight_decay": training_args.weight_decay,
842
+ "adam_beta1": training_args.adam_beta1,
843
+ "adam_beta2": training_args.adam_beta2,
844
+ "temperature": model_args.temperature,
845
+ },
846
+ )
847
+
848
+ # Detecting last checkpoint and eventually continue from last checkpoint
849
+ last_checkpoint = None
850
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
851
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
852
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
853
+ raise ValueError(
854
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
855
+ "Use --overwrite_output_dir to overcome."
856
+ )
857
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
858
+ logger.info(
859
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
860
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
861
+ )
862
+
863
+ # Setup logging
864
+ logging.basicConfig(
865
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
866
+ datefmt="%m/%d/%Y %H:%M:%S",
867
+ handlers=[logging.StreamHandler(sys.stdout)],
868
+ )
869
+ logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
870
+
871
+ # Log a small summary on each proces
872
+ logger.warning(
873
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
874
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
875
+ )
876
+
877
+ # Set the verbosity to info of the Transformers logger (on main process only)
878
+ if accelerator.is_local_main_process:
879
+ datasets.utils.logging.set_verbosity_warning()
880
+ transformers.utils.logging.set_verbosity_info()
881
+ else:
882
+ datasets.utils.logging.set_verbosity_error()
883
+ transformers.utils.logging.set_verbosity_error()
884
+
885
+ logger.info("Training/evaluation parameters %s", training_args)
886
+
887
+ # Set seed before initializing model.
888
+ set_seed(training_args.seed)
889
+ num_workers = data_args.preprocessing_num_workers
890
+
891
+ # 1. First, lett's instantiate the feature extractor, tokenizers and model
892
+ # Note for distributed training, the .from_pretrained methods guarantee that only
893
+ # one local process can concurrently download model & vocab.
894
+
895
+ # load feature extractor
896
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
897
+ model_args.feature_extractor_name or model_args.model_name_or_path,
898
+ cache_dir=model_args.cache_dir,
899
+ token=data_args.token,
900
+ trust_remote_code=data_args.trust_remote_code,
901
+ )
902
+ sampling_rate = feature_extractor.sampling_rate
903
+
904
+ # load prompt tokenizer
905
+ prompt_tokenizer = AutoTokenizer.from_pretrained(
906
+ model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
907
+ cache_dir=model_args.cache_dir,
908
+ token=data_args.token,
909
+ trust_remote_code=data_args.trust_remote_code,
910
+ use_fast=model_args.use_fast_tokenizer,
911
+ padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
912
+ )
913
+
914
+ # load description tokenizer
915
+ description_tokenizer = AutoTokenizer.from_pretrained(
916
+ model_args.description_tokenizer_name or model_args.model_name_or_path,
917
+ cache_dir=model_args.cache_dir,
918
+ token=data_args.token,
919
+ trust_remote_code=data_args.trust_remote_code,
920
+ use_fast=model_args.use_fast_tokenizer,
921
+ )
922
+
923
+ if model_args.use_fast_tokenizer:
924
+ logger.warning(
925
+ "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
926
+ )
927
+ prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
928
+ description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
929
+
930
+ # 2. Now, let's load the dataset
931
+
932
+ if data_args.save_to_disk is not None:
933
+ os.makedirs(data_args.save_to_disk, exist_ok=True)
934
+
935
+ # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
936
+ dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
937
+ if dataset_was_precomputed:
938
+ vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
939
+ else:
940
+ raw_datasets = DatasetDict()
941
+
942
+ columns_to_keep = {
943
+ "target_audio_column_name": data_args.target_audio_column_name,
944
+ "prompt_column_name": data_args.prompt_column_name,
945
+ }
946
+ if data_args.description_column_name is not None:
947
+ columns_to_keep["description_column_name"] = data_args.description_column_name
948
+
949
+ if training_args.do_train:
950
+ raw_datasets["train"] = load_multiple_datasets(
951
+ accelerator,
952
+ data_args.train_dataset_name,
953
+ data_args.train_dataset_config_name,
954
+ metadata_dataset_names=data_args.train_metadata_dataset_name,
955
+ splits=data_args.train_split_name,
956
+ dataset_samples=data_args.train_dataset_samples,
957
+ seed=training_args.seed,
958
+ cache_dir=model_args.cache_dir,
959
+ num_proc=data_args.preprocessing_num_workers,
960
+ id_column_name=data_args.id_column_name,
961
+ columns_to_keep=columns_to_keep.values(),
962
+ prompt_column_name=data_args.prompt_column_name,
963
+ audio_column_name=data_args.target_audio_column_name,
964
+ sampling_rate=sampling_rate,
965
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
966
+ )
967
+
968
+ for key in columns_to_keep:
969
+ if columns_to_keep[key] not in raw_datasets["train"].column_names:
970
+ raise ValueError(
971
+ f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
972
+ f" Make sure to set `--{key}` to the correct audio column - one of"
973
+ f" {', '.join(raw_datasets['train'].column_names)}."
974
+ )
975
+
976
+ if data_args.max_train_samples is not None:
977
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
978
+
979
+ if training_args.do_eval:
980
+ raw_datasets["eval"] = load_multiple_datasets(
981
+ accelerator,
982
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
983
+ data_args.eval_dataset_config_name
984
+ if data_args.eval_dataset_config_name
985
+ else data_args.train_dataset_config_name,
986
+ metadata_dataset_names=data_args.eval_metadata_dataset_name,
987
+ splits=data_args.eval_split_name,
988
+ cache_dir=model_args.cache_dir,
989
+ num_proc=data_args.preprocessing_num_workers,
990
+ id_column_name=data_args.id_column_name,
991
+ columns_to_keep=columns_to_keep.values(),
992
+ prompt_column_name=data_args.prompt_column_name,
993
+ audio_column_name=data_args.target_audio_column_name,
994
+ sampling_rate=sampling_rate,
995
+ # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
996
+ )
997
+
998
+ if data_args.max_eval_samples is not None:
999
+ raw_datasets["eval"] = (
1000
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
1001
+ )
1002
+
1003
+ # 3. Next, let's load the config.
1004
+ config = ParlerTTSConfig.from_pretrained(
1005
+ model_args.model_name_or_path,
1006
+ cache_dir=model_args.cache_dir,
1007
+ token=data_args.token,
1008
+ trust_remote_code=data_args.trust_remote_code,
1009
+ )
1010
+
1011
+ # update pad token id and decoder_start_token_id
1012
+ config.update(
1013
+ {
1014
+ "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
1015
+ "decoder_start_token_id": (
1016
+ model_args.decoder_start_token_id
1017
+ if model_args.decoder_start_token_id is not None
1018
+ else config.decoder_start_token_id
1019
+ ),
1020
+ }
1021
+ )
1022
+
1023
+ # create model
1024
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
1025
+ model_args.model_name_or_path,
1026
+ cache_dir=model_args.cache_dir,
1027
+ config=config,
1028
+ token=data_args.token,
1029
+ trust_remote_code=data_args.trust_remote_code,
1030
+ )
1031
+
1032
+ # enable gradient checkpointing if necessary
1033
+ if training_args.gradient_checkpointing:
1034
+ model.gradient_checkpointing_enable()
1035
+
1036
+ # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
1037
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
1038
+ # so that we just need to set the correct target sampling rate and normalize the input
1039
+ # via the `feature_extractor`
1040
+
1041
+ # derive max & min input length for sample rate & max duration
1042
+ sampling_rate = feature_extractor.sampling_rate
1043
+ max_target_length = data_args.max_duration_in_seconds * sampling_rate
1044
+ min_target_length = data_args.min_duration_in_seconds * sampling_rate
1045
+ target_audio_column_name = data_args.target_audio_column_name
1046
+ description_column_name = data_args.description_column_name
1047
+ prompt_column_name = data_args.prompt_column_name
1048
+ feature_extractor_input_name = feature_extractor.model_input_names[0]
1049
+ audio_encoder_pad_token_id = config.decoder.pad_token_id
1050
+ audio_encoder_eos_token_id = config.decoder.eos_token_id
1051
+ audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
1052
+ max_length = model.generation_config.max_length
1053
+ num_codebooks = model.decoder.config.num_codebooks
1054
+ bandwidth = model_args.bandwidth
1055
+
1056
+ # Freeze Encoders
1057
+ model.freeze_encoders(model_args.freeze_text_encoder)
1058
+
1059
+ # Test all gather - used for warmout and avoiding timeout
1060
+ test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
1061
+ gathered_tensor = accelerator.gather(test_tensor)
1062
+ print("gathered_tensor", gathered_tensor)
1063
+ accelerator.wait_for_everyone()
1064
+
1065
+ if not dataset_was_precomputed:
1066
+ # Filter on text length
1067
+ if description_column_name is not None and data_args.max_text_length is not None:
1068
+ with accelerator.main_process_first():
1069
+ # filter description that is shorter than max_text_length
1070
+ raw_datasets = raw_datasets.filter(
1071
+ lambda x: len(x) < data_args.max_text_length,
1072
+ num_proc=num_workers,
1073
+ input_columns=[description_column_name],
1074
+ )
1075
+
1076
+ # Preprocessing the dataset.
1077
+ # We need to tokenize the texts.
1078
+ def pass_through_processors(description, prompt):
1079
+ batch = {}
1080
+
1081
+ batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
1082
+ batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
1083
+
1084
+ return batch
1085
+
1086
+ with accelerator.main_process_first():
1087
+ # this is a trick to avoid to rewrite the entire audio column which takes ages
1088
+ vectorized_datasets = raw_datasets.map(
1089
+ pass_through_processors,
1090
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1091
+ input_columns=[description_column_name, prompt_column_name],
1092
+ num_proc=num_workers,
1093
+ desc="preprocess datasets",
1094
+ )
1095
+
1096
+ # We use Accelerate to perform distributed inference
1097
+ # T5 doesn't support fp16
1098
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
1099
+
1100
+ # Now we encode the audio labels with encodec.
1101
+ ####### B. Encode audio
1102
+
1103
+ logger.info("*** Encode target audio with encodec ***")
1104
+
1105
+ # no need to prepare audio_decoder because used for inference without mixed precision
1106
+ # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
1107
+ if training_args.torch_compile:
1108
+ audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
1109
+ else:
1110
+ audio_decoder = model.audio_encoder
1111
+
1112
+ encoder_data_collator = DataCollatorEncodecWithPadding(
1113
+ feature_extractor,
1114
+ audio_column_name=target_audio_column_name,
1115
+ feature_extractor_input_name=feature_extractor_input_name,
1116
+ max_length=max_target_length,
1117
+ padding=padding,
1118
+ )
1119
+
1120
+ def apply_audio_decoder(batch):
1121
+ len_audio = batch.pop("len_audio")
1122
+ audio_decoder.to(batch["input_values"].device).eval()
1123
+ with torch.no_grad():
1124
+ labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
1125
+ output = {}
1126
+ output["len_audio"] = len_audio
1127
+ # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
1128
+ output["labels"] = labels.squeeze(0).transpose(1, 2)
1129
+ output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
1130
+ return output
1131
+
1132
+ for split in vectorized_datasets:
1133
+ data_loader = DataLoader(
1134
+ raw_datasets[split],
1135
+ batch_size=training_args.audio_encoder_per_device_batch_size,
1136
+ collate_fn=encoder_data_collator,
1137
+ num_workers=training_args.dataloader_num_workers,
1138
+ pin_memory=True,
1139
+ )
1140
+ data_loader = accelerator.prepare(data_loader)
1141
+
1142
+ all_generated_labels = []
1143
+ all_lens = []
1144
+ for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
1145
+ generate_labels = apply_audio_decoder(batch)
1146
+ generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
1147
+ generate_labels = accelerator.gather_for_metrics(generate_labels)
1148
+
1149
+ if accelerator.is_main_process:
1150
+ lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
1151
+ rat = generate_labels["ratio"].cpu().squeeze()
1152
+ lens = generate_labels["len_audio"].cpu().squeeze()
1153
+ lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
1154
+
1155
+ all_generated_labels.extend(lab)
1156
+ all_lens.extend(lens)
1157
+
1158
+ # (1, codebooks, seq_len) where seq_len=1
1159
+ bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
1160
+
1161
+ if accelerator.is_main_process:
1162
+ tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
1163
+ tmp_labels.save_to_disk(
1164
+ os.path.join(data_args.temporary_save_to_disk, split),
1165
+ num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
1166
+ )
1167
+ accelerator.wait_for_everyone()
1168
+ del all_generated_labels
1169
+
1170
+ tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
1171
+ with accelerator.main_process_first():
1172
+ vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
1173
+
1174
+ def postprocess_dataset(labels):
1175
+ # (1, codebooks, seq_len)
1176
+ labels = torch.tensor(labels).unsqueeze(0)
1177
+ # add bos
1178
+ labels = torch.cat([bos_labels, labels], dim=-1)
1179
+
1180
+ labels, delay_pattern_mask = build_delay_pattern_mask(
1181
+ labels,
1182
+ bos_token_id=audio_encoder_bos_token_id,
1183
+ pad_token_id=audio_encoder_eos_token_id,
1184
+ max_length=labels.shape[-1] + num_codebooks,
1185
+ num_codebooks=num_codebooks,
1186
+ )
1187
+
1188
+ # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
1189
+ # to take care of EOS
1190
+ # we want labels to look like this:
1191
+ # - [B, a, b, E, E, E, E]
1192
+ # - [B, B, c, d, E, E, E]
1193
+ # - [B, B, B, e, f, E, E]
1194
+ # - [B, B, B, B, g, h, E]
1195
+ labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
1196
+
1197
+ # the first timestamp is associated to a row full of BOS, let's get rid of it
1198
+ # we also remove the last timestampts (full of PAD)
1199
+ output = {"labels": labels[:, 1:]}
1200
+ return output
1201
+
1202
+ with accelerator.main_process_first():
1203
+ vectorized_datasets[split] = vectorized_datasets[split].map(
1204
+ postprocess_dataset,
1205
+ num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
1206
+ input_columns=["labels"],
1207
+ desc="Postprocessing labeling",
1208
+ )
1209
+
1210
+ accelerator.free_memory()
1211
+ del generate_labels, all_lens
1212
+
1213
+ with accelerator.main_process_first():
1214
+ # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
1215
+ # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
1216
+ # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
1217
+
1218
+ def is_audio_in_length_range(length):
1219
+ return length > min_target_length and length < max_target_length
1220
+
1221
+ # filter data that is shorter than min_target_length
1222
+ vectorized_datasets = vectorized_datasets.filter(
1223
+ is_audio_in_length_range,
1224
+ num_proc=num_workers,
1225
+ input_columns=["target_length"],
1226
+ )
1227
+
1228
+ if description_column_name is not None and data_args.max_description_token_length is not None:
1229
+ with accelerator.main_process_first():
1230
+ # filter description that is shorter than max_text_length
1231
+ vectorized_datasets = vectorized_datasets.filter(
1232
+ lambda x: len(x) < data_args.max_description_token_length,
1233
+ num_proc=num_workers,
1234
+ input_columns=["input_ids"],
1235
+ )
1236
+
1237
+ if data_args.max_prompt_token_length is not None:
1238
+ with accelerator.main_process_first():
1239
+ # filter description that is shorter than max_text_length
1240
+ vectorized_datasets = vectorized_datasets.filter(
1241
+ lambda x: len(x) < data_args.max_prompt_token_length,
1242
+ num_proc=num_workers,
1243
+ input_columns=["prompt_input_ids"],
1244
+ )
1245
+
1246
+ if data_args.save_to_disk is not None and not dataset_was_precomputed:
1247
+ if accelerator.is_main_process:
1248
+ vectorized_datasets.save_to_disk(
1249
+ data_args.save_to_disk,
1250
+ num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
1251
+ )
1252
+ logger.info(f"Dataset saved at {data_args.save_to_disk}")
1253
+
1254
+ audio_max_length = None
1255
+ if training_args.torch_compile:
1256
+ audio_max_length = max(vectorized_datasets["train"]["target_length"])
1257
+ with accelerator.main_process_first():
1258
+ max_sample = vectorized_datasets["train"].filter(
1259
+ lambda x: x == audio_max_length,
1260
+ num_proc=num_workers,
1261
+ input_columns=["target_length"],
1262
+ )
1263
+ audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
1264
+
1265
+ # for large datasets it is advised to run the preprocessing on a
1266
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
1267
+ # be a timeout when running the script in distributed mode.
1268
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
1269
+ # cached dataset
1270
+ if data_args.preprocessing_only and data_args.save_to_disk is None:
1271
+ raise ValueError(
1272
+ "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
1273
+ )
1274
+ elif data_args.preprocessing_only:
1275
+ logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
1276
+ return
1277
+
1278
+ # 6. Next, we can prepare the training.
1279
+
1280
+ # Let's use word CLAP similary and WER metrics as our evaluation metrics,
1281
+
1282
+ # Define evaluation metrics during training, *i.e.* CLAP similarity
1283
+ clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
1284
+ clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
1285
+ metric = evaluate.load("wer")
1286
+
1287
+ def clap_similarity(texts, audios, device):
1288
+ clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
1289
+ clap.to(device)
1290
+ with torch.no_grad():
1291
+ text_features = clap.get_text_features(
1292
+ clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
1293
+ )
1294
+ audio_features = clap.get_audio_features(clap_inputs["input_features"])
1295
+
1296
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
1297
+
1298
+ clap.to("cpu")
1299
+ clap_inputs.to("cpu")
1300
+ return cosine_sim.mean().to("cpu")
1301
+
1302
+ def wer(prompts, audios, device):
1303
+ asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
1304
+ transcriptions = asr_pipeline(
1305
+ [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
1306
+ batch_size=int(training_args.per_device_eval_batch_size),
1307
+ )
1308
+
1309
+ word_error = 100 * metric.compute(
1310
+ predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
1311
+ )
1312
+
1313
+ return word_error, [t["text"] for t in transcriptions]
1314
+
1315
+ eval_methods = {"clap": clap_similarity, "wer": wer}
1316
+
1317
+ def compute_metrics(audios, descriptions, prompts, device="cpu"):
1318
+ input_ids = descriptions
1319
+ texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
1320
+ prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
1321
+ audios = [a.cpu().numpy() for a in audios]
1322
+ results = {"clap": eval_methods["clap"](texts, audios, device)}
1323
+ word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
1324
+ results["wer"] = word_error
1325
+
1326
+ return results, texts, prompts, audios, transcriptions
1327
+
1328
+ # Define Training Schedule
1329
+ # Store some constants
1330
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1331
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1332
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1333
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1334
+
1335
+ if training_args.max_steps < 0:
1336
+ num_epochs = int(training_args.num_train_epochs)
1337
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1338
+ total_train_steps = steps_per_epoch * num_epochs
1339
+ elif training_args.max_steps > 0:
1340
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1341
+ total_train_steps = int(training_args.max_steps)
1342
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1343
+ num_epochs = sys.maxsize
1344
+ steps_per_epoch = total_train_steps
1345
+
1346
+ if training_args.evaluation_strategy == "epoch":
1347
+ eval_steps = steps_per_epoch
1348
+ elif training_args.eval_steps is None:
1349
+ logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
1350
+ eval_steps = steps_per_epoch
1351
+ else:
1352
+ eval_steps = training_args.eval_steps
1353
+
1354
+ if training_args.save_strategy == "epoch":
1355
+ save_steps = steps_per_epoch
1356
+ elif training_args.save_strategy == "steps":
1357
+ save_steps = training_args.save_steps
1358
+ else:
1359
+ save_steps = sys.maxsize
1360
+
1361
+ # T5 doesn't support fp16
1362
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
1363
+
1364
+ # Define optimizer, LR scheduler, collator
1365
+ optimizer = torch.optim.AdamW(
1366
+ params=model.parameters(),
1367
+ lr=training_args.learning_rate,
1368
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1369
+ eps=training_args.adam_epsilon,
1370
+ weight_decay=training_args.weight_decay,
1371
+ )
1372
+
1373
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1374
+ lr_scheduler = get_scheduler(
1375
+ name=training_args.lr_scheduler_type,
1376
+ optimizer=optimizer,
1377
+ num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
1378
+ num_training_steps=total_train_steps * accelerator.num_processes,
1379
+ )
1380
+
1381
+ # Instantiate custom data collator
1382
+ data_collator = DataCollatorParlerTTSWithPadding(
1383
+ prompt_tokenizer=prompt_tokenizer,
1384
+ description_tokenizer=description_tokenizer,
1385
+ pad_to_multiple_of=data_args.pad_to_multiple_of,
1386
+ padding=padding,
1387
+ prompt_max_length=data_args.max_prompt_token_length,
1388
+ description_max_length=data_args.max_description_token_length,
1389
+ audio_max_length=audio_max_length,
1390
+ )
1391
+
1392
+ # Prepare everything with accelerate
1393
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
1394
+
1395
+ logger.info("***** Running training *****")
1396
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1397
+ logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
1398
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1399
+ logger.info(
1400
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1401
+ )
1402
+ logger.info(f" Total optimization steps = {total_train_steps}")
1403
+
1404
+ # ======================== Training ================================
1405
+ train_time = 0
1406
+ train_start = time.time()
1407
+ steps_trained_progress_bar = tqdm(
1408
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1409
+ )
1410
+ continue_training = True
1411
+ epochs_trained = 0
1412
+ cur_step = 0
1413
+
1414
+ checkpoint = None
1415
+ if training_args.resume_from_checkpoint is not None:
1416
+ checkpoint = training_args.resume_from_checkpoint
1417
+ elif last_checkpoint is not None:
1418
+ checkpoint = last_checkpoint
1419
+
1420
+ if accelerator.is_main_process:
1421
+ if training_args.push_to_hub:
1422
+ # Retrieve of infer repo_name
1423
+ repo_name = training_args.hub_model_id
1424
+ if repo_name is None:
1425
+ repo_name = Path(training_args.output_dir).absolute().name
1426
+ # Create repo and retrieve repo_id
1427
+ repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
1428
+ # Clone repo locally
1429
+ repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
1430
+
1431
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
1432
+ if "wandb" not in gitignore:
1433
+ gitignore.write("wandb\n")
1434
+ elif training_args.output_dir is not None:
1435
+ os.makedirs(training_args.output_dir, exist_ok=True)
1436
+ accelerator.wait_for_everyone()
1437
+
1438
+ # Now save everything to be able to create a single processor later
1439
+ # make sure all processes wait until data is saved
1440
+ with accelerator.main_process_first():
1441
+ # only the main process saves them
1442
+ if accelerator.is_main_process:
1443
+ # save feature extractor, tokenizer and config
1444
+ if (
1445
+ model_args.prompt_tokenizer_name is None
1446
+ and model_args.description_tokenizer_name
1447
+ or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
1448
+ ):
1449
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
1450
+ else:
1451
+ logger.warning(
1452
+ "Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
1453
+ )
1454
+ prompt_tokenizer.save_pretrained(training_args.output_dir)
1455
+
1456
+ feature_extractor.save_pretrained(training_args.output_dir)
1457
+ config.save_pretrained(training_args.output_dir)
1458
+
1459
+ if checkpoint is not None:
1460
+ accelerator.load_state(checkpoint)
1461
+ # Find num steps and epoch from saved state string pattern
1462
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1463
+ match = re.search(pattern, checkpoint)
1464
+ cur_step = int(match.group(1))
1465
+ epochs_trained = int(match.group(2))
1466
+
1467
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1468
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1469
+ logger.info(f" Continuing training from global step {cur_step}")
1470
+
1471
+ steps_trained_progress_bar.update(cur_step)
1472
+
1473
+ for epoch in range(0, epochs_trained):
1474
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1475
+
1476
+ if training_args.max_steps < 0:
1477
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1478
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1479
+ else:
1480
+ # Currently we don't know how many steps we've taken in the current epoch
1481
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1482
+ # This is "good enough" for our purposes but not fully correct
1483
+ resume_step = None
1484
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1485
+ else:
1486
+ resume_step = None
1487
+
1488
+ gen_kwargs = {
1489
+ "do_sample": model_args.do_sample,
1490
+ "temperature": model_args.temperature,
1491
+ "max_length": model_args.max_length,
1492
+ }
1493
+
1494
+ # Define gradient update step fn
1495
+ def train_step(
1496
+ batch,
1497
+ accelerator,
1498
+ autocast_kwargs,
1499
+ ):
1500
+ model.train()
1501
+
1502
+ if mixed_precision == "fp16":
1503
+ # fp16 doesn't work with T5-like models
1504
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
1505
+ if training_args.parallel_mode.value != "distributed":
1506
+ encoder_outputs = model.text_encoder(
1507
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1508
+ )
1509
+ else:
1510
+ encoder_outputs = model.module.text_encoder(
1511
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1512
+ )
1513
+ batch["encoder_outputs"] = encoder_outputs
1514
+
1515
+ outputs = model(**batch)
1516
+ # CE (data) loss
1517
+ ce_loss = outputs.loss
1518
+
1519
+ metrics = {"loss": ce_loss}
1520
+ return ce_loss, metrics
1521
+
1522
+ # Define eval fn
1523
+ def eval_step(
1524
+ batch,
1525
+ accelerator,
1526
+ autocast_kwargs,
1527
+ ):
1528
+ eval_model = model if not training_args.torch_compile else model._orig_mod
1529
+ eval_model.eval()
1530
+
1531
+ if mixed_precision == "fp16":
1532
+ # fp16 doesn't work with T5-like models
1533
+ with accelerator.autocast(autocast_handler=autocast_kwargs):
1534
+ with torch.no_grad():
1535
+ if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
1536
+ encoder_outputs = eval_model.text_encoder(
1537
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1538
+ )
1539
+ else:
1540
+ encoder_outputs = eval_model.module.text_encoder(
1541
+ input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
1542
+ )
1543
+ batch["encoder_outputs"] = encoder_outputs
1544
+
1545
+ with torch.no_grad():
1546
+ outputs = eval_model(**batch)
1547
+ # CE (data) loss
1548
+ ce_loss = outputs.loss
1549
+ metrics = {"loss": ce_loss}
1550
+ return metrics
1551
+
1552
+ def generate_step(batch):
1553
+ batch.pop("decoder_attention_mask", None)
1554
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
1555
+ if training_args.torch_compile:
1556
+ eval_model = model._orig_mod
1557
+
1558
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
1559
+ output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
1560
+ return output_audios
1561
+
1562
+ for epoch in range(epochs_trained, num_epochs):
1563
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1564
+ sampler = None
1565
+ if training_args.group_by_length:
1566
+ sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
1567
+ train_dataloader = DataLoader(
1568
+ vectorized_datasets["train"],
1569
+ collate_fn=data_collator,
1570
+ batch_size=per_device_train_batch_size,
1571
+ sampler=sampler,
1572
+ num_workers=training_args.dataloader_num_workers,
1573
+ pin_memory=training_args.dataloader_pin_memory,
1574
+ )
1575
+ train_dataloader = accelerator.prepare(train_dataloader)
1576
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1577
+ train_dataloader.dataset.set_epoch(epoch)
1578
+
1579
+ if resume_step is not None:
1580
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1581
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1582
+ resume_step = None
1583
+
1584
+ for batch in train_dataloader:
1585
+ with accelerator.accumulate(model):
1586
+ loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
1587
+ accelerator.backward(loss)
1588
+ if accelerator.sync_gradients:
1589
+ accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
1590
+ optimizer.step()
1591
+ lr_scheduler.step()
1592
+ optimizer.zero_grad()
1593
+
1594
+ # Check if the accelerator has performed an optimization step behind the scenes
1595
+ if accelerator.sync_gradients:
1596
+ steps_trained_progress_bar.update(1)
1597
+ cur_step += 1
1598
+
1599
+ if cur_step % training_args.logging_steps == 0:
1600
+ steps_trained_progress_bar.write(
1601
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1602
+ f" {train_metric['loss']}, Learning Rate:"
1603
+ f" {lr_scheduler.get_last_lr()[0]})"
1604
+ )
1605
+ log_metric(
1606
+ accelerator,
1607
+ metrics=train_metric,
1608
+ learning_rate=lr_scheduler.get_last_lr()[0],
1609
+ train_time=train_time + time.time() - train_start,
1610
+ step=cur_step,
1611
+ epoch=epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1612
+ prefix="train",
1613
+ )
1614
+
1615
+ # save checkpoint and weights after each save_steps and at the end of training
1616
+ if (cur_step % save_steps == 0) or cur_step == total_train_steps:
1617
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1618
+ # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
1619
+ # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
1620
+ accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
1621
+ accelerator.wait_for_everyone()
1622
+ if accelerator.is_main_process:
1623
+ rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1624
+
1625
+ if cur_step == total_train_steps:
1626
+ # un-wrap student model for save
1627
+ unwrapped_model = accelerator.unwrap_model(model)
1628
+ unwrapped_model.save_pretrained(training_args.output_dir)
1629
+
1630
+ if training_args.push_to_hub:
1631
+ repo.push_to_hub(
1632
+ commit_message=f"Saving train state of step {cur_step}",
1633
+ blocking=False,
1634
+ )
1635
+
1636
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1637
+ train_time += time.time() - train_start
1638
+ # ======================== Evaluating ==============================
1639
+ eval_metrics = []
1640
+ eval_preds = []
1641
+ eval_descriptions = []
1642
+ eval_prompts = []
1643
+ eval_start = time.time()
1644
+
1645
+ # release training input batch
1646
+ batch = release_memory(batch)
1647
+
1648
+ validation_dataloader = DataLoader(
1649
+ vectorized_datasets["eval"],
1650
+ collate_fn=data_collator,
1651
+ batch_size=per_device_eval_batch_size,
1652
+ drop_last=False,
1653
+ num_workers=training_args.dataloader_pin_memory,
1654
+ pin_memory=training_args.dataloader_pin_memory,
1655
+ )
1656
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1657
+
1658
+ for batch in tqdm(
1659
+ validation_dataloader,
1660
+ desc="Evaluating - Inference ...",
1661
+ position=2,
1662
+ disable=not accelerator.is_local_main_process,
1663
+ ):
1664
+ # Model forward
1665
+ eval_metric = eval_step(batch, accelerator, autocast_kwargs)
1666
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1667
+ eval_metrics.append(eval_metric)
1668
+
1669
+ if training_args.predict_with_generate:
1670
+ validation_dataloader = DataLoader(
1671
+ vectorized_datasets["eval"],
1672
+ collate_fn=data_collator,
1673
+ batch_size=per_device_eval_batch_size,
1674
+ drop_last=False,
1675
+ num_workers=training_args.dataloader_pin_memory,
1676
+ pin_memory=training_args.dataloader_pin_memory,
1677
+ )
1678
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1679
+ # generation
1680
+ for batch in tqdm(
1681
+ validation_dataloader,
1682
+ desc="Evaluating - Generation ...",
1683
+ position=2,
1684
+ disable=not accelerator.is_local_main_process,
1685
+ ):
1686
+ generated_audios = generate_step(batch)
1687
+ # Gather all predictions and targets
1688
+ generated_audios, input_ids, prompts = accelerator.pad_across_processes(
1689
+ (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
1690
+ )
1691
+ generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
1692
+ (generated_audios, input_ids, prompts)
1693
+ )
1694
+ eval_preds.extend(generated_audios.to("cpu"))
1695
+ eval_descriptions.extend(input_ids.to("cpu"))
1696
+ eval_prompts.extend(prompts.to("cpu"))
1697
+
1698
+ eval_time = time.time() - eval_start
1699
+ # normalize eval metrics
1700
+ eval_metrics = {
1701
+ key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
1702
+ for key in eval_metrics[0]
1703
+ }
1704
+
1705
+ # compute metrics
1706
+ metrics_desc = ""
1707
+ if training_args.predict_with_generate:
1708
+ metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics(
1709
+ eval_preds, eval_descriptions, eval_prompts, accelerator.device
1710
+ )
1711
+ eval_metrics.update(metric_values)
1712
+ metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
1713
+ if "wandb" in training_args.report_to:
1714
+ log_pred(
1715
+ accelerator,
1716
+ pred_descriptions,
1717
+ pred_prompts,
1718
+ transcriptions,
1719
+ audios,
1720
+ sampling_rate=sampling_rate,
1721
+ step=cur_step,
1722
+ prefix="eval",
1723
+ )
1724
+
1725
+ # Print metrics and update progress bar
1726
+ steps_trained_progress_bar.write(
1727
+ f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1728
+ f" {metrics_desc})"
1729
+ )
1730
+
1731
+ log_metric(
1732
+ accelerator,
1733
+ metrics=eval_metrics,
1734
+ train_time=eval_time,
1735
+ step=cur_step,
1736
+ epoch=epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1737
+ prefix="eval",
1738
+ )
1739
+
1740
+ # release eval batch and relax metrics
1741
+ eval_metrics = []
1742
+ eval_preds = []
1743
+ eval_descriptions = []
1744
+ eval_prompts = []
1745
+ batch = release_memory(batch)
1746
+
1747
+ # flush the train metrics
1748
+ train_start = time.time()
1749
+
1750
+ # break condition
1751
+ if cur_step == total_train_steps:
1752
+ continue_training = False
1753
+ break
1754
+
1755
+ if not continue_training:
1756
+ break
1757
+
1758
+ accelerator.end_training()
1759
+
1760
+
1761
+ if __name__ == "__main__":
1762
+ set_start_method("spawn")
1763
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "padding_side": "left",
938
+ "sp_model_kwargs": {},
939
+ "tokenizer_class": "T5Tokenizer",
940
+ "unk_token": "<unk>"
941
+ }