05/07/2022 18:18:56 - INFO - __main__ - Training/evaluation parameters FlaxSeq2SeqTrainingArguments( _n_gpu=-1, adafactor=False, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, bf16=False, bf16_full_eval=False, data_seed=None, dataloader_drop_last=False, dataloader_num_workers=0, dataloader_pin_memory=True, ddp_bucket_cap_mb=None, ddp_find_unused_parameters=None, debug=, deepspeed=None, disable_tqdm=None, do_eval=True, do_predict=True, do_train=True, eval_accumulation_steps=None, eval_delay=0, eval_steps=5, evaluation_strategy=no, final_generation_max_length=50, final_generation_num_beams=2, fp16=False, fp16_backend=auto, fp16_full_eval=False, fp16_opt_level=O1, generation_length_penalty=1, generation_max_length=40, generation_num_beams=1, gradient_accumulation_steps=1, gradient_checkpointing=False, greater_is_better=None, group_by_length=False, half_precision_backend=auto, hub_model_id=None, hub_strategy=every_save, hub_token=, ignore_data_skip=False, label_names=None, label_smoothing_factor=0.0, learning_rate=0.0003, length_column_name=input_length, load_best_model_at_end=False, local_rank=-1, log_level=passive, log_level_replica=passive, log_on_each_node=True, logging_dir=None, logging_first_step=False, logging_nan_inf_filter=True, logging_steps=1, logging_strategy=steps, lr_scheduler_type=linear, matmul_precision=default, max_grad_norm=1.0, max_steps=15, metric_for_best_model=None, mp_parameters=, no_cuda=False, num_train_epochs=3.0, optim=adamw_hf, output_dir=./, overwrite_output_dir=True, past_index=-1, per_device_eval_batch_size=2, per_device_train_batch_size=2, precision=full, predict_with_generate=True, prediction_loss_only=False, push_to_hub=True, push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=, remove_unused_columns=True, report_to=None, resume_from_checkpoint=None, run_name=None, save_on_each_node=False, save_steps=5, save_strategy=steps, save_total_limit=1, seed=42, sharded_ddp=, skip_memory_metrics=True, sortish_sampler=False, tf32=None, tpu_metrics_debug=False, tpu_num_cores=None, use_legacy_prediction_loop=False, warmup_ratio=0.0, warmup_steps=500, weight_decay=0.0, xpu_backend=None, ) 05/07/2022 18:18:56 - INFO - __main__ - JAX devices: 8, matmul precision: default 05/07/2022 18:18:57 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b) 05/07/2022 18:18:57 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b) loading configuration file ./config.json You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2. Model config SpeechEncoderDecoderConfig { "_name_or_path": "./", "architectures": [ "SpeechEncoderDecoderModel" ], "decoder": { "_name_or_path": "", "activation_dropout": 0.0, "activation_function": "gelu", "add_cross_attention": true, "architectures": null, "attention_dropout": 0.1, "bad_words_ids": null, "bos_token_id": 0, "chunk_size_feed_forward": 0, "classifier_dropout": 0.0, "cross_attention_hidden_size": null, "d_model": 16, "decoder_attention_heads": 4, "decoder_ffn_dim": 4, "decoder_layerdrop": 0.0, "decoder_layers": 2, "decoder_start_token_id": 2, "diversity_penalty": 0.0, "do_sample": false, "dropout": 0.1, "early_stopping": false, "encoder_attention_heads": 4, "encoder_ffn_dim": 4, "encoder_layerdrop": 0.0, "encoder_layers": 2, "encoder_no_repeat_ngram_size": 0, "eos_token_id": 2, "exponential_decay_length_penalty": null, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": 2, "fuse_matmuls": false, "gradient_checkpointing": false, "id2label": { "0": "LABEL_0", "1": "LABEL_1" }, "init_std": 0.02, "is_decoder": true, "is_encoder_decoder": false, "label2id": { "LABEL_0": 0, "LABEL_1": 1 }, "length_penalty": 1.0, "max_length": 20, "max_position_embeddings": 100, "min_length": 0, "model_type": "bart", "no_repeat_ngram_size": 0, "num_beam_groups": 1, "num_beams": 1, "num_hidden_layers": 2, "num_return_sequences": 1, "output_attentions": false, "output_hidden_states": false, "output_scores": false, "pad_token_id": 1, "prefix": null, "problem_type": null, "pruned_heads": {}, "remove_invalid_values": false, "repetition_penalty": 1.0, "return_dict": true, "return_dict_in_generate": false, "scale_embedding": false, "sep_token_id": null, "task_specific_params": null, "temperature": 1.0, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, "top_k": 50, "top_p": 1.0, "torch_dtype": null, "torchscript": false, "transformers_version": "4.18.0.dev0", "typical_p": 1.0, "use_bfloat16": false, "use_cache": true, "use_scan": true, "vocab_size": 1000 }, "decoder_start_token_id": 0, "encoder": { "_name_or_path": "", "activation_dropout": 0.1, "adapter_kernel_size": 3, "adapter_stride": 2, "add_adapter": true, "add_cross_attention": false, "apply_spec_augment": true, "architectures": null, "attention_dropout": 0.1, "bad_words_ids": null, "bos_token_id": 1, "chunk_size_feed_forward": 0, "classifier_proj_size": 256, "codevector_dim": 256, "contrastive_logits_temperature": 0.1, "conv_bias": false, "conv_dim": [ 32, 32, 32 ], "conv_kernel": [ 8, 8, 8 ], "conv_stride": [ 4, 4, 4 ], "cross_attention_hidden_size": null, "ctc_loss_reduction": "sum", "ctc_zero_infinity": false, "decoder_start_token_id": null, "diversity_loss_weight": 0.1, "diversity_penalty": 0.0, "do_sample": false, "do_stable_layer_norm": true, "early_stopping": false, "encoder_no_repeat_ngram_size": 0, "eos_token_id": 2, "exponential_decay_length_penalty": null, "feat_extract_activation": "gelu", "feat_extract_dropout": 0.0, "feat_extract_norm": "layer", "feat_proj_dropout": 0.0, "feat_quantizer_dropout": 0.0, "final_dropout": 0.0, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": null, "fuse_matmuls": false, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout": 0.1, "hidden_dropout_prob": 0.1, "hidden_size": 16, "id2label": { "0": "LABEL_0", "1": "LABEL_1" }, "initializer_range": 0.02, "intermediate_size": 20, "is_decoder": false, "is_encoder_decoder": false, "label2id": { "LABEL_0": 0, "LABEL_1": 1 }, "layer_norm_eps": 1e-05, "layerdrop": 0.0, "length_penalty": 1.0, "mask_feature_length": 10, "mask_feature_min_masks": 0, "mask_feature_prob": 0.0, "mask_time_length": 10, "mask_time_min_masks": 2, "mask_time_prob": 0.1, "max_length": 20, "min_length": 0, "model_type": "wav2vec2", "no_repeat_ngram_size": 0, "num_adapter_layers": 3, "num_attention_heads": 2, "num_beam_groups": 1, "num_beams": 1, "num_codevector_groups": 2, "num_codevectors_per_group": 320, "num_conv_pos_embedding_groups": 2, "num_conv_pos_embeddings": 16, "num_feat_extract_layers": 3, "num_hidden_layers": 4, "num_negatives": 10, "num_return_sequences": 1, "output_attentions": false, "output_hidden_size": 16, "output_hidden_states": false, "output_scores": false, "pad_token_id": 0, "prefix": null, "problem_type": null, "proj_codevector_dim": 256, "pruned_heads": {}, "remove_invalid_values": false, "repetition_penalty": 1.0, "return_dict": true, "return_dict_in_generate": false, "sep_token_id": null, "task_specific_params": null, "tdnn_dilation": [ 1, 2, 3, 1, 1 ], "tdnn_dim": [ 512, 512, 512, 512, 1500 ], "tdnn_kernel": [ 5, 3, 3, 1, 1 ], "temperature": 1.0, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, "top_k": 50, "top_p": 1.0, "torch_dtype": null, "torchscript": false, "transformers_version": "4.18.0.dev0", "typical_p": 1.0, "use_bfloat16": false, "use_scan": true, "use_weighted_layer_sum": false, "vocab_size": 32, "xvector_output_dim": 512 }, "eos_token_id": 2, "is_encoder_decoder": true, "max_length": 40, "model_type": "speech-encoder-decoder", "pad_token_id": 1, "processor_class": "Wav2Vec2Processor", "tie_word_embeddings": false, "transformers_version": null, "use_cache": false } loading feature extractor configuration file ./preprocessor_config.json Feature extractor Wav2Vec2FeatureExtractor { "do_normalize": true, "feature_extractor_type": "Wav2Vec2FeatureExtractor", "feature_size": 1, "padding_side": "right", "padding_value": 0.0, "return_attention_mask": false, "sampling_rate": 16000 } Didn't find file ./added_tokens.json. We won't load it. loading file ./vocab.json loading file ./merges.txt loading file ./tokenizer.json loading file None loading file ./special_tokens_map.json loading file ./tokenizer_config.json loading weights file ./flax_model.msgpack 05/07/2022 18:18:58 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b) 05/07/2022 18:18:59 - WARNING - datasets.builder - Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b) All model checkpoint weights were used when initializing FlaxSpeechEncoderDecoderModel. All the weights of FlaxSpeechEncoderDecoderModel were initialized from the model checkpoint at ./. If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxSpeechEncoderDecoderModel for predictions without further training. 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 194.44ba/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 737.52ba/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 693.16ba/s] Feature extractor saved in ./preprocessor_config.json tokenizer config file saved in ./tokenizer_config.json Special tokens file saved in ./special_tokens_map.json Configuration saved in ./config.json loading feature extractor configuration file ./preprocessor_config.json loading configuration file ./config.json You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2. Model config SpeechEncoderDecoderConfig { "_name_or_path": "./", "architectures": [ "SpeechEncoderDecoderModel" ], "decoder": { "_name_or_path": "", "activation_dropout": 0.0, "activation_function": "gelu", "add_cross_attention": true, "architectures": null, "attention_dropout": 0.1, "bad_words_ids": null, "bos_token_id": 0, "chunk_size_feed_forward": 0, "classifier_dropout": 0.0, "cross_attention_hidden_size": null, "d_model": 16, "decoder_attention_heads": 4, "decoder_ffn_dim": 4, "decoder_layerdrop": 0.0, "decoder_layers": 2, "decoder_start_token_id": 2, "diversity_penalty": 0.0, "do_sample": false, "dropout": 0.1, "early_stopping": false, "encoder_attention_heads": 4, "encoder_ffn_dim": 4, "encoder_layerdrop": 0.0, "encoder_layers": 2, "encoder_no_repeat_ngram_size": 0, "eos_token_id": 2, "exponential_decay_length_penalty": null, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": 2, "fuse_matmuls": false, "gradient_checkpointing": false, "id2label": { "0": "LABEL_0", "1": "LABEL_1" }, "init_std": 0.02, "is_decoder": true, "is_encoder_decoder": false, "label2id": { "LABEL_0": 0, "LABEL_1": 1 }, "length_penalty": 1.0, "max_length": 20, "max_position_embeddings": 100, "min_length": 0, "model_type": "bart", "no_repeat_ngram_size": 0, "num_beam_groups": 1, "num_beams": 1, "num_hidden_layers": 2, "num_return_sequences": 1, "output_attentions": false, "output_hidden_states": false, "output_scores": false, "pad_token_id": 1, "prefix": null, "problem_type": null, "pruned_heads": {}, "remove_invalid_values": false, "repetition_penalty": 1.0, "return_dict": true, "return_dict_in_generate": false, "scale_embedding": false, "sep_token_id": null, "task_specific_params": null, "temperature": 1.0, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, "top_k": 50, "top_p": 1.0, "torch_dtype": null, "torchscript": false, "transformers_version": "4.18.0.dev0", "typical_p": 1.0, "use_bfloat16": false, "use_cache": true, "use_scan": true, "vocab_size": 1000 }, "decoder_start_token_id": 0, "encoder": { "_name_or_path": "", "activation_dropout": 0.1, "adapter_kernel_size": 3, "adapter_stride": 2, "add_adapter": true, "add_cross_attention": false, "apply_spec_augment": true, "architectures": null, "attention_dropout": 0.1, "bad_words_ids": null, "bos_token_id": 1, "chunk_size_feed_forward": 0, "classifier_proj_size": 256, "codevector_dim": 256, "contrastive_logits_temperature": 0.1, "conv_bias": false, "conv_dim": [ 32, 32, 32 ], "conv_kernel": [ 8, 8, 8 ], "conv_stride": [ 4, 4, 4 ], "cross_attention_hidden_size": null, "ctc_loss_reduction": "sum", "ctc_zero_infinity": false, "decoder_start_token_id": null, "diversity_loss_weight": 0.1, "diversity_penalty": 0.0, "do_sample": false, "do_stable_layer_norm": true, "early_stopping": false, "encoder_no_repeat_ngram_size": 0, "eos_token_id": 2, "exponential_decay_length_penalty": null, "feat_extract_activation": "gelu", "feat_extract_dropout": 0.0, "feat_extract_norm": "layer", "feat_proj_dropout": 0.0, "feat_quantizer_dropout": 0.0, "final_dropout": 0.0, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": null, "fuse_matmuls": false, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout": 0.1, "hidden_dropout_prob": 0.1, "hidden_size": 16, "id2label": { "0": "LABEL_0", "1": "LABEL_1" }, "initializer_range": 0.02, "intermediate_size": 20, "is_decoder": false, "is_encoder_decoder": false, "label2id": { "LABEL_0": 0, "LABEL_1": 1 }, "layer_norm_eps": 1e-05, "layerdrop": 0.0, "length_penalty": 1.0, "mask_feature_length": 10, "mask_feature_min_masks": 0, "mask_feature_prob": 0.0, "mask_time_length": 10, "mask_time_min_masks": 2, "mask_time_prob": 0.1, "max_length": 20, "min_length": 0, "model_type": "wav2vec2", "no_repeat_ngram_size": 0, "num_adapter_layers": 3, "num_attention_heads": 2, "num_beam_groups": 1, "num_beams": 1, "num_codevector_groups": 2, "num_codevectors_per_group": 320, "num_conv_pos_embedding_groups": 2, "num_conv_pos_embeddings": 16, "num_feat_extract_layers": 3, "num_hidden_layers": 4, "num_negatives": 10, "num_return_sequences": 1, "output_attentions": false, "output_hidden_size": 16, "output_hidden_states": false, "output_scores": false, "pad_token_id": 0, "prefix": null, "problem_type": null, "proj_codevector_dim": 256, "pruned_heads": {}, "remove_invalid_values": false, "repetition_penalty": 1.0, "return_dict": true, "return_dict_in_generate": false, "sep_token_id": null, "task_specific_params": null, "tdnn_dilation": [ 1, 2, 3, 1, 1 ], "tdnn_dim": [ 512, 512, 512, 512, 1500 ], "tdnn_kernel": [ 5, 3, 3, 1, 1 ], "temperature": 1.0, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, "top_k": 50, "top_p": 1.0, "torch_dtype": null, "torchscript": false, "transformers_version": "4.18.0.dev0", "typical_p": 1.0, "use_bfloat16": false, "use_scan": true, "use_weighted_layer_sum": false, "vocab_size": 32, "xvector_output_dim": 512 }, "eos_token_id": 2, "is_encoder_decoder": true, "max_length": 40, "model_type": "speech-encoder-decoder", "pad_token_id": 1, "processor_class": "Wav2Vec2Processor", "tie_word_embeddings": false, "transformers_version": null, "use_cache": false } loading feature extractor configuration file ./preprocessor_config.json Feature extractor Wav2Vec2FeatureExtractor { "do_normalize": true, "feature_extractor_type": "Wav2Vec2FeatureExtractor", "feature_size": 1, "padding_side": "right", "padding_value": 0.0, "return_attention_mask": false, "sampling_rate": 16000 } Didn't find file ./added_tokens.json. We won't load it. loading file ./vocab.json loading file ./merges.txt loading file ./tokenizer.json loading file None loading file ./special_tokens_map.json loading file ./tokenizer_config.json 2022-05-07 18:19:09.719228: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory 2022-05-07 18:19:09.719261: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303) 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-172908eb2439798c.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-172908eb2439798c.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-3788df45b822e09d.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-836619e0a5bdc111.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-760f29b7172d4ca5.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-760f29b7172d4ca5.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-c95f128933ad1116.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-0b5b21388629ee07.arrow 05/07/2022 18:19:09 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /home/sanchitgandhi/cache/huggingface/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-885308cf79812c65.arrow 05/07/2022 18:19:11 - WARNING - huggingface_hub.repository - /home/sanchitgandhi/flax-dummy/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-dummy. Make sure you pull the latest changes with `repo.git_pull()`. 05/07/2022 18:19:12 - INFO - __main__ - ***** Running training ***** 05/07/2022 18:19:12 - INFO - __main__ - Num examples = 70 05/07/2022 18:19:12 - INFO - __main__ - Num Epochs = 4 05/07/2022 18:19:12 - INFO - __main__ - Instantaneous batch size per device = 2 05/07/2022 18:19:12 - INFO - __main__ - Num gradient accumulation steps = 1 05/07/2022 18:19:12 - INFO - __main__ - Total train batch size (w. parallel & distributed) = 16 05/07/2022 18:19:12 - INFO - __main__ - Total optimization steps = 15 05/07/2022 18:19:12 - INFO - __main__ - Gradient checkpointing: False 05/07/2022 18:19:12 - INFO - __main__ - Use scan: True 05/07/2022 18:19:12 - INFO - __main__ - Fuse matmuls: False /home/sanchitgandhi/flax-dummy/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-dummy. Make sure you pull the latest changes with `repo.git_pull()`. Epoch ... (1/4): 0%| | 0/4 [00:00 main() File "run_flax_speech_recognition_seq2seq.py", line 1303, in main state, train_metric = p_train_step(state, batch) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 1979, in cache_miss out_tree, out_flat = f_pmapped_(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 1855, in pmap_f out = pxla.xla_pmap( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1797, in bind return map_bind(self, fun, *args, **params) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1828, in map_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 1800, in process return trace.process_map(self, fun, tracers, params) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/core.py", line 614, in process_call return primitive.impl(f, *tracers, **params) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 792, in xla_pmap_impl compiled_fun, fingerprint = parallel_callable( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 272, in memoized_fun ans = call(fun, *args) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 820, in parallel_callable pmap_computation = lower_parallel_callable( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 993, in lower_parallel_callable jaxpr, consts, replicas, parts, shards = stage_parallel_callable( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 900, in stage_parallel_callable jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1798, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1775, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers_) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "run_flax_speech_recognition_seq2seq.py", line 1086, in train_step (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 954, in value_and_grad_f ans, vjp_py, aux = _vjp( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/api.py", line 2413, in _vjp out_primal, out_vjp, aux = ad.vjp( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/ad.py", line 106, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 592, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "run_flax_speech_recognition_seq2seq.py", line 1073, in compute_loss logits = state.apply_fn( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 765, in __call__ return self.module.apply( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 1162, in apply return apply( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/scope.py", line 806, in wrapper y = fn(root, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 1446, in scope_fn return fn(module.clone(parent=scope), *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 340, in __call__ decoder_outputs = self.decoder( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 752, in __call__ outputs = self.model( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 718, in __call__ return self.decoder(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 500, in __call__ outputs = self.layers( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 389, in __call__ hidden_states, _ = scan_with_axes( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 312, in wrapped_fn ret = trafo_fn(module_scopes, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 213, in wrapper y, out_variable_groups_xs_t = fn( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 286, in wrapper y = fn(scopes, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 305, in core_fn res = fn(cloned, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 312, in wrapped_fn ret = trafo_fn(module_scopes, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper y, out_variable_groups_xs_t = fn( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 766, in inner broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 135, in scan_fn _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 592, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 111, in body_fn broadcast_out, c, ys = fn(broadcast_in, c, *xs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/lift.py", line 750, in scanned c, y = fn(scope, c, *args) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 305, in core_fn res = fn(cloned, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 308, in __call__ hidden_states, self_attn_weights = self.self_attn( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/transforms.py", line 1166, in wrapped_fn return prewrapped_fn(self, *args, **kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 350, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method y = fun(self, *args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 206, in __call__ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 357, in _broadcast_to raise ValueError(msg.format(arr_shape, shape)) jax._src.traceback_util.UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 127) and requested shape (2, 1, 100, 100) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "run_flax_speech_recognition_seq2seq.py", line 1396, in main() File "run_flax_speech_recognition_seq2seq.py", line 1303, in main state, train_metric = p_train_step(state, batch) File "run_flax_speech_recognition_seq2seq.py", line 1086, in train_step (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch) File "run_flax_speech_recognition_seq2seq.py", line 1073, in compute_loss logits = state.apply_fn( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 765, in __call__ return self.module.apply( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_speech_encoder_decoder.py", line 340, in __call__ decoder_outputs = self.decoder( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 752, in __call__ outputs = self.model( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 718, in __call__ return self.decoder(*args, **kwargs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 500, in __call__ outputs = self.layers( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 389, in __call__ hidden_states, _ = scan_with_axes( File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 135, in scan_fn _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/flax/core/axes_scan.py", line 111, in body_fn broadcast_out, c, ys = fn(broadcast_in, c, *xs) File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 308, in __call__ hidden_states, self_attn_weights = self.self_attn( File "/home/sanchitgandhi/seq2seq-speech/models/modeling_flax_bart.py", line 206, in __call__ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) File "/home/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 357, in _broadcast_to raise ValueError(msg.format(arr_shape, shape)) ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 127) and requested shape (2, 1, 100, 100)