sanchit-gandhi's picture
Saving weights and logs of epoch 1
8319971
raw
history blame
No virus
38 kB
05/07/2022 18:18:56 - INFO - __main__ - Training/evaluation parameters FlaxSeq2SeqTrainingArguments(
_n_gpu=-1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=,
deepspeed=None,
disable_tqdm=None,
do_eval=True,
do_predict=True,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=5,
evaluation_strategy=no,
final_generation_max_length=50,
final_generation_num_beams=2,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
generation_length_penalty=1,
generation_max_length=40,
generation_num_beams=1,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=0.0003,
length_column_name=input_length,
load_best_model_at_end=False,
local_rank=-1,
log_level=passive,
log_level_replica=passive,
log_on_each_node=True,
logging_dir=None,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1,
logging_strategy=steps,
lr_scheduler_type=linear,
matmul_precision=default,
max_grad_norm=1.0,
max_steps=15,
metric_for_best_model=None,
mp_parameters=,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_hf,
output_dir=./,
overwrite_output_dir=True,
past_index=-1,
per_device_eval_batch_size=2,
per_device_train_batch_size=2,
precision=full,
predict_with_generate=True,
prediction_loss_only=False,
push_to_hub=True,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
remove_unused_columns=True,
report_to=None,
resume_from_checkpoint=None,
run_name=None,
save_on_each_node=False,
save_steps=5,
save_strategy=steps,
save_total_limit=1,
seed=42,
sharded_ddp=,
skip_memory_metrics=True,
sortish_sampler=False,
tf32=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_legacy_prediction_loop=False,
warmup_ratio=0.0,
warmup_steps=500,
weight_decay=0.0,
xpu_backend=None,
)
05/07/2022 18:18:56 - INFO - __main__ - JAX devices: 8, matmul precision: default
05/07/2022 18:18:57 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
05/07/2022 18:18:57 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
loading configuration file ./config.json
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
Model config SpeechEncoderDecoderConfig {
"_name_or_path": "./",
"architectures": [
"SpeechEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": null,
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": null,
"d_model": 16,
"decoder_attention_heads": 4,
"decoder_ffn_dim": 4,
"decoder_layerdrop": 0.0,
"decoder_layers": 2,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.1,
"early_stopping": false,
"encoder_attention_heads": 4,
"encoder_ffn_dim": 4,
"encoder_layerdrop": 0.0,
"encoder_layers": 2,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": 2,
"fuse_matmuls": false,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 100,
"min_length": 0,
"model_type": "bart",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 2,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.18.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": true,
"use_scan": true,
"vocab_size": 1000
},
"decoder_start_token_id": 0,
"encoder": {
"_name_or_path": "",
"activation_dropout": 0.1,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": true,
"add_cross_attention": false,
"apply_spec_augment": true,
"architectures": null,
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 1,
"chunk_size_feed_forward": 0,
"classifier_proj_size": 256,
"codevector_dim": 256,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
32,
32,
32
],
"conv_kernel": [
8,
8,
8
],
"conv_stride": [
4,
4,
4
],
"cross_attention_hidden_size": null,
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"decoder_start_token_id": null,
"diversity_loss_weight": 0.1,
"diversity_penalty": 0.0,
"do_sample": false,
"do_stable_layer_norm": true,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.0,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"fuse_matmuls": false,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 16,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 20,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"layerdrop": 0.0,
"length_penalty": 1.0,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.1,
"max_length": 20,
"min_length": 0,
"model_type": "wav2vec2",
"no_repeat_ngram_size": 0,
"num_adapter_layers": 3,
"num_attention_heads": 2,
"num_beam_groups": 1,
"num_beams": 1,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 2,
"num_conv_pos_embeddings": 16,
"num_feat_extract_layers": 3,
"num_hidden_layers": 4,
"num_negatives": 10,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_size": 16,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 0,
"prefix": null,
"problem_type": null,
"proj_codevector_dim": 256,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"tdnn_dilation": [
1,
2,
3,
1,
1
],
"tdnn_dim": [
512,
512,
512,
512,
1500
],
"tdnn_kernel": [
5,
3,
3,
1,
1
],
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.18.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_scan": true,
"use_weighted_layer_sum": false,
"vocab_size": 32,
"xvector_output_dim": 512
},
"eos_token_id": 2,
"is_encoder_decoder": true,
"max_length": 40,
"model_type": "speech-encoder-decoder",
"pad_token_id": 1,
"processor_class": "Wav2Vec2Processor",
"tie_word_embeddings": false,
"transformers_version": null,
"use_cache": false
}
loading feature extractor configuration file ./preprocessor_config.json
Feature extractor Wav2Vec2FeatureExtractor {
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"return_attention_mask": false,
"sampling_rate": 16000
}
Didn't find file ./added_tokens.json. We won't load it.
loading file ./vocab.json
loading file ./merges.txt
loading file ./tokenizer.json
loading file None
loading file ./special_tokens_map.json
loading file ./tokenizer_config.json
loading weights file ./flax_model.msgpack
05/07/2022 18:18:58 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
05/07/2022 18:18:59 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
All model checkpoint weights were used when initializing FlaxSpeechEncoderDecoderModel.
All the weights of FlaxSpeechEncoderDecoderModel were initialized from the model checkpoint at ./.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxSpeechEncoderDecoderModel for predictions without further training.
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 194.44ba/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 737.52ba/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 693.16ba/s]
Feature extractor saved in ./preprocessor_config.json
tokenizer config file saved in ./tokenizer_config.json
Special tokens file saved in ./special_tokens_map.json
Configuration saved in ./config.json
loading feature extractor configuration file ./preprocessor_config.json
loading configuration file ./config.json
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
Model config SpeechEncoderDecoderConfig {
"_name_or_path": "./",
"architectures": [
"SpeechEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": null,
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": null,
"d_model": 16,
"decoder_attention_heads": 4,
"decoder_ffn_dim": 4,
"decoder_layerdrop": 0.0,
"decoder_layers": 2,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.1,
"early_stopping": false,
"encoder_attention_heads": 4,
"encoder_ffn_dim": 4,
"encoder_layerdrop": 0.0,
"encoder_layers": 2,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": 2,
"fuse_matmuls": false,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 100,
"min_length": 0,
"model_type": "bart",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 2,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.18.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": true,
"use_scan": true,
"vocab_size": 1000
},
"decoder_start_token_id": 0,
"encoder": {
"_name_or_path": "",
"activation_dropout": 0.1,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": true,
"add_cross_attention": false,
"apply_spec_augment": true,
"architectures": null,
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 1,
"chunk_size_feed_forward": 0,
"classifier_proj_size": 256,
"codevector_dim": 256,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
32,
32,
32
],
"conv_kernel": [
8,
8,
8
],
"conv_stride": [
4,
4,
4
],
"cross_attention_hidden_size": null,
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"decoder_start_token_id": null,
"diversity_loss_weight": 0.1,
"diversity_penalty": 0.0,
"do_sample": false,
"do_stable_layer_norm": true,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.0,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"fuse_matmuls": false,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 16,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 20,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"layerdrop": 0.0,
"length_penalty": 1.0,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.1,
"max_length": 20,
"min_length": 0,
"model_type": "wav2vec2",
"no_repeat_ngram_size": 0,
"num_adapter_layers": 3,
"num_attention_heads": 2,
"num_beam_groups": 1,
"num_beams": 1,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 2,
"num_conv_pos_embeddings": 16,
"num_feat_extract_layers": 3,
"num_hidden_layers": 4,
"num_negatives": 10,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_size": 16,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 0,
"prefix": null,
"problem_type": null,
"proj_codevector_dim": 256,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"tdnn_dilation": [
1,
2,
3,
1,
1
],
"tdnn_dim": [
512,
512,
512,
512,
1500
],
"tdnn_kernel": [
5,
3,
3,
1,
1
],
"temperature": 1.0,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.18.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_scan": true,
"use_weighted_layer_sum": false,
"vocab_size": 32,
"xvector_output_dim": 512
},
"eos_token_id": 2,
"is_encoder_decoder": true,
"max_length": 40,
"model_type": "speech-encoder-decoder",
"pad_token_id": 1,
"processor_class": "Wav2Vec2Processor",
"tie_word_embeddings": false,
"transformers_version": null,
"use_cache": false
}
loading feature extractor configuration file ./preprocessor_config.json
Feature extractor Wav2Vec2FeatureExtractor {
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"return_attention_mask": false,
"sampling_rate": 16000
}
Didn't find file ./added_tokens.json. We won't load it.
loading file ./vocab.json
loading file ./merges.txt
loading file ./tokenizer.json
loading file None
loading file ./special_tokens_map.json
loading file ./tokenizer_config.json
2022-05-07 18:19:09.719228: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-05-07 18:19:09.719261: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-172908eb2439798c.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-172908eb2439798c.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-3788df45b822e09d.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-836619e0a5bdc111.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-760f29b7172d4ca5.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-760f29b7172d4ca5.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-c95f128933ad1116.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-0b5b21388629ee07.arrow
05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-885308cf79812c65.arrow
05/07/2022 18:19:11 - WARNING - huggingface_hub.repository - /home/sanchitgandhi/flax-dummy/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-dummy. Make sure you pull the latest changes with `repo.git_pull()`.
05/07/2022 18:19:12 - INFO - __main__ - ***** Running training *****
05/07/2022 18:19:12 - INFO - __main__ - Num examples = 70
05/07/2022 18:19:12 - INFO - __main__ - Num Epochs = 4
05/07/2022 18:19:12 - INFO - __main__ - Instantaneous batch size per device = 2
05/07/2022 18:19:12 - INFO - __main__ - Num gradient accumulation steps = 1
05/07/2022 18:19:12 - INFO - __main__ - Total train batch size (w. parallel & distributed) = 16
05/07/2022 18:19:12 - INFO - __main__ - Total optimization steps = 15
05/07/2022 18:19:12 - INFO - __main__ - Gradient checkpointing: False
05/07/2022 18:19:12 - INFO - __main__ - Use scan: True
05/07/2022 18:19:12 - INFO - __main__ - Fuse matmuls: False
/home/sanchitgandhi/flax-dummy/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-dummy. Make sure you pull the latest changes with `repo.git_pull()`.
Epoch ... (1/4): 0%| | 0/4 [00:00<?, ?it/s]
Epoch ... (1/4): 0%| | 0/4 [00:01<?, ?it/s]
Traceback (most recent call last):
File "run_flax_speech_recognition_seq2seq.py", line 1396, in <module>
main()
File "run_flax_speech_recognition_seq2seq.py", line 1303, in main
state, train_metric = p_train_step(state, batch)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 1979, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 1855, in pmap_f
out = pxla.xla_pmap(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1797, in bind
return map_bind(self, fun, *args, **params)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1828, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1800, in process
return trace.process_map(self, fun, tracers, params)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 614, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 792, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 820, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 993, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 900, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1798, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_flax_speech_recognition_seq2seq.py", line 1086, in train_step
(loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 954, in value_and_grad_f
ans, vjp_py, aux = _vjp(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 2413, in _vjp
out_primal, out_vjp, aux = ad.vjp(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/ad.py", line 106, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 592, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_flax_speech_recognition_seq2seq.py", line 1073, in compute_loss
logits = state.apply_fn(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 765, in __call__
return self.module.apply(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 1162, in apply
return apply(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/scope.py", line 806, in wrapper
y = fn(root, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 1446, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 340, in __call__
decoder_outputs = self.decoder(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 752, in __call__
outputs = self.model(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 718, in __call__
return self.decoder(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 500, in __call__
outputs = self.layers(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 389, in __call__
hidden_states, _ = scan_with_axes(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 312, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 213, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 286, in wrapper
y = fn(scopes, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 305, in core_fn
res = fn(cloned, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 312, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 766, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 135, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 592, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 111, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 750, in scanned
c, y = fn(scope, c, *args)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 305, in core_fn
res = fn(cloned, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 308, in __call__
hidden_states, self_attn_weights = self.self_attn(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 206, in __call__
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 357, in _broadcast_to
raise ValueError(msg.format(arr_shape, shape))
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 127) and requested shape (2, 1, 100, 100)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_flax_speech_recognition_seq2seq.py", line 1396, in <module>
main()
File "run_flax_speech_recognition_seq2seq.py", line 1303, in main
state, train_metric = p_train_step(state, batch)
File "run_flax_speech_recognition_seq2seq.py", line 1086, in train_step
(loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
File "run_flax_speech_recognition_seq2seq.py", line 1073, in compute_loss
logits = state.apply_fn(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 765, in __call__
return self.module.apply(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 340, in __call__
decoder_outputs = self.decoder(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 752, in __call__
outputs = self.model(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 718, in __call__
return self.decoder(*args, **kwargs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 500, in __call__
outputs = self.layers(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 389, in __call__
hidden_states, _ = scan_with_axes(
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 135, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 111, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 308, in __call__
hidden_states, self_attn_weights = self.self_attn(
File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 206, in __call__
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 357, in _broadcast_to
raise ValueError(msg.format(arr_shape, shape))
ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 127) and requested shape (2, 1, 100, 100)