sanchit-gandhi
commited on
Commit
·
f74be82
1
Parent(s):
140399a
2hx8pk65: saving weights and logs of step 10k
Browse files- .gitattributes +1 -0
- config.json +291 -0
- flax_model.msgpack +3 -0
- merges.txt +0 -0
- models/__init__.py +6 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/configuration_bart.cpython-38.pyc +0 -0
- models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc +0 -0
- models/__pycache__/configuration_wav2vec2.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_bart.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc +0 -0
- models/configuration_bart.py +183 -0
- models/configuration_speech_encoder_decoder.py +121 -0
- models/configuration_wav2vec2.py +344 -0
- models/modeling_flax_bart.py +816 -0
- models/modeling_flax_speech_encoder_decoder.py +1245 -0
- models/modeling_flax_wav2vec2.py +975 -0
- nohup.out +0 -0
- preprocessor_config.json +9 -0
- run_flax_speech_recognition_seq2seq.py +1572 -0
- run_librispeech.sh +39 -0
- special_tokens_map.json +15 -0
- tokenizer.json +0 -0
- tokenizer_config.json +16 -0
- vocab.json +0 -0
- wandb/debug-internal.log +1 -0
- wandb/debug.log +1 -0
- wandb/latest-run +1 -0
- wandb/run-20220828_084407-nbdgecc9/files/config.yaml +36 -0
- wandb/run-20220828_084407-nbdgecc9/files/output.log +110 -0
- wandb/run-20220828_084407-nbdgecc9/files/requirements.txt +167 -0
- wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json +59 -0
- wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json +1 -0
- wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log +144 -0
- wandb/run-20220828_084407-nbdgecc9/logs/debug.log +131 -0
- wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb +3 -0
- wandb/run-20220828_085247-2hx8pk65/files/config.yaml +28 -0
- wandb/run-20220828_085247-2hx8pk65/files/media/table/eval/step_10k_10000_8b44e8a00a036a18ffdf.table.json +1 -0
- wandb/run-20220828_085247-2hx8pk65/files/output.log +0 -0
- wandb/run-20220828_085247-2hx8pk65/files/requirements.txt +167 -0
- wandb/run-20220828_085247-2hx8pk65/files/wandb-metadata.json +59 -0
- wandb/run-20220828_085247-2hx8pk65/files/wandb-summary.json +1 -0
- wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log +0 -0
- wandb/run-20220828_085247-2hx8pk65/logs/debug.log +25 -0
- wandb/run-20220828_085247-2hx8pk65/run-2hx8pk65.wandb +3 -0
.gitattributes
CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.wandb filter=lfs diff=lfs merge=lfs -text
|
config.json
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
|
3 |
+
"architectures": [
|
4 |
+
"SpeechEncoderDecoderModel"
|
5 |
+
],
|
6 |
+
"decoder": {
|
7 |
+
"_name_or_path": "",
|
8 |
+
"activation_dropout": 0.2,
|
9 |
+
"activation_function": "gelu",
|
10 |
+
"add_bias_logits": false,
|
11 |
+
"add_cross_attention": true,
|
12 |
+
"add_final_layer_norm": false,
|
13 |
+
"architectures": [
|
14 |
+
"BartModel"
|
15 |
+
],
|
16 |
+
"attention_dropout": 0.1,
|
17 |
+
"bad_words_ids": null,
|
18 |
+
"bos_token_id": 0,
|
19 |
+
"chunk_size_feed_forward": 0,
|
20 |
+
"classif_dropout": 0.1,
|
21 |
+
"classifier_dropout": 0.0,
|
22 |
+
"cross_attention_hidden_size": null,
|
23 |
+
"d_model": 1024,
|
24 |
+
"decoder_attention_heads": 16,
|
25 |
+
"decoder_ffn_dim": 4096,
|
26 |
+
"decoder_layerdrop": 0.0,
|
27 |
+
"decoder_layers": 12,
|
28 |
+
"decoder_start_token_id": 2,
|
29 |
+
"diversity_penalty": 0.0,
|
30 |
+
"do_sample": false,
|
31 |
+
"dropout": 0.2,
|
32 |
+
"early_stopping": true,
|
33 |
+
"encoder_attention_heads": 16,
|
34 |
+
"encoder_ffn_dim": 4096,
|
35 |
+
"encoder_layerdrop": 0.0,
|
36 |
+
"encoder_layers": 12,
|
37 |
+
"encoder_no_repeat_ngram_size": 0,
|
38 |
+
"eos_token_id": 2,
|
39 |
+
"exponential_decay_length_penalty": null,
|
40 |
+
"finetuning_task": null,
|
41 |
+
"forced_bos_token_id": 0,
|
42 |
+
"forced_eos_token_id": 2,
|
43 |
+
"fuse_matmuls": false,
|
44 |
+
"gradient_checkpointing": true,
|
45 |
+
"id2label": {
|
46 |
+
"0": "LABEL_0",
|
47 |
+
"1": "LABEL_1",
|
48 |
+
"2": "LABEL_2"
|
49 |
+
},
|
50 |
+
"init_std": 0.02,
|
51 |
+
"is_decoder": true,
|
52 |
+
"is_encoder_decoder": false,
|
53 |
+
"label2id": {
|
54 |
+
"LABEL_0": 0,
|
55 |
+
"LABEL_1": 1,
|
56 |
+
"LABEL_2": 2
|
57 |
+
},
|
58 |
+
"length_penalty": 1.0,
|
59 |
+
"max_length": 20,
|
60 |
+
"max_position_embeddings": 1024,
|
61 |
+
"min_length": 0,
|
62 |
+
"model_type": "bart",
|
63 |
+
"no_repeat_ngram_size": 3,
|
64 |
+
"normalize_before": false,
|
65 |
+
"num_beam_groups": 1,
|
66 |
+
"num_beams": 4,
|
67 |
+
"num_hidden_layers": 12,
|
68 |
+
"num_return_sequences": 1,
|
69 |
+
"output_attentions": false,
|
70 |
+
"output_hidden_states": false,
|
71 |
+
"output_scores": false,
|
72 |
+
"pad_token_id": 1,
|
73 |
+
"prefix": null,
|
74 |
+
"problem_type": null,
|
75 |
+
"pruned_heads": {},
|
76 |
+
"remove_invalid_values": false,
|
77 |
+
"repetition_penalty": 1.0,
|
78 |
+
"return_dict": true,
|
79 |
+
"return_dict_in_generate": false,
|
80 |
+
"scale_embedding": false,
|
81 |
+
"sep_token_id": null,
|
82 |
+
"task_specific_params": {
|
83 |
+
"summarization": {
|
84 |
+
"length_penalty": 1.0,
|
85 |
+
"max_length": 128,
|
86 |
+
"min_length": 12,
|
87 |
+
"num_beams": 4
|
88 |
+
},
|
89 |
+
"summarization_cnn": {
|
90 |
+
"length_penalty": 2.0,
|
91 |
+
"max_length": 142,
|
92 |
+
"min_length": 56,
|
93 |
+
"num_beams": 4
|
94 |
+
},
|
95 |
+
"summarization_xsum": {
|
96 |
+
"length_penalty": 1.0,
|
97 |
+
"max_length": 62,
|
98 |
+
"min_length": 11,
|
99 |
+
"num_beams": 6
|
100 |
+
}
|
101 |
+
},
|
102 |
+
"temperature": 1.0,
|
103 |
+
"tf_legacy_loss": false,
|
104 |
+
"tie_encoder_decoder": false,
|
105 |
+
"tie_word_embeddings": true,
|
106 |
+
"tokenizer_class": null,
|
107 |
+
"top_k": 50,
|
108 |
+
"top_p": 1.0,
|
109 |
+
"torch_dtype": "float32",
|
110 |
+
"torchscript": false,
|
111 |
+
"transformers_version": "4.21.0.dev0",
|
112 |
+
"typical_p": 1.0,
|
113 |
+
"use_bfloat16": false,
|
114 |
+
"use_cache": true,
|
115 |
+
"use_scan": true,
|
116 |
+
"vocab_size": 50265
|
117 |
+
},
|
118 |
+
"decoder_start_token_id": 0,
|
119 |
+
"encoder": {
|
120 |
+
"_name_or_path": "",
|
121 |
+
"activation_dropout": 0.2,
|
122 |
+
"adapter_kernel_size": 3,
|
123 |
+
"adapter_stride": 2,
|
124 |
+
"add_adapter": true,
|
125 |
+
"add_cross_attention": false,
|
126 |
+
"apply_spec_augment": true,
|
127 |
+
"architectures": [
|
128 |
+
"Wav2Vec2ForPreTraining"
|
129 |
+
],
|
130 |
+
"attention_dropout": 0.1,
|
131 |
+
"bad_words_ids": null,
|
132 |
+
"bos_token_id": 1,
|
133 |
+
"chunk_size_feed_forward": 0,
|
134 |
+
"classifier_proj_size": 256,
|
135 |
+
"codevector_dim": 768,
|
136 |
+
"contrastive_logits_temperature": 0.1,
|
137 |
+
"conv_bias": true,
|
138 |
+
"conv_dim": [
|
139 |
+
512,
|
140 |
+
512,
|
141 |
+
512,
|
142 |
+
512,
|
143 |
+
512,
|
144 |
+
512,
|
145 |
+
512
|
146 |
+
],
|
147 |
+
"conv_kernel": [
|
148 |
+
10,
|
149 |
+
3,
|
150 |
+
3,
|
151 |
+
3,
|
152 |
+
3,
|
153 |
+
2,
|
154 |
+
2
|
155 |
+
],
|
156 |
+
"conv_stride": [
|
157 |
+
5,
|
158 |
+
2,
|
159 |
+
2,
|
160 |
+
2,
|
161 |
+
2,
|
162 |
+
2,
|
163 |
+
2
|
164 |
+
],
|
165 |
+
"cross_attention_hidden_size": null,
|
166 |
+
"ctc_loss_reduction": "sum",
|
167 |
+
"ctc_zero_infinity": false,
|
168 |
+
"decoder_start_token_id": null,
|
169 |
+
"diversity_loss_weight": 0.1,
|
170 |
+
"diversity_penalty": 0.0,
|
171 |
+
"do_sample": false,
|
172 |
+
"do_stable_layer_norm": true,
|
173 |
+
"early_stopping": false,
|
174 |
+
"encoder_no_repeat_ngram_size": 0,
|
175 |
+
"eos_token_id": 2,
|
176 |
+
"exponential_decay_length_penalty": null,
|
177 |
+
"feat_extract_activation": "gelu",
|
178 |
+
"feat_extract_dropout": 0.0,
|
179 |
+
"feat_extract_norm": "layer",
|
180 |
+
"feat_proj_dropout": 0.2,
|
181 |
+
"feat_quantizer_dropout": 0.0,
|
182 |
+
"final_dropout": 0.0,
|
183 |
+
"finetuning_task": null,
|
184 |
+
"forced_bos_token_id": null,
|
185 |
+
"forced_eos_token_id": null,
|
186 |
+
"fuse_matmuls": false,
|
187 |
+
"gradient_checkpointing": true,
|
188 |
+
"hidden_act": "gelu",
|
189 |
+
"hidden_dropout": 0.2,
|
190 |
+
"hidden_dropout_prob": 0.1,
|
191 |
+
"hidden_size": 1024,
|
192 |
+
"id2label": {
|
193 |
+
"0": "LABEL_0",
|
194 |
+
"1": "LABEL_1"
|
195 |
+
},
|
196 |
+
"initializer_range": 0.02,
|
197 |
+
"intermediate_size": 4096,
|
198 |
+
"is_decoder": false,
|
199 |
+
"is_encoder_decoder": false,
|
200 |
+
"label2id": {
|
201 |
+
"LABEL_0": 0,
|
202 |
+
"LABEL_1": 1
|
203 |
+
},
|
204 |
+
"layer_norm_eps": 1e-05,
|
205 |
+
"layerdrop": 0.0,
|
206 |
+
"length_penalty": 1.0,
|
207 |
+
"mask_feature_length": 10,
|
208 |
+
"mask_feature_min_masks": 0,
|
209 |
+
"mask_feature_prob": 0.0,
|
210 |
+
"mask_time_length": 10,
|
211 |
+
"mask_time_min_masks": 2,
|
212 |
+
"mask_time_prob": 0.1,
|
213 |
+
"max_length": 20,
|
214 |
+
"min_length": 0,
|
215 |
+
"model_type": "wav2vec2",
|
216 |
+
"no_repeat_ngram_size": 0,
|
217 |
+
"num_adapter_layers": 3,
|
218 |
+
"num_attention_heads": 16,
|
219 |
+
"num_beam_groups": 1,
|
220 |
+
"num_beams": 1,
|
221 |
+
"num_codevector_groups": 2,
|
222 |
+
"num_codevectors_per_group": 320,
|
223 |
+
"num_conv_pos_embedding_groups": 16,
|
224 |
+
"num_conv_pos_embeddings": 128,
|
225 |
+
"num_feat_extract_layers": 7,
|
226 |
+
"num_hidden_layers": 24,
|
227 |
+
"num_negatives": 100,
|
228 |
+
"num_return_sequences": 1,
|
229 |
+
"output_attentions": false,
|
230 |
+
"output_hidden_size": 1024,
|
231 |
+
"output_hidden_states": false,
|
232 |
+
"output_scores": false,
|
233 |
+
"pad_token_id": 0,
|
234 |
+
"prefix": null,
|
235 |
+
"problem_type": null,
|
236 |
+
"proj_codevector_dim": 768,
|
237 |
+
"pruned_heads": {},
|
238 |
+
"remove_invalid_values": false,
|
239 |
+
"repetition_penalty": 1.0,
|
240 |
+
"return_dict": true,
|
241 |
+
"return_dict_in_generate": false,
|
242 |
+
"sep_token_id": null,
|
243 |
+
"task_specific_params": null,
|
244 |
+
"tdnn_dilation": [
|
245 |
+
1,
|
246 |
+
2,
|
247 |
+
3,
|
248 |
+
1,
|
249 |
+
1
|
250 |
+
],
|
251 |
+
"tdnn_dim": [
|
252 |
+
512,
|
253 |
+
512,
|
254 |
+
512,
|
255 |
+
512,
|
256 |
+
1500
|
257 |
+
],
|
258 |
+
"tdnn_kernel": [
|
259 |
+
5,
|
260 |
+
3,
|
261 |
+
3,
|
262 |
+
1,
|
263 |
+
1
|
264 |
+
],
|
265 |
+
"temperature": 1.0,
|
266 |
+
"tf_legacy_loss": false,
|
267 |
+
"tie_encoder_decoder": false,
|
268 |
+
"tie_word_embeddings": true,
|
269 |
+
"tokenizer_class": null,
|
270 |
+
"top_k": 50,
|
271 |
+
"top_p": 1.0,
|
272 |
+
"torch_dtype": null,
|
273 |
+
"torchscript": false,
|
274 |
+
"transformers_version": "4.21.0.dev0",
|
275 |
+
"typical_p": 1.0,
|
276 |
+
"use_bfloat16": false,
|
277 |
+
"use_scan": true,
|
278 |
+
"use_weighted_layer_sum": false,
|
279 |
+
"vocab_size": 32,
|
280 |
+
"xvector_output_dim": 512
|
281 |
+
},
|
282 |
+
"eos_token_id": 2,
|
283 |
+
"is_encoder_decoder": true,
|
284 |
+
"max_length": 40,
|
285 |
+
"model_type": "speech-encoder-decoder",
|
286 |
+
"pad_token_id": 1,
|
287 |
+
"processor_class": "Wav2Vec2Processor",
|
288 |
+
"tie_word_embeddings": false,
|
289 |
+
"transformers_version": null,
|
290 |
+
"use_cache": false
|
291 |
+
}
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4bbb8026d3a4c9acb651189cbf65ab582eb2284bbcae68d0c6512395b962329
|
3 |
+
size 2353616717
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.configuration_bart import BartConfig
|
2 |
+
from models.configuration_wav2vec2 import Wav2Vec2Config
|
3 |
+
from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
|
4 |
+
from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
|
5 |
+
from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
|
6 |
+
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (762 Bytes). View file
|
|
models/__pycache__/configuration_bart.cpython-38.pyc
ADDED
Binary file (7.06 kB). View file
|
|
models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc
ADDED
Binary file (4.64 kB). View file
|
|
models/__pycache__/configuration_wav2vec2.cpython-38.pyc
ADDED
Binary file (16.8 kB). View file
|
|
models/__pycache__/modeling_flax_bart.cpython-38.pyc
ADDED
Binary file (21.1 kB). View file
|
|
models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc
ADDED
Binary file (39.4 kB). View file
|
|
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc
ADDED
Binary file (30.7 kB). View file
|
|
models/configuration_bart.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" BART model configuration"""
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
25 |
+
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
|
26 |
+
# See all BART models at https://huggingface.co/models?filter=bart
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class BartConfig(PretrainedConfig):
|
31 |
+
r"""
|
32 |
+
This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
|
33 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
34 |
+
defaults will yield a similar configuration to that of the BART
|
35 |
+
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
43 |
+
Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
|
45 |
+
d_model (`int`, *optional*, defaults to 1024):
|
46 |
+
Dimensionality of the layers and the pooler layer.
|
47 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
48 |
+
Number of encoder layers.
|
49 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
50 |
+
Number of decoder layers.
|
51 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
53 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
54 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
55 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
58 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
59 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
60 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
61 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
62 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
63 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
64 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for the attention probabilities.
|
66 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio for activations inside the fully connected layer.
|
68 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
69 |
+
The dropout ratio for classifier.
|
70 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
71 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
72 |
+
just in case (e.g., 512 or 1024 or 2048).
|
73 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
74 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
75 |
+
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
76 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
77 |
+
for more details.
|
78 |
+
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
79 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
80 |
+
for more details.
|
81 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
82 |
+
Scale embeddings by diving by sqrt(d_model).
|
83 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
84 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
85 |
+
num_labels: (`int`, *optional*, defaults to 3):
|
86 |
+
The number of labels to use in [`BartForSequenceClassification`].
|
87 |
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
88 |
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
89 |
+
`eos_token_id`.
|
90 |
+
use_scan (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether or not to use nn.scan in the Flax Bart attention layers.
|
92 |
+
|
93 |
+
Example:
|
94 |
+
|
95 |
+
```python
|
96 |
+
>>> from transformers import BartModel, BartConfig
|
97 |
+
|
98 |
+
>>> # Initializing a BART facebook/bart-large style configuration
|
99 |
+
>>> configuration = BartConfig()
|
100 |
+
|
101 |
+
>>> # Initializing a model from the facebook/bart-large style configuration
|
102 |
+
>>> model = BartModel(configuration)
|
103 |
+
|
104 |
+
>>> # Accessing the model configuration
|
105 |
+
>>> configuration = model.config
|
106 |
+
```"""
|
107 |
+
model_type = "bart"
|
108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
vocab_size=50265,
|
114 |
+
max_position_embeddings=1024,
|
115 |
+
encoder_layers=12,
|
116 |
+
encoder_ffn_dim=4096,
|
117 |
+
encoder_attention_heads=16,
|
118 |
+
decoder_layers=12,
|
119 |
+
decoder_ffn_dim=4096,
|
120 |
+
decoder_attention_heads=16,
|
121 |
+
encoder_layerdrop=0.0,
|
122 |
+
decoder_layerdrop=0.0,
|
123 |
+
activation_function="gelu",
|
124 |
+
d_model=1024,
|
125 |
+
dropout=0.1,
|
126 |
+
attention_dropout=0.0,
|
127 |
+
activation_dropout=0.0,
|
128 |
+
init_std=0.02,
|
129 |
+
classifier_dropout=0.0,
|
130 |
+
scale_embedding=False,
|
131 |
+
use_cache=True,
|
132 |
+
use_scan=False,
|
133 |
+
fuse_matmuls=False,
|
134 |
+
num_labels=3,
|
135 |
+
pad_token_id=1,
|
136 |
+
bos_token_id=0,
|
137 |
+
eos_token_id=2,
|
138 |
+
is_encoder_decoder=True,
|
139 |
+
decoder_start_token_id=2,
|
140 |
+
forced_eos_token_id=2,
|
141 |
+
**kwargs
|
142 |
+
):
|
143 |
+
self.vocab_size = vocab_size
|
144 |
+
self.max_position_embeddings = max_position_embeddings
|
145 |
+
self.d_model = d_model
|
146 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
147 |
+
self.encoder_layers = encoder_layers
|
148 |
+
self.encoder_attention_heads = encoder_attention_heads
|
149 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
150 |
+
self.decoder_layers = decoder_layers
|
151 |
+
self.decoder_attention_heads = decoder_attention_heads
|
152 |
+
self.dropout = dropout
|
153 |
+
self.attention_dropout = attention_dropout
|
154 |
+
self.activation_dropout = activation_dropout
|
155 |
+
self.activation_function = activation_function
|
156 |
+
self.init_std = init_std
|
157 |
+
self.encoder_layerdrop = encoder_layerdrop
|
158 |
+
self.decoder_layerdrop = decoder_layerdrop
|
159 |
+
self.classifier_dropout = classifier_dropout
|
160 |
+
self.use_cache = use_cache
|
161 |
+
self.use_scan = use_scan
|
162 |
+
self.fuse_matmuls = fuse_matmuls
|
163 |
+
self.num_hidden_layers = encoder_layers
|
164 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
165 |
+
|
166 |
+
super().__init__(
|
167 |
+
num_labels=num_labels,
|
168 |
+
pad_token_id=pad_token_id,
|
169 |
+
bos_token_id=bos_token_id,
|
170 |
+
eos_token_id=eos_token_id,
|
171 |
+
is_encoder_decoder=is_encoder_decoder,
|
172 |
+
decoder_start_token_id=decoder_start_token_id,
|
173 |
+
forced_eos_token_id=forced_eos_token_id,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
# ensure backward compatibility for BART CNN models
|
178 |
+
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
179 |
+
self.forced_bos_token_id = self.bos_token_id
|
180 |
+
warnings.warn(
|
181 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
182 |
+
"The config can simply be saved and uploaded again to be fixed."
|
183 |
+
)
|
models/configuration_speech_encoder_decoder.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import copy
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
from models.configuration_wav2vec2 import Wav2Vec2Config
|
22 |
+
from models.configuration_bart import BartConfig
|
23 |
+
from transformers import AutoConfig
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class SpeechEncoderDecoderConfig(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
[`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
|
32 |
+
[`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
|
33 |
+
arguments, defining the encoder and decoder configs.
|
34 |
+
|
35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
+
documentation from [`PretrainedConfig`] for more information.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
kwargs (*optional*):
|
40 |
+
Dictionary of keyword arguments. Notably:
|
41 |
+
|
42 |
+
- **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
43 |
+
the encoder config.
|
44 |
+
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
45 |
+
the decoder config.
|
46 |
+
|
47 |
+
Examples:
|
48 |
+
|
49 |
+
```python
|
50 |
+
>>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
|
51 |
+
|
52 |
+
>>> # Initializing a Wav2Vec2 & BERT style configuration
|
53 |
+
>>> config_encoder = Wav2Vec2Config()
|
54 |
+
>>> config_decoder = BertConfig()
|
55 |
+
|
56 |
+
>>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
|
57 |
+
|
58 |
+
>>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
|
59 |
+
>>> model = SpeechEncoderDecoderModel(config=config)
|
60 |
+
|
61 |
+
>>> # Accessing the model configuration
|
62 |
+
>>> config_encoder = model.config.encoder
|
63 |
+
>>> config_decoder = model.config.decoder
|
64 |
+
>>> # set decoder config to causal lm
|
65 |
+
>>> config_decoder.is_decoder = True
|
66 |
+
>>> config_decoder.add_cross_attention = True
|
67 |
+
|
68 |
+
>>> # Saving the model, including its configuration
|
69 |
+
>>> model.save_pretrained("my-model")
|
70 |
+
|
71 |
+
>>> # loading model and config from pretrained folder
|
72 |
+
>>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
|
73 |
+
>>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
|
74 |
+
```"""
|
75 |
+
model_type = "speech-encoder-decoder"
|
76 |
+
is_composition = True
|
77 |
+
|
78 |
+
def __init__(self, **kwargs):
|
79 |
+
super().__init__(**kwargs)
|
80 |
+
if "encoder" not in kwargs or "decoder" not in kwargs:
|
81 |
+
raise ValueError(
|
82 |
+
f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
|
83 |
+
)
|
84 |
+
|
85 |
+
encoder_config = kwargs.pop("encoder")
|
86 |
+
decoder_config = kwargs.pop("decoder")
|
87 |
+
|
88 |
+
# TODO: Load configs from AutoConfig (as done in Transformers 🤗)
|
89 |
+
self.encoder = Wav2Vec2Config(**encoder_config)
|
90 |
+
self.decoder = BartConfig(**decoder_config)
|
91 |
+
self.is_encoder_decoder = True
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_encoder_decoder_configs(
|
95 |
+
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
96 |
+
) -> PretrainedConfig:
|
97 |
+
r"""
|
98 |
+
Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
|
99 |
+
configuration and decoder model configuration.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
[`SpeechEncoderDecoderConfig`]: An instance of a configuration object
|
103 |
+
"""
|
104 |
+
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
105 |
+
decoder_config.is_decoder = True
|
106 |
+
decoder_config.add_cross_attention = True
|
107 |
+
|
108 |
+
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
109 |
+
|
110 |
+
def to_dict(self):
|
111 |
+
"""
|
112 |
+
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
116 |
+
"""
|
117 |
+
output = copy.deepcopy(self.__dict__)
|
118 |
+
output["encoder"] = self.encoder.to_dict()
|
119 |
+
output["decoder"] = self.decoder.to_dict()
|
120 |
+
output["model_type"] = self.__class__.model_type
|
121 |
+
return output
|
models/configuration_wav2vec2.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Wav2Vec2 model configuration"""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import operator
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
|
28 |
+
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class Wav2Vec2Config(PretrainedConfig):
|
33 |
+
r"""
|
34 |
+
This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
|
35 |
+
Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
36 |
+
with the defaults will yield a similar configuration to that of the Wav2Vec2
|
37 |
+
[facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
|
38 |
+
|
39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
40 |
+
documentation from [`PretrainedConfig`] for more information.
|
41 |
+
|
42 |
+
|
43 |
+
Args:
|
44 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
45 |
+
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
|
46 |
+
the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
|
47 |
+
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
|
48 |
+
method of [`Wav2Vec2Model`].
|
49 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
50 |
+
Dimensionality of the encoder layers and the pooler layer.
|
51 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
52 |
+
Number of hidden layers in the Transformer encoder.
|
53 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
54 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
55 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
56 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
57 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
60 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
61 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
63 |
+
The dropout ratio for the attention probabilities.
|
64 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
65 |
+
The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
|
66 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
67 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
68 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
69 |
+
The epsilon used by the layer normalization layers.
|
70 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
71 |
+
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
72 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
73 |
+
convolutional layers.
|
74 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
75 |
+
The dropout probability for output of the feature encoder.
|
76 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
77 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
78 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
79 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
80 |
+
The dropout probabilitiy for quantized feature encoder states.
|
81 |
+
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
82 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
83 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
84 |
+
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
85 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
86 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
87 |
+
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
88 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
89 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
90 |
+
*conv_dim*.
|
91 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
92 |
+
Whether the 1D convolutional layers have a bias.
|
93 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
94 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
95 |
+
embeddings layer.
|
96 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
97 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
98 |
+
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
|
99 |
+
Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
|
100 |
+
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
|
101 |
+
False` corresponds to applying layer norm after the attention layer.
|
102 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
103 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
104 |
+
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
105 |
+
Recognition](https://arxiv.org/abs/1904.08779).
|
106 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
107 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
108 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
109 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
110 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
111 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
112 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
113 |
+
Length of vector span along the time axis.
|
114 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
115 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
116 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
117 |
+
mask_time_min_masks''
|
118 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
119 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
120 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
121 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
122 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
123 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
124 |
+
True`.
|
125 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
126 |
+
Length of vector span along the feature axis.
|
127 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
128 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
129 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
130 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
131 |
+
num_codevectors_per_group (`int`, *optional*, defaults to 320):
|
132 |
+
Number of entries in each quantization codebook (group).
|
133 |
+
num_codevector_groups (`int`, *optional*, defaults to 2):
|
134 |
+
Number of codevector groups for product codevector quantization.
|
135 |
+
contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
|
136 |
+
The temperature *kappa* in the contrastive loss.
|
137 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
138 |
+
The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
|
139 |
+
num_negatives (`int`, *optional*, defaults to 100):
|
140 |
+
Number of negative samples for the contrastive loss.
|
141 |
+
codevector_dim (`int`, *optional*, defaults to 256):
|
142 |
+
Dimensionality of the quantized feature vectors.
|
143 |
+
proj_codevector_dim (`int`, *optional*, defaults to 256):
|
144 |
+
Dimensionality of the final projection of both the quantized and the transformer features.
|
145 |
+
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
|
146 |
+
The weight of the codebook diversity loss component.
|
147 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
148 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
149 |
+
instance of [`Wav2Vec2ForCTC`].
|
150 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
151 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
152 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
153 |
+
of [`Wav2Vec2ForCTC`].
|
154 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
155 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
156 |
+
instance of [`Wav2Vec2ForSequenceClassification`].
|
157 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
158 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
159 |
+
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
|
160 |
+
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
|
161 |
+
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
|
162 |
+
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
|
163 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
|
164 |
+
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
|
165 |
+
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
|
166 |
+
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
|
167 |
+
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
|
168 |
+
xvector_output_dim (`int`, *optional*, defaults to 512):
|
169 |
+
Dimensionality of the *XVector* embedding vectors.
|
170 |
+
add_adapter (`bool`, *optional*, defaults to `False`):
|
171 |
+
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
172 |
+
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
173 |
+
adapter_kernel_size (`int`, *optional*, defaults to 3):
|
174 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
175 |
+
adapter_stride (`int`, *optional*, defaults to 2):
|
176 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
177 |
+
num_adapter_layers (`int`, *optional*, defaults to 3):
|
178 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
179 |
+
True`.
|
180 |
+
output_hidden_size (`int`, *optional*):
|
181 |
+
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
|
182 |
+
if `add_adapter is True`.
|
183 |
+
use_scan (`bool`, *optional*, defaults to `False`):
|
184 |
+
Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
|
185 |
+
|
186 |
+
Example:
|
187 |
+
|
188 |
+
```python
|
189 |
+
>>> from transformers import Wav2Vec2Model, Wav2Vec2Config
|
190 |
+
|
191 |
+
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
|
192 |
+
>>> configuration = Wav2Vec2Config()
|
193 |
+
|
194 |
+
>>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
|
195 |
+
>>> model = Wav2Vec2Model(configuration)
|
196 |
+
|
197 |
+
>>> # Accessing the model configuration
|
198 |
+
>>> configuration = model.config
|
199 |
+
```"""
|
200 |
+
model_type = "wav2vec2"
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
vocab_size=32,
|
205 |
+
hidden_size=768,
|
206 |
+
num_hidden_layers=12,
|
207 |
+
num_attention_heads=12,
|
208 |
+
intermediate_size=3072,
|
209 |
+
hidden_act="gelu",
|
210 |
+
hidden_dropout=0.1,
|
211 |
+
activation_dropout=0.1,
|
212 |
+
attention_dropout=0.1,
|
213 |
+
feat_proj_dropout=0.0,
|
214 |
+
feat_quantizer_dropout=0.0,
|
215 |
+
final_dropout=0.1,
|
216 |
+
layerdrop=0.1,
|
217 |
+
initializer_range=0.02,
|
218 |
+
layer_norm_eps=1e-5,
|
219 |
+
feat_extract_norm="group",
|
220 |
+
feat_extract_activation="gelu",
|
221 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
222 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
223 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
224 |
+
conv_bias=False,
|
225 |
+
num_conv_pos_embeddings=128,
|
226 |
+
num_conv_pos_embedding_groups=16,
|
227 |
+
do_stable_layer_norm=False,
|
228 |
+
apply_spec_augment=True,
|
229 |
+
mask_time_prob=0.05,
|
230 |
+
mask_time_length=10,
|
231 |
+
mask_time_min_masks=2,
|
232 |
+
mask_feature_prob=0.0,
|
233 |
+
mask_feature_length=10,
|
234 |
+
mask_feature_min_masks=0,
|
235 |
+
num_codevectors_per_group=320,
|
236 |
+
num_codevector_groups=2,
|
237 |
+
contrastive_logits_temperature=0.1,
|
238 |
+
num_negatives=100,
|
239 |
+
codevector_dim=256,
|
240 |
+
proj_codevector_dim=256,
|
241 |
+
diversity_loss_weight=0.1,
|
242 |
+
ctc_loss_reduction="sum",
|
243 |
+
ctc_zero_infinity=False,
|
244 |
+
use_weighted_layer_sum=False,
|
245 |
+
classifier_proj_size=256,
|
246 |
+
tdnn_dim=(512, 512, 512, 512, 1500),
|
247 |
+
tdnn_kernel=(5, 3, 3, 1, 1),
|
248 |
+
tdnn_dilation=(1, 2, 3, 1, 1),
|
249 |
+
xvector_output_dim=512,
|
250 |
+
pad_token_id=0,
|
251 |
+
bos_token_id=1,
|
252 |
+
eos_token_id=2,
|
253 |
+
add_adapter=False,
|
254 |
+
adapter_kernel_size=3,
|
255 |
+
adapter_stride=2,
|
256 |
+
num_adapter_layers=3,
|
257 |
+
output_hidden_size=None,
|
258 |
+
use_scan=False,
|
259 |
+
fuse_matmuls=False,
|
260 |
+
**kwargs
|
261 |
+
):
|
262 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
263 |
+
self.hidden_size = hidden_size
|
264 |
+
self.feat_extract_norm = feat_extract_norm
|
265 |
+
self.feat_extract_activation = feat_extract_activation
|
266 |
+
self.conv_dim = list(conv_dim)
|
267 |
+
self.conv_stride = list(conv_stride)
|
268 |
+
self.conv_kernel = list(conv_kernel)
|
269 |
+
self.conv_bias = conv_bias
|
270 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
271 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
272 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
273 |
+
self.num_hidden_layers = num_hidden_layers
|
274 |
+
self.intermediate_size = intermediate_size
|
275 |
+
self.hidden_act = hidden_act
|
276 |
+
self.num_attention_heads = num_attention_heads
|
277 |
+
self.hidden_dropout = hidden_dropout
|
278 |
+
self.attention_dropout = attention_dropout
|
279 |
+
self.activation_dropout = activation_dropout
|
280 |
+
self.feat_proj_dropout = feat_proj_dropout
|
281 |
+
self.final_dropout = final_dropout
|
282 |
+
self.layerdrop = layerdrop
|
283 |
+
self.layer_norm_eps = layer_norm_eps
|
284 |
+
self.initializer_range = initializer_range
|
285 |
+
self.vocab_size = vocab_size
|
286 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
287 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
288 |
+
self.use_scan = use_scan
|
289 |
+
self.fuse_matmuls = fuse_matmuls
|
290 |
+
|
291 |
+
if (
|
292 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
293 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
294 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
295 |
+
):
|
296 |
+
raise ValueError(
|
297 |
+
"Configuration for convolutional layers is incorrect. "
|
298 |
+
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
|
299 |
+
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
|
300 |
+
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
301 |
+
)
|
302 |
+
|
303 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
304 |
+
self.apply_spec_augment = apply_spec_augment
|
305 |
+
self.mask_time_prob = mask_time_prob
|
306 |
+
self.mask_time_length = mask_time_length
|
307 |
+
self.mask_time_min_masks = mask_time_min_masks
|
308 |
+
self.mask_feature_prob = mask_feature_prob
|
309 |
+
self.mask_feature_length = mask_feature_length
|
310 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
311 |
+
|
312 |
+
# parameters for pretraining with codevector quantized representations
|
313 |
+
self.num_codevectors_per_group = num_codevectors_per_group
|
314 |
+
self.num_codevector_groups = num_codevector_groups
|
315 |
+
self.contrastive_logits_temperature = contrastive_logits_temperature
|
316 |
+
self.feat_quantizer_dropout = feat_quantizer_dropout
|
317 |
+
self.num_negatives = num_negatives
|
318 |
+
self.codevector_dim = codevector_dim
|
319 |
+
self.proj_codevector_dim = proj_codevector_dim
|
320 |
+
self.diversity_loss_weight = diversity_loss_weight
|
321 |
+
|
322 |
+
# ctc loss
|
323 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
324 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
325 |
+
|
326 |
+
# adapter
|
327 |
+
self.add_adapter = add_adapter
|
328 |
+
self.adapter_kernel_size = adapter_kernel_size
|
329 |
+
self.adapter_stride = adapter_stride
|
330 |
+
self.num_adapter_layers = num_adapter_layers
|
331 |
+
self.output_hidden_size = output_hidden_size or hidden_size
|
332 |
+
|
333 |
+
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
334 |
+
self.classifier_proj_size = classifier_proj_size
|
335 |
+
|
336 |
+
# XVector-specific parameters. Feel free to ignore for other classes.
|
337 |
+
self.tdnn_dim = list(tdnn_dim)
|
338 |
+
self.tdnn_kernel = list(tdnn_kernel)
|
339 |
+
self.tdnn_dilation = list(tdnn_dilation)
|
340 |
+
self.xvector_output_dim = xvector_output_dim
|
341 |
+
|
342 |
+
@property
|
343 |
+
def inputs_to_logits_ratio(self):
|
344 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
models/modeling_flax_bart.py
ADDED
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax Bart model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import random
|
19 |
+
from functools import partial
|
20 |
+
from typing import Optional, Tuple
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import flax.linen as nn
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
28 |
+
from flax.linen import combine_masks, make_causal_mask
|
29 |
+
from flax.linen import partitioning as nn_partitioning
|
30 |
+
from flax.linen.attention import dot_product_attention_weights
|
31 |
+
from jax import lax
|
32 |
+
from jax.random import PRNGKey
|
33 |
+
|
34 |
+
from transformers.modeling_flax_outputs import (
|
35 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
36 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
37 |
+
)
|
38 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
39 |
+
|
40 |
+
from models import BartConfig
|
41 |
+
|
42 |
+
|
43 |
+
scan_with_axes = nn_partitioning.scan_with_axes
|
44 |
+
remat = nn_partitioning.remat
|
45 |
+
|
46 |
+
|
47 |
+
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
48 |
+
"""
|
49 |
+
Shift input ids one token to the right.
|
50 |
+
"""
|
51 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
52 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
53 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
54 |
+
|
55 |
+
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
56 |
+
return shifted_input_ids
|
57 |
+
|
58 |
+
|
59 |
+
class FlaxBartAttention(nn.Module):
|
60 |
+
config: BartConfig
|
61 |
+
embed_dim: int
|
62 |
+
num_heads: int
|
63 |
+
dropout: float = 0.0
|
64 |
+
causal: bool = False
|
65 |
+
bias: bool = True
|
66 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
67 |
+
|
68 |
+
def setup(self) -> None:
|
69 |
+
self.head_dim = self.embed_dim // self.num_heads
|
70 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
71 |
+
raise ValueError(
|
72 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
73 |
+
f" and `num_heads`: {self.num_heads})."
|
74 |
+
)
|
75 |
+
|
76 |
+
dense = partial(
|
77 |
+
nn.Dense,
|
78 |
+
self.embed_dim,
|
79 |
+
use_bias=self.bias,
|
80 |
+
dtype=self.dtype,
|
81 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
82 |
+
)
|
83 |
+
|
84 |
+
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
85 |
+
|
86 |
+
self.fused_proj = nn.Dense(
|
87 |
+
self.embed_dim * 3,
|
88 |
+
use_bias=self.bias,
|
89 |
+
dtype=self.dtype,
|
90 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
91 |
+
)
|
92 |
+
|
93 |
+
self.fused_key_value = nn.Dense(
|
94 |
+
self.embed_dim * 2,
|
95 |
+
use_bias=self.bias,
|
96 |
+
dtype=self.dtype,
|
97 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.out_proj = dense()
|
101 |
+
|
102 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
103 |
+
|
104 |
+
if self.causal:
|
105 |
+
self.causal_mask = make_causal_mask(
|
106 |
+
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
107 |
+
)
|
108 |
+
|
109 |
+
def _split_heads(self, hidden_states):
|
110 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
111 |
+
|
112 |
+
def _merge_heads(self, hidden_states):
|
113 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
114 |
+
|
115 |
+
@nn.compact
|
116 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
117 |
+
"""
|
118 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
119 |
+
states from previous steps. This function is slighly adapted from the official Flax repository:
|
120 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
121 |
+
"""
|
122 |
+
# detect if we're initializing by absence of existing cache data.
|
123 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
124 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
125 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
126 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
127 |
+
|
128 |
+
if is_initialized:
|
129 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
130 |
+
# update key, value caches with our new 1d spatial slices
|
131 |
+
cur_index = cache_index.value
|
132 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
133 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
134 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
135 |
+
cached_key.value = key
|
136 |
+
cached_value.value = value
|
137 |
+
num_updated_cache_vectors = query.shape[1]
|
138 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
139 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
140 |
+
pad_mask = jnp.broadcast_to(
|
141 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
142 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
143 |
+
)
|
144 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
145 |
+
return key, value, attention_mask
|
146 |
+
|
147 |
+
def __call__(
|
148 |
+
self,
|
149 |
+
hidden_states: jnp.ndarray,
|
150 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
151 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
152 |
+
init_cache: bool = False,
|
153 |
+
deterministic: bool = True,
|
154 |
+
) -> Tuple[jnp.ndarray]:
|
155 |
+
"""Input shape: Batch x Time x Channel"""
|
156 |
+
|
157 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
158 |
+
# for the decoder
|
159 |
+
is_cross_attention = key_value_states is not None
|
160 |
+
batch_size = hidden_states.shape[0]
|
161 |
+
|
162 |
+
if self.config.fuse_matmuls:
|
163 |
+
# get key, value proj
|
164 |
+
if is_cross_attention:
|
165 |
+
# get query proj
|
166 |
+
query_states = self.q_proj(hidden_states)
|
167 |
+
# cross_attentions
|
168 |
+
attention_states = self.fused_key_value(key_value_states)
|
169 |
+
key_states, value_states = jnp.split(attention_states, 2, axis=-1)
|
170 |
+
else:
|
171 |
+
attention_states = self.fused_proj(hidden_states)
|
172 |
+
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
173 |
+
|
174 |
+
else:
|
175 |
+
# get query proj
|
176 |
+
query_states = self.q_proj(hidden_states)
|
177 |
+
# get key, value proj
|
178 |
+
if is_cross_attention:
|
179 |
+
# cross_attentions
|
180 |
+
key_states = self.k_proj(key_value_states)
|
181 |
+
value_states = self.v_proj(key_value_states)
|
182 |
+
else:
|
183 |
+
# self_attention
|
184 |
+
key_states = self.k_proj(hidden_states)
|
185 |
+
value_states = self.v_proj(hidden_states)
|
186 |
+
|
187 |
+
query_states = self._split_heads(query_states)
|
188 |
+
key_states = self._split_heads(key_states)
|
189 |
+
value_states = self._split_heads(value_states)
|
190 |
+
|
191 |
+
# handle cache prepare causal attention mask
|
192 |
+
if self.causal:
|
193 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
194 |
+
if self.has_variable("cache", "cached_key"):
|
195 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
196 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
197 |
+
causal_mask = lax.dynamic_slice(
|
198 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
202 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
203 |
+
|
204 |
+
# combine masks if needed
|
205 |
+
if attention_mask is not None and self.causal:
|
206 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
207 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
208 |
+
elif self.causal:
|
209 |
+
attention_mask = causal_mask
|
210 |
+
elif attention_mask is not None:
|
211 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
212 |
+
|
213 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
214 |
+
# and cache the keys and values step by step.
|
215 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
216 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
217 |
+
key_states, value_states, query_states, attention_mask
|
218 |
+
)
|
219 |
+
|
220 |
+
# Convert the boolean attention mask to an attention bias.
|
221 |
+
if attention_mask is not None:
|
222 |
+
# attention mask in the form of attention bias
|
223 |
+
attention_bias = lax.select(
|
224 |
+
attention_mask > 0,
|
225 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
226 |
+
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
attention_bias = None
|
230 |
+
|
231 |
+
dropout_rng = None
|
232 |
+
if not deterministic and self.dropout > 0.0:
|
233 |
+
dropout_rng = self.make_rng("dropout")
|
234 |
+
|
235 |
+
attn_weights = dot_product_attention_weights(
|
236 |
+
query_states,
|
237 |
+
key_states,
|
238 |
+
bias=attention_bias,
|
239 |
+
dropout_rng=dropout_rng,
|
240 |
+
dropout_rate=self.dropout,
|
241 |
+
broadcast_dropout=True,
|
242 |
+
deterministic=deterministic,
|
243 |
+
dtype=self.dtype,
|
244 |
+
precision=None,
|
245 |
+
)
|
246 |
+
|
247 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
248 |
+
attn_output = self._merge_heads(attn_output)
|
249 |
+
attn_output = self.out_proj(attn_output)
|
250 |
+
|
251 |
+
return attn_output, attn_weights
|
252 |
+
|
253 |
+
|
254 |
+
class FlaxBartDecoderLayer(nn.Module):
|
255 |
+
config: BartConfig
|
256 |
+
dtype: jnp.dtype = jnp.float32
|
257 |
+
|
258 |
+
def setup(self) -> None:
|
259 |
+
self.embed_dim = self.config.d_model
|
260 |
+
self.self_attn = FlaxBartAttention(
|
261 |
+
config=self.config,
|
262 |
+
embed_dim=self.embed_dim,
|
263 |
+
num_heads=self.config.decoder_attention_heads,
|
264 |
+
dropout=self.config.attention_dropout,
|
265 |
+
causal=True,
|
266 |
+
dtype=self.dtype,
|
267 |
+
)
|
268 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
269 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
270 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
271 |
+
|
272 |
+
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
273 |
+
self.encoder_attn = FlaxBartAttention(
|
274 |
+
config=self.config,
|
275 |
+
embed_dim=self.embed_dim,
|
276 |
+
num_heads=self.config.decoder_attention_heads,
|
277 |
+
dropout=self.config.attention_dropout,
|
278 |
+
dtype=self.dtype,
|
279 |
+
)
|
280 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
281 |
+
self.fc1 = nn.Dense(
|
282 |
+
self.config.encoder_ffn_dim,
|
283 |
+
dtype=self.dtype,
|
284 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
285 |
+
)
|
286 |
+
self.fc2 = nn.Dense(
|
287 |
+
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
288 |
+
)
|
289 |
+
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
290 |
+
|
291 |
+
def __call__(
|
292 |
+
self,
|
293 |
+
hidden_states: jnp.ndarray,
|
294 |
+
attention_mask: jnp.ndarray,
|
295 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
296 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
297 |
+
init_cache: bool = False,
|
298 |
+
output_attentions: bool = True,
|
299 |
+
deterministic: bool = True,
|
300 |
+
) -> Tuple[jnp.ndarray]:
|
301 |
+
|
302 |
+
if self.config.use_scan:
|
303 |
+
hidden_states = hidden_states[0]
|
304 |
+
|
305 |
+
residual = hidden_states
|
306 |
+
|
307 |
+
# Self Attention
|
308 |
+
hidden_states, self_attn_weights = self.self_attn(
|
309 |
+
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
|
310 |
+
)
|
311 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
312 |
+
hidden_states = residual + hidden_states
|
313 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
314 |
+
|
315 |
+
# Cross-Attention Block
|
316 |
+
cross_attn_weights = None
|
317 |
+
if encoder_hidden_states is not None:
|
318 |
+
residual = hidden_states
|
319 |
+
|
320 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
321 |
+
hidden_states=hidden_states,
|
322 |
+
key_value_states=encoder_hidden_states,
|
323 |
+
attention_mask=encoder_attention_mask,
|
324 |
+
)
|
325 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
326 |
+
hidden_states = residual + hidden_states
|
327 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
328 |
+
|
329 |
+
# Fully Connected
|
330 |
+
residual = hidden_states
|
331 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
332 |
+
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
|
333 |
+
hidden_states = self.fc2(hidden_states)
|
334 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
335 |
+
hidden_states = residual + hidden_states
|
336 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
337 |
+
|
338 |
+
outputs = (hidden_states,)
|
339 |
+
|
340 |
+
if output_attentions:
|
341 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
342 |
+
|
343 |
+
if self.config.use_scan:
|
344 |
+
outputs = (outputs, None)
|
345 |
+
|
346 |
+
return outputs
|
347 |
+
|
348 |
+
|
349 |
+
class FlaxBartDecoderLayerCollection(nn.Module):
|
350 |
+
config: BartConfig
|
351 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
352 |
+
|
353 |
+
@nn.compact
|
354 |
+
def __call__(
|
355 |
+
self,
|
356 |
+
hidden_states,
|
357 |
+
attention_mask,
|
358 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
359 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
360 |
+
deterministic: bool = True,
|
361 |
+
init_cache: bool = False,
|
362 |
+
output_attentions: bool = False,
|
363 |
+
output_hidden_states: bool = False,
|
364 |
+
return_dict: bool = True,
|
365 |
+
):
|
366 |
+
# decoder layers
|
367 |
+
all_hidden_states = () if output_hidden_states else None
|
368 |
+
all_self_attns = () if output_attentions else None
|
369 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
370 |
+
|
371 |
+
num_decoder_layers = self.config.decoder_layers
|
372 |
+
BlockDecoderLayer = (
|
373 |
+
remat(
|
374 |
+
FlaxBartDecoderLayer,
|
375 |
+
static_argnums=(4, 5, 6),
|
376 |
+
prevent_cse=not self.config.use_scan,
|
377 |
+
)
|
378 |
+
if self.config.gradient_checkpointing
|
379 |
+
else FlaxBartDecoderLayer
|
380 |
+
)
|
381 |
+
|
382 |
+
if self.config.use_scan:
|
383 |
+
# since all decoder layers are the same, we use nn.scan directly
|
384 |
+
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
385 |
+
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
386 |
+
hidden_states = (hidden_states,)
|
387 |
+
|
388 |
+
# TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
|
389 |
+
hidden_states, _ = scan_with_axes(
|
390 |
+
BlockDecoderLayer,
|
391 |
+
variable_axes={"params": 0, "cache": 0},
|
392 |
+
split_rngs={"params": True, "dropout": True},
|
393 |
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
|
394 |
+
length=num_decoder_layers,
|
395 |
+
)(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
|
396 |
+
hidden_states,
|
397 |
+
attention_mask,
|
398 |
+
encoder_hidden_states,
|
399 |
+
encoder_attention_mask,
|
400 |
+
init_cache,
|
401 |
+
output_attentions,
|
402 |
+
deterministic,
|
403 |
+
)
|
404 |
+
hidden_states = hidden_states[0]
|
405 |
+
|
406 |
+
else:
|
407 |
+
for layer in range(num_decoder_layers):
|
408 |
+
if output_hidden_states:
|
409 |
+
all_hidden_states += (hidden_states,)
|
410 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
411 |
+
dropout_probability = random.uniform(0, 1)
|
412 |
+
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
|
413 |
+
layer_outputs = (None, None, None)
|
414 |
+
else:
|
415 |
+
layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
|
416 |
+
hidden_states,
|
417 |
+
attention_mask,
|
418 |
+
encoder_hidden_states,
|
419 |
+
encoder_attention_mask,
|
420 |
+
init_cache,
|
421 |
+
output_attentions,
|
422 |
+
deterministic,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = layer_outputs[0]
|
426 |
+
if output_attentions:
|
427 |
+
all_self_attns += (layer_outputs[1],)
|
428 |
+
|
429 |
+
if encoder_hidden_states is not None:
|
430 |
+
all_cross_attentions += (layer_outputs[2],)
|
431 |
+
|
432 |
+
# add hidden states from the last decoder layer
|
433 |
+
if output_hidden_states:
|
434 |
+
all_hidden_states += (hidden_states,)
|
435 |
+
|
436 |
+
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
|
437 |
+
|
438 |
+
if not return_dict:
|
439 |
+
return tuple(v for v in outputs if v is not None)
|
440 |
+
|
441 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
442 |
+
last_hidden_state=hidden_states,
|
443 |
+
hidden_states=all_hidden_states,
|
444 |
+
attentions=all_self_attns,
|
445 |
+
cross_attentions=all_cross_attentions,
|
446 |
+
)
|
447 |
+
|
448 |
+
|
449 |
+
class FlaxBartDecoder(nn.Module):
|
450 |
+
config: BartConfig
|
451 |
+
embed_tokens: nn.Embed
|
452 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
453 |
+
|
454 |
+
def setup(self):
|
455 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
456 |
+
|
457 |
+
embed_dim = self.config.d_model
|
458 |
+
self.padding_idx = self.config.pad_token_id
|
459 |
+
self.max_target_positions = self.config.max_position_embeddings
|
460 |
+
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
461 |
+
|
462 |
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
463 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
464 |
+
self.offset = 2
|
465 |
+
self.embed_positions = nn.Embed(
|
466 |
+
self.config.max_position_embeddings + self.offset,
|
467 |
+
embed_dim,
|
468 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
469 |
+
)
|
470 |
+
|
471 |
+
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
472 |
+
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
473 |
+
|
474 |
+
def __call__(
|
475 |
+
self,
|
476 |
+
input_ids,
|
477 |
+
attention_mask,
|
478 |
+
position_ids,
|
479 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
480 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
481 |
+
init_cache: bool = False,
|
482 |
+
output_attentions: bool = False,
|
483 |
+
output_hidden_states: bool = False,
|
484 |
+
return_dict: bool = True,
|
485 |
+
deterministic: bool = True,
|
486 |
+
):
|
487 |
+
input_shape = input_ids.shape
|
488 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
489 |
+
|
490 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
491 |
+
|
492 |
+
# embed positions
|
493 |
+
positions = self.embed_positions(position_ids + self.offset)
|
494 |
+
|
495 |
+
hidden_states = inputs_embeds + positions
|
496 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
497 |
+
|
498 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
499 |
+
|
500 |
+
outputs = self.layers(
|
501 |
+
hidden_states,
|
502 |
+
attention_mask,
|
503 |
+
encoder_hidden_states,
|
504 |
+
encoder_attention_mask,
|
505 |
+
deterministic=deterministic,
|
506 |
+
init_cache=init_cache,
|
507 |
+
output_attentions=output_attentions,
|
508 |
+
output_hidden_states=output_hidden_states,
|
509 |
+
return_dict=return_dict,
|
510 |
+
)
|
511 |
+
|
512 |
+
if not return_dict:
|
513 |
+
return outputs
|
514 |
+
|
515 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
516 |
+
last_hidden_state=outputs.last_hidden_state,
|
517 |
+
hidden_states=outputs.hidden_states,
|
518 |
+
attentions=outputs.attentions,
|
519 |
+
cross_attentions=outputs.cross_attentions,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
|
524 |
+
config_class = BartConfig
|
525 |
+
base_model_prefix: str = "model"
|
526 |
+
module_class: nn.Module = None
|
527 |
+
|
528 |
+
def __init__(
|
529 |
+
self,
|
530 |
+
config: BartConfig,
|
531 |
+
input_shape: Tuple[int] = (1, 1),
|
532 |
+
seed: int = 0,
|
533 |
+
dtype: jnp.dtype = jnp.float32,
|
534 |
+
_do_init: bool = True,
|
535 |
+
**kwargs
|
536 |
+
):
|
537 |
+
config.is_decoder = True
|
538 |
+
config.is_encoder_decoder = False
|
539 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
540 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
541 |
+
|
542 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
543 |
+
# init input tensors
|
544 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
545 |
+
attention_mask = jnp.ones_like(input_ids)
|
546 |
+
|
547 |
+
batch_size, sequence_length = input_ids.shape
|
548 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
549 |
+
|
550 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
551 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
552 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
|
553 |
+
encoder_attention_mask = attention_mask
|
554 |
+
module_init_outputs = self.module.init(
|
555 |
+
rngs,
|
556 |
+
input_ids,
|
557 |
+
attention_mask,
|
558 |
+
position_ids,
|
559 |
+
encoder_hidden_states,
|
560 |
+
encoder_attention_mask,
|
561 |
+
return_dict=False,
|
562 |
+
)
|
563 |
+
return module_init_outputs["params"]
|
564 |
+
|
565 |
+
def init_cache(self, batch_size, max_length):
|
566 |
+
r"""
|
567 |
+
Args:
|
568 |
+
batch_size (`int`):
|
569 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
570 |
+
max_length (`int`):
|
571 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
572 |
+
cache.
|
573 |
+
"""
|
574 |
+
# init input variables to retrieve cache
|
575 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
576 |
+
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
577 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
578 |
+
|
579 |
+
init_variables = self.module.init(
|
580 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
581 |
+
)
|
582 |
+
return unfreeze(init_variables["cache"])
|
583 |
+
|
584 |
+
def __call__(
|
585 |
+
self,
|
586 |
+
input_ids: jnp.ndarray,
|
587 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
588 |
+
position_ids: Optional[jnp.ndarray] = None,
|
589 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
590 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
591 |
+
output_attentions: Optional[bool] = None,
|
592 |
+
output_hidden_states: Optional[bool] = None,
|
593 |
+
return_dict: Optional[bool] = None,
|
594 |
+
train: bool = False,
|
595 |
+
params: dict = None,
|
596 |
+
past_key_values: dict = None,
|
597 |
+
dropout_rng: PRNGKey = None,
|
598 |
+
):
|
599 |
+
"""
|
600 |
+
Args:
|
601 |
+
input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
|
602 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
603 |
+
|
604 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
605 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
606 |
+
|
607 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
608 |
+
|
609 |
+
For translation and summarization training, `decoder_input_ids` should be provided. If no
|
610 |
+
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
|
611 |
+
for denoising pre-training following the paper.
|
612 |
+
attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
|
613 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
614 |
+
be used by default.
|
615 |
+
|
616 |
+
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
|
617 |
+
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
618 |
+
position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
|
619 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
620 |
+
range `[0, config.max_position_embeddings - 1]`.
|
621 |
+
encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
622 |
+
A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
623 |
+
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
624 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
625 |
+
|
626 |
+
- 1 for tokens that are **not masked**,
|
627 |
+
- 0 for tokens that are **masked**.
|
628 |
+
|
629 |
+
[What are attention masks?](../glossary#attention-mask)
|
630 |
+
output_attentions (`bool`, *optional*):
|
631 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
632 |
+
tensors for more detail.
|
633 |
+
output_hidden_states (`bool`, *optional*):
|
634 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
635 |
+
more detail.
|
636 |
+
return_dict (`bool`, *optional*):
|
637 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
638 |
+
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
639 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
640 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
641 |
+
"""
|
642 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
643 |
+
output_hidden_states = (
|
644 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
645 |
+
)
|
646 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
647 |
+
|
648 |
+
if encoder_hidden_states is not None and encoder_attention_mask is None:
|
649 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
650 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
651 |
+
|
652 |
+
# prepare decoder inputs
|
653 |
+
if attention_mask is None:
|
654 |
+
attention_mask = jnp.ones_like(input_ids)
|
655 |
+
if position_ids is None:
|
656 |
+
batch_size, sequence_length = input_ids.shape
|
657 |
+
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
658 |
+
|
659 |
+
# Handle any PRNG if needed
|
660 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
661 |
+
|
662 |
+
inputs = {"params": params or self.params}
|
663 |
+
|
664 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
665 |
+
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
666 |
+
# changed by FlaxBartAttention module
|
667 |
+
if past_key_values:
|
668 |
+
inputs["cache"] = past_key_values
|
669 |
+
mutable = ["cache"]
|
670 |
+
else:
|
671 |
+
mutable = False
|
672 |
+
|
673 |
+
outputs = self.module.apply(
|
674 |
+
inputs,
|
675 |
+
input_ids=jnp.array(input_ids, dtype="i4"),
|
676 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
677 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
678 |
+
encoder_hidden_states=encoder_hidden_states,
|
679 |
+
encoder_attention_mask=encoder_attention_mask,
|
680 |
+
output_attentions=output_attentions,
|
681 |
+
output_hidden_states=output_hidden_states,
|
682 |
+
return_dict=return_dict,
|
683 |
+
deterministic=not train,
|
684 |
+
rngs=rngs,
|
685 |
+
mutable=mutable,
|
686 |
+
)
|
687 |
+
|
688 |
+
# add updated cache to model output
|
689 |
+
if past_key_values is not None and return_dict:
|
690 |
+
outputs, past_key_values = outputs
|
691 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
692 |
+
return outputs
|
693 |
+
elif past_key_values is not None and not return_dict:
|
694 |
+
outputs, past_key_values = outputs
|
695 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
696 |
+
|
697 |
+
return outputs
|
698 |
+
|
699 |
+
|
700 |
+
class FlaxBartDecoderWrapper(nn.Module):
|
701 |
+
"""
|
702 |
+
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
703 |
+
used in combination with the [`EncoderDecoderModel`] framework.
|
704 |
+
"""
|
705 |
+
|
706 |
+
config: BartConfig
|
707 |
+
dtype: jnp.dtype = jnp.float32
|
708 |
+
|
709 |
+
def setup(self):
|
710 |
+
embed_dim = self.config.d_model
|
711 |
+
embed_tokens = nn.Embed(
|
712 |
+
self.config.vocab_size,
|
713 |
+
embed_dim,
|
714 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
715 |
+
)
|
716 |
+
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
|
717 |
+
|
718 |
+
def __call__(self, *args, **kwargs):
|
719 |
+
return self.decoder(*args, **kwargs)
|
720 |
+
|
721 |
+
|
722 |
+
class FlaxBartForCausalLMModule(nn.Module):
|
723 |
+
"""Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
724 |
+
e.g. for autoregressive tasks.
|
725 |
+
"""
|
726 |
+
|
727 |
+
config: BartConfig
|
728 |
+
dtype: jnp.dtype = jnp.float32
|
729 |
+
|
730 |
+
def setup(self):
|
731 |
+
self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
|
732 |
+
self.lm_head = nn.Dense(
|
733 |
+
self.config.vocab_size,
|
734 |
+
use_bias=False,
|
735 |
+
dtype=self.dtype,
|
736 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
737 |
+
)
|
738 |
+
|
739 |
+
def __call__(
|
740 |
+
self,
|
741 |
+
input_ids,
|
742 |
+
attention_mask,
|
743 |
+
position_ids,
|
744 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
745 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
746 |
+
init_cache: bool = False,
|
747 |
+
output_attentions: bool = False,
|
748 |
+
output_hidden_states: bool = False,
|
749 |
+
return_dict: bool = True,
|
750 |
+
deterministic: bool = True,
|
751 |
+
):
|
752 |
+
|
753 |
+
outputs = self.model(
|
754 |
+
input_ids,
|
755 |
+
attention_mask,
|
756 |
+
position_ids,
|
757 |
+
encoder_hidden_states,
|
758 |
+
encoder_attention_mask,
|
759 |
+
deterministic=deterministic,
|
760 |
+
init_cache=init_cache,
|
761 |
+
output_attentions=output_attentions,
|
762 |
+
output_hidden_states=output_hidden_states,
|
763 |
+
return_dict=return_dict,
|
764 |
+
)
|
765 |
+
|
766 |
+
hidden_states = outputs[0]
|
767 |
+
|
768 |
+
if self.config.tie_word_embeddings:
|
769 |
+
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
|
770 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
771 |
+
else:
|
772 |
+
lm_logits = self.lm_head(hidden_states)
|
773 |
+
|
774 |
+
if not return_dict:
|
775 |
+
return (lm_logits,) + outputs[1:]
|
776 |
+
|
777 |
+
return FlaxCausalLMOutputWithCrossAttentions(
|
778 |
+
logits=lm_logits,
|
779 |
+
hidden_states=outputs.hidden_states,
|
780 |
+
attentions=outputs.attentions,
|
781 |
+
cross_attentions=outputs.cross_attentions,
|
782 |
+
)
|
783 |
+
|
784 |
+
|
785 |
+
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
786 |
+
"""Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
787 |
+
e.g. for autoregressive tasks.
|
788 |
+
"""
|
789 |
+
|
790 |
+
module_class = FlaxBartForCausalLMModule
|
791 |
+
|
792 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
793 |
+
# initializing the cache
|
794 |
+
batch_size, seq_length = input_ids.shape
|
795 |
+
|
796 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
797 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
798 |
+
# But since the decoder uses a causal mask, those positions are masked anyway.
|
799 |
+
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
800 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
801 |
+
if attention_mask is not None:
|
802 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
803 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
804 |
+
else:
|
805 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
806 |
+
|
807 |
+
return {
|
808 |
+
"past_key_values": past_key_values,
|
809 |
+
"attention_mask": extended_attention_mask,
|
810 |
+
"position_ids": position_ids,
|
811 |
+
}
|
812 |
+
|
813 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
814 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
815 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
816 |
+
return model_kwargs
|
models/modeling_flax_speech_encoder_decoder.py
ADDED
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Classes to support Flax Speech-Encoder-Decoder architectures"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
from typing import Optional, Tuple, Union, Dict
|
20 |
+
|
21 |
+
import flax
|
22 |
+
import flax.linen as nn
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
26 |
+
from jax import lax
|
27 |
+
from jax.random import PRNGKey
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
31 |
+
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
32 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
|
33 |
+
from transformers.generation_flax_utils import FlaxLogitsProcessorList
|
34 |
+
from models import (
|
35 |
+
FlaxWav2Vec2Model,
|
36 |
+
FlaxWav2Vec2Module,
|
37 |
+
FlaxBartForCausalLM,
|
38 |
+
FlaxBartForCausalLMModule,
|
39 |
+
BartConfig,
|
40 |
+
Wav2Vec2Config,
|
41 |
+
SpeechEncoderDecoderConfig,
|
42 |
+
)
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
|
47 |
+
|
48 |
+
SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
49 |
+
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
|
50 |
+
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
|
51 |
+
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
|
52 |
+
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
|
53 |
+
and should be fine-tuned on a downstream generative task, like summarization.
|
54 |
+
|
55 |
+
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
56 |
+
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
|
57 |
+
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
58 |
+
Zhou, Wei Li, Peter J. Liu.
|
59 |
+
|
60 |
+
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
|
61 |
+
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
|
62 |
+
translation yields a significant performance improvement.
|
63 |
+
|
64 |
+
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
65 |
+
models (see the examples for more information).
|
66 |
+
|
67 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
68 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
69 |
+
etc.)
|
70 |
+
|
71 |
+
This model is also a Flax Linen
|
72 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
73 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
|
77 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
78 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
79 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
80 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
81 |
+
`jax.numpy.bfloat16` (on TPUs).
|
82 |
+
|
83 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
84 |
+
specified all the computation will be performed with the given `dtype`.
|
85 |
+
|
86 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
87 |
+
parameters.**
|
88 |
+
|
89 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
90 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
91 |
+
"""
|
92 |
+
|
93 |
+
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
94 |
+
Args:
|
95 |
+
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
96 |
+
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
97 |
+
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
98 |
+
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
99 |
+
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
100 |
+
*torch.FloatTensor*.
|
101 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
102 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
103 |
+
|
104 |
+
- 1 for tokens that are **not masked**,
|
105 |
+
- 0 for tokens that are **masked**.
|
106 |
+
|
107 |
+
[What are attention masks?](../glossary#attention-mask)
|
108 |
+
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
109 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
110 |
+
|
111 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
112 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
113 |
+
|
114 |
+
[What are input IDs?](../glossary#input-ids)
|
115 |
+
|
116 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
117 |
+
`past_key_values`).
|
118 |
+
|
119 |
+
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
120 |
+
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
121 |
+
and prepending them with the `decoder_start_token_id`.
|
122 |
+
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
123 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
124 |
+
be used by default.
|
125 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
126 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
127 |
+
range `[0, config.decoder.max_position_embeddings - 1]`.
|
128 |
+
output_hidden_states (`bool`, *optional*):
|
129 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
130 |
+
more detail.
|
131 |
+
return_dict (`bool`, *optional*):
|
132 |
+
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
|
133 |
+
"""
|
134 |
+
|
135 |
+
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
|
136 |
+
Args:
|
137 |
+
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
138 |
+
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
139 |
+
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
140 |
+
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
141 |
+
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
142 |
+
*torch.FloatTensor*.
|
143 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
144 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
145 |
+
|
146 |
+
- 1 for tokens that are **not masked**,
|
147 |
+
- 0 for tokens that are **masked**.
|
148 |
+
|
149 |
+
[What are attention masks?](../glossary#attention-mask)
|
150 |
+
output_attentions (`bool`, *optional*):
|
151 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
152 |
+
tensors for more detail.
|
153 |
+
output_hidden_states (`bool`, *optional*):
|
154 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
155 |
+
more detail.
|
156 |
+
return_dict (`bool`, *optional*):
|
157 |
+
If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
|
158 |
+
"""
|
159 |
+
|
160 |
+
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
161 |
+
Args:
|
162 |
+
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
163 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
164 |
+
|
165 |
+
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
166 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
167 |
+
|
168 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
169 |
+
|
170 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
171 |
+
`past_key_values`).
|
172 |
+
|
173 |
+
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
174 |
+
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
175 |
+
and prepending them with the `decoder_start_token_id`.
|
176 |
+
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
177 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
178 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
179 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
180 |
+
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
181 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
182 |
+
|
183 |
+
- 1 for tokens that are **not masked**,
|
184 |
+
- 0 for tokens that are **masked**.
|
185 |
+
|
186 |
+
[What are attention masks?](../glossary#attention-mask)
|
187 |
+
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
188 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
189 |
+
be used by default.
|
190 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
191 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
192 |
+
range `[0, config.decoder.max_position_embeddings - 1]`.
|
193 |
+
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
194 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
195 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
196 |
+
output_attentions (`bool`, *optional*):
|
197 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
198 |
+
tensors for more detail.
|
199 |
+
output_hidden_states (`bool`, *optional*):
|
200 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
201 |
+
more detail.
|
202 |
+
return_dict (`bool`, *optional*):
|
203 |
+
If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
|
204 |
+
plain tuple.
|
205 |
+
"""
|
206 |
+
|
207 |
+
@flax.struct.dataclass
|
208 |
+
class FlaxBeamSearchOutput(ModelOutput):
|
209 |
+
"""
|
210 |
+
Flax Base class for outputs of decoder-only generation models using greedy search.
|
211 |
+
|
212 |
+
|
213 |
+
Args:
|
214 |
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
215 |
+
The generated sequences.
|
216 |
+
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
217 |
+
The scores (log probabilites) of the generated sequences.
|
218 |
+
"""
|
219 |
+
|
220 |
+
sequences: jnp.ndarray = None
|
221 |
+
scores: jnp.ndarray = None
|
222 |
+
|
223 |
+
|
224 |
+
@flax.struct.dataclass
|
225 |
+
class BeamSearchState:
|
226 |
+
cur_len: jnp.ndarray
|
227 |
+
running_sequences: jnp.ndarray
|
228 |
+
running_scores: jnp.ndarray
|
229 |
+
sequences: jnp.ndarray
|
230 |
+
scores: jnp.ndarray
|
231 |
+
is_sent_finished: jnp.ndarray
|
232 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
class FlaxSpeechEncoderDecoderModule(nn.Module):
|
238 |
+
config: SpeechEncoderDecoderConfig
|
239 |
+
dtype: jnp.dtype = jnp.float32
|
240 |
+
|
241 |
+
def setup(self):
|
242 |
+
encoder_config = self.config.encoder
|
243 |
+
decoder_config = self.config.decoder
|
244 |
+
|
245 |
+
# TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
|
246 |
+
encoder_module = FlaxWav2Vec2Module
|
247 |
+
decoder_module = FlaxBartForCausalLMModule
|
248 |
+
|
249 |
+
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
250 |
+
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
251 |
+
|
252 |
+
# encoder outputs might need to be projected to different dimension for decoder
|
253 |
+
if (
|
254 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
255 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
256 |
+
):
|
257 |
+
self.enc_to_dec_proj = nn.Dense(
|
258 |
+
self.decoder.config.hidden_size,
|
259 |
+
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
260 |
+
dtype=self.dtype,
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
self.enc_to_dec_proj = None
|
264 |
+
|
265 |
+
def _get_feat_extract_output_lengths(
|
266 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
267 |
+
):
|
268 |
+
"""
|
269 |
+
Computes the output length of the convolutional layers
|
270 |
+
"""
|
271 |
+
|
272 |
+
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
|
273 |
+
|
274 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
275 |
+
# 1D convolutional layer output length formula taken
|
276 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
277 |
+
return (input_length - kernel_size) // stride + 1
|
278 |
+
|
279 |
+
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
|
280 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
281 |
+
|
282 |
+
if add_adapter:
|
283 |
+
for _ in range(self.config.encoder.num_adapter_layers):
|
284 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
|
285 |
+
|
286 |
+
return input_lengths
|
287 |
+
|
288 |
+
def _get_encoder_module(self):
|
289 |
+
return self.encoder
|
290 |
+
|
291 |
+
def _get_projection_module(self):
|
292 |
+
return self.enc_to_dec_proj
|
293 |
+
|
294 |
+
def _get_decoder_module(self):
|
295 |
+
return self.decoder
|
296 |
+
|
297 |
+
def __call__(
|
298 |
+
self,
|
299 |
+
inputs,
|
300 |
+
attention_mask,
|
301 |
+
decoder_input_ids,
|
302 |
+
decoder_attention_mask,
|
303 |
+
decoder_position_ids,
|
304 |
+
encoder_outputs=None,
|
305 |
+
extract_features=None,
|
306 |
+
output_attentions: bool = False,
|
307 |
+
output_hidden_states: bool = False,
|
308 |
+
output_features: bool = False,
|
309 |
+
return_dict: bool = True,
|
310 |
+
deterministic: bool = True,
|
311 |
+
freeze_feature_encoder: bool = False,
|
312 |
+
):
|
313 |
+
if encoder_outputs is None:
|
314 |
+
encoder_outputs = self.encoder(
|
315 |
+
inputs,
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
extract_features=extract_features,
|
318 |
+
output_attentions=output_attentions,
|
319 |
+
output_hidden_states=output_hidden_states,
|
320 |
+
output_features=output_features,
|
321 |
+
return_dict=return_dict,
|
322 |
+
deterministic=deterministic,
|
323 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
324 |
+
)
|
325 |
+
|
326 |
+
if output_features:
|
327 |
+
return encoder_outputs
|
328 |
+
|
329 |
+
encoder_hidden_states = encoder_outputs[0]
|
330 |
+
|
331 |
+
# optionally project encoder_hidden_states
|
332 |
+
if self.enc_to_dec_proj is not None:
|
333 |
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
334 |
+
|
335 |
+
# compute correct encoder attention mask
|
336 |
+
if attention_mask is not None:
|
337 |
+
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
|
338 |
+
encoder_hidden_states.shape[1], attention_mask
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
encoder_attention_mask = None
|
342 |
+
|
343 |
+
# flax script modeling_flax_wav2vec2.py
|
344 |
+
decoder_outputs = self.decoder(
|
345 |
+
input_ids=decoder_input_ids,
|
346 |
+
attention_mask=decoder_attention_mask,
|
347 |
+
position_ids=decoder_position_ids,
|
348 |
+
encoder_hidden_states=encoder_hidden_states,
|
349 |
+
encoder_attention_mask=encoder_attention_mask,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
output_hidden_states=output_hidden_states,
|
352 |
+
return_dict=return_dict,
|
353 |
+
deterministic=deterministic,
|
354 |
+
)
|
355 |
+
|
356 |
+
if not return_dict:
|
357 |
+
return decoder_outputs + encoder_outputs
|
358 |
+
|
359 |
+
return FlaxSeq2SeqLMOutput(
|
360 |
+
logits=decoder_outputs.logits,
|
361 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
362 |
+
decoder_attentions=decoder_outputs.attentions,
|
363 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
364 |
+
encoder_last_hidden_state=encoder_hidden_states,
|
365 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
366 |
+
encoder_attentions=encoder_outputs.attentions,
|
367 |
+
)
|
368 |
+
|
369 |
+
|
370 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
|
371 |
+
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
372 |
+
r"""
|
373 |
+
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
374 |
+
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
|
375 |
+
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
|
376 |
+
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
|
377 |
+
"""
|
378 |
+
|
379 |
+
config_class = SpeechEncoderDecoderConfig
|
380 |
+
base_model_prefix: str = "speech_encoder_decoder"
|
381 |
+
module_class = FlaxSpeechEncoderDecoderModule
|
382 |
+
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
config: SpeechEncoderDecoderConfig,
|
386 |
+
input_shape: Optional[Tuple] = None,
|
387 |
+
seed: int = 0,
|
388 |
+
dtype: jnp.dtype = jnp.float32,
|
389 |
+
_do_init: bool = True,
|
390 |
+
**kwargs
|
391 |
+
):
|
392 |
+
|
393 |
+
if not _do_init:
|
394 |
+
raise ValueError(
|
395 |
+
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
396 |
+
)
|
397 |
+
|
398 |
+
if config.decoder.cross_attention_hidden_size is not None:
|
399 |
+
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
400 |
+
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
401 |
+
raise ValueError(
|
402 |
+
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
403 |
+
"it has to be equal to the encoder's `hidden_size`. "
|
404 |
+
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
405 |
+
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
406 |
+
)
|
407 |
+
|
408 |
+
# make sure input & output embeddings are not tied
|
409 |
+
config.tie_word_embeddings = False
|
410 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
411 |
+
|
412 |
+
if input_shape is None:
|
413 |
+
# speech encoders almost always downsample the sequence length dimension
|
414 |
+
encoder_input_length = 1024
|
415 |
+
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
416 |
+
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
417 |
+
|
418 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
419 |
+
|
420 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
421 |
+
encoder_input_shape, decoder_input_shape = input_shape
|
422 |
+
|
423 |
+
# init input DeviceArrays
|
424 |
+
inputs = jnp.zeros(encoder_input_shape, dtype="f4")
|
425 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
426 |
+
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
427 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
428 |
+
|
429 |
+
batch_size, sequence_length = inputs.shape
|
430 |
+
|
431 |
+
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
|
432 |
+
if not decoder_batch_size == batch_size:
|
433 |
+
raise ValueError(
|
434 |
+
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
|
435 |
+
)
|
436 |
+
decoder_position_ids = jnp.broadcast_to(
|
437 |
+
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
|
438 |
+
)
|
439 |
+
|
440 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
441 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
442 |
+
|
443 |
+
return self.module.init(
|
444 |
+
rngs,
|
445 |
+
inputs,
|
446 |
+
attention_mask,
|
447 |
+
decoder_input_ids,
|
448 |
+
decoder_attention_mask,
|
449 |
+
decoder_position_ids,
|
450 |
+
)["params"]
|
451 |
+
|
452 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
453 |
+
r"""
|
454 |
+
Args:
|
455 |
+
batch_size (`int`):
|
456 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
457 |
+
max_length (`int`):
|
458 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
459 |
+
cache.
|
460 |
+
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
461 |
+
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
462 |
+
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
463 |
+
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
464 |
+
cross-attention of the decoder.
|
465 |
+
"""
|
466 |
+
# init input variables to retrieve cache
|
467 |
+
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
468 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
469 |
+
decoder_position_ids = jnp.broadcast_to(
|
470 |
+
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
|
471 |
+
)
|
472 |
+
|
473 |
+
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
474 |
+
decoder_module = module._get_decoder_module()
|
475 |
+
return decoder_module(
|
476 |
+
input_ids=decoder_input_ids,
|
477 |
+
attention_mask=decoder_attention_mask,
|
478 |
+
position_ids=decoder_position_ids,
|
479 |
+
**kwargs,
|
480 |
+
)
|
481 |
+
|
482 |
+
init_variables = self.module.init(
|
483 |
+
jax.random.PRNGKey(0),
|
484 |
+
decoder_input_ids=decoder_input_ids,
|
485 |
+
decoder_attention_mask=decoder_attention_mask,
|
486 |
+
decoder_position_ids=decoder_position_ids,
|
487 |
+
encoder_hidden_states=encoder_outputs[0],
|
488 |
+
init_cache=True,
|
489 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
490 |
+
)
|
491 |
+
return unfreeze(init_variables["cache"])
|
492 |
+
|
493 |
+
def _get_feat_extract_output_lengths(
|
494 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
495 |
+
):
|
496 |
+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
497 |
+
|
498 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
|
499 |
+
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
500 |
+
def encode(
|
501 |
+
self,
|
502 |
+
inputs: jnp.ndarray,
|
503 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
504 |
+
extract_features: Optional[jnp.ndarray] = None,
|
505 |
+
output_attentions: Optional[bool] = None,
|
506 |
+
output_hidden_states: Optional[bool] = None,
|
507 |
+
output_features: Optional[bool] = None,
|
508 |
+
return_dict: Optional[bool] = None,
|
509 |
+
train: bool = False,
|
510 |
+
freeze_feature_encoder: bool = False,
|
511 |
+
params: dict = None,
|
512 |
+
dropout_rng: PRNGKey = None,
|
513 |
+
):
|
514 |
+
r"""
|
515 |
+
Returns:
|
516 |
+
|
517 |
+
Example:
|
518 |
+
|
519 |
+
```python
|
520 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
521 |
+
|
522 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
523 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
524 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
525 |
+
... )
|
526 |
+
|
527 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
528 |
+
>>> encoder_outputs = model.encode(inputs)
|
529 |
+
```"""
|
530 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
531 |
+
output_hidden_states = (
|
532 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
533 |
+
)
|
534 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
535 |
+
|
536 |
+
if attention_mask is None:
|
537 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
538 |
+
|
539 |
+
if extract_features is not None:
|
540 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
541 |
+
|
542 |
+
# Handle any PRNG if needed
|
543 |
+
rngs = {}
|
544 |
+
if dropout_rng is not None:
|
545 |
+
rngs["dropout"] = dropout_rng
|
546 |
+
|
547 |
+
def _encoder_forward(module, inputs, attention_mask, **kwargs):
|
548 |
+
encode_module = module._get_encoder_module()
|
549 |
+
return encode_module(inputs, attention_mask, **kwargs)
|
550 |
+
|
551 |
+
outputs = self.module.apply(
|
552 |
+
{"params": params or self.params},
|
553 |
+
inputs=jnp.array(inputs, dtype="f4"),
|
554 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
555 |
+
extract_features=extract_features,
|
556 |
+
output_attentions=output_attentions,
|
557 |
+
output_hidden_states=output_hidden_states,
|
558 |
+
output_features=output_features,
|
559 |
+
return_dict=return_dict,
|
560 |
+
deterministic=not train,
|
561 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
562 |
+
rngs=rngs,
|
563 |
+
method=_encoder_forward,
|
564 |
+
)
|
565 |
+
|
566 |
+
if return_dict and not output_features:
|
567 |
+
outputs = FlaxBaseModelOutput(
|
568 |
+
last_hidden_state=outputs.last_hidden_state,
|
569 |
+
hidden_states=outputs.hidden_states,
|
570 |
+
attentions=outputs.attentions,
|
571 |
+
)
|
572 |
+
|
573 |
+
return outputs
|
574 |
+
|
575 |
+
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
|
576 |
+
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
577 |
+
def decode(
|
578 |
+
self,
|
579 |
+
decoder_input_ids,
|
580 |
+
encoder_outputs,
|
581 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
582 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
583 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
584 |
+
past_key_values: dict = None,
|
585 |
+
output_attentions: Optional[bool] = None,
|
586 |
+
output_hidden_states: Optional[bool] = None,
|
587 |
+
return_dict: Optional[bool] = None,
|
588 |
+
train: bool = False,
|
589 |
+
params: dict = None,
|
590 |
+
dropout_rng: PRNGKey = None,
|
591 |
+
):
|
592 |
+
r"""
|
593 |
+
Returns:
|
594 |
+
|
595 |
+
Example:
|
596 |
+
|
597 |
+
```python
|
598 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
599 |
+
>>> import jax.numpy as jnp
|
600 |
+
|
601 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
602 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
603 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
604 |
+
... )
|
605 |
+
|
606 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
607 |
+
>>> encoder_outputs = model.encode(inputs)
|
608 |
+
|
609 |
+
>>> decoder_start_token_id = model.config.decoder.bos_token_id
|
610 |
+
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
|
611 |
+
|
612 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
613 |
+
>>> logits = outputs.logits
|
614 |
+
```"""
|
615 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
616 |
+
output_hidden_states = (
|
617 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
618 |
+
)
|
619 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
620 |
+
|
621 |
+
encoder_hidden_states = encoder_outputs[0]
|
622 |
+
if encoder_attention_mask is None:
|
623 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
624 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
625 |
+
|
626 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
627 |
+
if decoder_attention_mask is None:
|
628 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
629 |
+
|
630 |
+
if decoder_position_ids is None:
|
631 |
+
if past_key_values is not None:
|
632 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
633 |
+
|
634 |
+
decoder_position_ids = jnp.broadcast_to(
|
635 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
636 |
+
)
|
637 |
+
|
638 |
+
# Handle any PRNG if needed
|
639 |
+
rngs = {}
|
640 |
+
if dropout_rng is not None:
|
641 |
+
rngs["dropout"] = dropout_rng
|
642 |
+
|
643 |
+
params = {"params": params or self.params}
|
644 |
+
|
645 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
646 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
647 |
+
# it can be changed by FlaxBartAttention module
|
648 |
+
if past_key_values:
|
649 |
+
params["cache"] = past_key_values
|
650 |
+
mutable = ["cache"]
|
651 |
+
else:
|
652 |
+
mutable = False
|
653 |
+
|
654 |
+
def _decoder_forward(
|
655 |
+
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
656 |
+
):
|
657 |
+
|
658 |
+
projection_module = module._get_projection_module()
|
659 |
+
decoder_module = module._get_decoder_module()
|
660 |
+
|
661 |
+
# optionally project encoder_hidden_states
|
662 |
+
if projection_module is not None:
|
663 |
+
encoder_hidden_states = projection_module(encoder_hidden_states)
|
664 |
+
|
665 |
+
return decoder_module(
|
666 |
+
decoder_input_ids,
|
667 |
+
decoder_attention_mask,
|
668 |
+
decoder_position_ids,
|
669 |
+
encoder_hidden_states,
|
670 |
+
**kwargs,
|
671 |
+
)
|
672 |
+
|
673 |
+
outputs = self.module.apply(
|
674 |
+
params,
|
675 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
676 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
677 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
678 |
+
encoder_hidden_states=encoder_hidden_states,
|
679 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
680 |
+
output_attentions=output_attentions,
|
681 |
+
output_hidden_states=output_hidden_states,
|
682 |
+
return_dict=return_dict,
|
683 |
+
deterministic=not train,
|
684 |
+
rngs=rngs,
|
685 |
+
mutable=mutable,
|
686 |
+
method=_decoder_forward,
|
687 |
+
)
|
688 |
+
|
689 |
+
# add updated cache to model output
|
690 |
+
if past_key_values is not None and return_dict:
|
691 |
+
outputs, past = outputs
|
692 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
693 |
+
return outputs
|
694 |
+
elif past_key_values is not None and not return_dict:
|
695 |
+
outputs, past = outputs
|
696 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
697 |
+
|
698 |
+
return outputs
|
699 |
+
|
700 |
+
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
|
701 |
+
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
702 |
+
def __call__(
|
703 |
+
self,
|
704 |
+
inputs: jnp.ndarray,
|
705 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
706 |
+
extract_features: Optional[jnp.ndarray] = None,
|
707 |
+
decoder_input_ids: Optional[jnp.ndarray] = None,
|
708 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
709 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
710 |
+
output_attentions: Optional[bool] = None,
|
711 |
+
output_hidden_states: Optional[bool] = None,
|
712 |
+
output_features: Optional[bool] = None,
|
713 |
+
return_dict: Optional[bool] = None,
|
714 |
+
train: bool = False,
|
715 |
+
freeze_feature_encoder: bool = False,
|
716 |
+
params: dict = None,
|
717 |
+
dropout_rng: PRNGKey = None,
|
718 |
+
):
|
719 |
+
r"""
|
720 |
+
Returns:
|
721 |
+
|
722 |
+
Examples:
|
723 |
+
|
724 |
+
```python
|
725 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
|
726 |
+
|
727 |
+
>>> # load a fine-tuned wav2vec2-2-bart model
|
728 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
|
729 |
+
>>> # load output tokenizer
|
730 |
+
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
|
731 |
+
|
732 |
+
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
733 |
+
|
734 |
+
>>> # use bart's special bos, pad and eos tokens
|
735 |
+
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
|
736 |
+
>>> model.config.pad_token_id = model.decoder.config.pad_token_id
|
737 |
+
>>> model.config.eos_token_id = model.decoder.config.eos_token_id
|
738 |
+
|
739 |
+
>>> outputs = model.generate(inputs)
|
740 |
+
# Assert something? More interesting input? dtype correct?
|
741 |
+
```
|
742 |
+
"""
|
743 |
+
|
744 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
745 |
+
output_hidden_states = (
|
746 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
747 |
+
)
|
748 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
749 |
+
|
750 |
+
# prepare encoder inputs
|
751 |
+
if attention_mask is None:
|
752 |
+
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
753 |
+
|
754 |
+
if extract_features is not None:
|
755 |
+
inputs = None # we can omit passing the inputs to the model to save memory
|
756 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
757 |
+
else:
|
758 |
+
inputs = jnp.array(inputs, dtype="f4")
|
759 |
+
|
760 |
+
# prepare decoder inputs
|
761 |
+
if decoder_input_ids is None:
|
762 |
+
raise ValueError(
|
763 |
+
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
|
764 |
+
)
|
765 |
+
if decoder_attention_mask is None:
|
766 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
767 |
+
if decoder_position_ids is None:
|
768 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
769 |
+
decoder_position_ids = jnp.broadcast_to(
|
770 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
771 |
+
)
|
772 |
+
|
773 |
+
# Handle any PRNG if needed
|
774 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
775 |
+
|
776 |
+
return self.module.apply(
|
777 |
+
{"params": params or self.params},
|
778 |
+
inputs=inputs,
|
779 |
+
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
780 |
+
extract_features=extract_features,
|
781 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
782 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
783 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
784 |
+
output_attentions=output_attentions,
|
785 |
+
output_hidden_states=output_hidden_states,
|
786 |
+
output_features=output_features,
|
787 |
+
return_dict=return_dict,
|
788 |
+
deterministic=not train,
|
789 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
790 |
+
rngs=rngs,
|
791 |
+
)
|
792 |
+
|
793 |
+
def prepare_inputs_for_generation(
|
794 |
+
self,
|
795 |
+
decoder_input_ids,
|
796 |
+
max_length,
|
797 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
798 |
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
799 |
+
encoder_outputs=None,
|
800 |
+
**kwargs
|
801 |
+
):
|
802 |
+
# initializing the cache
|
803 |
+
batch_size, seq_length = decoder_input_ids.shape
|
804 |
+
|
805 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
806 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
|
807 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
808 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
809 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
810 |
+
if decoder_attention_mask is not None:
|
811 |
+
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
812 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
813 |
+
else:
|
814 |
+
decoder_position_ids = jnp.broadcast_to(
|
815 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
816 |
+
)
|
817 |
+
|
818 |
+
return {
|
819 |
+
"past_key_values": past_key_values,
|
820 |
+
"encoder_outputs": encoder_outputs,
|
821 |
+
"encoder_attention_mask": attention_mask,
|
822 |
+
"decoder_attention_mask": extended_attention_mask,
|
823 |
+
"decoder_position_ids": decoder_position_ids,
|
824 |
+
}
|
825 |
+
|
826 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
827 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
828 |
+
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
829 |
+
return model_kwargs
|
830 |
+
|
831 |
+
@classmethod
|
832 |
+
def from_encoder_decoder_pretrained(
|
833 |
+
cls,
|
834 |
+
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
835 |
+
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
836 |
+
*model_args,
|
837 |
+
**kwargs
|
838 |
+
) -> FlaxPreTrainedModel:
|
839 |
+
r"""
|
840 |
+
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
841 |
+
checkpoints.
|
842 |
+
|
843 |
+
Params:
|
844 |
+
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
|
845 |
+
Information necessary to initiate the encoder. Can be either:
|
846 |
+
|
847 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
848 |
+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
849 |
+
user or organization name, like `dbmdz/bert-base-german-cased`.
|
850 |
+
- A path to a *directory* containing model weights saved using
|
851 |
+
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
852 |
+
|
853 |
+
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
|
854 |
+
Information necessary to initiate the decoder. Can be either:
|
855 |
+
|
856 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
857 |
+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
858 |
+
user or organization name, like `dbmdz/bert-base-german-cased`.
|
859 |
+
- A path to a *directory* containing model weights saved using
|
860 |
+
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
861 |
+
|
862 |
+
model_args (remaining positional arguments, *optional*):
|
863 |
+
All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
864 |
+
|
865 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
866 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
867 |
+
`output_attentions=True`).
|
868 |
+
|
869 |
+
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
870 |
+
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
871 |
+
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
872 |
+
|
873 |
+
Behaves differently depending on whether a `config` is provided or automatically loaded.
|
874 |
+
|
875 |
+
Example:
|
876 |
+
|
877 |
+
```python
|
878 |
+
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
879 |
+
|
880 |
+
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
881 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
882 |
+
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
883 |
+
... )
|
884 |
+
>>> # saving model after fine-tuning
|
885 |
+
>>> model.save_pretrained("./wav2vec2-2-bart-large")
|
886 |
+
>>> # load fine-tuned model
|
887 |
+
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
|
888 |
+
```"""
|
889 |
+
|
890 |
+
kwargs_encoder = {
|
891 |
+
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
892 |
+
}
|
893 |
+
|
894 |
+
kwargs_decoder = {
|
895 |
+
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
896 |
+
}
|
897 |
+
|
898 |
+
# remove encoder, decoder kwargs from kwargs
|
899 |
+
for key in kwargs_encoder.keys():
|
900 |
+
del kwargs["encoder_" + key]
|
901 |
+
for key in kwargs_decoder.keys():
|
902 |
+
del kwargs["decoder_" + key]
|
903 |
+
|
904 |
+
# Load and initialize the encoder and decoder
|
905 |
+
# The distinction between encoder and decoder at the model level is made
|
906 |
+
# by the value of the flag `is_decoder` that we need to set correctly.
|
907 |
+
encoder = kwargs_encoder.pop("model", None)
|
908 |
+
if encoder is None:
|
909 |
+
if encoder_pretrained_model_name_or_path is None:
|
910 |
+
raise ValueError(
|
911 |
+
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
912 |
+
"to be defined."
|
913 |
+
)
|
914 |
+
|
915 |
+
if "config" not in kwargs_encoder:
|
916 |
+
# TODO: AutoConfig .from_pretrained
|
917 |
+
encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
|
918 |
+
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
919 |
+
)
|
920 |
+
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
921 |
+
logger.info(
|
922 |
+
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
923 |
+
"from a decoder model. Cross-attention and casual mask are disabled."
|
924 |
+
)
|
925 |
+
encoder_config.is_decoder = False
|
926 |
+
encoder_config.add_cross_attention = False
|
927 |
+
|
928 |
+
kwargs_encoder["config"] = encoder_config
|
929 |
+
|
930 |
+
# TODO: FlaxAutoModel .from_pretrained
|
931 |
+
encoder = FlaxWav2Vec2Model.from_pretrained(
|
932 |
+
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
933 |
+
)
|
934 |
+
|
935 |
+
decoder = kwargs_decoder.pop("model", None)
|
936 |
+
if decoder is None:
|
937 |
+
if decoder_pretrained_model_name_or_path is None:
|
938 |
+
raise ValueError(
|
939 |
+
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
940 |
+
"to be defined."
|
941 |
+
)
|
942 |
+
|
943 |
+
if "config" not in kwargs_decoder:
|
944 |
+
# TODO: AutoConfig .from_pretrained
|
945 |
+
decoder_config, kwargs_decoder = BartConfig.from_pretrained(
|
946 |
+
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
947 |
+
)
|
948 |
+
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
949 |
+
logger.info(
|
950 |
+
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
951 |
+
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
952 |
+
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
953 |
+
"cross attention layers."
|
954 |
+
)
|
955 |
+
decoder_config.is_decoder = True
|
956 |
+
decoder_config.add_cross_attention = True
|
957 |
+
|
958 |
+
kwargs_decoder["config"] = decoder_config
|
959 |
+
|
960 |
+
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
961 |
+
logger.warning(
|
962 |
+
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
963 |
+
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
964 |
+
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
965 |
+
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
966 |
+
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
967 |
+
)
|
968 |
+
|
969 |
+
# TODO: FlaxAutoModelForCausalLM .from_pretrained
|
970 |
+
decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
971 |
+
|
972 |
+
# instantiate config with corresponding kwargs
|
973 |
+
dtype = kwargs.pop("dtype", jnp.float32)
|
974 |
+
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
975 |
+
|
976 |
+
# make sure input & output word embeddings are not tied
|
977 |
+
config.tie_word_embeddings = False
|
978 |
+
|
979 |
+
# init model
|
980 |
+
model = cls(config, dtype=dtype)
|
981 |
+
model.params["encoder"] = encoder.params
|
982 |
+
model.params["decoder"] = decoder.params
|
983 |
+
|
984 |
+
return model
|
985 |
+
|
986 |
+
def _beam_search(
|
987 |
+
self,
|
988 |
+
input_ids: None,
|
989 |
+
max_length: Optional[int] = None,
|
990 |
+
pad_token_id: Optional[int] = None,
|
991 |
+
eos_token_id: Optional[int] = None,
|
992 |
+
length_penalty: Optional[float] = None,
|
993 |
+
early_stopping: Optional[bool] = None,
|
994 |
+
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
995 |
+
trace: bool = True,
|
996 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
997 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
998 |
+
):
|
999 |
+
"""
|
1000 |
+
This beam search function is heavily inspired by Flax's official example:
|
1001 |
+
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
1002 |
+
"""
|
1003 |
+
|
1004 |
+
def flatten_beam_dim(tensor):
|
1005 |
+
"""Flattens the first two dimensions of a non-scalar array."""
|
1006 |
+
# ignore scalars (e.g. cache index)
|
1007 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1008 |
+
return tensor
|
1009 |
+
elif tensor.ndim == 6:
|
1010 |
+
return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
|
1011 |
+
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
1012 |
+
|
1013 |
+
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
1014 |
+
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
1015 |
+
# ignore scalars (e.g. cache index)
|
1016 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1017 |
+
return tensor
|
1018 |
+
if tensor.ndim == 5:
|
1019 |
+
return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
|
1020 |
+
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
1021 |
+
|
1022 |
+
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
1023 |
+
"""
|
1024 |
+
Gathers the beam slices indexed by beam_indices into new beam array.
|
1025 |
+
"""
|
1026 |
+
batch_indices = jnp.reshape(
|
1027 |
+
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
def gather_fn(tensor):
|
1031 |
+
# ignore scalars (e.g. cache index)
|
1032 |
+
if tensor.ndim == 0 or tensor.ndim == 1:
|
1033 |
+
return tensor
|
1034 |
+
if tensor.ndim == 6:
|
1035 |
+
return tensor[:, batch_indices, beam_indices]
|
1036 |
+
return tensor[batch_indices, beam_indices]
|
1037 |
+
|
1038 |
+
return jax.tree_map(gather_fn, nested)
|
1039 |
+
|
1040 |
+
# init values
|
1041 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
1042 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
1043 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
1044 |
+
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
1045 |
+
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
1046 |
+
|
1047 |
+
batch_size, num_beams, cur_len = input_ids.shape
|
1048 |
+
|
1049 |
+
eos_token_id = jnp.array(eos_token_id)
|
1050 |
+
pad_token_id = jnp.array(pad_token_id)
|
1051 |
+
cur_len = jnp.array(cur_len)
|
1052 |
+
|
1053 |
+
# per batch,beam-item holding current token in loop.
|
1054 |
+
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1055 |
+
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1056 |
+
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
1057 |
+
|
1058 |
+
# per batch,beam-item state bit indicating if sentence has finished.
|
1059 |
+
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
1060 |
+
|
1061 |
+
# per batch,beam-item score, logprobs
|
1062 |
+
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
1063 |
+
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
1064 |
+
|
1065 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
1066 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
1067 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
1068 |
+
|
1069 |
+
# flatten beam dim
|
1070 |
+
if "encoder_outputs" in model_kwargs:
|
1071 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
1072 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
1073 |
+
)
|
1074 |
+
if "attention_mask" in model_kwargs:
|
1075 |
+
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
1076 |
+
|
1077 |
+
# initialize model specific kwargs
|
1078 |
+
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
1079 |
+
|
1080 |
+
# initialize state
|
1081 |
+
state = BeamSearchState(
|
1082 |
+
cur_len=cur_len,
|
1083 |
+
running_sequences=running_sequences,
|
1084 |
+
running_scores=running_scores,
|
1085 |
+
sequences=sequences,
|
1086 |
+
scores=scores,
|
1087 |
+
is_sent_finished=is_sent_finished,
|
1088 |
+
model_kwargs=model_kwargs,
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
def beam_search_cond_fn(state):
|
1092 |
+
"""beam search state termination condition fn."""
|
1093 |
+
|
1094 |
+
# 1. is less than max length?
|
1095 |
+
not_max_length_yet = state.cur_len < max_length
|
1096 |
+
|
1097 |
+
# 2. can the new beams still improve?
|
1098 |
+
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
1099 |
+
worst_finished_score = jnp.where(
|
1100 |
+
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
1101 |
+
)
|
1102 |
+
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
1103 |
+
|
1104 |
+
# 3. is there still a beam that has not finished?
|
1105 |
+
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
1106 |
+
|
1107 |
+
return not_max_length_yet & still_open_beam & improvement_still_possible
|
1108 |
+
|
1109 |
+
def beam_search_body_fn(state, input_ids_length=1):
|
1110 |
+
"""beam search state update fn."""
|
1111 |
+
# 1. Forward current tokens
|
1112 |
+
# Collect the current position slice along length to feed the fast
|
1113 |
+
# autoregressive decoder model. Flatten the beam dimension into batch
|
1114 |
+
# dimension for feeding into the model.
|
1115 |
+
# unflatten beam dimension
|
1116 |
+
# Unflatten beam dimension in attention cache arrays
|
1117 |
+
input_token = flatten_beam_dim(
|
1118 |
+
lax.dynamic_slice(
|
1119 |
+
state.running_sequences,
|
1120 |
+
(0, 0, state.cur_len - input_ids_length),
|
1121 |
+
(batch_size, num_beams, input_ids_length),
|
1122 |
+
)
|
1123 |
+
)
|
1124 |
+
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
1125 |
+
|
1126 |
+
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
1127 |
+
cache = jax.tree_map(
|
1128 |
+
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
# adapt logits for FlaxMarianMTModel
|
1132 |
+
logits = self._adapt_logits_for_beam_search(logits)
|
1133 |
+
|
1134 |
+
# 2. Compute log probs
|
1135 |
+
# get log probabilities from logits,
|
1136 |
+
# process logits with processors (*e.g.* min_length, ...), and
|
1137 |
+
# add new logprobs to existing running logprobs scores.
|
1138 |
+
log_probs = jax.nn.log_softmax(logits)
|
1139 |
+
log_probs = logits_processor(
|
1140 |
+
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
1141 |
+
)
|
1142 |
+
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
1143 |
+
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
1144 |
+
vocab_size = log_probs.shape[2]
|
1145 |
+
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
1146 |
+
|
1147 |
+
# 3. Retrieve top-K
|
1148 |
+
# Each item in batch has num_beams * vocab_size candidate sequences.
|
1149 |
+
# For each item, get the top 2*k candidates with the highest log-
|
1150 |
+
# probabilities. We gather the top 2*K beams here so that even if the best
|
1151 |
+
# K sequences reach EOS simultaneously, we have another K sequences
|
1152 |
+
# remaining to continue the live beam search.
|
1153 |
+
# Gather the top 2*K scores from _all_ beams.
|
1154 |
+
# Gather 2*k top beams.
|
1155 |
+
# Recover the beam index by floor division.
|
1156 |
+
# Recover token id by modulo division and expand Id array for broadcasting.
|
1157 |
+
# Update sequences for the 2*K top-k new sequences.
|
1158 |
+
beams_to_keep = 2 * num_beams
|
1159 |
+
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
1160 |
+
topk_beam_indices = topk_indices // vocab_size
|
1161 |
+
topk_running_sequences = gather_beams(
|
1162 |
+
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
1163 |
+
)
|
1164 |
+
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
1165 |
+
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
1166 |
+
|
1167 |
+
# 4. Check which sequences have ended
|
1168 |
+
# Update current sequences:
|
1169 |
+
# Did any of these sequences reach an end marker?
|
1170 |
+
# To prevent these just finished sequences from being added to the current sequences
|
1171 |
+
# set of active beam search sequences, set their log probs to a very large
|
1172 |
+
# negative value.
|
1173 |
+
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
1174 |
+
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
1175 |
+
# 5. Get running sequences scores for next
|
1176 |
+
# Determine the top k beam indices (from top 2*k beams) from log probs
|
1177 |
+
# and gather top k beams (from top 2*k beams).
|
1178 |
+
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
1179 |
+
next_running_sequences, next_running_scores = gather_beams(
|
1180 |
+
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
# 6. Process topk logits
|
1184 |
+
# Further process log probs:
|
1185 |
+
# - add length penalty
|
1186 |
+
# - make sure no scores can be added anymore if beam is full
|
1187 |
+
# - make sure still running sequences cannot be chosen as finalized beam
|
1188 |
+
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
1189 |
+
beams_in_batch_are_full = (
|
1190 |
+
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
1191 |
+
& early_stopping
|
1192 |
+
)
|
1193 |
+
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
1194 |
+
topk_log_probs += add_penalty * np.array(-1.0e7)
|
1195 |
+
|
1196 |
+
# 7. Get scores, sequences, is sentence finished for next.
|
1197 |
+
# Combine sequences, scores, and flags along the beam dimension and compare
|
1198 |
+
# new finished sequence scores to existing finished scores and select the
|
1199 |
+
# best from the new set of beams
|
1200 |
+
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
1201 |
+
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
1202 |
+
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
1203 |
+
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
1204 |
+
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
1205 |
+
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
# 8. Update model kwargs.
|
1209 |
+
# Determine the top k beam indices from the original set of all beams.
|
1210 |
+
# With these, gather the top k beam-associated caches.
|
1211 |
+
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
1212 |
+
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
1213 |
+
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
1214 |
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
1215 |
+
|
1216 |
+
return BeamSearchState(
|
1217 |
+
cur_len=state.cur_len + 1,
|
1218 |
+
running_scores=next_running_scores,
|
1219 |
+
running_sequences=next_running_sequences,
|
1220 |
+
scores=next_scores,
|
1221 |
+
sequences=next_sequences,
|
1222 |
+
is_sent_finished=next_is_sent_finished,
|
1223 |
+
model_kwargs=next_model_kwargs,
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
1227 |
+
if input_ids.shape[-1] > 1:
|
1228 |
+
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
1229 |
+
|
1230 |
+
if not trace:
|
1231 |
+
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
1232 |
+
else:
|
1233 |
+
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
1234 |
+
|
1235 |
+
# Account for the edge-case where there are no finished sequences for a
|
1236 |
+
# particular batch item. If so, return running sequences for that batch item.
|
1237 |
+
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
1238 |
+
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
1239 |
+
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
1240 |
+
|
1241 |
+
# return all beams for each batch and the best score
|
1242 |
+
sequences = sequences[:, :]
|
1243 |
+
scores = scores[:, -1]
|
1244 |
+
|
1245 |
+
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
models/modeling_flax_wav2vec2.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax Wav2Vec2 model."""
|
16 |
+
|
17 |
+
from functools import partial
|
18 |
+
from typing import Optional, Tuple, Union
|
19 |
+
|
20 |
+
import flax
|
21 |
+
import flax.linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from flax.core.frozen_dict import FrozenDict
|
25 |
+
from flax.linen import partitioning as nn_partitioning
|
26 |
+
from flax.linen.attention import dot_product_attention_weights
|
27 |
+
from jax import lax
|
28 |
+
|
29 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
30 |
+
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
31 |
+
from transformers.utils import ModelOutput
|
32 |
+
|
33 |
+
from models import Wav2Vec2Config
|
34 |
+
|
35 |
+
scan_with_axes = nn_partitioning.scan_with_axes
|
36 |
+
remat = nn_partitioning.remat
|
37 |
+
|
38 |
+
|
39 |
+
@flax.struct.dataclass
|
40 |
+
class FlaxWav2Vec2BaseModelOutput(ModelOutput):
|
41 |
+
"""
|
42 |
+
Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
46 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
47 |
+
extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
|
48 |
+
Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
|
49 |
+
being the dimension of the last convolutional layer.
|
50 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
51 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
52 |
+
`(batch_size, sequence_length, hidden_size)`.
|
53 |
+
|
54 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
55 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
56 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
57 |
+
sequence_length)`.
|
58 |
+
|
59 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
60 |
+
heads.
|
61 |
+
"""
|
62 |
+
|
63 |
+
last_hidden_state: jnp.ndarray = None
|
64 |
+
extract_features: jnp.ndarray = None
|
65 |
+
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
66 |
+
attentions: Optional[Tuple[jnp.ndarray]] = None
|
67 |
+
|
68 |
+
|
69 |
+
WAV_2_VEC_2_START_DOCSTRING = r"""
|
70 |
+
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
71 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
72 |
+
Auli.
|
73 |
+
|
74 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
75 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
76 |
+
etc.)
|
77 |
+
|
78 |
+
This model is also a Flax Linen
|
79 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
80 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
81 |
+
|
82 |
+
Finally, this model supports inherent JAX features such as:
|
83 |
+
|
84 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
85 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
86 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
87 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
88 |
+
|
89 |
+
Parameters:
|
90 |
+
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
|
91 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
92 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
93 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
94 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
95 |
+
`jax.numpy.bfloat16` (on TPUs).
|
96 |
+
|
97 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
98 |
+
specified all the computation will be performed with the given `dtype`.
|
99 |
+
|
100 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
101 |
+
parameters.**
|
102 |
+
|
103 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
104 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
105 |
+
"""
|
106 |
+
|
107 |
+
|
108 |
+
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
109 |
+
Args:
|
110 |
+
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
111 |
+
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
112 |
+
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
113 |
+
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
|
114 |
+
and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
|
115 |
+
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
116 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
117 |
+
1]`:
|
118 |
+
|
119 |
+
- 1 for tokens that are **not masked**,
|
120 |
+
- 0 for tokens that are **masked**.
|
121 |
+
|
122 |
+
[What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
|
123 |
+
if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
|
124 |
+
has `config.return_attention_mask == False`, such as
|
125 |
+
[wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
|
126 |
+
passed to avoid degraded performance when doing batched inference. For such models `input_values` should
|
127 |
+
simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
|
128 |
+
different results depending on whether `input_values` is padded or not.
|
129 |
+
mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
130 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
131 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
132 |
+
output_attentions (`bool`, *optional*):
|
133 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
134 |
+
tensors for more detail.
|
135 |
+
output_hidden_states (`bool`, *optional*):
|
136 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
137 |
+
more detail.
|
138 |
+
return_dict (`bool`, *optional*):
|
139 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
140 |
+
"""
|
141 |
+
|
142 |
+
|
143 |
+
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
|
144 |
+
config: Wav2Vec2Config
|
145 |
+
layer_id: int = 0
|
146 |
+
dtype: jnp.dtype = jnp.float32
|
147 |
+
|
148 |
+
def setup(self):
|
149 |
+
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
|
150 |
+
self.out_conv_dim = self.config.conv_dim[self.layer_id]
|
151 |
+
|
152 |
+
self.conv = nn.Conv(
|
153 |
+
features=self.config.conv_dim[self.layer_id],
|
154 |
+
kernel_size=(self.config.conv_kernel[self.layer_id],),
|
155 |
+
strides=(self.config.conv_stride[self.layer_id],),
|
156 |
+
use_bias=self.config.conv_bias,
|
157 |
+
kernel_init=jax.nn.initializers.he_normal(),
|
158 |
+
padding="VALID",
|
159 |
+
dtype=self.dtype,
|
160 |
+
)
|
161 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
162 |
+
self.activation = ACT2FN[self.config.feat_extract_activation]
|
163 |
+
|
164 |
+
def __call__(self, hidden_states):
|
165 |
+
hidden_states = self.conv(hidden_states)
|
166 |
+
hidden_states = self.layer_norm(hidden_states)
|
167 |
+
hidden_states = self.activation(hidden_states)
|
168 |
+
return hidden_states
|
169 |
+
|
170 |
+
|
171 |
+
class FlaxConvWithWeightNorm(nn.Module):
|
172 |
+
config: Wav2Vec2Config
|
173 |
+
dtype: jnp.dtype = jnp.float32
|
174 |
+
|
175 |
+
def setup(self):
|
176 |
+
self.conv = nn.Conv(
|
177 |
+
features=self.config.hidden_size,
|
178 |
+
kernel_size=(self.config.num_conv_pos_embeddings,),
|
179 |
+
kernel_init=jax.nn.initializers.he_normal(),
|
180 |
+
padding="VALID",
|
181 |
+
feature_group_count=self.config.num_conv_pos_embedding_groups,
|
182 |
+
dtype=self.dtype,
|
183 |
+
)
|
184 |
+
weight_shape = (
|
185 |
+
self.conv.features,
|
186 |
+
self.conv.features // self.conv.feature_group_count,
|
187 |
+
self.conv.kernel_size[0],
|
188 |
+
)
|
189 |
+
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
|
190 |
+
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
|
191 |
+
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
|
192 |
+
self.prev_padding = self.conv.kernel_size[0] // 2
|
193 |
+
|
194 |
+
def _get_normed_weights(self):
|
195 |
+
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
|
196 |
+
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
|
197 |
+
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
|
198 |
+
return normed_kernel
|
199 |
+
|
200 |
+
def __call__(self, hidden_states):
|
201 |
+
kernel = self._get_normed_weights()
|
202 |
+
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
|
203 |
+
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
|
204 |
+
return hidden_states
|
205 |
+
|
206 |
+
|
207 |
+
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
|
208 |
+
config: Wav2Vec2Config
|
209 |
+
dtype: jnp.dtype = jnp.float32
|
210 |
+
|
211 |
+
def setup(self):
|
212 |
+
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
|
213 |
+
self.activation = ACT2FN[self.config.feat_extract_activation]
|
214 |
+
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
|
215 |
+
|
216 |
+
def __call__(self, hidden_states):
|
217 |
+
hidden_states = hidden_states.transpose((0, 1, 2))
|
218 |
+
|
219 |
+
hidden_states = self.conv(hidden_states)
|
220 |
+
|
221 |
+
if self.num_pad_remove > 0:
|
222 |
+
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
|
223 |
+
hidden_states = self.activation(hidden_states)
|
224 |
+
|
225 |
+
hidden_states = hidden_states.transpose((0, 1, 2))
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class FlaxConvLayersCollection(nn.Module):
|
230 |
+
config: Wav2Vec2Config
|
231 |
+
dtype: jnp.dtype = jnp.float32
|
232 |
+
|
233 |
+
def setup(self):
|
234 |
+
if self.config.feat_extract_norm == "layer":
|
235 |
+
# note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
|
236 |
+
BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
|
237 |
+
self.layers = [
|
238 |
+
BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
239 |
+
for i in range(self.config.num_feat_extract_layers)
|
240 |
+
]
|
241 |
+
elif self.config.feat_extract_norm == "group":
|
242 |
+
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
246 |
+
)
|
247 |
+
|
248 |
+
def __call__(self, hidden_states):
|
249 |
+
for i, conv_layer in enumerate(self.layers):
|
250 |
+
hidden_states = conv_layer(hidden_states)
|
251 |
+
return hidden_states
|
252 |
+
|
253 |
+
|
254 |
+
class FlaxWav2Vec2FeatureEncoder(nn.Module):
|
255 |
+
"""Construct the features from raw audio waveform"""
|
256 |
+
|
257 |
+
config: Wav2Vec2Config
|
258 |
+
dtype: jnp.dtype = jnp.float32
|
259 |
+
|
260 |
+
def setup(self):
|
261 |
+
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
|
262 |
+
|
263 |
+
def __call__(self, input_values, freeze_feature_encoder=False):
|
264 |
+
hidden_states = input_values[:, :, None]
|
265 |
+
hidden_states = self.conv_layers(hidden_states)
|
266 |
+
if freeze_feature_encoder:
|
267 |
+
hidden_states = jax.lax.stop_gradient(hidden_states)
|
268 |
+
return hidden_states
|
269 |
+
|
270 |
+
|
271 |
+
class FlaxWav2Vec2FeatureProjection(nn.Module):
|
272 |
+
config: Wav2Vec2Config
|
273 |
+
dtype: jnp.dtype = jnp.float32
|
274 |
+
|
275 |
+
def setup(self):
|
276 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
277 |
+
self.projection = nn.Dense(
|
278 |
+
self.config.hidden_size,
|
279 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
280 |
+
dtype=self.dtype,
|
281 |
+
)
|
282 |
+
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
|
283 |
+
|
284 |
+
def __call__(self, hidden_states, deterministic=True):
|
285 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
286 |
+
hidden_states = self.projection(norm_hidden_states)
|
287 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
288 |
+
return hidden_states, norm_hidden_states
|
289 |
+
|
290 |
+
|
291 |
+
class FlaxWav2Vec2Attention(nn.Module):
|
292 |
+
config: Wav2Vec2Config
|
293 |
+
embed_dim: int
|
294 |
+
num_heads: int
|
295 |
+
dropout: float = 0.0
|
296 |
+
bias: bool = True
|
297 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
298 |
+
|
299 |
+
def setup(self) -> None:
|
300 |
+
self.head_dim = self.embed_dim // self.num_heads
|
301 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
302 |
+
raise ValueError(
|
303 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
304 |
+
)
|
305 |
+
|
306 |
+
dense = partial(
|
307 |
+
nn.Dense,
|
308 |
+
self.embed_dim,
|
309 |
+
use_bias=self.bias,
|
310 |
+
dtype=self.dtype,
|
311 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
312 |
+
)
|
313 |
+
|
314 |
+
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
315 |
+
|
316 |
+
self.fused_proj = nn.Dense(
|
317 |
+
self.embed_dim * 3,
|
318 |
+
use_bias=self.bias,
|
319 |
+
dtype=self.dtype,
|
320 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
321 |
+
)
|
322 |
+
|
323 |
+
self.out_proj = dense()
|
324 |
+
|
325 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
326 |
+
|
327 |
+
def _split_heads(self, hidden_states):
|
328 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
329 |
+
|
330 |
+
def _merge_heads(self, hidden_states):
|
331 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
332 |
+
|
333 |
+
def __call__(
|
334 |
+
self,
|
335 |
+
hidden_states: jnp.ndarray,
|
336 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
337 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
338 |
+
deterministic: bool = True,
|
339 |
+
) -> Tuple[jnp.ndarray]:
|
340 |
+
"""Input shape: Batch x Time x Channel"""
|
341 |
+
|
342 |
+
if self.config.fuse_matmuls:
|
343 |
+
attention_states = self.fused_proj(hidden_states)
|
344 |
+
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
345 |
+
|
346 |
+
else:
|
347 |
+
# get query proj
|
348 |
+
query_states = self.q_proj(hidden_states)
|
349 |
+
|
350 |
+
key_states = self.k_proj(hidden_states)
|
351 |
+
value_states = self.v_proj(hidden_states)
|
352 |
+
|
353 |
+
query_states = self._split_heads(query_states)
|
354 |
+
key_states = self._split_heads(key_states)
|
355 |
+
value_states = self._split_heads(value_states)
|
356 |
+
|
357 |
+
if attention_mask is not None:
|
358 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
359 |
+
|
360 |
+
# Convert the boolean attention mask to an attention bias.
|
361 |
+
if attention_mask is not None:
|
362 |
+
# attention mask in the form of attention bias
|
363 |
+
attention_bias = lax.select(
|
364 |
+
attention_mask > 0,
|
365 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
366 |
+
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
367 |
+
)
|
368 |
+
else:
|
369 |
+
attention_bias = None
|
370 |
+
|
371 |
+
dropout_rng = None
|
372 |
+
if not deterministic and self.dropout > 0.0:
|
373 |
+
dropout_rng = self.make_rng("dropout")
|
374 |
+
|
375 |
+
attn_weights = dot_product_attention_weights(
|
376 |
+
query_states,
|
377 |
+
key_states,
|
378 |
+
bias=attention_bias,
|
379 |
+
dropout_rng=dropout_rng,
|
380 |
+
dropout_rate=self.dropout,
|
381 |
+
broadcast_dropout=True,
|
382 |
+
deterministic=deterministic,
|
383 |
+
dtype=self.dtype,
|
384 |
+
precision=None,
|
385 |
+
)
|
386 |
+
|
387 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
388 |
+
attn_output = self._merge_heads(attn_output)
|
389 |
+
attn_output = self.out_proj(attn_output)
|
390 |
+
|
391 |
+
return attn_output, attn_weights
|
392 |
+
|
393 |
+
|
394 |
+
class FlaxWav2Vec2FeedForward(nn.Module):
|
395 |
+
config: Wav2Vec2Config
|
396 |
+
dtype: jnp.dtype = jnp.float32
|
397 |
+
|
398 |
+
def setup(self):
|
399 |
+
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
|
400 |
+
|
401 |
+
self.intermediate_dense = nn.Dense(
|
402 |
+
self.config.intermediate_size,
|
403 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
404 |
+
dtype=self.dtype,
|
405 |
+
)
|
406 |
+
if isinstance(self.config.hidden_act, str):
|
407 |
+
self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
|
408 |
+
else:
|
409 |
+
self.intermediate_act_fn = self.config.hidden_act
|
410 |
+
|
411 |
+
self.output_dense = nn.Dense(
|
412 |
+
self.config.hidden_size,
|
413 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
414 |
+
dtype=self.dtype,
|
415 |
+
)
|
416 |
+
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
417 |
+
|
418 |
+
def __call__(self, hidden_states, deterministic=True):
|
419 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
420 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
421 |
+
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
|
422 |
+
|
423 |
+
hidden_states = self.output_dense(hidden_states)
|
424 |
+
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
|
425 |
+
return hidden_states
|
426 |
+
|
427 |
+
|
428 |
+
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
429 |
+
config: Wav2Vec2Config
|
430 |
+
dtype: jnp.dtype = jnp.float32
|
431 |
+
|
432 |
+
def setup(self):
|
433 |
+
self.attention = FlaxWav2Vec2Attention(
|
434 |
+
config=self.config,
|
435 |
+
embed_dim=self.config.hidden_size,
|
436 |
+
num_heads=self.config.num_attention_heads,
|
437 |
+
dropout=self.config.attention_dropout,
|
438 |
+
dtype=self.dtype,
|
439 |
+
)
|
440 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
441 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
442 |
+
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
|
443 |
+
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
444 |
+
|
445 |
+
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
|
446 |
+
if self.config.use_scan:
|
447 |
+
hidden_states = hidden_states[0]
|
448 |
+
attn_residual = hidden_states
|
449 |
+
hidden_states = self.layer_norm(hidden_states)
|
450 |
+
hidden_states, attn_weights = self.attention(
|
451 |
+
hidden_states, attention_mask=attention_mask, deterministic=deterministic
|
452 |
+
)
|
453 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
454 |
+
hidden_states = attn_residual + hidden_states
|
455 |
+
hidden_states = hidden_states + self.feed_forward(
|
456 |
+
self.final_layer_norm(hidden_states), deterministic=deterministic
|
457 |
+
)
|
458 |
+
|
459 |
+
outputs = (hidden_states,)
|
460 |
+
|
461 |
+
if output_attentions:
|
462 |
+
outputs += (attn_weights,)
|
463 |
+
|
464 |
+
if self.config.use_scan:
|
465 |
+
outputs = (outputs, None)
|
466 |
+
|
467 |
+
return outputs
|
468 |
+
|
469 |
+
|
470 |
+
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
|
471 |
+
config: Wav2Vec2Config
|
472 |
+
dtype: jnp.dtype = jnp.float32
|
473 |
+
|
474 |
+
@nn.compact
|
475 |
+
def __call__(
|
476 |
+
self,
|
477 |
+
hidden_states,
|
478 |
+
attention_mask=None,
|
479 |
+
deterministic: bool = True,
|
480 |
+
output_attentions: bool = False,
|
481 |
+
output_hidden_states: bool = False,
|
482 |
+
return_dict: bool = True,
|
483 |
+
):
|
484 |
+
all_attentions = () if output_attentions else None
|
485 |
+
all_hidden_states = () if output_hidden_states else None
|
486 |
+
|
487 |
+
num_layers = self.config.num_hidden_layers
|
488 |
+
BlockEncoderLayer = (
|
489 |
+
remat(
|
490 |
+
FlaxWav2Vec2EncoderLayerStableLayerNorm,
|
491 |
+
static_argnums=(2, 3),
|
492 |
+
prevent_cse=not self.config.use_scan,
|
493 |
+
)
|
494 |
+
if self.config.gradient_checkpointing
|
495 |
+
else FlaxWav2Vec2EncoderLayerStableLayerNorm
|
496 |
+
)
|
497 |
+
|
498 |
+
if self.config.use_scan:
|
499 |
+
# since all decoder layers are the same, we use nn.scan directly
|
500 |
+
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
501 |
+
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
502 |
+
hidden_states = (hidden_states,)
|
503 |
+
|
504 |
+
hidden_states, _ = scan_with_axes(
|
505 |
+
BlockEncoderLayer,
|
506 |
+
variable_axes={"params": 0, "cache": 0},
|
507 |
+
split_rngs={"params": True, "dropout": True},
|
508 |
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
|
509 |
+
length=num_layers,
|
510 |
+
)(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
|
511 |
+
hidden_states, attention_mask, deterministic, output_attentions
|
512 |
+
)
|
513 |
+
hidden_states = hidden_states[0]
|
514 |
+
|
515 |
+
else:
|
516 |
+
for layer in range(num_layers):
|
517 |
+
if output_hidden_states:
|
518 |
+
all_hidden_states += (hidden_states,)
|
519 |
+
|
520 |
+
layer_outputs = BlockEncoderLayer(
|
521 |
+
self.config,
|
522 |
+
dtype=self.dtype,
|
523 |
+
name=str(layer),
|
524 |
+
)(hidden_states, attention_mask, deterministic, output_attentions)
|
525 |
+
|
526 |
+
hidden_states = layer_outputs[0]
|
527 |
+
|
528 |
+
if output_attentions:
|
529 |
+
all_attentions += (layer_outputs[1],)
|
530 |
+
|
531 |
+
if output_hidden_states:
|
532 |
+
all_hidden_states += (hidden_states,)
|
533 |
+
|
534 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
535 |
+
|
536 |
+
if not return_dict:
|
537 |
+
return tuple(v for v in outputs if v is not None)
|
538 |
+
|
539 |
+
return FlaxBaseModelOutput(
|
540 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
|
545 |
+
config: Wav2Vec2Config
|
546 |
+
dtype: jnp.dtype = jnp.float32
|
547 |
+
|
548 |
+
def setup(self):
|
549 |
+
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
|
550 |
+
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
551 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
552 |
+
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
|
553 |
+
|
554 |
+
def __call__(
|
555 |
+
self,
|
556 |
+
hidden_states,
|
557 |
+
attention_mask=None,
|
558 |
+
deterministic=True,
|
559 |
+
output_attentions=False,
|
560 |
+
output_hidden_states=False,
|
561 |
+
return_dict=True,
|
562 |
+
):
|
563 |
+
|
564 |
+
if attention_mask is not None:
|
565 |
+
# make sure padded tokens are not attended to
|
566 |
+
hidden_states = jnp.where(
|
567 |
+
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
|
568 |
+
)
|
569 |
+
|
570 |
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
571 |
+
|
572 |
+
hidden_states = hidden_states + position_embeddings
|
573 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
574 |
+
|
575 |
+
outputs = self.layers(
|
576 |
+
hidden_states,
|
577 |
+
attention_mask,
|
578 |
+
output_attentions=output_attentions,
|
579 |
+
output_hidden_states=output_hidden_states,
|
580 |
+
return_dict=return_dict,
|
581 |
+
)
|
582 |
+
|
583 |
+
last_hidden_state = self.layer_norm(outputs[0])
|
584 |
+
|
585 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
586 |
+
hidden_states = None
|
587 |
+
if output_hidden_states:
|
588 |
+
hidden_states = outputs[1]
|
589 |
+
hidden_states = hidden_states[:-1] + (last_hidden_state,)
|
590 |
+
|
591 |
+
if not return_dict:
|
592 |
+
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
593 |
+
return tuple(v for v in outputs if v is not None)
|
594 |
+
|
595 |
+
return FlaxBaseModelOutput(
|
596 |
+
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
|
597 |
+
)
|
598 |
+
|
599 |
+
|
600 |
+
class FlaxWav2Vec2Adapter(nn.Module):
|
601 |
+
config: Wav2Vec2Config
|
602 |
+
dtype: jnp.dtype = jnp.float32
|
603 |
+
|
604 |
+
def setup(self):
|
605 |
+
# hidden_states require down-projection if feature dims don't match
|
606 |
+
if self.config.output_hidden_size != self.config.hidden_size:
|
607 |
+
self.proj = nn.Dense(
|
608 |
+
self.config.output_hidden_size,
|
609 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
610 |
+
dtype=self.dtype,
|
611 |
+
)
|
612 |
+
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
613 |
+
else:
|
614 |
+
self.proj = self.proj_layer_norm = None
|
615 |
+
|
616 |
+
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
|
617 |
+
|
618 |
+
def __call__(self, hidden_states, deterministic=True):
|
619 |
+
# down-project hidden_states if required
|
620 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
621 |
+
hidden_states = self.proj(hidden_states)
|
622 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
623 |
+
|
624 |
+
hidden_states = self.layers(hidden_states)
|
625 |
+
|
626 |
+
return hidden_states
|
627 |
+
|
628 |
+
|
629 |
+
class FlaxWav2Vec2AdapterLayer(nn.Module):
|
630 |
+
config: Wav2Vec2Config
|
631 |
+
dtype: jnp.dtype = jnp.float32
|
632 |
+
|
633 |
+
def setup(self):
|
634 |
+
self.conv = nn.Conv(
|
635 |
+
features=2 * self.config.output_hidden_size,
|
636 |
+
kernel_size=(self.config.adapter_kernel_size,),
|
637 |
+
strides=(self.config.adapter_stride,),
|
638 |
+
padding=((1, 1),),
|
639 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
640 |
+
dtype=self.dtype,
|
641 |
+
)
|
642 |
+
|
643 |
+
def __call__(self, hidden_states):
|
644 |
+
hidden_states = self.conv(hidden_states)
|
645 |
+
hidden_states = nn.glu(hidden_states, axis=2)
|
646 |
+
|
647 |
+
return hidden_states
|
648 |
+
|
649 |
+
|
650 |
+
class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
|
651 |
+
config: Wav2Vec2Config
|
652 |
+
dtype: jnp.dtype = jnp.float32
|
653 |
+
|
654 |
+
def setup(self):
|
655 |
+
BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
|
656 |
+
self.layers = [
|
657 |
+
BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
|
658 |
+
for i in range(self.config.num_adapter_layers)
|
659 |
+
]
|
660 |
+
|
661 |
+
def __call__(self, hidden_states):
|
662 |
+
for conv_layer in self.layers:
|
663 |
+
hidden_states = conv_layer(hidden_states)
|
664 |
+
|
665 |
+
return hidden_states
|
666 |
+
|
667 |
+
|
668 |
+
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
669 |
+
"""
|
670 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
671 |
+
models.
|
672 |
+
"""
|
673 |
+
|
674 |
+
config_class = Wav2Vec2Config
|
675 |
+
base_model_prefix: str = "wav2vec2"
|
676 |
+
main_input_name = "input_values"
|
677 |
+
module_class: nn.Module = None
|
678 |
+
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
config: Wav2Vec2Config,
|
682 |
+
input_shape: Tuple = (1, 1024),
|
683 |
+
seed: int = 0,
|
684 |
+
dtype: jnp.dtype = jnp.float32,
|
685 |
+
_do_init: bool = True,
|
686 |
+
**kwargs,
|
687 |
+
):
|
688 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
689 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
690 |
+
|
691 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
692 |
+
# init input tensors
|
693 |
+
input_values = jnp.zeros(input_shape, dtype="i4")
|
694 |
+
attention_mask = jnp.ones_like(input_values)
|
695 |
+
params_rng, dropout_rng = jax.random.split(rng, 2)
|
696 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
697 |
+
|
698 |
+
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
699 |
+
|
700 |
+
def __call__(
|
701 |
+
self,
|
702 |
+
input_values,
|
703 |
+
attention_mask=None,
|
704 |
+
mask_time_indices=None,
|
705 |
+
extract_features=None,
|
706 |
+
params: dict = None,
|
707 |
+
dropout_rng: jax.random.PRNGKey = None,
|
708 |
+
train: bool = False,
|
709 |
+
output_attentions: Optional[bool] = None,
|
710 |
+
output_hidden_states: Optional[bool] = None,
|
711 |
+
output_features: Optional[bool] = None,
|
712 |
+
freeze_feature_encoder: bool = False,
|
713 |
+
return_dict: Optional[bool] = None,
|
714 |
+
):
|
715 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
716 |
+
output_hidden_states = (
|
717 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
718 |
+
)
|
719 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
720 |
+
|
721 |
+
if attention_mask is None:
|
722 |
+
batch_size, sequence_length = input_values.shape
|
723 |
+
attention_mask = jnp.ones((batch_size, sequence_length))
|
724 |
+
|
725 |
+
if extract_features is not None:
|
726 |
+
extract_features = jnp.array(extract_features, dtype="f4")
|
727 |
+
|
728 |
+
# Handle any PRNG if needed
|
729 |
+
rngs = {}
|
730 |
+
if dropout_rng is not None:
|
731 |
+
rngs["dropout"] = dropout_rng
|
732 |
+
|
733 |
+
inputs = {"params": params or self.params}
|
734 |
+
|
735 |
+
return self.module.apply(
|
736 |
+
inputs,
|
737 |
+
jnp.array(input_values, dtype="f4"),
|
738 |
+
jnp.array(attention_mask, dtype="i4"),
|
739 |
+
mask_time_indices,
|
740 |
+
extract_features,
|
741 |
+
not train,
|
742 |
+
output_attentions,
|
743 |
+
output_hidden_states,
|
744 |
+
output_features,
|
745 |
+
freeze_feature_encoder,
|
746 |
+
return_dict,
|
747 |
+
rngs=rngs,
|
748 |
+
)
|
749 |
+
|
750 |
+
def _get_feat_extract_output_lengths(
|
751 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
752 |
+
):
|
753 |
+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
754 |
+
|
755 |
+
def _get_feature_vector_attention_mask(
|
756 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
757 |
+
):
|
758 |
+
return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
|
759 |
+
|
760 |
+
|
761 |
+
class FlaxWav2Vec2Module(nn.Module):
|
762 |
+
config: Wav2Vec2Config
|
763 |
+
dtype: jnp.dtype = jnp.float32
|
764 |
+
|
765 |
+
def setup(self):
|
766 |
+
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
|
767 |
+
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
|
768 |
+
self.masked_spec_embed = self.param(
|
769 |
+
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
|
770 |
+
)
|
771 |
+
|
772 |
+
if self.config.do_stable_layer_norm:
|
773 |
+
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
|
774 |
+
else:
|
775 |
+
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
|
776 |
+
|
777 |
+
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
|
778 |
+
|
779 |
+
def __call__(
|
780 |
+
self,
|
781 |
+
input_values,
|
782 |
+
attention_mask=None,
|
783 |
+
mask_time_indices=None,
|
784 |
+
extract_features=None,
|
785 |
+
deterministic=True,
|
786 |
+
output_attentions=None,
|
787 |
+
output_hidden_states=None,
|
788 |
+
output_features=False,
|
789 |
+
freeze_feature_encoder=False,
|
790 |
+
return_dict=None,
|
791 |
+
):
|
792 |
+
|
793 |
+
# forward pass through the feature extractor if features not specified
|
794 |
+
if extract_features is None:
|
795 |
+
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
|
796 |
+
|
797 |
+
if output_features:
|
798 |
+
return extract_features
|
799 |
+
|
800 |
+
# make sure that no loss is computed on padded inputs
|
801 |
+
if attention_mask is not None:
|
802 |
+
# compute reduced attention_mask corresponding to feature vectors
|
803 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
804 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
805 |
+
)
|
806 |
+
|
807 |
+
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
808 |
+
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
809 |
+
hidden_states = jnp.where(
|
810 |
+
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
|
811 |
+
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
|
812 |
+
hidden_states,
|
813 |
+
)
|
814 |
+
|
815 |
+
encoder_outputs = self.encoder(
|
816 |
+
hidden_states,
|
817 |
+
attention_mask=attention_mask,
|
818 |
+
deterministic=deterministic,
|
819 |
+
output_attentions=output_attentions,
|
820 |
+
output_hidden_states=output_hidden_states,
|
821 |
+
return_dict=return_dict,
|
822 |
+
)
|
823 |
+
|
824 |
+
hidden_states = encoder_outputs[0]
|
825 |
+
|
826 |
+
if self.adapter is not None:
|
827 |
+
hidden_states = self.adapter(hidden_states)
|
828 |
+
|
829 |
+
if not return_dict:
|
830 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
831 |
+
|
832 |
+
return FlaxWav2Vec2BaseModelOutput(
|
833 |
+
last_hidden_state=hidden_states,
|
834 |
+
extract_features=extract_features,
|
835 |
+
hidden_states=encoder_outputs.hidden_states,
|
836 |
+
attentions=encoder_outputs.attentions,
|
837 |
+
)
|
838 |
+
|
839 |
+
def _get_feat_extract_output_lengths(
|
840 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
841 |
+
):
|
842 |
+
"""
|
843 |
+
Computes the output length of the convolutional layers
|
844 |
+
"""
|
845 |
+
|
846 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
847 |
+
|
848 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
849 |
+
# 1D convolutional layer output length formula taken
|
850 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
851 |
+
return (input_length - kernel_size) // stride + 1
|
852 |
+
|
853 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
854 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
855 |
+
|
856 |
+
if add_adapter:
|
857 |
+
for _ in range(self.config.num_adapter_layers):
|
858 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
859 |
+
|
860 |
+
return input_lengths
|
861 |
+
|
862 |
+
def _get_feature_vector_attention_mask(
|
863 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
864 |
+
):
|
865 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
866 |
+
# on inference mode.
|
867 |
+
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
868 |
+
|
869 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
870 |
+
|
871 |
+
batch_size = attention_mask.shape[0]
|
872 |
+
|
873 |
+
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
874 |
+
# these two operations makes sure that all values
|
875 |
+
# before the output lengths indices are attended to
|
876 |
+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
877 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
878 |
+
return attention_mask
|
879 |
+
|
880 |
+
|
881 |
+
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
882 |
+
module_class = FlaxWav2Vec2Module
|
883 |
+
|
884 |
+
|
885 |
+
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
886 |
+
config: Wav2Vec2Config
|
887 |
+
dtype: jnp.dtype = jnp.float32
|
888 |
+
|
889 |
+
def setup(self):
|
890 |
+
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
|
891 |
+
self.dropout = nn.Dropout(rate=self.config.final_dropout)
|
892 |
+
self.lm_head = nn.Dense(
|
893 |
+
self.config.vocab_size,
|
894 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
895 |
+
dtype=self.dtype,
|
896 |
+
)
|
897 |
+
|
898 |
+
def __call__(
|
899 |
+
self,
|
900 |
+
input_values,
|
901 |
+
attention_mask=None,
|
902 |
+
mask_time_indices=None,
|
903 |
+
extract_features=None,
|
904 |
+
deterministic=True,
|
905 |
+
output_attentions=None,
|
906 |
+
output_hidden_states=None,
|
907 |
+
output_features=False,
|
908 |
+
freeze_feature_encoder=False,
|
909 |
+
return_dict=None,
|
910 |
+
):
|
911 |
+
outputs = self.wav2vec2(
|
912 |
+
input_values,
|
913 |
+
attention_mask=attention_mask,
|
914 |
+
mask_time_indices=mask_time_indices,
|
915 |
+
deterministic=deterministic,
|
916 |
+
output_attentions=output_attentions,
|
917 |
+
output_hidden_states=output_hidden_states,
|
918 |
+
freeze_feature_encoder=freeze_feature_encoder,
|
919 |
+
return_dict=return_dict,
|
920 |
+
)
|
921 |
+
|
922 |
+
hidden_states = outputs[0]
|
923 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
924 |
+
|
925 |
+
logits = self.lm_head(hidden_states)
|
926 |
+
|
927 |
+
if not return_dict:
|
928 |
+
return (logits,) + outputs[2:]
|
929 |
+
|
930 |
+
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
931 |
+
|
932 |
+
def _get_feat_extract_output_lengths(
|
933 |
+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
934 |
+
):
|
935 |
+
"""
|
936 |
+
Computes the output length of the convolutional layers
|
937 |
+
"""
|
938 |
+
|
939 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
940 |
+
|
941 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
942 |
+
# 1D convolutional layer output length formula taken
|
943 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
944 |
+
return (input_length - kernel_size) // stride + 1
|
945 |
+
|
946 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
947 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
948 |
+
|
949 |
+
if add_adapter:
|
950 |
+
for _ in range(self.config.num_adapter_layers):
|
951 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
952 |
+
|
953 |
+
return input_lengths
|
954 |
+
|
955 |
+
def _get_feature_vector_attention_mask(
|
956 |
+
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
957 |
+
):
|
958 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
959 |
+
# on inference mode.
|
960 |
+
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
961 |
+
|
962 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
963 |
+
|
964 |
+
batch_size = attention_mask.shape[0]
|
965 |
+
|
966 |
+
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
967 |
+
# these two operations makes sure that all values
|
968 |
+
# before the output lengths indices are attended to
|
969 |
+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
970 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
971 |
+
return attention_mask
|
972 |
+
|
973 |
+
|
974 |
+
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
|
975 |
+
module_class = FlaxWav2Vec2ForCTCModule
|
nohup.out
ADDED
The diff for this file is too large to render.
See raw diff
|
|
preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"return_attention_mask": true,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
run_flax_speech_recognition_seq2seq.py
ADDED
@@ -0,0 +1,1572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2022 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the Flax library models for sequence to sequence speech recognition.
|
18 |
+
"""
|
19 |
+
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import re
|
25 |
+
import sys
|
26 |
+
import time
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
30 |
+
|
31 |
+
import datasets
|
32 |
+
import numpy as np
|
33 |
+
from datasets import DatasetDict, load_dataset, load_metric
|
34 |
+
from tqdm import tqdm
|
35 |
+
|
36 |
+
import flax
|
37 |
+
import jax
|
38 |
+
import jax.numpy as jnp
|
39 |
+
import optax
|
40 |
+
import transformers
|
41 |
+
import wandb as wandb
|
42 |
+
from flax import core, jax_utils, struct, traverse_util
|
43 |
+
from flax.jax_utils import pad_shard_unpad, unreplicate
|
44 |
+
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
45 |
+
from huggingface_hub import Repository
|
46 |
+
from models import FlaxSpeechEncoderDecoderModel
|
47 |
+
from optax._src import linear_algebra
|
48 |
+
from transformers import (
|
49 |
+
AutoConfig,
|
50 |
+
AutoFeatureExtractor,
|
51 |
+
AutoProcessor,
|
52 |
+
AutoTokenizer,
|
53 |
+
HfArgumentParser,
|
54 |
+
Seq2SeqTrainingArguments,
|
55 |
+
is_tensorboard_available,
|
56 |
+
)
|
57 |
+
from transformers.file_utils import get_full_repo_name
|
58 |
+
from transformers.trainer_utils import get_last_checkpoint
|
59 |
+
from transformers.utils import check_min_version
|
60 |
+
from transformers.utils.versions import require_version
|
61 |
+
|
62 |
+
|
63 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
64 |
+
check_min_version("4.17.0.dev0")
|
65 |
+
|
66 |
+
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
67 |
+
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
|
70 |
+
|
71 |
+
@flax.struct.dataclass
|
72 |
+
class ModelArguments:
|
73 |
+
"""
|
74 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
75 |
+
"""
|
76 |
+
|
77 |
+
model_name_or_path: str = field(
|
78 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
79 |
+
)
|
80 |
+
config_name: Optional[str] = field(
|
81 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
82 |
+
)
|
83 |
+
tokenizer_name: Optional[str] = field(
|
84 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
85 |
+
)
|
86 |
+
feature_extractor_name: Optional[str] = field(
|
87 |
+
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
|
88 |
+
)
|
89 |
+
cache_dir: Optional[str] = field(
|
90 |
+
default=None,
|
91 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
92 |
+
)
|
93 |
+
use_fast_tokenizer: bool = field(
|
94 |
+
default=True,
|
95 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
96 |
+
)
|
97 |
+
model_revision: str = field(
|
98 |
+
default="main",
|
99 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
100 |
+
)
|
101 |
+
use_auth_token: bool = field(
|
102 |
+
default=False,
|
103 |
+
metadata={
|
104 |
+
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
105 |
+
"with private models)."
|
106 |
+
},
|
107 |
+
)
|
108 |
+
freeze_feature_encoder: bool = field(
|
109 |
+
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
110 |
+
)
|
111 |
+
activation_dropout: float = field(
|
112 |
+
default=0.1,
|
113 |
+
metadata={
|
114 |
+
"help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
hidden_dropout: float = field(
|
118 |
+
default=0.1,
|
119 |
+
metadata={
|
120 |
+
"help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
|
121 |
+
},
|
122 |
+
)
|
123 |
+
feat_proj_dropout: float = field(
|
124 |
+
default=0.0,
|
125 |
+
metadata={
|
126 |
+
"help": "The feat proj dropout probability for feature encoder representations."
|
127 |
+
},
|
128 |
+
)
|
129 |
+
mask_time_prob: float = field(
|
130 |
+
default=0.1,
|
131 |
+
metadata={
|
132 |
+
"help": "The spec aug dropout probability for feature encoder representations."
|
133 |
+
},
|
134 |
+
)
|
135 |
+
encoder_add_adapter: bool = field(
|
136 |
+
default=True, metadata={"help": "Whether to add an adapter layer between the encoder and decoder."}
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
@flax.struct.dataclass
|
141 |
+
class DataTrainingArguments:
|
142 |
+
"""
|
143 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
144 |
+
"""
|
145 |
+
|
146 |
+
dataset_name: str = field(
|
147 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
148 |
+
)
|
149 |
+
dataset_config_name: Optional[str] = field(
|
150 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
151 |
+
)
|
152 |
+
text_column: Optional[str] = field(
|
153 |
+
default=None,
|
154 |
+
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
155 |
+
)
|
156 |
+
dataset_cache_dir: Optional[str] = field(
|
157 |
+
default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
|
158 |
+
)
|
159 |
+
overwrite_cache: bool = field(
|
160 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
161 |
+
)
|
162 |
+
preprocessing_num_workers: Optional[int] = field(
|
163 |
+
default=None,
|
164 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
165 |
+
)
|
166 |
+
max_train_samples: Optional[int] = field(
|
167 |
+
default=None,
|
168 |
+
metadata={
|
169 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
170 |
+
"value if set."
|
171 |
+
},
|
172 |
+
)
|
173 |
+
max_eval_samples: Optional[int] = field(
|
174 |
+
default=None,
|
175 |
+
metadata={
|
176 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
177 |
+
"value if set."
|
178 |
+
},
|
179 |
+
)
|
180 |
+
max_test_samples: Optional[int] = field(
|
181 |
+
default=None,
|
182 |
+
metadata={
|
183 |
+
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
184 |
+
"value if set."
|
185 |
+
},
|
186 |
+
)
|
187 |
+
audio_column_name: str = field(
|
188 |
+
default="audio",
|
189 |
+
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
190 |
+
)
|
191 |
+
text_column_name: str = field(
|
192 |
+
default="text",
|
193 |
+
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
194 |
+
)
|
195 |
+
id_column_name: str = field(
|
196 |
+
default="id",
|
197 |
+
metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
|
198 |
+
)
|
199 |
+
max_duration_in_seconds: float = field(
|
200 |
+
default=20.0,
|
201 |
+
metadata={
|
202 |
+
"help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
|
203 |
+
},
|
204 |
+
)
|
205 |
+
min_duration_in_seconds: float = field(
|
206 |
+
default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
|
207 |
+
)
|
208 |
+
max_target_length: Optional[int] = field(
|
209 |
+
default=128,
|
210 |
+
metadata={
|
211 |
+
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
212 |
+
"than this will be truncated, sequences shorter will be padded."
|
213 |
+
},
|
214 |
+
)
|
215 |
+
min_target_length: Optional[int] = field(
|
216 |
+
default=0,
|
217 |
+
metadata={
|
218 |
+
"help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
|
219 |
+
"than this will be filtered."
|
220 |
+
},
|
221 |
+
)
|
222 |
+
pad_input_to_multiple_of: Optional[int] = field(
|
223 |
+
default=24000,
|
224 |
+
metadata={
|
225 |
+
"help": "If set will pad the input sequence to a multiple of the provided value. "
|
226 |
+
"This is important to avoid triggering recompilations on TPU."
|
227 |
+
},
|
228 |
+
)
|
229 |
+
pad_target_to_multiple_of: Optional[int] = field(
|
230 |
+
default=None,
|
231 |
+
metadata={
|
232 |
+
"help": "If set will pad the target sequence to a multiple of the provided value. "
|
233 |
+
"This is important to avoid triggering recompilations on TPU. If unspecified, will default to `max_target_length`, "
|
234 |
+
" the equivalent of padding the targets to max length."
|
235 |
+
},
|
236 |
+
)
|
237 |
+
preprocessing_only: bool = field(
|
238 |
+
default=False,
|
239 |
+
metadata={
|
240 |
+
"help": "Whether to only do data preprocessing and skip training. "
|
241 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
242 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
243 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
244 |
+
},
|
245 |
+
)
|
246 |
+
train_split_name: str = field(
|
247 |
+
default="train",
|
248 |
+
metadata={
|
249 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
250 |
+
},
|
251 |
+
)
|
252 |
+
eval_split_name: str = field(
|
253 |
+
default="validation",
|
254 |
+
metadata={
|
255 |
+
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
|
256 |
+
},
|
257 |
+
)
|
258 |
+
test_split_name: str = field(
|
259 |
+
default="test",
|
260 |
+
metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
|
261 |
+
)
|
262 |
+
do_lower_case: bool = field(
|
263 |
+
default=True,
|
264 |
+
metadata={"help": "Whether the target text should be lower cased."},
|
265 |
+
)
|
266 |
+
wandb_project: str = field(
|
267 |
+
default="flax-speech-recognition-seq2seq",
|
268 |
+
metadata={"help": "The name of the wandb project."},
|
269 |
+
)
|
270 |
+
wandb_name: str = field(
|
271 |
+
default=None,
|
272 |
+
metadata={"help": "The name of the wandb run."},
|
273 |
+
)
|
274 |
+
wandb_job_type: str = field(
|
275 |
+
default="Seq2Seq",
|
276 |
+
metadata={"help": "The name of the wandb job type."},
|
277 |
+
)
|
278 |
+
log_first_ids: bool = field(
|
279 |
+
default=True,
|
280 |
+
metadata={
|
281 |
+
"help": "Whether to log the first id's from the dataset. Defaults to `True`. If `False`, will log the first id's returned by the grouped length sampler."
|
282 |
+
},
|
283 |
+
)
|
284 |
+
|
285 |
+
|
286 |
+
# @flax.struct.dataclass
|
287 |
+
@dataclass
|
288 |
+
class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
|
289 |
+
precision: str = field(
|
290 |
+
default="full",
|
291 |
+
metadata={
|
292 |
+
"help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
|
293 |
+
"**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
|
294 |
+
},
|
295 |
+
)
|
296 |
+
matmul_precision: str = field(
|
297 |
+
default="default",
|
298 |
+
metadata={
|
299 |
+
"help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
|
300 |
+
"This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
|
301 |
+
"This configuration option does not change the behaviours of such calls with explicit precision arguments; "
|
302 |
+
"it only changes the behaviors of calls with no such argument provided. "
|
303 |
+
"One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
|
304 |
+
},
|
305 |
+
)
|
306 |
+
generation_length_penalty: float = field(
|
307 |
+
default=1,
|
308 |
+
metadata={
|
309 |
+
"help": "Exponential penalty to the length. 1.0 (default) means no penalty. Set to values < 1.0 in order to encourage the model"
|
310 |
+
"to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences."
|
311 |
+
},
|
312 |
+
)
|
313 |
+
final_generation_max_length: int = field(
|
314 |
+
default=None,
|
315 |
+
metadata={
|
316 |
+
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
|
317 |
+
"to the `max_length` value of the model configuration."
|
318 |
+
},
|
319 |
+
)
|
320 |
+
final_generation_num_beams: int = field(
|
321 |
+
default=None,
|
322 |
+
metadata={
|
323 |
+
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
|
324 |
+
"to the `num_beams` value of the model configuration."
|
325 |
+
},
|
326 |
+
)
|
327 |
+
|
328 |
+
def __post_init__(self):
|
329 |
+
if self.final_generation_max_length is None:
|
330 |
+
self.final_generation_max_length = self.generation_max_length
|
331 |
+
if self.final_generation_num_beams is None:
|
332 |
+
self.final_generation_num_beams = self.generation_num_beams
|
333 |
+
|
334 |
+
|
335 |
+
def to_fp32(t):
|
336 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
337 |
+
|
338 |
+
|
339 |
+
def to_bf16(t):
|
340 |
+
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
|
341 |
+
|
342 |
+
|
343 |
+
class MixedPrecisionTrainState(struct.PyTreeNode):
|
344 |
+
"""Train state for use with a single Optax optimizer.
|
345 |
+
Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
|
346 |
+
|
347 |
+
Synopsis::
|
348 |
+
|
349 |
+
state = TrainState.create(
|
350 |
+
apply_fn=model.apply,
|
351 |
+
params=variables['params'],
|
352 |
+
tx=tx)
|
353 |
+
grad_fn = jax.grad(make_loss_fn(state.apply_fn))
|
354 |
+
for batch in data:
|
355 |
+
grads = grad_fn(state.params, batch)
|
356 |
+
state = state.apply_gradients(grads=grads)
|
357 |
+
|
358 |
+
Args:
|
359 |
+
step: Counter starts at 0 and is incremented by every call to
|
360 |
+
`.apply_gradients()`.
|
361 |
+
apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
|
362 |
+
convenience to have a shorter params list for the `train_step()` function
|
363 |
+
in your training loop.
|
364 |
+
params: The parameters to be updated by `tx` and used by `apply_fn`.
|
365 |
+
tx: An Optax gradient transformation.
|
366 |
+
opt_state: The state for `tx`.
|
367 |
+
dropout_rng: PRNG key for stochastic operations.
|
368 |
+
bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
|
369 |
+
"""
|
370 |
+
|
371 |
+
step: int
|
372 |
+
apply_fn: Callable = struct.field(pytree_node=False)
|
373 |
+
params: core.FrozenDict[str, Any]
|
374 |
+
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
375 |
+
opt_state: optax.OptState
|
376 |
+
dropout_rng: jnp.ndarray
|
377 |
+
max_grad_norm: Optional[float] = 1.0
|
378 |
+
|
379 |
+
def apply_gradients(self, *, grads, to_dtype, **kwargs):
|
380 |
+
"""Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
|
381 |
+
|
382 |
+
Note that internally this function calls `.tx.update()` followed by a call
|
383 |
+
to `optax.apply_updates()` to update `params` and `opt_state`.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
grads: Gradients that have the same pytree structure as `.params`.
|
387 |
+
**kwargs: Additional dataclass attributes that should be `.replace()`-ed.
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
An updated instance of `self` with `step` incremented by one, `params`
|
391 |
+
and `opt_state` updated by applying `grads`, and additional attributes
|
392 |
+
replaced as specified by `kwargs`.
|
393 |
+
"""
|
394 |
+
|
395 |
+
# clip gradients by global l2 norm
|
396 |
+
casted_max_grad_norm = to_dtype(self.max_grad_norm)
|
397 |
+
g_norm = linear_algebra.global_norm(grads)
|
398 |
+
g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
|
399 |
+
grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
|
400 |
+
|
401 |
+
# perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
|
402 |
+
# grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
|
403 |
+
updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
|
404 |
+
|
405 |
+
new_params = optax.apply_updates(self.params, updates)
|
406 |
+
return self.replace(
|
407 |
+
step=self.step + 1,
|
408 |
+
params=new_params,
|
409 |
+
opt_state=to_dtype(new_opt_state),
|
410 |
+
**kwargs,
|
411 |
+
)
|
412 |
+
|
413 |
+
@classmethod
|
414 |
+
def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
|
415 |
+
"""Creates a new instance with `step=0` and initialized `opt_state`."""
|
416 |
+
# downcast optimizer state to bf16 if mixed-precision training
|
417 |
+
opt_state = tx.init(to_dtype(params)) if tx is not None else None
|
418 |
+
return cls(
|
419 |
+
step=0,
|
420 |
+
apply_fn=apply_fn,
|
421 |
+
params=params,
|
422 |
+
tx=tx,
|
423 |
+
opt_state=opt_state,
|
424 |
+
**kwargs,
|
425 |
+
)
|
426 |
+
|
427 |
+
def replicate(self):
|
428 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
429 |
+
|
430 |
+
|
431 |
+
def pad_to_max_length(data, tokenizer):
|
432 |
+
# Get lengths of each row of data
|
433 |
+
lens = np.array([len(i) for i in data])
|
434 |
+
|
435 |
+
# Mask of valid places in each row
|
436 |
+
mask = np.arange(lens.max()) < lens[:, None]
|
437 |
+
|
438 |
+
# Setup output array and put elements from data into masked positions
|
439 |
+
out = np.ones_like(mask, dtype=data.dtype) * tokenizer.pad_token_id
|
440 |
+
out[mask] = np.concatenate(data)
|
441 |
+
return out
|
442 |
+
|
443 |
+
|
444 |
+
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
|
445 |
+
"""
|
446 |
+
Shift label ids one token to the right.
|
447 |
+
"""
|
448 |
+
shifted_label_ids = np.zeros_like(label_ids)
|
449 |
+
shifted_label_ids[:, 1:] = label_ids[:, :-1]
|
450 |
+
shifted_label_ids[:, 0] = decoder_start_token_id
|
451 |
+
|
452 |
+
return shifted_label_ids
|
453 |
+
|
454 |
+
|
455 |
+
@flax.struct.dataclass
|
456 |
+
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
|
457 |
+
"""
|
458 |
+
Data collator that will dynamically pad the inputs received.
|
459 |
+
Args:
|
460 |
+
processor ([`Wav2Vec2Processor`])
|
461 |
+
The processor used for proccessing the data.
|
462 |
+
decoder_start_token_id (:obj: `int`)
|
463 |
+
The begin-of-sentence of the decoder.
|
464 |
+
input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
465 |
+
Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
|
466 |
+
among:
|
467 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
468 |
+
sequence if provided).
|
469 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
470 |
+
maximum acceptable input length for the model if that argument is not provided.
|
471 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
472 |
+
different lengths).
|
473 |
+
target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
474 |
+
Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
|
475 |
+
See above for details.
|
476 |
+
max_input_length (:obj:`float`, `optional`):
|
477 |
+
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
478 |
+
max_target_length (:obj:`int`, `optional`):
|
479 |
+
Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
|
480 |
+
pad_input_to_multiple_of (:obj:`int`, `optional`):
|
481 |
+
If set will pad the input sequence to a multiple of the provided value.
|
482 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
483 |
+
7.5 (Volta).
|
484 |
+
pad_target_to_multiple_of (:obj:`int`, `optional`):
|
485 |
+
If set will pad the target sequence to a multiple of the provided value.
|
486 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
487 |
+
7.5 (Volta).
|
488 |
+
"""
|
489 |
+
|
490 |
+
processor: Any
|
491 |
+
decoder_start_token_id: int
|
492 |
+
input_padding: Union[bool, str] = "longest"
|
493 |
+
target_padding: Union[bool, str] = "max_length"
|
494 |
+
max_input_length: Optional[float] = None
|
495 |
+
max_target_length: Optional[int] = None
|
496 |
+
pad_input_to_multiple_of: Optional[int] = None
|
497 |
+
pad_target_to_multiple_of: Optional[int] = None
|
498 |
+
|
499 |
+
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
|
500 |
+
# split inputs and labels since they have to be of different lengths and need
|
501 |
+
# different padding methods
|
502 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
503 |
+
input_ids = [feature["input_id"] for feature in features]
|
504 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
505 |
+
|
506 |
+
# reformat list to dict and set to pytorch format
|
507 |
+
batch = self.processor.feature_extractor.pad(
|
508 |
+
input_features,
|
509 |
+
max_length=self.max_input_length,
|
510 |
+
padding=self.input_padding,
|
511 |
+
pad_to_multiple_of=self.pad_input_to_multiple_of,
|
512 |
+
return_tensors="np",
|
513 |
+
)
|
514 |
+
|
515 |
+
labels_batch = self.processor.tokenizer.pad(
|
516 |
+
label_features,
|
517 |
+
max_length=self.max_target_length,
|
518 |
+
padding=self.target_padding,
|
519 |
+
pad_to_multiple_of=self.pad_target_to_multiple_of,
|
520 |
+
return_tensors="np",
|
521 |
+
)
|
522 |
+
|
523 |
+
# if bos token is appended in previous tokenization step,
|
524 |
+
# cut bos token here as it's append later anyways
|
525 |
+
labels = labels_batch["input_ids"]
|
526 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().item():
|
527 |
+
labels = labels[:, 1:]
|
528 |
+
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
|
529 |
+
|
530 |
+
decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
|
531 |
+
|
532 |
+
# replace padding with -100 to ignore correctly when computing the loss
|
533 |
+
labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
|
534 |
+
labels = labels.filled(fill_value=-100)
|
535 |
+
|
536 |
+
batch["inputs"] = batch.pop("input_values")
|
537 |
+
batch["input_ids"] = input_ids
|
538 |
+
batch["labels"] = labels
|
539 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
540 |
+
|
541 |
+
return batch
|
542 |
+
|
543 |
+
|
544 |
+
def get_grouped_indices(
|
545 |
+
dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
|
546 |
+
) -> np.array:
|
547 |
+
"""
|
548 |
+
Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
|
549 |
+
Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
|
550 |
+
lengths. To do this, the indices are:
|
551 |
+
|
552 |
+
- randomly permuted (if a JAX rng is specified)
|
553 |
+
- grouped in mega-batches of size `mega_batch_mult * batch_size`
|
554 |
+
- sorted by length in each mega-batch
|
555 |
+
|
556 |
+
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
|
557 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
558 |
+
"""
|
559 |
+
lengths = dataset["input_length"]
|
560 |
+
|
561 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
562 |
+
if mega_batch_mult is None:
|
563 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
564 |
+
# Just in case, for tiny datasets
|
565 |
+
if mega_batch_mult == 0:
|
566 |
+
mega_batch_mult = 1
|
567 |
+
|
568 |
+
# We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
|
569 |
+
num_samples = len(lengths)
|
570 |
+
indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
|
571 |
+
indices = np.asarray(indices)
|
572 |
+
|
573 |
+
megabatch_size = mega_batch_mult * batch_size
|
574 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
575 |
+
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
|
576 |
+
|
577 |
+
# The rest is to get the biggest batch first.
|
578 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
579 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
580 |
+
max_idx = np.argmax(megabatch_maximums).item()
|
581 |
+
# Switch to put the longest batch in first position
|
582 |
+
# (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
|
583 |
+
megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
|
584 |
+
|
585 |
+
megabatches = np.array([i for megabatch in megabatches for i in megabatch])
|
586 |
+
|
587 |
+
return megabatches
|
588 |
+
|
589 |
+
|
590 |
+
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last_batch=True) -> np.ndarray:
|
591 |
+
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
592 |
+
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
593 |
+
num_samples = len(samples_idx)
|
594 |
+
if drop_last_batch:
|
595 |
+
samples_to_remove = num_samples % batch_size
|
596 |
+
if samples_to_remove != 0:
|
597 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
598 |
+
sections_split = num_samples // batch_size
|
599 |
+
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
600 |
+
else:
|
601 |
+
sections_split = math.ceil(num_samples / batch_size)
|
602 |
+
samples_idx = np.array_split(samples_idx, sections_split)
|
603 |
+
return samples_idx
|
604 |
+
|
605 |
+
|
606 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
607 |
+
summary_writer.scalar("train_time", train_time, step)
|
608 |
+
|
609 |
+
train_metrics = get_metrics(train_metrics)
|
610 |
+
for key, vals in train_metrics.items():
|
611 |
+
tag = f"train_{key}"
|
612 |
+
for i, val in enumerate(vals):
|
613 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
614 |
+
|
615 |
+
|
616 |
+
def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
|
617 |
+
for metric_name, value in eval_metrics.items():
|
618 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
619 |
+
|
620 |
+
if pred_str is not None:
|
621 |
+
# write output actual predictions for debugging
|
622 |
+
summary_writer.text("eval_predictions", "\n".join(pred_str), step)
|
623 |
+
|
624 |
+
|
625 |
+
def write_wandb_log(metrics, step, prefix=None):
|
626 |
+
if jax.process_index() == 0:
|
627 |
+
log_metrics = {}
|
628 |
+
for k, v in metrics.items():
|
629 |
+
if "layer" in k:
|
630 |
+
log_metrics[f"{k}/"] = v
|
631 |
+
elif prefix is not None:
|
632 |
+
log_metrics[f"{prefix}/{k}"] = v
|
633 |
+
else:
|
634 |
+
log_metrics[k] = v
|
635 |
+
wandb.log(log_metrics, step)
|
636 |
+
|
637 |
+
|
638 |
+
def write_wandb_pred(pred_str, label_str, eval_ids, step, prefix="eval", top_ids=None, final_step=True):
|
639 |
+
if jax.process_index() == 0:
|
640 |
+
top_ids = top_ids if top_ids else eval_ids
|
641 |
+
num_beams = len(pred_str)
|
642 |
+
# convert str data to a wandb compatible format
|
643 |
+
str_data = []
|
644 |
+
for id in top_ids:
|
645 |
+
if id in eval_ids:
|
646 |
+
idx = eval_ids.index(id)
|
647 |
+
str_data.append([eval_ids[idx], label_str[idx]] + [pred_str[beam][idx] for beam in range(num_beams)])
|
648 |
+
columns = ["id", "label_str"] + [f"beam_{i + 1}" for i in range(num_beams)]
|
649 |
+
wandb.log(
|
650 |
+
{f"{prefix}/step_{int(step / 1000)}k": wandb.Table(columns=columns, data=str_data[:50])},
|
651 |
+
step,
|
652 |
+
)
|
653 |
+
if final_step:
|
654 |
+
str_data = np.array(str_data)
|
655 |
+
wandb.log(
|
656 |
+
{f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(columns=columns, data=str_data[:200000])},
|
657 |
+
step,
|
658 |
+
)
|
659 |
+
str_data = str_data[str_data[:, 1] != str_data[:, 2]]
|
660 |
+
wandb.log(
|
661 |
+
{f"{prefix}/step_{int(step / 1000)}k_incorrect": wandb.Table(columns=columns, data=str_data[:200000])},
|
662 |
+
step,
|
663 |
+
)
|
664 |
+
|
665 |
+
|
666 |
+
def create_learning_rate_fn(
|
667 |
+
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
668 |
+
) -> Callable[[int], jnp.array]:
|
669 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
670 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
671 |
+
decay_fn = optax.linear_schedule(
|
672 |
+
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
673 |
+
)
|
674 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
675 |
+
return schedule_fn
|
676 |
+
|
677 |
+
|
678 |
+
def main():
|
679 |
+
# 1. Parse input arguments
|
680 |
+
# See all possible arguments in src/transformers/training_args.py
|
681 |
+
# or by passing the --help flag to this script.
|
682 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
683 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
|
684 |
+
|
685 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
686 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
687 |
+
# let's parse it to get our arguments.
|
688 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
689 |
+
else:
|
690 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
691 |
+
|
692 |
+
# 2. Setup logging
|
693 |
+
# Make one log on every process with the configuration for debugging.
|
694 |
+
logging.basicConfig(
|
695 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
696 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
697 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
698 |
+
)
|
699 |
+
# Set the verbosity to info of the Transformers logger.
|
700 |
+
# We only want one process per machine to log things on the screen.
|
701 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
702 |
+
if jax.process_index() == 0:
|
703 |
+
datasets.utils.logging.set_verbosity_warning()
|
704 |
+
transformers.utils.logging.set_verbosity_info()
|
705 |
+
else:
|
706 |
+
datasets.utils.logging.set_verbosity_error()
|
707 |
+
transformers.utils.logging.set_verbosity_error()
|
708 |
+
|
709 |
+
# Set up wandb run
|
710 |
+
if jax.process_index() == 0:
|
711 |
+
wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
|
712 |
+
|
713 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
714 |
+
|
715 |
+
# Set the default TPU matmul precision and display the number of devices
|
716 |
+
jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
|
717 |
+
logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
|
718 |
+
|
719 |
+
# TODO: 3. Detecting last checkpoint and eventually continue from last checkpoint
|
720 |
+
last_checkpoint = None
|
721 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
722 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
723 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
724 |
+
raise ValueError(
|
725 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
726 |
+
"Use --overwrite_output_dir to overcome."
|
727 |
+
)
|
728 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
729 |
+
logger.info(
|
730 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
731 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
732 |
+
)
|
733 |
+
|
734 |
+
# 4. Load dataset
|
735 |
+
raw_datasets = DatasetDict()
|
736 |
+
|
737 |
+
if training_args.do_train:
|
738 |
+
raw_datasets["train"] = load_dataset(
|
739 |
+
data_args.dataset_name,
|
740 |
+
data_args.dataset_config_name,
|
741 |
+
split=data_args.train_split_name,
|
742 |
+
cache_dir=data_args.dataset_cache_dir,
|
743 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
744 |
+
)
|
745 |
+
|
746 |
+
if training_args.do_eval:
|
747 |
+
raw_datasets["eval"] = load_dataset(
|
748 |
+
data_args.dataset_name,
|
749 |
+
data_args.dataset_config_name,
|
750 |
+
split=data_args.eval_split_name,
|
751 |
+
cache_dir=data_args.dataset_cache_dir,
|
752 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
753 |
+
)
|
754 |
+
|
755 |
+
if training_args.do_predict:
|
756 |
+
test_split = data_args.test_split_name.split("+")
|
757 |
+
for split in test_split:
|
758 |
+
raw_datasets[split] = load_dataset(
|
759 |
+
data_args.dataset_name,
|
760 |
+
data_args.dataset_config_name,
|
761 |
+
split=split,
|
762 |
+
cache_dir=data_args.dataset_cache_dir,
|
763 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
764 |
+
)
|
765 |
+
|
766 |
+
if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
|
767 |
+
raise ValueError(
|
768 |
+
"Cannot not train, not do evaluation and not do prediction. At least one of "
|
769 |
+
"training, evaluation or prediction has to be done."
|
770 |
+
)
|
771 |
+
|
772 |
+
# if not training, there is no need to run multiple epochs
|
773 |
+
if not training_args.do_train:
|
774 |
+
training_args.num_train_epochs = 1
|
775 |
+
|
776 |
+
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
777 |
+
raise ValueError(
|
778 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
779 |
+
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
780 |
+
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
781 |
+
)
|
782 |
+
|
783 |
+
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
|
784 |
+
raise ValueError(
|
785 |
+
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
786 |
+
"Make sure to set `--text_column_name` to the correct text column - one of "
|
787 |
+
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
788 |
+
)
|
789 |
+
|
790 |
+
if data_args.log_first_ids and data_args.id_column_name not in next(iter(raw_datasets.values())).column_names:
|
791 |
+
raise ValueError(
|
792 |
+
f"--id_column_name {data_args.id_column_name} not found in dataset '{data_args.dataset_name}'. "
|
793 |
+
"Make sure to set `--id_column_name` to the correct id column - one of "
|
794 |
+
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
795 |
+
)
|
796 |
+
|
797 |
+
# 5. Load pretrained model, tokenizer, and feature extractor
|
798 |
+
#
|
799 |
+
# Distributed training:
|
800 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
801 |
+
config = AutoConfig.from_pretrained(
|
802 |
+
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
803 |
+
cache_dir=model_args.cache_dir,
|
804 |
+
revision=model_args.model_revision,
|
805 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
806 |
+
)
|
807 |
+
|
808 |
+
# update config according to training and model args
|
809 |
+
config.encoder.update(
|
810 |
+
{
|
811 |
+
"gradient_checkpointing": training_args.gradient_checkpointing,
|
812 |
+
"hidden_dropout": model_args.hidden_dropout,
|
813 |
+
"activation_dropout": model_args.activation_dropout,
|
814 |
+
"feat_proj_dropout": model_args.feat_proj_dropout,
|
815 |
+
"mask_time_prob": model_args.mask_time_prob,
|
816 |
+
"add_adapter": model_args.encoder_add_adapter,
|
817 |
+
}
|
818 |
+
)
|
819 |
+
config.decoder.update(
|
820 |
+
{
|
821 |
+
"gradient_checkpointing": training_args.gradient_checkpointing,
|
822 |
+
"dropout": model_args.hidden_dropout,
|
823 |
+
"activation_dropout": model_args.activation_dropout,
|
824 |
+
}
|
825 |
+
)
|
826 |
+
|
827 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
828 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
829 |
+
cache_dir=model_args.cache_dir,
|
830 |
+
revision=model_args.model_revision,
|
831 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
832 |
+
)
|
833 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
834 |
+
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
835 |
+
cache_dir=model_args.cache_dir,
|
836 |
+
use_fast=model_args.use_fast_tokenizer,
|
837 |
+
revision=model_args.model_revision,
|
838 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
839 |
+
)
|
840 |
+
|
841 |
+
if training_args.precision == "full_mixed":
|
842 |
+
dtype = jnp.bfloat16
|
843 |
+
training_args.mixed_precision = True
|
844 |
+
elif training_args.precision == "half_mixed":
|
845 |
+
dtype = jnp.bfloat16
|
846 |
+
training_args.mixed_precision = False
|
847 |
+
else:
|
848 |
+
dtype = jnp.float32
|
849 |
+
training_args.mixed_precision = False
|
850 |
+
|
851 |
+
model = FlaxSpeechEncoderDecoderModel.from_pretrained(
|
852 |
+
model_args.model_name_or_path,
|
853 |
+
config=config,
|
854 |
+
dtype=dtype,
|
855 |
+
cache_dir=model_args.cache_dir,
|
856 |
+
revision=model_args.model_revision,
|
857 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
858 |
+
)
|
859 |
+
|
860 |
+
if model.config.decoder_start_token_id is None:
|
861 |
+
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
862 |
+
|
863 |
+
# 6. Resample speech dataset ALWAYS
|
864 |
+
raw_datasets = raw_datasets.cast_column(
|
865 |
+
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
866 |
+
)
|
867 |
+
|
868 |
+
# 7. Preprocessing the datasets.
|
869 |
+
# We need to read the audio files as arrays and tokenize the targets.
|
870 |
+
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
871 |
+
min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
872 |
+
max_target_length = data_args.max_target_length
|
873 |
+
min_target_length = data_args.min_target_length
|
874 |
+
pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
|
875 |
+
pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
|
876 |
+
audio_column_name = data_args.audio_column_name
|
877 |
+
num_workers = data_args.preprocessing_num_workers
|
878 |
+
text_column_name = data_args.text_column_name
|
879 |
+
id_column_name = data_args.id_column_name
|
880 |
+
model_input_name = feature_extractor.model_input_names[0]
|
881 |
+
do_lower_case = data_args.do_lower_case
|
882 |
+
log_first_ids = data_args.log_first_ids
|
883 |
+
dataset_name = data_args.dataset_name
|
884 |
+
tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
|
885 |
+
gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
|
886 |
+
gigaspeech_disfluencies = ["<other>", "<sil>"]
|
887 |
+
swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
|
888 |
+
"[vocalized-noise]", "_1"]
|
889 |
+
swb_punctuations = ["{", "}", "[", "]-", "]"]
|
890 |
+
earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
|
891 |
+
ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
|
892 |
+
"[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
|
893 |
+
|
894 |
+
if training_args.do_train and data_args.max_train_samples is not None:
|
895 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
896 |
+
|
897 |
+
if training_args.do_eval and data_args.max_eval_samples is not None:
|
898 |
+
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
899 |
+
|
900 |
+
if training_args.do_predict and data_args.max_test_samples is not None:
|
901 |
+
for split in test_split:
|
902 |
+
raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
|
903 |
+
|
904 |
+
# filter data where the targets are ignored in scoring
|
905 |
+
def is_target_labels(input_str):
|
906 |
+
return input_str.lower() not in ignore_segments
|
907 |
+
|
908 |
+
raw_datasets = raw_datasets.filter(
|
909 |
+
is_target_labels,
|
910 |
+
num_proc=num_workers,
|
911 |
+
input_columns=[text_column_name],
|
912 |
+
desc="filtering data where the targets are ignored in scoring",
|
913 |
+
)
|
914 |
+
|
915 |
+
def prepare_dataset(batch):
|
916 |
+
# Pre-process audio
|
917 |
+
try:
|
918 |
+
sample = batch[audio_column_name]
|
919 |
+
except ValueError:
|
920 |
+
# E22: some samples are empty (no audio). Reading the empty audio array will trigger
|
921 |
+
# a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
|
922 |
+
# They will be filtered in the subsequent filtering stage and so are
|
923 |
+
# explicitly ignored during training.
|
924 |
+
sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
|
925 |
+
|
926 |
+
# normalise audio (mean, std) to (0, 1)
|
927 |
+
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
928 |
+
# process audio length
|
929 |
+
batch[model_input_name] = inputs.input_values[0]
|
930 |
+
batch["input_length"] = len(batch["input_values"])
|
931 |
+
batch["input_id"] = batch[id_column_name] if log_first_ids else None
|
932 |
+
|
933 |
+
# 'Error correction' of targets
|
934 |
+
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
935 |
+
|
936 |
+
# LibriSpeech ASR
|
937 |
+
if dataset_name == "librispeech_asr":
|
938 |
+
pass # no error correction necessary
|
939 |
+
|
940 |
+
# VoxPopuli
|
941 |
+
if dataset_name == "google/xtreme_s":
|
942 |
+
pass # no error correction necessary
|
943 |
+
|
944 |
+
# Common Voice 9
|
945 |
+
if dataset_name == "mozilla-foundation/common_voice_9_0":
|
946 |
+
if input_str.startswith('"') and input_str.endswith('"'):
|
947 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
948 |
+
input_str = input_str[1:-1]
|
949 |
+
# replace double quotation marks with single
|
950 |
+
input_str = input_str.replace('""', '"')
|
951 |
+
|
952 |
+
# TED-LIUM (Release 3)
|
953 |
+
if dataset_name == "LIUM/tedlium":
|
954 |
+
# delete the <unk> token from the text
|
955 |
+
input_str = input_str.replace("<unk>", "")
|
956 |
+
# replace spaced apostrophes with un-spaced (it 's -> it's)
|
957 |
+
for contraction in tedlium_contractions:
|
958 |
+
input_str = input_str.replace(contraction, contraction[1:])
|
959 |
+
|
960 |
+
# GigaSpeech
|
961 |
+
if dataset_name == "speechcolab/gigaspeech":
|
962 |
+
for disfluency in gigaspeech_disfluencies:
|
963 |
+
input_str = input_str.replace(disfluency, "")
|
964 |
+
# convert spelled out punctuation to symbolic form
|
965 |
+
for punctuation, replacement in gigaspeech_punctuation.items():
|
966 |
+
input_str = input_str.replace(punctuation, replacement)
|
967 |
+
|
968 |
+
# SWB: hide the path to the private HF dataset
|
969 |
+
if "switchboard" in dataset_name:
|
970 |
+
for disfluency in swb_disfluencies:
|
971 |
+
input_str = input_str.replace(disfluency, "")
|
972 |
+
# remove parenthesised text (test data only)
|
973 |
+
input_str = re.sub("[\(].*?[\)]", "", input_str)
|
974 |
+
for punctuation in swb_punctuations:
|
975 |
+
input_str = input_str.replace(punctuation, "")
|
976 |
+
# replace anomalous words with their correct transcriptions
|
977 |
+
split_str = input_str.split("/")
|
978 |
+
if len(split_str) > 1:
|
979 |
+
input_str = " ".join(
|
980 |
+
[" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
|
981 |
+
|
982 |
+
# Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
|
983 |
+
if "earnings22" in dataset_name:
|
984 |
+
for disfluency in earnings_disfluencies:
|
985 |
+
input_str = input_str.replace(disfluency, "")
|
986 |
+
|
987 |
+
# SPGISpeech
|
988 |
+
if dataset_name == "kensho/spgispeech":
|
989 |
+
pass # no error correction necessary
|
990 |
+
|
991 |
+
# JIWER compliance (for WER/CER calc.)
|
992 |
+
# remove multiple spaces
|
993 |
+
input_str = re.sub(r"\s\s+", " ", input_str)
|
994 |
+
# strip trailing spaces
|
995 |
+
input_str = input_str.strip()
|
996 |
+
|
997 |
+
# Finally, we tokenize the processed text
|
998 |
+
batch["labels"] = tokenizer(input_str).input_ids
|
999 |
+
batch["labels_length"] = len(batch["labels"])
|
1000 |
+
return batch
|
1001 |
+
|
1002 |
+
vectorized_datasets = raw_datasets.map(
|
1003 |
+
prepare_dataset,
|
1004 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
1005 |
+
num_proc=num_workers,
|
1006 |
+
desc="preprocess train dataset",
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
# filter training data with inputs longer than max_input_length
|
1010 |
+
def is_audio_in_length_range(length):
|
1011 |
+
return length > min_input_length and length < max_input_length
|
1012 |
+
|
1013 |
+
if training_args.do_train:
|
1014 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
|
1015 |
+
is_audio_in_length_range,
|
1016 |
+
num_proc=num_workers,
|
1017 |
+
input_columns=["input_length"],
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
# filter data with targets shorter than min_target_length or longer than max_target_length
|
1021 |
+
def is_labels_in_length_range(length):
|
1022 |
+
return length > min_target_length and length < max_target_length
|
1023 |
+
|
1024 |
+
if training_args.do_train:
|
1025 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
|
1026 |
+
is_labels_in_length_range,
|
1027 |
+
num_proc=num_workers,
|
1028 |
+
input_columns=["labels_length"],
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
# filter data with targets shorter than 2 tokens: <s></s> -> empty sentences
|
1032 |
+
def is_labels_greater_than_min(length):
|
1033 |
+
return length > 2
|
1034 |
+
|
1035 |
+
vectorized_datasets = vectorized_datasets.filter(
|
1036 |
+
is_labels_greater_than_min,
|
1037 |
+
num_proc=num_workers,
|
1038 |
+
input_columns=["labels_length"],
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
# for large datasets it is advised to run the preprocessing on a
|
1042 |
+
# single machine first with `args.preprocessing_only` since there will mostly likely
|
1043 |
+
# be a timeout when running the script in distributed mode.
|
1044 |
+
# In a second step `args.preprocessing_only` can then be set to `False` to load the
|
1045 |
+
# cached dataset
|
1046 |
+
if data_args.preprocessing_only:
|
1047 |
+
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
|
1048 |
+
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
|
1049 |
+
return
|
1050 |
+
|
1051 |
+
# 8. Load Metrics
|
1052 |
+
wer_metric = load_metric("wer")
|
1053 |
+
cer_metric = load_metric("cer")
|
1054 |
+
|
1055 |
+
def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
|
1056 |
+
label_ids = (
|
1057 |
+
pad_to_max_length(np.array(label_ids, dtype="object"), tokenizer)
|
1058 |
+
if pad_target_to_multiple_of
|
1059 |
+
else label_ids
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
|
1063 |
+
# we do not want to group tokens when computing the metrics
|
1064 |
+
label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
|
1065 |
+
|
1066 |
+
pred_ids = np.array(pred_ids)
|
1067 |
+
num_beams = pred_ids.shape[1]
|
1068 |
+
# decode on a beam-by-beam basis
|
1069 |
+
pred_str = [
|
1070 |
+
tokenizer.batch_decode(pred_ids[:, beam, :], skip_special_tokens=True)
|
1071 |
+
for beam in reversed(range(num_beams))
|
1072 |
+
]
|
1073 |
+
# compute word/character error rate for top beam
|
1074 |
+
wer = wer_metric.compute(predictions=pred_str[0], references=label_str)
|
1075 |
+
cer = cer_metric.compute(predictions=pred_str[0], references=label_str)
|
1076 |
+
|
1077 |
+
return {"wer": wer, "cer": cer}, pred_str, label_str
|
1078 |
+
|
1079 |
+
# 9. Save feature extractor, tokenizer and config
|
1080 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
1081 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
1082 |
+
config.save_pretrained(training_args.output_dir)
|
1083 |
+
|
1084 |
+
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
1085 |
+
|
1086 |
+
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
|
1087 |
+
processor=processor,
|
1088 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
1089 |
+
input_padding="longest",
|
1090 |
+
target_padding="longest",
|
1091 |
+
max_target_length=max_target_length,
|
1092 |
+
pad_input_to_multiple_of=pad_input_to_multiple_of,
|
1093 |
+
pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_target_length,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
# Enable tensorboard only on the master node
|
1097 |
+
has_tensorboard = is_tensorboard_available()
|
1098 |
+
if has_tensorboard and jax.process_index() == 0:
|
1099 |
+
try:
|
1100 |
+
from flax.metrics.tensorboard import SummaryWriter
|
1101 |
+
|
1102 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
1103 |
+
except ImportError as ie:
|
1104 |
+
has_tensorboard = False
|
1105 |
+
logger.warning(
|
1106 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
1107 |
+
)
|
1108 |
+
else:
|
1109 |
+
logger.warning(
|
1110 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
1111 |
+
"Please run `pip install tensorboard` to enable."
|
1112 |
+
)
|
1113 |
+
|
1114 |
+
# 10. Handle the repository creation
|
1115 |
+
if training_args.push_to_hub:
|
1116 |
+
with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
|
1117 |
+
git_lfs_extensions = f.read()
|
1118 |
+
if "*.wandb" not in git_lfs_extensions:
|
1119 |
+
f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
|
1120 |
+
if training_args.hub_model_id is None:
|
1121 |
+
repo_name = get_full_repo_name(
|
1122 |
+
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
repo_name = training_args.hub_model_id
|
1126 |
+
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
1127 |
+
|
1128 |
+
# 11. Initialize our training
|
1129 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
1130 |
+
rng, dropout_rng = jax.random.split(rng)
|
1131 |
+
|
1132 |
+
# Store some constants
|
1133 |
+
max_steps = int(training_args.max_steps)
|
1134 |
+
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
|
1135 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
1136 |
+
batch_size_per_update = train_batch_size * gradient_accumulation_steps
|
1137 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
1138 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
1139 |
+
to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
|
1140 |
+
|
1141 |
+
if training_args.do_train:
|
1142 |
+
num_train_samples = len(vectorized_datasets["train"])
|
1143 |
+
steps_per_epoch = num_train_samples // batch_size_per_update
|
1144 |
+
if max_steps > 0:
|
1145 |
+
num_epochs = -(training_args.max_steps // -steps_per_epoch)
|
1146 |
+
total_train_steps = max_steps
|
1147 |
+
else:
|
1148 |
+
num_epochs = int(training_args.num_train_epochs)
|
1149 |
+
total_train_steps = steps_per_epoch * num_epochs
|
1150 |
+
|
1151 |
+
# Create learning rate schedule
|
1152 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
1153 |
+
total_train_steps,
|
1154 |
+
training_args.warmup_steps,
|
1155 |
+
training_args.learning_rate,
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
1159 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
1160 |
+
# mask boolean with the same structure as the parameters.
|
1161 |
+
# The mask is True for parameters that should be decayed.
|
1162 |
+
# Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
|
1163 |
+
# For FlaxT5, one should correct the layer norm parameter naming
|
1164 |
+
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
1165 |
+
def decay_mask_fn(params):
|
1166 |
+
flat_params = traverse_util.flatten_dict(params)
|
1167 |
+
layer_norm_params = [
|
1168 |
+
(name, "scale")
|
1169 |
+
for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
1170 |
+
]
|
1171 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
1172 |
+
return traverse_util.unflatten_dict(flat_mask)
|
1173 |
+
|
1174 |
+
if training_args.adafactor:
|
1175 |
+
# Create Adafactor optimizer
|
1176 |
+
optim = optax.adafactor(
|
1177 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
1178 |
+
dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
|
1179 |
+
weight_decay_rate=training_args.weight_decay,
|
1180 |
+
weight_decay_mask=decay_mask_fn,
|
1181 |
+
)
|
1182 |
+
else:
|
1183 |
+
# Create AdamW optimizer
|
1184 |
+
optim = optax.adamw(
|
1185 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
1186 |
+
b1=training_args.adam_beta1,
|
1187 |
+
b2=training_args.adam_beta2,
|
1188 |
+
eps=training_args.adam_epsilon,
|
1189 |
+
weight_decay=training_args.weight_decay,
|
1190 |
+
mask=decay_mask_fn,
|
1191 |
+
)
|
1192 |
+
else:
|
1193 |
+
num_epochs = 0
|
1194 |
+
total_train_steps = 0
|
1195 |
+
num_train_samples = 0
|
1196 |
+
optim = None
|
1197 |
+
|
1198 |
+
# Setup train state
|
1199 |
+
state = MixedPrecisionTrainState.create(
|
1200 |
+
apply_fn=model.__call__,
|
1201 |
+
params=model.params,
|
1202 |
+
tx=optim,
|
1203 |
+
to_dtype=to_dtype,
|
1204 |
+
dropout_rng=dropout_rng,
|
1205 |
+
max_grad_norm=training_args.max_grad_norm,
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
# Cross entropy loss
|
1209 |
+
def loss_fn(logits, labels):
|
1210 |
+
vocab_size = logits.shape[-1]
|
1211 |
+
# optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
|
1212 |
+
onehot_targets = to_dtype(onehot(labels, vocab_size))
|
1213 |
+
loss = optax.softmax_cross_entropy(logits, onehot_targets)
|
1214 |
+
# ignore padded tokens from loss, i.e. where labels are not set to -100
|
1215 |
+
padding = labels >= 0
|
1216 |
+
loss = loss * padding
|
1217 |
+
loss = loss.sum()
|
1218 |
+
num_labels = padding.sum()
|
1219 |
+
return loss, num_labels
|
1220 |
+
|
1221 |
+
# Define gradient update step fn
|
1222 |
+
def train_step(state, batch):
|
1223 |
+
# only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
|
1224 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
1225 |
+
|
1226 |
+
def compute_loss(params, minibatch):
|
1227 |
+
labels = minibatch.pop("labels")
|
1228 |
+
logits = state.apply_fn(
|
1229 |
+
**minibatch,
|
1230 |
+
params=params,
|
1231 |
+
dropout_rng=dropout_rng,
|
1232 |
+
freeze_feature_encoder=model_args.freeze_feature_encoder,
|
1233 |
+
train=True,
|
1234 |
+
)[0]
|
1235 |
+
loss, num_labels = loss_fn(logits, labels)
|
1236 |
+
return loss, num_labels
|
1237 |
+
|
1238 |
+
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
1239 |
+
|
1240 |
+
if gradient_accumulation_steps == 1:
|
1241 |
+
(loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
|
1242 |
+
|
1243 |
+
# Custom gradient accumulation
|
1244 |
+
else:
|
1245 |
+
# add a first dimension over gradient_accumulation_steps for minibatch slices
|
1246 |
+
batch = jax.tree_map(
|
1247 |
+
lambda x: x.reshape(
|
1248 |
+
gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
|
1249 |
+
),
|
1250 |
+
batch,
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
def accum_minibatch_step(accum_grad, minibatch):
|
1254 |
+
# compute loss, num labels and grad over minibatch and accumulate
|
1255 |
+
(loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch)
|
1256 |
+
return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels)
|
1257 |
+
|
1258 |
+
# create an initial state for accumulating losses, num labels and gradients
|
1259 |
+
init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
|
1260 |
+
# loop accum minibatch step over the number of gradient accumulation steps
|
1261 |
+
grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch)
|
1262 |
+
|
1263 |
+
grad = jax.lax.psum(grad, "batch")
|
1264 |
+
loss = jax.lax.psum(loss.sum(), "batch")
|
1265 |
+
total_samples = jax.lax.psum(num_labels.sum(), "batch")
|
1266 |
+
grad = jax.tree_map(lambda g: g / total_samples, grad)
|
1267 |
+
loss = jax.tree_map(lambda l: l / total_samples, loss)
|
1268 |
+
|
1269 |
+
# update state
|
1270 |
+
new_state = state.apply_gradients(
|
1271 |
+
grads=grad,
|
1272 |
+
dropout_rng=new_dropout_rng,
|
1273 |
+
to_dtype=to_dtype,
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
# compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
|
1277 |
+
layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
|
1278 |
+
logs = {
|
1279 |
+
"layer_grad_norm": layer_grad_norm,
|
1280 |
+
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
|
1281 |
+
"decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
|
1282 |
+
}
|
1283 |
+
logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
|
1284 |
+
|
1285 |
+
# compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
|
1286 |
+
layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
|
1287 |
+
logs["layer_param_norm"] = layer_param_norm
|
1288 |
+
logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"]))
|
1289 |
+
logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"]))
|
1290 |
+
logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
|
1291 |
+
|
1292 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
1293 |
+
metrics.update(logs)
|
1294 |
+
|
1295 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
1296 |
+
# metrics = to_fp32(metrics)
|
1297 |
+
|
1298 |
+
return new_state, metrics
|
1299 |
+
|
1300 |
+
# Define eval fn
|
1301 |
+
def eval_step(params, batch):
|
1302 |
+
labels = batch.pop("labels")
|
1303 |
+
logits = model(**batch, params=params, train=False)[0]
|
1304 |
+
loss, num_labels = loss_fn(logits, labels)
|
1305 |
+
|
1306 |
+
total_samples = jax.lax.psum(num_labels, "batch")
|
1307 |
+
loss = jax.lax.psum(loss, "batch")
|
1308 |
+
loss = jax.tree_map(lambda l: l / total_samples, loss)
|
1309 |
+
|
1310 |
+
# summarize metrics
|
1311 |
+
metrics = {"loss": loss}
|
1312 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
1313 |
+
# metrics = to_fp32(metrics)
|
1314 |
+
return metrics
|
1315 |
+
|
1316 |
+
# Define generation function
|
1317 |
+
gen_kwargs = {
|
1318 |
+
"max_length": training_args.generation_max_length,
|
1319 |
+
"num_beams": training_args.generation_num_beams,
|
1320 |
+
"length_penalty": training_args.generation_length_penalty,
|
1321 |
+
}
|
1322 |
+
final_gen_kwargs = {
|
1323 |
+
"max_length": training_args.final_generation_max_length,
|
1324 |
+
"num_beams": training_args.final_generation_num_beams,
|
1325 |
+
"length_penalty": training_args.generation_length_penalty,
|
1326 |
+
}
|
1327 |
+
|
1328 |
+
def generate_step(params, batch):
|
1329 |
+
model.params = params
|
1330 |
+
output_ids = model.generate(batch["inputs"], **gen_kwargs)
|
1331 |
+
return output_ids.sequences
|
1332 |
+
|
1333 |
+
def final_generate_step(params, batch):
|
1334 |
+
model.params = params
|
1335 |
+
output_ids = model.generate(batch["inputs"], **final_gen_kwargs)
|
1336 |
+
return output_ids.sequences
|
1337 |
+
|
1338 |
+
# Create parallel version of the train and eval step
|
1339 |
+
if training_args.do_train:
|
1340 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
1341 |
+
|
1342 |
+
if training_args.do_eval or training_args.do_predict:
|
1343 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
1344 |
+
|
1345 |
+
if training_args.predict_with_generate:
|
1346 |
+
p_generate_step = jax.pmap(generate_step, "batch")
|
1347 |
+
p_final_generate_step = jax.pmap(final_generate_step, "batch")
|
1348 |
+
|
1349 |
+
def run_evaluation(step, final_step=False):
|
1350 |
+
if training_args.do_eval:
|
1351 |
+
# ======================== Evaluating ==============================
|
1352 |
+
eval_metrics = []
|
1353 |
+
eval_preds = []
|
1354 |
+
eval_ids = []
|
1355 |
+
eval_labels = []
|
1356 |
+
|
1357 |
+
# Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
|
1358 |
+
eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
|
1359 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last_batch=False)
|
1360 |
+
|
1361 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
1362 |
+
samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
|
1363 |
+
batch = data_collator(samples)
|
1364 |
+
eval_ids.extend(batch.pop("input_ids"))
|
1365 |
+
labels = batch["labels"]
|
1366 |
+
|
1367 |
+
metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1368 |
+
eval_metrics.append(metrics)
|
1369 |
+
|
1370 |
+
# generation
|
1371 |
+
if training_args.predict_with_generate:
|
1372 |
+
if not final_step:
|
1373 |
+
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1374 |
+
eval_preds.extend(
|
1375 |
+
jax.device_get(
|
1376 |
+
generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"])
|
1377 |
+
)
|
1378 |
+
)
|
1379 |
+
else:
|
1380 |
+
generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1381 |
+
eval_preds.extend(
|
1382 |
+
jax.device_get(
|
1383 |
+
generated_ids.reshape(
|
1384 |
+
-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"]
|
1385 |
+
)
|
1386 |
+
)
|
1387 |
+
)
|
1388 |
+
eval_labels.extend(labels)
|
1389 |
+
|
1390 |
+
# normalize eval metrics
|
1391 |
+
eval_metrics = get_metrics(eval_metrics)
|
1392 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
1393 |
+
eval_metrics = to_fp32(eval_metrics)
|
1394 |
+
|
1395 |
+
# compute error rate metric and get predicted string (for debugging)
|
1396 |
+
error_rate_desc = ""
|
1397 |
+
pred_str = []
|
1398 |
+
label_str = []
|
1399 |
+
if training_args.predict_with_generate:
|
1400 |
+
error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
|
1401 |
+
eval_metrics.update(error_rate_metric)
|
1402 |
+
error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
|
1403 |
+
|
1404 |
+
# Print metrics and update progress bar
|
1405 |
+
desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
|
1406 |
+
epochs.write(desc)
|
1407 |
+
epochs.desc = desc
|
1408 |
+
|
1409 |
+
# Save metrics
|
1410 |
+
write_wandb_log(eval_metrics, step, prefix="eval")
|
1411 |
+
write_wandb_pred(
|
1412 |
+
pred_str,
|
1413 |
+
label_str,
|
1414 |
+
eval_ids,
|
1415 |
+
step,
|
1416 |
+
top_ids=vectorized_datasets["eval"]["input_id"] if data_args.log_first_ids else None,
|
1417 |
+
final_step=final_step,
|
1418 |
+
)
|
1419 |
+
# if has_tensorboard and jax.process_index() == 0:
|
1420 |
+
# write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
|
1421 |
+
|
1422 |
+
def save_checkpoint(step):
|
1423 |
+
# save and push checkpoint to the hub
|
1424 |
+
if jax.process_index() == 0:
|
1425 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1426 |
+
model.save_pretrained(training_args.output_dir, params=params)
|
1427 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
1428 |
+
if training_args.push_to_hub:
|
1429 |
+
repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
|
1430 |
+
|
1431 |
+
# Replicate the train state on each device
|
1432 |
+
state = state.replicate()
|
1433 |
+
|
1434 |
+
logger.info("***** Running training *****")
|
1435 |
+
logger.info(f" Num examples = {num_train_samples}")
|
1436 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
1437 |
+
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
1438 |
+
logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
|
1439 |
+
logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
|
1440 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
1441 |
+
logger.info(f" Gradient checkpointing: {config.encoder.gradient_checkpointing}")
|
1442 |
+
logger.info(f" Use scan: {config.encoder.use_scan}")
|
1443 |
+
logger.info(f" Fuse matmuls: {config.encoder.fuse_matmuls}")
|
1444 |
+
|
1445 |
+
train_time = cur_step = 0
|
1446 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
1447 |
+
for epoch in epochs:
|
1448 |
+
if training_args.do_train:
|
1449 |
+
# ======================== Training ================================
|
1450 |
+
train_start = time.time()
|
1451 |
+
|
1452 |
+
# Create sampling rng
|
1453 |
+
rng, input_rng = jax.random.split(rng)
|
1454 |
+
|
1455 |
+
# Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
|
1456 |
+
train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
|
1457 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update, drop_last_batch=True)
|
1458 |
+
|
1459 |
+
# Gather the indices for creating the batch and do a training step
|
1460 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
|
1461 |
+
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
1462 |
+
batch = data_collator(samples)
|
1463 |
+
batch.pop("input_ids")
|
1464 |
+
batch = shard(batch.data)
|
1465 |
+
state, train_metric = p_train_step(state, batch)
|
1466 |
+
|
1467 |
+
cur_step = epoch * (num_train_samples // batch_size_per_update) + step
|
1468 |
+
|
1469 |
+
if cur_step % training_args.logging_steps == 0:
|
1470 |
+
# Save metrics
|
1471 |
+
train_metric = unreplicate(train_metric)
|
1472 |
+
train_time += time.time() - train_start
|
1473 |
+
# need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
|
1474 |
+
write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
|
1475 |
+
# we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
|
1476 |
+
# if has_tensorboard and jax.process_index() == 0:
|
1477 |
+
# write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
1478 |
+
|
1479 |
+
epochs.write(
|
1480 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
|
1481 |
+
)
|
1482 |
+
|
1483 |
+
if cur_step % total_train_steps == 0:
|
1484 |
+
break
|
1485 |
+
|
1486 |
+
if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
|
1487 |
+
# run beam search at each eval step
|
1488 |
+
run_evaluation(cur_step, final_step=False)
|
1489 |
+
|
1490 |
+
if cur_step % training_args.save_steps == 0:
|
1491 |
+
save_checkpoint(cur_step)
|
1492 |
+
|
1493 |
+
if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
|
1494 |
+
# run evaluation at the end of the epoch if eval steps are not specified
|
1495 |
+
run_evaluation(cur_step, final_step=False)
|
1496 |
+
save_checkpoint(cur_step)
|
1497 |
+
|
1498 |
+
if training_args.do_train:
|
1499 |
+
save_checkpoint(cur_step)
|
1500 |
+
|
1501 |
+
cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
|
1502 |
+
|
1503 |
+
if training_args.do_eval:
|
1504 |
+
run_evaluation(cur_step, final_step=True)
|
1505 |
+
|
1506 |
+
# TODO: collapse 'do_predict' into the run_evaluation function
|
1507 |
+
if training_args.do_predict:
|
1508 |
+
# ======================== Prediction ==============================
|
1509 |
+
for split in test_split:
|
1510 |
+
pred_metrics = []
|
1511 |
+
pred_generations = []
|
1512 |
+
pred_ids = []
|
1513 |
+
pred_labels = []
|
1514 |
+
|
1515 |
+
# Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
|
1516 |
+
pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
|
1517 |
+
pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size, drop_last_batch=False)
|
1518 |
+
|
1519 |
+
for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)):
|
1520 |
+
samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
|
1521 |
+
batch = data_collator(samples)
|
1522 |
+
pred_ids.extend(batch.pop("input_ids"))
|
1523 |
+
labels = batch["labels"]
|
1524 |
+
|
1525 |
+
metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data,
|
1526 |
+
min_device_batch=per_device_eval_batch_size)
|
1527 |
+
pred_metrics.append(metrics)
|
1528 |
+
|
1529 |
+
# generation
|
1530 |
+
if training_args.predict_with_generate:
|
1531 |
+
generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
|
1532 |
+
pred_generations.extend(
|
1533 |
+
jax.device_get(
|
1534 |
+
generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
|
1535 |
+
)
|
1536 |
+
)
|
1537 |
+
pred_labels.extend(labels)
|
1538 |
+
|
1539 |
+
# normalize eval metrics
|
1540 |
+
pred_metrics = get_metrics(pred_metrics)
|
1541 |
+
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
1542 |
+
pred_metrics = to_fp32(pred_metrics)
|
1543 |
+
|
1544 |
+
# compute error rate metric and get predicted string (for debugging)
|
1545 |
+
error_rate_desc = ""
|
1546 |
+
pred_str = []
|
1547 |
+
label_str = []
|
1548 |
+
if training_args.predict_with_generate:
|
1549 |
+
error_rate_metric, pred_str, label_str = compute_metrics(pred_generations, pred_labels)
|
1550 |
+
pred_metrics.update(error_rate_metric)
|
1551 |
+
error_rate_desc = " ".join([f"{split} {key}: {value} |" for key, value in error_rate_metric.items()])
|
1552 |
+
|
1553 |
+
# Print metrics and update progress bar
|
1554 |
+
desc = f"Step... ({cur_step}/{total_train_steps} | {split} Loss: {pred_metrics['loss']} | {error_rate_desc})"
|
1555 |
+
epochs.write(desc)
|
1556 |
+
epochs.desc = desc
|
1557 |
+
|
1558 |
+
# Save metrics
|
1559 |
+
write_wandb_log(pred_metrics, cur_step, prefix=split)
|
1560 |
+
write_wandb_pred(
|
1561 |
+
pred_str,
|
1562 |
+
label_str,
|
1563 |
+
pred_ids,
|
1564 |
+
cur_step,
|
1565 |
+
prefix=split,
|
1566 |
+
top_ids=vectorized_datasets[split]["input_id"] if data_args.log_first_ids else None,
|
1567 |
+
final_step=True,
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
|
1571 |
+
if __name__ == "__main__":
|
1572 |
+
main()
|
run_librispeech.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
python run_flax_speech_recognition_seq2seq.py \
|
3 |
+
--dataset_name="librispeech_asr" \
|
4 |
+
--model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
|
5 |
+
--dataset_config_name="all" \
|
6 |
+
--train_split_name="train.clean.100+train.clean.360+train.other.500" \
|
7 |
+
--eval_split_name="validation.clean" \
|
8 |
+
--test_split_name="validation.other+test.clean+test.other" \
|
9 |
+
--text_column_name="text" \
|
10 |
+
--id_column_name="id" \
|
11 |
+
--output_dir="./" \
|
12 |
+
--wandb_project="librispeech_960h" \
|
13 |
+
--wandb_name="flax-wav2vec2-2-bart-large-ls-960h-black-box" \
|
14 |
+
--dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
|
15 |
+
--per_device_train_batch_size="8" \
|
16 |
+
--per_device_eval_batch_size="4" \
|
17 |
+
--learning_rate="1e-4" \
|
18 |
+
--warmup_steps="500" \
|
19 |
+
--logging_steps="25" \
|
20 |
+
--max_steps="50000" \
|
21 |
+
--eval_steps="10000" \
|
22 |
+
--save_steps="10000" \
|
23 |
+
--generation_max_length="200" \
|
24 |
+
--generation_num_beams="5" \
|
25 |
+
--generation_length_penalty="1.2" \
|
26 |
+
--hidden_dropout="0.2" \
|
27 |
+
--activation_dropout="0.2" \
|
28 |
+
--feat_proj_dropout="0.2" \
|
29 |
+
--overwrite_output_dir \
|
30 |
+
--gradient_checkpointing \
|
31 |
+
--freeze_feature_encoder \
|
32 |
+
--predict_with_generate \
|
33 |
+
--do_lower_case \
|
34 |
+
--do_eval \
|
35 |
+
--do_train \
|
36 |
+
--do_predict \
|
37 |
+
--push_to_hub \
|
38 |
+
--use_auth_token
|
39 |
+
|
special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"cls_token": "<s>",
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"mask_token": {
|
6 |
+
"content": "<mask>",
|
7 |
+
"lstrip": true,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"pad_token": "<pad>",
|
13 |
+
"sep_token": "</s>",
|
14 |
+
"unk_token": "<unk>"
|
15 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<s>",
|
4 |
+
"cls_token": "<s>",
|
5 |
+
"eos_token": "</s>",
|
6 |
+
"errors": "replace",
|
7 |
+
"mask_token": "<mask>",
|
8 |
+
"model_max_length": 1024,
|
9 |
+
"name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
|
10 |
+
"pad_token": "<pad>",
|
11 |
+
"sep_token": "</s>",
|
12 |
+
"special_tokens_map_file": null,
|
13 |
+
"tokenizer_class": "BartTokenizer",
|
14 |
+
"trim_offsets": true,
|
15 |
+
"unk_token": "<unk>"
|
16 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wandb/debug-internal.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
run-20220828_085247-2hx8pk65/logs/debug-internal.log
|
wandb/debug.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
run-20220828_085247-2hx8pk65/logs/debug.log
|
wandb/latest-run
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
run-20220828_085247-2hx8pk65
|
wandb/run-20220828_084407-nbdgecc9/files/config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
_wandb:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
cli_version: 0.12.15
|
7 |
+
framework: huggingface
|
8 |
+
huggingface_version: 4.21.0.dev0
|
9 |
+
is_jupyter_run: false
|
10 |
+
is_kaggle_kernel: false
|
11 |
+
python_version: 3.8.10
|
12 |
+
start_time: 1661676247
|
13 |
+
t:
|
14 |
+
1:
|
15 |
+
- 1
|
16 |
+
- 11
|
17 |
+
- 12
|
18 |
+
- 45
|
19 |
+
- 49
|
20 |
+
- 51
|
21 |
+
- 55
|
22 |
+
2:
|
23 |
+
- 1
|
24 |
+
- 11
|
25 |
+
- 12
|
26 |
+
- 45
|
27 |
+
- 49
|
28 |
+
- 51
|
29 |
+
- 55
|
30 |
+
3:
|
31 |
+
- 13
|
32 |
+
4: 3.8.10
|
33 |
+
5: 0.12.15
|
34 |
+
6: 4.21.0.dev0
|
35 |
+
8:
|
36 |
+
- 5
|
wandb/run-20220828_084407-nbdgecc9/files/output.log
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INFO:__main__:Training/evaluation parameters FlaxSeq2SeqTrainingArguments(
|
2 |
+
_n_gpu=-1,
|
3 |
+
adafactor=False,
|
4 |
+
adam_beta1=0.9,
|
5 |
+
adam_beta2=0.999,
|
6 |
+
adam_epsilon=1e-08,
|
7 |
+
auto_find_batch_size=False,
|
8 |
+
bf16=False,
|
9 |
+
bf16_full_eval=False,
|
10 |
+
data_seed=None,
|
11 |
+
dataloader_drop_last=False,
|
12 |
+
dataloader_num_workers=0,
|
13 |
+
dataloader_pin_memory=True,
|
14 |
+
ddp_bucket_cap_mb=None,
|
15 |
+
ddp_find_unused_parameters=None,
|
16 |
+
debug=,
|
17 |
+
deepspeed=None,
|
18 |
+
disable_tqdm=None,
|
19 |
+
do_eval=True,
|
20 |
+
do_predict=True,
|
21 |
+
do_train=True,
|
22 |
+
eval_accumulation_steps=None,
|
23 |
+
eval_delay=0,
|
24 |
+
eval_steps=10000,
|
25 |
+
evaluation_strategy=no,
|
26 |
+
final_generation_max_length=200,
|
27 |
+
final_generation_num_beams=5,
|
28 |
+
fp16=False,
|
29 |
+
fp16_backend=auto,
|
30 |
+
fp16_full_eval=False,
|
31 |
+
fp16_opt_level=O1,
|
32 |
+
fsdp=,
|
33 |
+
fsdp_min_num_params=0,
|
34 |
+
fsdp_transformer_layer_cls_to_wrap=None,
|
35 |
+
full_determinism=False,
|
36 |
+
generation_length_penalty=1.2,
|
37 |
+
generation_max_length=200,
|
38 |
+
generation_num_beams=5,
|
39 |
+
gradient_accumulation_steps=1,
|
40 |
+
gradient_checkpointing=True,
|
41 |
+
greater_is_better=None,
|
42 |
+
group_by_length=False,
|
43 |
+
half_precision_backend=auto,
|
44 |
+
hub_model_id=None,
|
45 |
+
hub_private_repo=False,
|
46 |
+
hub_strategy=every_save,
|
47 |
+
hub_token=<HUB_TOKEN>,
|
48 |
+
ignore_data_skip=False,
|
49 |
+
include_inputs_for_metrics=False,
|
50 |
+
jit_mode_eval=False,
|
51 |
+
label_names=None,
|
52 |
+
label_smoothing_factor=0.0,
|
53 |
+
learning_rate=0.0001,
|
54 |
+
length_column_name=length,
|
55 |
+
load_best_model_at_end=False,
|
56 |
+
local_rank=-1,
|
57 |
+
log_level=passive,
|
58 |
+
log_level_replica=passive,
|
59 |
+
log_on_each_node=True,
|
60 |
+
logging_dir=None,
|
61 |
+
logging_first_step=False,
|
62 |
+
logging_nan_inf_filter=True,
|
63 |
+
logging_steps=25,
|
64 |
+
logging_strategy=steps,
|
65 |
+
lr_scheduler_type=linear,
|
66 |
+
matmul_precision=default,
|
67 |
+
max_grad_norm=1.0,
|
68 |
+
max_steps=50000,
|
69 |
+
metric_for_best_model=None,
|
70 |
+
mp_parameters=,
|
71 |
+
no_cuda=False,
|
72 |
+
num_train_epochs=3.0,
|
73 |
+
optim=adamw_hf,
|
74 |
+
output_dir=./,
|
75 |
+
overwrite_output_dir=True,
|
76 |
+
past_index=-1,
|
77 |
+
per_device_eval_batch_size=4,
|
78 |
+
per_device_train_batch_size=8,
|
79 |
+
precision=full,
|
80 |
+
predict_with_generate=True,
|
81 |
+
prediction_loss_only=False,
|
82 |
+
push_to_hub=True,
|
83 |
+
push_to_hub_model_id=None,
|
84 |
+
push_to_hub_organization=None,
|
85 |
+
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
|
86 |
+
ray_scope=last,
|
87 |
+
remove_unused_columns=True,
|
88 |
+
report_to=None,
|
89 |
+
resume_from_checkpoint=None,
|
90 |
+
run_name=None,
|
91 |
+
save_on_each_node=False,
|
92 |
+
save_steps=10000,
|
93 |
+
save_strategy=steps,
|
94 |
+
save_total_limit=None,
|
95 |
+
seed=42,
|
96 |
+
sharded_ddp=,
|
97 |
+
skip_memory_metrics=True,
|
98 |
+
sortish_sampler=False,
|
99 |
+
tf32=None,
|
100 |
+
torchdynamo=None,
|
101 |
+
tpu_metrics_debug=False,
|
102 |
+
tpu_num_cores=None,
|
103 |
+
use_ipex=False,
|
104 |
+
use_legacy_prediction_loop=False,
|
105 |
+
warmup_ratio=0.0,
|
106 |
+
warmup_steps=500,
|
107 |
+
weight_decay=0.0,
|
108 |
+
xpu_backend=None,
|
109 |
+
)
|
110 |
+
INFO:__main__:JAX devices: 8, matmul precision: default
|
wandb/run-20220828_084407-nbdgecc9/files/requirements.txt
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.0.0
|
2 |
+
aiohttp==3.8.1
|
3 |
+
aiosignal==1.2.0
|
4 |
+
anyio==3.5.0
|
5 |
+
appdirs==1.4.4
|
6 |
+
argon2-cffi-bindings==21.2.0
|
7 |
+
argon2-cffi==21.3.0
|
8 |
+
asttokens==2.0.5
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==21.4.0
|
11 |
+
audioread==2.1.9
|
12 |
+
babel==2.10.1
|
13 |
+
backcall==0.2.0
|
14 |
+
beautifulsoup4==4.11.1
|
15 |
+
bleach==5.0.0
|
16 |
+
certifi==2021.10.8
|
17 |
+
cffi==1.15.0
|
18 |
+
charset-normalizer==2.0.12
|
19 |
+
chex==0.1.3
|
20 |
+
click==8.1.3
|
21 |
+
colorama==0.4.5
|
22 |
+
commonmark==0.9.1
|
23 |
+
cycler==0.11.0
|
24 |
+
datasets==2.4.1.dev0
|
25 |
+
debugpy==1.6.0
|
26 |
+
decorator==5.1.1
|
27 |
+
defusedxml==0.7.1
|
28 |
+
dill==0.3.4
|
29 |
+
dm-tree==0.1.7
|
30 |
+
docker-pycreds==0.4.0
|
31 |
+
entrypoints==0.4
|
32 |
+
etils==0.6.0
|
33 |
+
executing==0.8.3
|
34 |
+
fastjsonschema==2.15.3
|
35 |
+
filelock==3.6.0
|
36 |
+
flatbuffers==2.0
|
37 |
+
flax==0.5.3
|
38 |
+
fonttools==4.33.3
|
39 |
+
frozenlist==1.3.0
|
40 |
+
fsspec==2022.3.0
|
41 |
+
gitdb==4.0.9
|
42 |
+
gitpython==3.1.27
|
43 |
+
huggingface-hub==0.5.1
|
44 |
+
idna==3.3
|
45 |
+
ijson==3.1.4
|
46 |
+
importlib-metadata==4.11.3
|
47 |
+
importlib-resources==5.7.1
|
48 |
+
iniconfig==1.1.1
|
49 |
+
ipdb==0.13.9
|
50 |
+
ipykernel==6.13.0
|
51 |
+
ipython-genutils==0.2.0
|
52 |
+
ipython==8.3.0
|
53 |
+
jax==0.3.15
|
54 |
+
jaxlib==0.3.15
|
55 |
+
jedi==0.18.1
|
56 |
+
jinja2==3.1.2
|
57 |
+
jiwer==2.3.0
|
58 |
+
joblib==1.1.0
|
59 |
+
json5==0.9.6
|
60 |
+
jsonschema==4.4.0
|
61 |
+
jupyter-client==7.3.0
|
62 |
+
jupyter-core==4.10.0
|
63 |
+
jupyter-server==1.17.0
|
64 |
+
jupyterlab-pygments==0.2.2
|
65 |
+
jupyterlab-server==2.13.0
|
66 |
+
jupyterlab==3.4.0
|
67 |
+
kiwisolver==1.4.2
|
68 |
+
librosa==0.9.1
|
69 |
+
libtpu-nightly==0.1.dev20220722
|
70 |
+
llvmlite==0.38.0
|
71 |
+
markupsafe==2.1.1
|
72 |
+
matplotlib-inline==0.1.3
|
73 |
+
matplotlib==3.5.1
|
74 |
+
mistune==0.8.4
|
75 |
+
msgpack==1.0.3
|
76 |
+
multidict==6.0.2
|
77 |
+
multiprocess==0.70.12.2
|
78 |
+
nbclassic==0.3.7
|
79 |
+
nbclient==0.6.2
|
80 |
+
nbconvert==6.5.0
|
81 |
+
nbformat==5.4.0
|
82 |
+
nest-asyncio==1.5.5
|
83 |
+
nltk==3.7
|
84 |
+
notebook-shim==0.1.0
|
85 |
+
notebook==6.4.11
|
86 |
+
numba==0.55.1
|
87 |
+
numpy==1.21.0
|
88 |
+
opt-einsum==3.3.0
|
89 |
+
optax==0.1.2
|
90 |
+
packaging==21.3
|
91 |
+
pandas==1.4.2
|
92 |
+
pandocfilters==1.5.0
|
93 |
+
parso==0.8.3
|
94 |
+
pathtools==0.1.2
|
95 |
+
pexpect==4.8.0
|
96 |
+
pickleshare==0.7.5
|
97 |
+
pillow==9.1.0
|
98 |
+
pip==20.0.2
|
99 |
+
pkg-resources==0.0.0
|
100 |
+
pluggy==1.0.0
|
101 |
+
pooch==1.6.0
|
102 |
+
prometheus-client==0.14.1
|
103 |
+
promise==2.3
|
104 |
+
prompt-toolkit==3.0.29
|
105 |
+
protobuf==3.20.1
|
106 |
+
psutil==5.9.0
|
107 |
+
ptyprocess==0.7.0
|
108 |
+
pure-eval==0.2.2
|
109 |
+
py==1.11.0
|
110 |
+
pyarrow==7.0.0
|
111 |
+
pycparser==2.21
|
112 |
+
pycryptodome==3.14.1
|
113 |
+
pygments==2.12.0
|
114 |
+
pyparsing==3.0.8
|
115 |
+
pyrsistent==0.18.1
|
116 |
+
pytest==7.1.2
|
117 |
+
python-dateutil==2.8.2
|
118 |
+
python-levenshtein==0.12.2
|
119 |
+
pytz==2022.1
|
120 |
+
pyyaml==6.0
|
121 |
+
pyzmq==22.3.0
|
122 |
+
regex==2022.4.24
|
123 |
+
requests==2.27.1
|
124 |
+
resampy==0.2.2
|
125 |
+
responses==0.18.0
|
126 |
+
rich==11.1.0
|
127 |
+
rouge-score==0.1.2
|
128 |
+
sacremoses==0.0.49
|
129 |
+
scikit-learn==1.0.2
|
130 |
+
scipy==1.8.0
|
131 |
+
send2trash==1.8.0
|
132 |
+
sentry-sdk==1.5.10
|
133 |
+
seqeval==1.2.2
|
134 |
+
setproctitle==1.2.3
|
135 |
+
setuptools==44.0.0
|
136 |
+
shortuuid==1.0.8
|
137 |
+
six==1.16.0
|
138 |
+
smmap==5.0.0
|
139 |
+
sniffio==1.2.0
|
140 |
+
soundfile==0.10.3.post1
|
141 |
+
soupsieve==2.3.2.post1
|
142 |
+
speechcolab==0.0.6a0
|
143 |
+
stack-data==0.2.0
|
144 |
+
tensorstore==0.1.21
|
145 |
+
terminado==0.13.3
|
146 |
+
threadpoolctl==3.1.0
|
147 |
+
tinycss2==1.1.1
|
148 |
+
tokenizers==0.12.1
|
149 |
+
toml==0.10.2
|
150 |
+
tomli==2.0.1
|
151 |
+
toolz==0.11.2
|
152 |
+
torch==1.11.0+cpu
|
153 |
+
torchaudio==0.11.0+cpu
|
154 |
+
tornado==6.1
|
155 |
+
tqdm==4.64.0
|
156 |
+
traitlets==5.1.1
|
157 |
+
transformers==4.21.0.dev0
|
158 |
+
typing-extensions==4.2.0
|
159 |
+
urllib3==1.26.9
|
160 |
+
wandb==0.12.15
|
161 |
+
wcwidth==0.2.5
|
162 |
+
webencodings==0.5.1
|
163 |
+
websocket-client==1.3.2
|
164 |
+
wheel==0.37.1
|
165 |
+
xxhash==3.0.0
|
166 |
+
yarl==1.7.2
|
167 |
+
zipp==3.8.0
|
wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.11.0-1028-gcp-x86_64-with-glibc2.29",
|
3 |
+
"python": "3.8.10",
|
4 |
+
"heartbeatAt": "2022-08-28T08:44:08.435675",
|
5 |
+
"startedAt": "2022-08-28T08:44:07.234991",
|
6 |
+
"docker": null,
|
7 |
+
"cpu_count": 96,
|
8 |
+
"cuda": null,
|
9 |
+
"args": [
|
10 |
+
"--dataset_name=librispeech_asr",
|
11 |
+
"--model_name_or_path=sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
|
12 |
+
"--dataset_config_name=all",
|
13 |
+
"--train_split_name=train.clean.100+train.clean.360+train.other.500",
|
14 |
+
"--eval_split_name=validation.clean",
|
15 |
+
"--test_split_name=validation.other+test.clean+test.other",
|
16 |
+
"--text_column_name=text",
|
17 |
+
"--id_column_name=id",
|
18 |
+
"--output_dir=./",
|
19 |
+
"--wandb_project=librispeech_960h",
|
20 |
+
"--wandb_name=flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
21 |
+
"--dataset_cache_dir=/home/sanchitgandhi/cache/huggingface/datasets",
|
22 |
+
"--per_device_train_batch_size=8",
|
23 |
+
"--per_device_eval_batch_size=4",
|
24 |
+
"--learning_rate=1e-4",
|
25 |
+
"--warmup_steps=500",
|
26 |
+
"--logging_steps=25",
|
27 |
+
"--max_steps=50000",
|
28 |
+
"--eval_steps=10000",
|
29 |
+
"--save_steps=10000",
|
30 |
+
"--generation_max_length=200",
|
31 |
+
"--generation_num_beams=5",
|
32 |
+
"--generation_length_penalty=1.2",
|
33 |
+
"--hidden_dropout=0.2",
|
34 |
+
"--activation_dropout=0.2",
|
35 |
+
"--feat_proj_dropout=0.2",
|
36 |
+
"--overwrite_output_dir",
|
37 |
+
"--gradient_checkpointing",
|
38 |
+
"--freeze_feature_encoder",
|
39 |
+
"--predict_with_generate",
|
40 |
+
"--do_lower_case",
|
41 |
+
"--do_eval",
|
42 |
+
"--do_train",
|
43 |
+
"--do_predict",
|
44 |
+
"--push_to_hub",
|
45 |
+
"--use_auth_token"
|
46 |
+
],
|
47 |
+
"state": "running",
|
48 |
+
"program": "run_flax_speech_recognition_seq2seq.py",
|
49 |
+
"codePath": "run_flax_speech_recognition_seq2seq.py",
|
50 |
+
"git": {
|
51 |
+
"remote": "https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
52 |
+
"commit": "140399a622e2a82685fa4b9727f3d970b8bef9e0"
|
53 |
+
},
|
54 |
+
"email": "sanchit@huggingface.co",
|
55 |
+
"root": "/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
56 |
+
"host": "t1v-n-5966b949-w-0",
|
57 |
+
"username": "sanchitgandhi",
|
58 |
+
"executable": "/home/sanchitgandhi/hf/bin/python"
|
59 |
+
}
|
wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"_wandb": {"runtime": 3}}
|
wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2022-08-28 08:44:08,160 INFO MainThread:52894 [internal.py:wandb_internal():90] W&B internal server running at pid: 52894, started at: 2022-08-28 08:44:08.159804
|
2 |
+
2022-08-28 08:44:08,162 INFO WriterThread:52894 [datastore.py:open_for_write():75] open: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb
|
3 |
+
2022-08-28 08:44:08,163 DEBUG SenderThread:52894 [sender.py:send():232] send: header
|
4 |
+
2022-08-28 08:44:08,163 DEBUG SenderThread:52894 [sender.py:send():232] send: run
|
5 |
+
2022-08-28 08:44:08,326 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: check_version
|
6 |
+
2022-08-28 08:44:08,390 INFO SenderThread:52894 [dir_watcher.py:__init__():166] watching files in: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files
|
7 |
+
2022-08-28 08:44:08,390 INFO SenderThread:52894 [sender.py:_start_run_threads():811] run started: nbdgecc9 with start time 1661676247
|
8 |
+
2022-08-28 08:44:08,390 DEBUG SenderThread:52894 [sender.py:send():232] send: summary
|
9 |
+
2022-08-28 08:44:08,391 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
|
10 |
+
2022-08-28 08:44:08,391 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: check_version
|
11 |
+
2022-08-28 08:44:08,434 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: run_start
|
12 |
+
2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:__init__():35] meta init
|
13 |
+
2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:__init__():49] meta init done
|
14 |
+
2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:probe():209] probe
|
15 |
+
2022-08-28 08:44:08,436 DEBUG HandlerThread:52894 [meta.py:_setup_git():199] setup git
|
16 |
+
2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_setup_git():206] setup git done
|
17 |
+
2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_save_pip():53] save pip
|
18 |
+
2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_save_pip():67] save pip done
|
19 |
+
2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:probe():247] probe done
|
20 |
+
2022-08-28 08:44:08,480 DEBUG SenderThread:52894 [sender.py:send():232] send: files
|
21 |
+
2022-08-28 08:44:08,480 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-metadata.json with policy now
|
22 |
+
2022-08-28 08:44:08,485 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: stop_status
|
23 |
+
2022-08-28 08:44:08,485 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: stop_status
|
24 |
+
2022-08-28 08:44:08,623 DEBUG SenderThread:52894 [sender.py:send():232] send: telemetry
|
25 |
+
2022-08-28 08:44:08,935 INFO Thread-11 :52894 [upload_job.py:push():137] Uploaded file /tmp/tmpos_hhp45wandb/3f0zop6c-wandb-metadata.json
|
26 |
+
2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
|
27 |
+
2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt
|
28 |
+
2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
|
29 |
+
2022-08-28 08:44:09,393 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json
|
30 |
+
2022-08-28 08:44:09,690 DEBUG SenderThread:52894 [sender.py:send():232] send: telemetry
|
31 |
+
2022-08-28 08:44:11,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
|
32 |
+
2022-08-28 08:44:12,001 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
33 |
+
2022-08-28 08:44:12,002 DEBUG SenderThread:52894 [sender.py:send():232] send: exit
|
34 |
+
2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():368] handling exit code: 1
|
35 |
+
2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():370] handling runtime: 3
|
36 |
+
2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
|
37 |
+
2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():376] send defer
|
38 |
+
2022-08-28 08:44:12,003 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
39 |
+
2022-08-28 08:44:12,003 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
40 |
+
2022-08-28 08:44:12,003 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 0
|
41 |
+
2022-08-28 08:44:12,003 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
42 |
+
2022-08-28 08:44:12,003 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 0
|
43 |
+
2022-08-28 08:44:12,003 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 1
|
44 |
+
2022-08-28 08:44:12,004 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
45 |
+
2022-08-28 08:44:12,004 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 1
|
46 |
+
2022-08-28 08:44:12,050 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
47 |
+
2022-08-28 08:44:12,050 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 1
|
48 |
+
2022-08-28 08:44:12,050 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 2
|
49 |
+
2022-08-28 08:44:12,050 DEBUG SenderThread:52894 [sender.py:send():232] send: stats
|
50 |
+
2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
51 |
+
2022-08-28 08:44:12,051 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 2
|
52 |
+
2022-08-28 08:44:12,051 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
53 |
+
2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 2
|
54 |
+
2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 3
|
55 |
+
2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
56 |
+
2022-08-28 08:44:12,051 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 3
|
57 |
+
2022-08-28 08:44:12,051 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
58 |
+
2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 3
|
59 |
+
2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 4
|
60 |
+
2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
61 |
+
2022-08-28 08:44:12,052 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 4
|
62 |
+
2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send():232] send: summary
|
63 |
+
2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
|
64 |
+
2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
65 |
+
2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 4
|
66 |
+
2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 5
|
67 |
+
2022-08-28 08:44:12,052 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
68 |
+
2022-08-28 08:44:12,052 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 5
|
69 |
+
2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
70 |
+
2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 5
|
71 |
+
2022-08-28 08:44:12,104 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
72 |
+
2022-08-28 08:44:12,199 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 6
|
73 |
+
2022-08-28 08:44:12,200 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
74 |
+
2022-08-28 08:44:12,200 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
75 |
+
2022-08-28 08:44:12,200 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 6
|
76 |
+
2022-08-28 08:44:12,200 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
77 |
+
2022-08-28 08:44:12,200 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 6
|
78 |
+
2022-08-28 08:44:12,200 INFO SenderThread:52894 [dir_watcher.py:finish():279] shutting down directory watcher
|
79 |
+
2022-08-28 08:44:12,301 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
80 |
+
2022-08-28 08:44:12,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
|
81 |
+
2022-08-28 08:44:12,392 INFO SenderThread:52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
|
82 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml
|
83 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():309] scan: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files
|
84 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log output.log
|
85 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json wandb-metadata.json
|
86 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml config.yaml
|
87 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt requirements.txt
|
88 |
+
2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json wandb-summary.json
|
89 |
+
2022-08-28 08:44:12,396 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 7
|
90 |
+
2022-08-28 08:44:12,396 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
91 |
+
2022-08-28 08:44:12,404 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
92 |
+
2022-08-28 08:44:12,404 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 7
|
93 |
+
2022-08-28 08:44:12,404 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
94 |
+
2022-08-28 08:44:12,405 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 7
|
95 |
+
2022-08-28 08:44:12,405 INFO SenderThread:52894 [file_pusher.py:finish():145] shutting down file pusher
|
96 |
+
2022-08-28 08:44:12,501 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
97 |
+
2022-08-28 08:44:12,501 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
98 |
+
2022-08-28 08:44:12,603 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
99 |
+
2022-08-28 08:44:12,603 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
100 |
+
2022-08-28 08:44:12,704 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
101 |
+
2022-08-28 08:44:12,705 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
102 |
+
2022-08-28 08:44:12,806 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
103 |
+
2022-08-28 08:44:12,806 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
104 |
+
2022-08-28 08:44:12,860 INFO Thread-14 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt
|
105 |
+
2022-08-28 08:44:12,865 INFO Thread-12 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
|
106 |
+
2022-08-28 08:44:12,866 INFO Thread-15 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
|
107 |
+
2022-08-28 08:44:12,908 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
108 |
+
2022-08-28 08:44:12,908 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
109 |
+
2022-08-28 08:44:12,949 INFO Thread-13 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml
|
110 |
+
2022-08-28 08:44:13,009 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
111 |
+
2022-08-28 08:44:13,009 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
112 |
+
2022-08-28 08:44:13,111 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
113 |
+
2022-08-28 08:44:13,111 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
114 |
+
2022-08-28 08:44:13,149 INFO Thread-6 :52894 [sender.py:transition_state():389] send defer: 8
|
115 |
+
2022-08-28 08:44:13,149 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
116 |
+
2022-08-28 08:44:13,149 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 8
|
117 |
+
2022-08-28 08:44:13,150 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
118 |
+
2022-08-28 08:44:13,150 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 8
|
119 |
+
2022-08-28 08:44:13,213 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
120 |
+
2022-08-28 08:44:13,272 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 9
|
121 |
+
2022-08-28 08:44:13,272 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
122 |
+
2022-08-28 08:44:13,273 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
123 |
+
2022-08-28 08:44:13,273 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 9
|
124 |
+
2022-08-28 08:44:13,273 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
125 |
+
2022-08-28 08:44:13,273 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 9
|
126 |
+
2022-08-28 08:44:13,273 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 10
|
127 |
+
2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send():232] send: final
|
128 |
+
2022-08-28 08:44:13,274 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
|
129 |
+
2022-08-28 08:44:13,274 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 10
|
130 |
+
2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send():232] send: footer
|
131 |
+
2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
|
132 |
+
2022-08-28 08:44:13,274 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 10
|
133 |
+
2022-08-28 08:44:13,374 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
|
134 |
+
2022-08-28 08:44:13,374 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
|
135 |
+
2022-08-28 08:44:13,375 INFO SenderThread:52894 [file_pusher.py:join():150] waiting for file pusher
|
136 |
+
2022-08-28 08:44:13,731 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: sampled_history
|
137 |
+
2022-08-28 08:44:13,732 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: get_summary
|
138 |
+
2022-08-28 08:44:13,732 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: shutdown
|
139 |
+
2022-08-28 08:44:13,732 INFO HandlerThread:52894 [handler.py:finish():806] shutting down handler
|
140 |
+
2022-08-28 08:44:14,274 INFO WriterThread:52894 [datastore.py:close():279] close: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb
|
141 |
+
2022-08-28 08:44:14,629 INFO SenderThread:52894 [sender.py:finish():1106] shutting down sender
|
142 |
+
2022-08-28 08:44:14,629 INFO SenderThread:52894 [file_pusher.py:finish():145] shutting down file pusher
|
143 |
+
2022-08-28 08:44:14,630 INFO SenderThread:52894 [file_pusher.py:join():150] waiting for file pusher
|
144 |
+
2022-08-28 08:44:14,632 INFO MainThread:52894 [internal.py:handle_exit():80] Internal process exited
|
wandb/run-20220828_084407-nbdgecc9/logs/debug.log
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/.config/wandb/settings
|
2 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/settings
|
3 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from environment variables: {}
|
4 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Inferring run settings from compute environment: {'program_relpath': 'run_flax_speech_recognition_seq2seq.py', 'program': 'run_flax_speech_recognition_seq2seq.py'}
|
5 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_init.py:_log_setup():437] Logging user logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/logs/debug.log
|
6 |
+
2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_init.py:_log_setup():438] Logging internal logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log
|
7 |
+
2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():471] calling init triggers
|
8 |
+
2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():474] wandb.init called with sweep_config: {}
|
9 |
+
config: {}
|
10 |
+
2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():524] starting backend
|
11 |
+
2022-08-28 08:44:07,237 INFO MainThread:51732 [backend.py:_multiprocessing_setup():97] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
12 |
+
2022-08-28 08:44:07,348 INFO MainThread:51732 [backend.py:ensure_launched():217] starting backend process...
|
13 |
+
2022-08-28 08:44:07,379 INFO MainThread:51732 [backend.py:ensure_launched():222] started backend process with pid: 52894
|
14 |
+
2022-08-28 08:44:07,381 INFO MainThread:51732 [wandb_init.py:init():533] backend started and connected
|
15 |
+
2022-08-28 08:44:07,392 INFO MainThread:51732 [wandb_init.py:init():597] updated telemetry
|
16 |
+
2022-08-28 08:44:07,454 INFO MainThread:51732 [wandb_init.py:init():628] communicating run to backend with 30 second timeout
|
17 |
+
2022-08-28 08:44:08,326 INFO MainThread:51732 [wandb_run.py:_on_init():1923] communicating current version
|
18 |
+
2022-08-28 08:44:08,426 INFO MainThread:51732 [wandb_run.py:_on_init():1927] got version response upgrade_message: "wandb version 0.13.2 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
|
19 |
+
|
20 |
+
2022-08-28 08:44:08,426 INFO MainThread:51732 [wandb_init.py:init():659] starting run threads in backend
|
21 |
+
2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_console_start():1897] atexit reg
|
22 |
+
2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_redirect():1770] redirect: SettingsConsole.REDIRECT
|
23 |
+
2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_redirect():1775] Redirecting console.
|
24 |
+
2022-08-28 08:44:08,487 INFO MainThread:51732 [wandb_run.py:_redirect():1831] Redirects installed.
|
25 |
+
2022-08-28 08:44:08,488 INFO MainThread:51732 [wandb_init.py:init():684] run started, returning control to user process
|
26 |
+
2022-08-28 08:44:09,687 INFO MainThread:51732 [wandb_run.py:_atexit_cleanup():1866] got exitcode: 1
|
27 |
+
2022-08-28 08:44:09,689 INFO MainThread:51732 [wandb_run.py:_restore():1838] restore
|
28 |
+
2022-08-28 08:44:12,003 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
29 |
+
wandb_count: 1
|
30 |
+
}
|
31 |
+
pusher_stats {
|
32 |
+
uploaded_bytes: 2233
|
33 |
+
total_bytes: 2233
|
34 |
+
}
|
35 |
+
|
36 |
+
2022-08-28 08:44:12,200 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
37 |
+
wandb_count: 1
|
38 |
+
}
|
39 |
+
pusher_stats {
|
40 |
+
uploaded_bytes: 2233
|
41 |
+
total_bytes: 2233
|
42 |
+
}
|
43 |
+
|
44 |
+
2022-08-28 08:44:12,400 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
45 |
+
wandb_count: 4
|
46 |
+
}
|
47 |
+
pusher_stats {
|
48 |
+
uploaded_bytes: 2233
|
49 |
+
total_bytes: 8131
|
50 |
+
}
|
51 |
+
|
52 |
+
2022-08-28 08:44:12,501 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
53 |
+
wandb_count: 5
|
54 |
+
}
|
55 |
+
pusher_stats {
|
56 |
+
uploaded_bytes: 2233
|
57 |
+
total_bytes: 8157
|
58 |
+
}
|
59 |
+
|
60 |
+
2022-08-28 08:44:12,603 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
61 |
+
wandb_count: 5
|
62 |
+
}
|
63 |
+
pusher_stats {
|
64 |
+
uploaded_bytes: 8157
|
65 |
+
total_bytes: 8157
|
66 |
+
}
|
67 |
+
|
68 |
+
2022-08-28 08:44:12,705 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
69 |
+
wandb_count: 5
|
70 |
+
}
|
71 |
+
pusher_stats {
|
72 |
+
uploaded_bytes: 8157
|
73 |
+
total_bytes: 8157
|
74 |
+
}
|
75 |
+
|
76 |
+
2022-08-28 08:44:12,807 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
77 |
+
wandb_count: 5
|
78 |
+
}
|
79 |
+
pusher_stats {
|
80 |
+
uploaded_bytes: 8157
|
81 |
+
total_bytes: 8157
|
82 |
+
}
|
83 |
+
|
84 |
+
2022-08-28 08:44:12,908 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
85 |
+
wandb_count: 5
|
86 |
+
}
|
87 |
+
pusher_stats {
|
88 |
+
uploaded_bytes: 8157
|
89 |
+
total_bytes: 8157
|
90 |
+
}
|
91 |
+
|
92 |
+
2022-08-28 08:44:13,010 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
93 |
+
wandb_count: 5
|
94 |
+
}
|
95 |
+
pusher_stats {
|
96 |
+
uploaded_bytes: 8157
|
97 |
+
total_bytes: 8157
|
98 |
+
}
|
99 |
+
|
100 |
+
2022-08-28 08:44:13,112 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
101 |
+
wandb_count: 5
|
102 |
+
}
|
103 |
+
pusher_stats {
|
104 |
+
uploaded_bytes: 8157
|
105 |
+
total_bytes: 8157
|
106 |
+
}
|
107 |
+
|
108 |
+
2022-08-28 08:44:13,273 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
|
109 |
+
wandb_count: 5
|
110 |
+
}
|
111 |
+
pusher_stats {
|
112 |
+
uploaded_bytes: 8157
|
113 |
+
total_bytes: 8157
|
114 |
+
}
|
115 |
+
|
116 |
+
2022-08-28 08:44:13,630 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: done: true
|
117 |
+
exit_result {
|
118 |
+
}
|
119 |
+
file_counts {
|
120 |
+
wandb_count: 5
|
121 |
+
}
|
122 |
+
pusher_stats {
|
123 |
+
uploaded_bytes: 8157
|
124 |
+
total_bytes: 8157
|
125 |
+
}
|
126 |
+
local_info {
|
127 |
+
}
|
128 |
+
|
129 |
+
2022-08-28 08:44:14,787 INFO MainThread:51732 [wandb_run.py:_footer_history_summary_info():3102] rendering history
|
130 |
+
2022-08-28 08:44:14,787 INFO MainThread:51732 [wandb_run.py:_footer_history_summary_info():3134] rendering summary
|
131 |
+
2022-08-28 08:44:14,789 INFO MainThread:51732 [wandb_run.py:_footer_sync_info():3057] logging synced files
|
wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbe55913f815bc0f117700949409b7d3cb181dfad65b7969044117f11f40af4d
|
3 |
+
size 3379
|
wandb/run-20220828_085247-2hx8pk65/files/config.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
_wandb:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
cli_version: 0.12.15
|
7 |
+
framework: huggingface
|
8 |
+
huggingface_version: 4.21.0.dev0
|
9 |
+
is_jupyter_run: false
|
10 |
+
is_kaggle_kernel: false
|
11 |
+
python_version: 3.8.10
|
12 |
+
start_time: 1661676767
|
13 |
+
t:
|
14 |
+
1:
|
15 |
+
- 1
|
16 |
+
- 11
|
17 |
+
- 12
|
18 |
+
- 45
|
19 |
+
- 49
|
20 |
+
- 51
|
21 |
+
- 55
|
22 |
+
3:
|
23 |
+
- 13
|
24 |
+
4: 3.8.10
|
25 |
+
5: 0.12.15
|
26 |
+
6: 4.21.0.dev0
|
27 |
+
8:
|
28 |
+
- 5
|
wandb/run-20220828_085247-2hx8pk65/files/media/table/eval/step_10k_10000_8b44e8a00a036a18ffdf.table.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"columns": ["id", "label_str", "beam_1", "beam_2", "beam_3", "beam_4", "beam_5"], "data": [["2277-149896-0000", "he was in a fevered state of mind owing to the blight his wife's action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon this entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the bight his action threatened to cast upon his entire future"], ["2277-149896-0001", "he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which he would now regularly demand or there would be trouble it did not matter what he did", " he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she'd now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she could now regularly demand or there would be trouble it did not matter what he did"], ["2277-149896-0002", "hurstwood walked the floor mentally arranging the chief points of his situation", "hurstwood walked to the floor mentally arranging the chief points of his situation", "hirschwood walked to the floor mentally arranging the chief points of his situation", "herstwood walked to the floor mentally arranging the chief points of his situation", "hurstwood walked the floor mentally arranging the chief points of his situation", "hilstwood walked to the floor mentally arranging the chief points of his situation"], ["2277-149896-0003", "he also thought of his managerial position", "he also thought of his managerial position", "he also thought of this managerial position", " he also thought of his managerial position", "he also thought his managerial position", "here also thought of his managerial position"], ["2277-149896-0004", "how would the papers talk about it", "how would the papers talk about it", "how could the papers talk about it", "how'd the papers talk about it", "how did the papers talk about it", "how would the papers talk about it yes"], ["2277-149896-0005", "many little wrinkles gathered between his eyes as he contemplated this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this in his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this and this brow moistened"], ["2277-149896-0006", "he could arrange that satisfactorily for carrie would be glad to wait if necessary", "he could arrange that satisfactorily for carrie would be glad to wait if necessary", "he could arrange that satisfactorily for carey would be glad to wait if necessary", "he could arrange that satisfactorily for carry would be glad to wait if necessary", "he could arrange this satisfactorily for carrie would be glad to wait if necessary", "he could arrange the satisfactorily for carrie would be glad to wait if necessary"], ["2277-149896-0007", "he would see how things turned out to morrow and then he would talk to her they were going to meet as usual", "he would see how things turned out tomorrow and then he would talk to her they were going to meet as usual", "he would see how things turned out to morrow and then he would talk to her they were going to meet as usual", "he would see how things turned out today and then he would talk to her they were going to meet as usual", "he would see how things turn out tomorrow and then he would talk to her they were going to meet as usual", "he could see how things turned out tomorrow and then he would talk to her they were going to meet as usual"], ["2277-149896-0008", "for some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come that way and was relieved when all the envelops had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come this way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "from some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if nothing might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed"], ["2277-149896-0009", "while the danger had not lessened it had not as yet materialised and with him no news was good news", "while the danger had not lessened it had not as yet materialized and with him no news was good news", "while the danger had not lessened it had not as yet materialised and with him no news was good news", "while danger had not lessened it had not as yet materialized and with him no news was good news", "whilst the danger had not lessened it had not as yet materialized and with him no news was good news", "while the danger had not lessen it had not as yet materialized and with him no news was good news"], ["2277-149896-0010", "so little did he consider drouet that it never once occurred to him to worry about his finding out", "so little did he consider drouet that it never once occurred to him to worry about his finding out", "so little did he consider drue that it never once occurred to him to worry about his finding out", "so little did he consider drua that it never once occurred to him to worry about his finding out", "so little did he consider drura that it never once occurred to him to worry about his finding out", "so little did he consider druecca that it never once occurred to him to worry about his finding out"], ["2277-149896-0011", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decide that perhaps it was nothing", " he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminating and then decided that perhaps it was nothing"], ["2277-149896-0012", "she had not been able to get away this morning", "she had not been able to get away this morning", "he had not been able to get away this morning", " she had not been able to get away this morning", "the she had not been able to get away this morning", "the had not been able to get away this morning"], ["2277-149896-0013", "he would get one to day it would probably be on his desk when he got back he would look for it at once", "he would get one today it would probably be on his desk when he got back he would look for it at once", "he would get one to day it would probably be on his desk when he got back he would look for it at once", "he would get one tomorrow it would probably be on his desk when he got back he would look for it at once", "he could get one to day it would probably be on his desk when he got back he would look for it at once", "he could get one today it would probably be on his desk when he got back he would look for it at once"], ["2277-149896-0014", "after a time he gave up waiting and drearily headed for the madison car", "after a time he gave up waiting and drearily headed for the madison car", "after a time he gave up waiting and drearily headed for the mattinson car", "after a time he gave up waiting and drearily headed for the mattison car", "after a time he gave up waiting and drearily headed for the madeison car", "after a time he gave up waiting and drearily headed for the madezons car"], ["2277-149896-0015", "he went in and examined his letters but there was nothing from carrie", "he went in and examined his letters but there was nothing from carrie", "he went in and examined his letters but there was nothing from carey", "he went in and examined his letters but there was nothing from carry", "he went in and examined his letters but there was nothing from kerry", "he went in and examined his letters but there was nothing from cary"], ["2277-149896-0016", "fortunately there was nothing from his wife either", "fortunately there was nothing from his wife either", "fortunately there was nothing from his wife either", " fortunately there was nothing from his wife either", "fortunately there was nothing from this wife either", "forfortunately there was nothing from his wife either"], ["2277-149896-0017", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rectors for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rectory's for lunch and when he returned a messenger was waiting for him", " at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him"], ["2277-149896-0018", "his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words in reply go to the devil", " his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words and reply go to the devil", "his first impulses was to write but four words in reply go to the devil"], ["2277-149896-0019", "but he compromised by telling the boy that there would be no reply", "but he compromised by telling the boy that there would be no reply", "but hecompromised by telling the boy that there would be no reply", "but he comprised by telling the boy that there would be no reply", "but he compromise by telling the boy that there would be no reply", " but he compromised by telling the boy that there would be no reply"], ["2277-149896-0020", "then he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the result of this work", " then he sat down in his chair and gazed without seeing contemplating the result of his work", "than he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the results of his work"], ["2277-149896-0021", "what would she do about that the confounded wretch", "what would she do about that the confounded wretch", "what would you do about that the confounded wretch", " what would she do about that the confounded wretch", "what could she do about that the confounded wretch", "but what would she do about that the confounded wretch"], ["2277-149896-0022", "later however his old discretion asserted itself", "later however his old discretion asserted itself", " later however his old discretion asserted itself", "later however his old discretion ascertained itself", "later however this old discretion asserted itself", "late however his old discretion asserted itself"], ["2277-149896-0023", "something had to be done a climax was near and she would not sit idle", "something had to be done a climax was near and she would not sit idle", " something had to be done a climax was near and she would not sit idle", "something had to be done the climax was near and she would not sit idle", "anything had to be done a climax was near and she would not sit idle", "nothing had to be done a climax was near and she would not sit idle"], ["2277-149896-0024", "he knew her well enough to know that when she had decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon the plan she would follow it up", "he knew her well enough to know that when she decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon a plan she should follow it up", "he knew her well enough to know when she had decided upon a plan she would follow it up"], ["2277-149896-0025", "he arose from his chair and went and looked out into the street", "he arose from his chair and went and looked out into the street", "he arose from his chair and went to looked out into the street", "he arose from his chair and went into the street", "he arose from his chair and went and looked out into this street", "he arose from his chair and went and looked out in the street"], ["2277-149896-0026", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "the long drizzle had begun petersians had turned up collars and trousers at the bottom", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", " the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "long drizzle had begun pedestrians had turned up collars and trousers at the bottom"], ["2277-149896-0027", "hurstwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed out loud at the insistency of this thing", "hirschwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed aloud at the insistency of this thing", "hilstwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed out loud at the insincerity of this thing"], ["2277-149896-0028", "he put on his hat and looked around for his umbrella", "he put on his hat and looked around for his umbrella", "he put on his hat and looked round for his umbrella", "he put on his hat and looked around for this umbrella", " he put on his hat and looked around for his umbrella", "he put on his hat and looked about for his umbrella"], ["2277-149896-0029", "he would have some arrangement of this thing", "he would have some arrangement of this thing", "he will have some arrangement of this thing", "he'd have some arrangement of this thing", "he could have some arrangement of this thing", " he would have some arrangement of this thing"], ["2277-149896-0030", "he began to wish that he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had compromised in some way or another that he had sent the money perhaps he could do it up here", "he began to wish that he had comprised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had coordinated in some way or other that he had sent the money perhaps he could do it up here"], ["2277-149896-0031", "he would go in and see anyhow he would have no row", "he would go in and see anyhow he would have no row", "he would go in and see anyhow he would have no rowl", "he would go in and see anyhow he would have no rue", "he could go in and see anyhow he would have no row", "he would go in and see anyhow he would have no raoul"], ["2277-149896-0032", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of this situation and wished over and over that some solution would offer itself that he could see his way out", " by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the troubles of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see this way out"], ["2277-149896-0033", "then he rang the bell no answer", "then he rang the bell no answer", "than he rang the bell no answer", " then he rang the bell no answer", "this he rang the bell no answer", "there he rang the bell no answer"], ["2277-149896-0034", "he rang again this time harder still no answer", "he rang again this time harder still no answer", "he wrang again this time harder still no answer", "he ring again this time harder still no answer", "he rang again this time harder still no answer", "he ringed again this time harder still no answer"], ["2277-149897-0000", "when hurstwood got back to his office again he was in a greater quandary than ever", "when hurstwood got back to his office again he was in a greater quandary than ever", "when hurstwood got back to his office again he was in a greater quondary than ever", "when hurstwood got back to his office again he was in a greater quandy than ever", "when hurstwood got back to his office again he was in a greater quadry than ever", "when hurstwood got back to his office again he was in a greater quorum than ever"], ["2277-149897-0001", "he could hardly realise how it had all come about", "he could hardly realize how it had all come about", "he could hardly realise how it had all come about", "he could hardly realize how it has all come about", "he could hardly realize how it had also come about", " he could hardly realize how it had all come about"], ["2277-149897-0002", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him this morning", " no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him at morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him the morning"], ["2277-149897-0003", "he saw that in the excitement of recent events he had not formulated a plan upon that score", "he saw that in the excitement of recent events he had not formulated a plan upon that score", "he saw that in the excitement of recent events he had not formulated a plan upon the score", "he saw that in the excitement of recent events he had not formulated a plan upon this score", "he saw that in the excitement of recent events he had not communicated a plan upon that score", "he saw that in the excited of recent events he had not formulated a plan upon that score"], ["2277-149897-0004", "he was getting some vague comfort out of a good cigar but it was no panacea for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panacea for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panegas for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panatia for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no pennesia for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panacea for the illness which affected him"], ["2277-149897-0005", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in its requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmative and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", " it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and declaration that at last he got an envelope placed in it the requested amount and slowly sealed it up"], ["2277-149897-0006", "then he called harry the boy of all work around the place", "then he called harry the boy of all work around the place", "then he called harry the boy of all work round the place", " then he called harry the boy of all work around the place", "now he called harry the boy of all work around the place", "this he called harry the boy of all work around the place"], ["2277-149897-0007", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", " you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this addressed he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to his address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy"], ["2277-149897-0008", "any answer i guess not", "any answer i guess not", "any answer i guess not i guess not", "any answer i guess not a guess not", "any answer i guessed not", "any answer i guess not the"], ["2277-149897-0009", "the boy hastened away and the manager fell to his musings", "the boy hastened away and the manager fell to his musings", "the boy hasted away and the manager fell to his musings", "the boy hastily away and the manager fell to his musings", " the boy hastened away and the manager fell to his musings", "the boy hastened away and the manager fell into his musings"], ["2277-149897-0010", "he was beaten for to night and he might just as well make the best of it", "he was beaten for to night and he might just as well make the best of it", "he was beaten for tonight and he might just as well make the best of it", "he was beaten for tomorrow and he might just as well make the best of it", "he was beaten for today and he might just as well make the best of it", "he was beaten for to night and he might just as well make the best of it"], ["2277-149897-0011", "she would take the envelope and know that she had triumphed", "she would take the envelope and know that she had triumphed", " she would take the envelope and know that she had triumphed", "he would take the envelope and know that she had triumphed", "the would take the envelope and know that she had triumphed", "we would take the envelope and know that she had triumphed"], ["2277-149897-0012", "if he only had that letter back he wouldn't send it", "if he only had that letter back he wouldn't send it", "if he only had that letter back he won't send it", "if he only had that letter back he couldn't send it", "if he only had that letter back he didn't send it", "if he only had that letter back he wuzn't send it"], ["2277-149897-0013", "for relief he arose and joined in conversation with a few friends who were drinking", "for relief he arose and joined in the conversation with a few friends who were drinking", "for relief he arose and joined in the conversation with the few friends who were drinking", "for relief he arose in the conversation with a few friends who were drinking", "for relief he arose in the conversation with the few friends who were drinking", "for relief he arose and joined in a conversation with a few friends who were drinking"], ["2277-149897-0014", "all the time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being therein enacted", "all this time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being therein enacted", " all the time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being herein enacted"]]}
|
wandb/run-20220828_085247-2hx8pk65/files/output.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wandb/run-20220828_085247-2hx8pk65/files/requirements.txt
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.0.0
|
2 |
+
aiohttp==3.8.1
|
3 |
+
aiosignal==1.2.0
|
4 |
+
anyio==3.5.0
|
5 |
+
appdirs==1.4.4
|
6 |
+
argon2-cffi-bindings==21.2.0
|
7 |
+
argon2-cffi==21.3.0
|
8 |
+
asttokens==2.0.5
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==21.4.0
|
11 |
+
audioread==2.1.9
|
12 |
+
babel==2.10.1
|
13 |
+
backcall==0.2.0
|
14 |
+
beautifulsoup4==4.11.1
|
15 |
+
bleach==5.0.0
|
16 |
+
certifi==2021.10.8
|
17 |
+
cffi==1.15.0
|
18 |
+
charset-normalizer==2.0.12
|
19 |
+
chex==0.1.3
|
20 |
+
click==8.1.3
|
21 |
+
colorama==0.4.5
|
22 |
+
commonmark==0.9.1
|
23 |
+
cycler==0.11.0
|
24 |
+
datasets==2.4.1.dev0
|
25 |
+
debugpy==1.6.0
|
26 |
+
decorator==5.1.1
|
27 |
+
defusedxml==0.7.1
|
28 |
+
dill==0.3.4
|
29 |
+
dm-tree==0.1.7
|
30 |
+
docker-pycreds==0.4.0
|
31 |
+
entrypoints==0.4
|
32 |
+
etils==0.6.0
|
33 |
+
executing==0.8.3
|
34 |
+
fastjsonschema==2.15.3
|
35 |
+
filelock==3.6.0
|
36 |
+
flatbuffers==2.0
|
37 |
+
flax==0.5.3
|
38 |
+
fonttools==4.33.3
|
39 |
+
frozenlist==1.3.0
|
40 |
+
fsspec==2022.3.0
|
41 |
+
gitdb==4.0.9
|
42 |
+
gitpython==3.1.27
|
43 |
+
huggingface-hub==0.5.1
|
44 |
+
idna==3.3
|
45 |
+
ijson==3.1.4
|
46 |
+
importlib-metadata==4.11.3
|
47 |
+
importlib-resources==5.7.1
|
48 |
+
iniconfig==1.1.1
|
49 |
+
ipdb==0.13.9
|
50 |
+
ipykernel==6.13.0
|
51 |
+
ipython-genutils==0.2.0
|
52 |
+
ipython==8.3.0
|
53 |
+
jax==0.3.15
|
54 |
+
jaxlib==0.3.15
|
55 |
+
jedi==0.18.1
|
56 |
+
jinja2==3.1.2
|
57 |
+
jiwer==2.3.0
|
58 |
+
joblib==1.1.0
|
59 |
+
json5==0.9.6
|
60 |
+
jsonschema==4.4.0
|
61 |
+
jupyter-client==7.3.0
|
62 |
+
jupyter-core==4.10.0
|
63 |
+
jupyter-server==1.17.0
|
64 |
+
jupyterlab-pygments==0.2.2
|
65 |
+
jupyterlab-server==2.13.0
|
66 |
+
jupyterlab==3.4.0
|
67 |
+
kiwisolver==1.4.2
|
68 |
+
librosa==0.9.1
|
69 |
+
libtpu-nightly==0.1.dev20220722
|
70 |
+
llvmlite==0.38.0
|
71 |
+
markupsafe==2.1.1
|
72 |
+
matplotlib-inline==0.1.3
|
73 |
+
matplotlib==3.5.1
|
74 |
+
mistune==0.8.4
|
75 |
+
msgpack==1.0.3
|
76 |
+
multidict==6.0.2
|
77 |
+
multiprocess==0.70.12.2
|
78 |
+
nbclassic==0.3.7
|
79 |
+
nbclient==0.6.2
|
80 |
+
nbconvert==6.5.0
|
81 |
+
nbformat==5.4.0
|
82 |
+
nest-asyncio==1.5.5
|
83 |
+
nltk==3.7
|
84 |
+
notebook-shim==0.1.0
|
85 |
+
notebook==6.4.11
|
86 |
+
numba==0.55.1
|
87 |
+
numpy==1.21.0
|
88 |
+
opt-einsum==3.3.0
|
89 |
+
optax==0.1.2
|
90 |
+
packaging==21.3
|
91 |
+
pandas==1.4.2
|
92 |
+
pandocfilters==1.5.0
|
93 |
+
parso==0.8.3
|
94 |
+
pathtools==0.1.2
|
95 |
+
pexpect==4.8.0
|
96 |
+
pickleshare==0.7.5
|
97 |
+
pillow==9.1.0
|
98 |
+
pip==20.0.2
|
99 |
+
pkg-resources==0.0.0
|
100 |
+
pluggy==1.0.0
|
101 |
+
pooch==1.6.0
|
102 |
+
prometheus-client==0.14.1
|
103 |
+
promise==2.3
|
104 |
+
prompt-toolkit==3.0.29
|
105 |
+
protobuf==3.20.1
|
106 |
+
psutil==5.9.0
|
107 |
+
ptyprocess==0.7.0
|
108 |
+
pure-eval==0.2.2
|
109 |
+
py==1.11.0
|
110 |
+
pyarrow==7.0.0
|
111 |
+
pycparser==2.21
|
112 |
+
pycryptodome==3.14.1
|
113 |
+
pygments==2.12.0
|
114 |
+
pyparsing==3.0.8
|
115 |
+
pyrsistent==0.18.1
|
116 |
+
pytest==7.1.2
|
117 |
+
python-dateutil==2.8.2
|
118 |
+
python-levenshtein==0.12.2
|
119 |
+
pytz==2022.1
|
120 |
+
pyyaml==6.0
|
121 |
+
pyzmq==22.3.0
|
122 |
+
regex==2022.4.24
|
123 |
+
requests==2.27.1
|
124 |
+
resampy==0.2.2
|
125 |
+
responses==0.18.0
|
126 |
+
rich==11.1.0
|
127 |
+
rouge-score==0.1.2
|
128 |
+
sacremoses==0.0.49
|
129 |
+
scikit-learn==1.0.2
|
130 |
+
scipy==1.8.0
|
131 |
+
send2trash==1.8.0
|
132 |
+
sentry-sdk==1.5.10
|
133 |
+
seqeval==1.2.2
|
134 |
+
setproctitle==1.2.3
|
135 |
+
setuptools==44.0.0
|
136 |
+
shortuuid==1.0.8
|
137 |
+
six==1.16.0
|
138 |
+
smmap==5.0.0
|
139 |
+
sniffio==1.2.0
|
140 |
+
soundfile==0.10.3.post1
|
141 |
+
soupsieve==2.3.2.post1
|
142 |
+
speechcolab==0.0.6a0
|
143 |
+
stack-data==0.2.0
|
144 |
+
tensorstore==0.1.21
|
145 |
+
terminado==0.13.3
|
146 |
+
threadpoolctl==3.1.0
|
147 |
+
tinycss2==1.1.1
|
148 |
+
tokenizers==0.12.1
|
149 |
+
toml==0.10.2
|
150 |
+
tomli==2.0.1
|
151 |
+
toolz==0.11.2
|
152 |
+
torch==1.11.0+cpu
|
153 |
+
torchaudio==0.11.0+cpu
|
154 |
+
tornado==6.1
|
155 |
+
tqdm==4.64.0
|
156 |
+
traitlets==5.1.1
|
157 |
+
transformers==4.21.0.dev0
|
158 |
+
typing-extensions==4.2.0
|
159 |
+
urllib3==1.26.9
|
160 |
+
wandb==0.12.15
|
161 |
+
wcwidth==0.2.5
|
162 |
+
webencodings==0.5.1
|
163 |
+
websocket-client==1.3.2
|
164 |
+
wheel==0.37.1
|
165 |
+
xxhash==3.0.0
|
166 |
+
yarl==1.7.2
|
167 |
+
zipp==3.8.0
|
wandb/run-20220828_085247-2hx8pk65/files/wandb-metadata.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.11.0-1028-gcp-x86_64-with-glibc2.29",
|
3 |
+
"python": "3.8.10",
|
4 |
+
"heartbeatAt": "2022-08-28T08:52:48.553677",
|
5 |
+
"startedAt": "2022-08-28T08:52:47.513374",
|
6 |
+
"docker": null,
|
7 |
+
"cpu_count": 96,
|
8 |
+
"cuda": null,
|
9 |
+
"args": [
|
10 |
+
"--dataset_name=librispeech_asr",
|
11 |
+
"--model_name_or_path=sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
|
12 |
+
"--dataset_config_name=all",
|
13 |
+
"--train_split_name=train.clean.100+train.clean.360+train.other.500",
|
14 |
+
"--eval_split_name=validation.clean",
|
15 |
+
"--test_split_name=validation.other+test.clean+test.other",
|
16 |
+
"--text_column_name=text",
|
17 |
+
"--id_column_name=id",
|
18 |
+
"--output_dir=./",
|
19 |
+
"--wandb_project=librispeech_960h",
|
20 |
+
"--wandb_name=flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
21 |
+
"--dataset_cache_dir=/home/sanchitgandhi/cache/huggingface/datasets",
|
22 |
+
"--per_device_train_batch_size=8",
|
23 |
+
"--per_device_eval_batch_size=4",
|
24 |
+
"--learning_rate=1e-4",
|
25 |
+
"--warmup_steps=500",
|
26 |
+
"--logging_steps=25",
|
27 |
+
"--max_steps=50000",
|
28 |
+
"--eval_steps=10000",
|
29 |
+
"--save_steps=10000",
|
30 |
+
"--generation_max_length=200",
|
31 |
+
"--generation_num_beams=5",
|
32 |
+
"--generation_length_penalty=1.2",
|
33 |
+
"--hidden_dropout=0.2",
|
34 |
+
"--activation_dropout=0.2",
|
35 |
+
"--feat_proj_dropout=0.2",
|
36 |
+
"--overwrite_output_dir",
|
37 |
+
"--gradient_checkpointing",
|
38 |
+
"--freeze_feature_encoder",
|
39 |
+
"--predict_with_generate",
|
40 |
+
"--do_lower_case",
|
41 |
+
"--do_eval",
|
42 |
+
"--do_train",
|
43 |
+
"--do_predict",
|
44 |
+
"--push_to_hub",
|
45 |
+
"--use_auth_token"
|
46 |
+
],
|
47 |
+
"state": "running",
|
48 |
+
"program": "run_flax_speech_recognition_seq2seq.py",
|
49 |
+
"codePath": "run_flax_speech_recognition_seq2seq.py",
|
50 |
+
"git": {
|
51 |
+
"remote": "https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
52 |
+
"commit": "140399a622e2a82685fa4b9727f3d970b8bef9e0"
|
53 |
+
},
|
54 |
+
"email": "sanchit@huggingface.co",
|
55 |
+
"root": "/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
|
56 |
+
"host": "t1v-n-5966b949-w-0",
|
57 |
+
"username": "sanchitgandhi",
|
58 |
+
"executable": "/home/sanchitgandhi/hf/bin/python"
|
59 |
+
}
|
wandb/run-20220828_085247-2hx8pk65/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"train/decoder_grad_norm": 0.5876523852348328, "train/decoder_param_norm": 1057.45703125, "train/encoder_grad_norm": 0.38440409302711487, "train/encoder_param_norm": 2316.3564453125, "train/grad_norm": 0.7022120952606201, "layer_grad_norm/": {"decoder": {"model": {"decoder": {"embed_positions": {"embedding": 0.10323784500360489}, "embed_tokens": {"embedding": 0.16808316111564636}, "layernorm_embedding": {"bias": 0.03703528642654419, "scale": 0.060806743800640106}, "layers": {"FlaxBartDecoderLayers": {"encoder_attn": {"k_proj": {"bias": 1.75027107616188e-05, "kernel": 0.030463965609669685}, "out_proj": {"bias": 0.024376848712563515, "kernel": 0.08760593086481094}, "q_proj": {"bias": 0.0016024636570364237, "kernel": 0.034829143434762955}, "v_proj": {"bias": 0.04787713289260864, "kernel": 0.07169140875339508}}, "encoder_attn_layer_norm": {"bias": 0.03529948368668556, "scale": 0.0380270853638649}, "fc1": {"bias": 0.013248836621642113, "kernel": 0.33658137917518616}, "fc2": {"bias": 0.030859898775815964, "kernel": 0.2677602767944336}, "final_layer_norm": {"bias": 0.1120176762342453, "scale": 0.05825764685869217}, "self_attn": {"k_proj": {"bias": 6.563532224390656e-06, "kernel": 0.047542572021484375}, "out_proj": {"bias": 0.068998321890831, "kernel": 0.15063460171222687}, "q_proj": {"bias": 0.003958633169531822, "kernel": 0.05425203591585159}, "v_proj": {"bias": 0.07329808175563812, "kernel": 0.198069229722023}}, "self_attn_layer_norm": {"bias": 0.023308640345931053, "scale": 0.030806636437773705}}}}}}, "encoder": {"adapter": {"layers": {"0": {"conv": {"bias": 0.04864540696144104, "kernel": 0.133722722530365}}, "1": {"conv": {"bias": 0.04470941796898842, "kernel": 0.09400613605976105}}, "2": {"conv": {"bias": 0.05692768096923828, "kernel": 0.1417897492647171}}}}, "encoder": {"layer_norm": {"bias": 0.16896693408489227, "scale": 0.08190205693244934}, "layers": {"FlaxWav2Vec2EncoderLayers": {"attention": {"k_proj": {"bias": 5.699832854588749e-06, "kernel": 0.03451818600296974}, "out_proj": {"bias": 0.004949449095875025, "kernel": 0.0711507499217987}, "q_proj": {"bias": 0.006232084706425667, "kernel": 0.03630899265408516}, "v_proj": {"bias": 0.021894006058573723, "kernel": 0.0699479877948761}}, "feed_forward": {"intermediate_dense": {"bias": 0.010628663003444672, "kernel": 0.08824677765369415}, "output_dense": {"bias": 0.0046577295288443565, "kernel": 0.07864432781934738}}, "final_layer_norm": {"bias": 0.053700175136327744, "scale": 0.06233147531747818}, "layer_norm": {"bias": 0.09289932250976562, "scale": 0.07505689561367035}}}, "pos_conv_embed": {"conv": {"bias": 0.001811191556043923, "weight_g": 0.04629991203546524, "weight_v": 0.05902065336704254}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.01040860079228878, "scale": 0.009696024470031261}, "projection": {"bias": 0.002452271291986108, "kernel": 0.06397733092308044}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"decoder": {"model": {"decoder": {"embed_positions": {"embedding": 58.57985305786133}, "embed_tokens": {"embedding": 628.9428100585938}, "layernorm_embedding": {"bias": 2.4099645614624023, "scale": 13.944293022155762}, "layers": {"FlaxBartDecoderLayers": {"encoder_attn": {"k_proj": {"bias": 47.96258544921875, "kernel": 330.1817932128906}, "out_proj": {"bias": 6.197176456451416, "kernel": 226.72259521484375}, "q_proj": {"bias": 20.796918869018555, "kernel": 337.1412658691406}, "v_proj": {"bias": 3.727905035018921, "kernel": 230.9994354248047}}, "encoder_attn_layer_norm": {"bias": 10.427277565002441, "scale": 56.38846206665039}, "fc1": {"bias": 25.47351837158203, "kernel": 339.21954345703125}, "fc2": {"bias": 7.897115707397461, "kernel": 243.82398986816406}, "final_layer_norm": {"bias": 4.000784873962402, "scale": 63.70562744140625}, "self_attn": {"k_proj": {"bias": 59.513954162597656, "kernel": 278.91595458984375}, "out_proj": {"bias": 3.8339650630950928, "kernel": 131.7364501953125}, "q_proj": {"bias": 32.09528732299805, "kernel": 282.0332336425781}, "v_proj": {"bias": 2.626418352127075, "kernel": 140.15884399414062}}, "self_attn_layer_norm": {"bias": 8.851421356201172, "scale": 84.72929382324219}}}}}}, "encoder": {"adapter": {"layers": {"0": {"conv": {"bias": 0.5224539637565613, "kernel": 58.06698226928711}}, "1": {"conv": {"bias": 0.6238547563552856, "kernel": 55.76792907714844}}, "2": {"conv": {"bias": 0.8834269046783447, "kernel": 55.83806610107422}}}}, "encoder": {"layer_norm": {"bias": 0.2885725498199463, "scale": 4.501636505126953}, "layers": {"FlaxWav2Vec2EncoderLayers": {"attention": {"k_proj": {"bias": 19.359642028808594, "kernel": 551.2367553710938}, "out_proj": {"bias": 16.819419860839844, "kernel": 703.838134765625}, "q_proj": {"bias": 40.78517532348633, "kernel": 543.7529907226562}, "v_proj": {"bias": 15.60958194732666, "kernel": 695.4569091796875}}, "feed_forward": {"intermediate_dense": {"bias": 24.515138626098633, "kernel": 1373.99365234375}, "output_dense": {"bias": 20.76974868774414, "kernel": 1299.6435546875}}, "final_layer_norm": {"bias": 32.476783752441406, "scale": 141.65736389160156}, "layer_norm": {"bias": 7.329699516296387, "scale": 45.53441619873047}}}, "pos_conv_embed": {"conv": {"bias": 15.283638954162598, "weight_g": 21.029205322265625, "weight_v": 212.9462127685547}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.5982058644294739, "kernel": 8.08896541595459}, "layer_norm": {"bias": 10.069783210754395, "scale": 10.451257705688477}}, "1": {"conv": {"bias": 4.74075174331665, "kernel": 90.8435287475586}, "layer_norm": {"bias": 6.922820091247559, "scale": 19.5467586517334}}, "2": {"conv": {"bias": 6.7732415199279785, "kernel": 146.13897705078125}, "layer_norm": {"bias": 9.044225692749023, "scale": 19.424888610839844}}, "3": {"conv": {"bias": 5.224758148193359, "kernel": 159.10508728027344}, "layer_norm": {"bias": 8.319666862487793, "scale": 17.64743423461914}}, "4": {"conv": {"bias": 4.434978008270264, "kernel": 157.35813903808594}, "layer_norm": {"bias": 9.193974494934082, "scale": 15.562357902526855}}, "5": {"conv": {"bias": 5.297643661499023, "kernel": 131.1835174560547}, "layer_norm": {"bias": 10.735219955444336, "scale": 13.812533378601074}}, "6": {"conv": {"bias": 5.615579128265381, "kernel": 136.41822814941406}, "layer_norm": {"bias": 12.515308380126953, "scale": 11.152680397033691}}}}, "feature_projection": {"layer_norm": {"bias": 9.422893524169922, "scale": 27.84585189819336}, "projection": {"bias": 4.289161682128906, "kernel": 88.30554962158203}}, "masked_spec_embed": 26.247730255126953}}, "train/learning_rate": 8.086059824563563e-05, "train/loss": 0.1043805480003357, "train/param_norm": 2546.3154296875, "_timestamp": 1661727380, "_runtime": 50613, "_step": 9975}
|
wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wandb/run-20220828_085247-2hx8pk65/logs/debug.log
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/.config/wandb/settings
|
2 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/settings
|
3 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from environment variables: {}
|
4 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Inferring run settings from compute environment: {'program_relpath': 'run_flax_speech_recognition_seq2seq.py', 'program': 'run_flax_speech_recognition_seq2seq.py'}
|
5 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:_log_setup():437] Logging user logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_085247-2hx8pk65/logs/debug.log
|
6 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:_log_setup():438] Logging internal logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log
|
7 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():471] calling init triggers
|
8 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():474] wandb.init called with sweep_config: {}
|
9 |
+
config: {}
|
10 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():524] starting backend
|
11 |
+
2022-08-28 08:52:47,515 INFO MainThread:53859 [backend.py:_multiprocessing_setup():97] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
12 |
+
2022-08-28 08:52:47,546 INFO MainThread:53859 [backend.py:ensure_launched():217] starting backend process...
|
13 |
+
2022-08-28 08:52:47,572 INFO MainThread:53859 [backend.py:ensure_launched():222] started backend process with pid: 54989
|
14 |
+
2022-08-28 08:52:47,574 INFO MainThread:53859 [wandb_init.py:init():533] backend started and connected
|
15 |
+
2022-08-28 08:52:47,585 INFO MainThread:53859 [wandb_init.py:init():597] updated telemetry
|
16 |
+
2022-08-28 08:52:47,649 INFO MainThread:53859 [wandb_init.py:init():628] communicating run to backend with 30 second timeout
|
17 |
+
2022-08-28 08:52:48,479 INFO MainThread:53859 [wandb_run.py:_on_init():1923] communicating current version
|
18 |
+
2022-08-28 08:52:48,543 INFO MainThread:53859 [wandb_run.py:_on_init():1927] got version response upgrade_message: "wandb version 0.13.2 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
|
19 |
+
|
20 |
+
2022-08-28 08:52:48,543 INFO MainThread:53859 [wandb_init.py:init():659] starting run threads in backend
|
21 |
+
2022-08-28 08:52:48,582 INFO MainThread:53859 [wandb_run.py:_console_start():1897] atexit reg
|
22 |
+
2022-08-28 08:52:48,582 INFO MainThread:53859 [wandb_run.py:_redirect():1770] redirect: SettingsConsole.REDIRECT
|
23 |
+
2022-08-28 08:52:48,583 INFO MainThread:53859 [wandb_run.py:_redirect():1775] Redirecting console.
|
24 |
+
2022-08-28 08:52:48,585 INFO MainThread:53859 [wandb_run.py:_redirect():1831] Redirects installed.
|
25 |
+
2022-08-28 08:52:48,585 INFO MainThread:53859 [wandb_init.py:init():684] run started, returning control to user process
|
wandb/run-20220828_085247-2hx8pk65/run-2hx8pk65.wandb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:094e92de49c7288ddfac32754880e9359cb30d1406e2d3bdff46b108a8c651aa
|
3 |
+
size 4469804
|