sanchit-gandhi's picture
2hx8pk65: saving weights and logs of step 10k
f74be82
raw
history blame
185 kB
Downloading and preparing dataset librispeech_asr/all (download: 57.14 GiB, generated: 59.44 GiB, post-processed: Unknown size, total: 116.59 GiB) to /home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/all/2.1.0/14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb...
INFO:__main__:Training/evaluation parameters FlaxSeq2SeqTrainingArguments(
_n_gpu=-1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
debug=,
deepspeed=None,
disable_tqdm=None,
do_eval=True,
do_predict=True,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=10000,
evaluation_strategy=no,
final_generation_max_length=200,
final_generation_num_beams=5,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=,
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
generation_length_penalty=1.2,
generation_max_length=200,
generation_num_beams=5,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=0.0001,
length_column_name=length,
load_best_model_at_end=False,
local_rank=-1,
log_level=passive,
log_level_replica=passive,
log_on_each_node=True,
logging_dir=None,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=25,
logging_strategy=steps,
lr_scheduler_type=linear,
matmul_precision=default,
max_grad_norm=1.0,
max_steps=50000,
metric_for_best_model=None,
mp_parameters=,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_hf,
output_dir=./,
overwrite_output_dir=True,
past_index=-1,
per_device_eval_batch_size=4,
per_device_train_batch_size=8,
precision=full,
predict_with_generate=True,
prediction_loss_only=False,
push_to_hub=True,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=None,
resume_from_checkpoint=None,
run_name=None,
save_on_each_node=False,
save_steps=10000,
save_strategy=steps,
save_total_limit=None,
seed=42,
sharded_ddp=,
skip_memory_metrics=True,
sortish_sampler=False,
tf32=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_ipex=False,
use_legacy_prediction_loop=False,
warmup_ratio=0.0,
warmup_steps=500,
weight_decay=0.0,
xpu_backend=None,
)
INFO:__main__:JAX devices: 8, matmul precision: default
Downloading data files: 0% 0/7 [00:00<?, ?it/s]
Downloading data: 54% 184M/338M [00:01<00:02, 74.3MB/s]
Downloading data: 98% 332M/338M [00:03<00:00, 76.4MB/s]
Downloading data files: 14% 1/7 [00:04<00:25, 4.29s/it]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
Downloading data: 100% 314M/314M [00:03<00:00, 100MB/s] ]
WARNING:datasets.builder:Reusing dataset librispeech_asr (/home/sanchitgandhi/cache/huggingface/datasets/librispeech_asr/all/2.1.0/14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)
"torchscript": false,0.0,andhi/flax-wav2vec2-2-bart-large-scan",hitgandhi/cache/huggingface/datasets/librispeech_asr/all/2.1.0/14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)
"num_beams": 6
}
},
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "float32",
"torchscript": false,0.0,andhi/flax-wav2vec2-2-bart-large-scan",hitgandhi/cache/huggingface/datasets/librispeech_asr/all/2.1.0/14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": true,
"use_scan": true,
"vocab_size": 50265
},
"decoder_start_token_id": 0,
"encoder": {
"_name_or_path": "",
"activation_dropout": 0.1,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": true,
"add_cross_attention": false,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForPreTraining"
],
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 1,
"chunk_size_feed_forward": 0,
"classifier_proj_size": 256,
"codevector_dim": 768,
"contrastive_logits_temperature": 0.1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"cross_attention_hidden_size": null,
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"decoder_start_token_id": null,
"diversity_loss_weight": 0.1,
"diversity_penalty": 0.0,
"do_sample": false,
"do_stable_layer_norm": true,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.0,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"fuse_matmuls": false,
"gradient_checkpointing": true,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"layerdrop": 0.0,
"length_penalty": 1.0,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.1,
"max_length": 20,
"min_length": 0,
"model_type": "wav2vec2",
"no_repeat_ngram_size": 0,
"num_adapter_layers": 3,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"num_negatives": 100,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_size": 1024,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 0,
"prefix": null,
"problem_type": null,
"proj_codevector_dim": 768,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"tdnn_dilation": [
1,
2,
3,
1,
1
],
"tdnn_dim": [
512,
512,
512,
512,
1500
],
"tdnn_kernel": [
5,
3,
3,
1,
1
],
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_scan": true,
"use_weighted_layer_sum": false,
"vocab_size": 32,
"xvector_output_dim": 512
},
"eos_token_id": 2,
"is_encoder_decoder": true,
"max_length": 40,
"model_type": "speech-encoder-decoder",
"pad_token_id": 1,
"processor_class": "Wav2Vec2Processor",
"tie_word_embeddings": false,
"transformers_version": null,
"use_cache": false
}
loading feature extractor configuration file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/preprocessor_config.json from cache at /home/sanchitgandhi/.cache/huggingface/transformers/bc2232c616201c7d3d66ba3f6a7d1186306134838dfb19786149f0e16122787d.bbc1eb890a39c82e710a893223b8452ac5b78e8b57083b2f893aa7dc59d4ed69
Feature extractor Wav2Vec2FeatureExtractor {
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"return_attention_mask": true,
"sampling_rate": 16000
}
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/vocab.json from cache at /home/sanchitgandhi/.cache/huggingface/transformers/86c0de13925d1534934e540ff4c9dd778f49761b4eaf59dae3335a4f6690a814.bfdcc444ff249bca1a95ca170ec350b442f81804d7df3a95a2252217574121d7
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/merges.txt from cache at /home/sanchitgandhi/.cache/huggingface/transformers/7cf4fc91891684e1177d1c519689e4c310ebdec965e00d6e45134bb9227ab01b.f5b91da9e34259b8f4d88dbc97c740667a0e8430b96314460cdb04e86d4fc435
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/tokenizer.json from cache at /home/sanchitgandhi/.cache/huggingface/transformers/c02f3f3009bfacaa24cfead1d0f7fbf4fc2fb5f8092f68703449f02aa3a28e03.393fa6a095aa312a3cce4d5263e471bd94ec0215e6c63448a6464d59ff900814
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/special_tokens_map.json from cache at /home/sanchitgandhi/.cache/huggingface/transformers/505d61b8f6e05764b5aec1483bfdd13a310681a5af54957263604323be3bbabf.a11ebb04664c067c8fe5ef8f8068b0f721263414a26058692f7b2e4ba2a1b342
loading file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/tokenizer_config.json from cache at /home/sanchitgandhi/.cache/huggingface/transformers/ff79c23164eac352d7f9651f3c3774a962ce80f81460d9e17d689235fa34ee80.0e8b2b497f91e23302894a5c1f19ced6334b0abd450a7bce75a67bf0f9ee5c54
loading weights file https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-scan/resolve/main/flax_model.msgpack from cache at /home/sanchitgandhi/.cache/huggingface/transformers/1279dc21f7dd9ed546f166e7e445e068b2672ddfa5386b2e3a3a973b8d668365.8e03496bb6919447aeb468483249e7b65dfb59c42989be9787af0aa6aa9b3f50
tcmalloc: large alloc 2353618944 bytes == 0xb1ee8000 @ 0x7f7cba873680 0x7f7cba894824 0x5fb391 0x64be71 0x5c6366 0x4f3b9e 0x651588 0x505a63 0x56bbfa 0x569dba 0x50bca0 0x56cc1f 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x6902a7 0x67f951 0x67f9cf 0x67fa71 0x681b97 0x6b9d32 0x6ba0bd 0x7f7cba6830b3 0x5fc5fe
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
All model checkpoint weights were used when initializing FlaxSpeechEncoderDecoderModel.
All the weights of FlaxSpeechEncoderDecoderModel were initialized from the model checkpoint at sanchit-gandhi/flax-wav2vec2-2-bart-large-scan.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxSpeechEncoderDecoderModel for predictions without further training.
/home/sanchitgandhi/transformers/src/transformers/modeling_flax_utils.py:904: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
filtering data where the targets are ignored in scoring: 13% 37/282 [00:00<00:00, 367.66ba/s]
filtering data where the targets are ignored in scoring: 100% 282/282 [00:00<00:00, 461.57ba/s]
filtering data where the targets are ignored in scoring: 100% 3/3 [00:00<00:00, 381.50ba/s]
filtering data where the targets are ignored in scoring: 100% 3/3 [00:00<00:00, 481.27ba/s]
filtering data where the targets are ignored in scoring: 100% 3/3 [00:00<00:00, 504.26ba/s]
filtering data where the targets are ignored in scoring: 100% 3/3 [00:00<00:00, 459.68ba/s]
preprocess train dataset: 100% 281241/281241 [34:57<00:00, 134.07ex/s]
preprocess train dataset: 100% 2703/2703 [00:12<00:00, 212.47ex/s]
preprocess train dataset: 100% 2864/2864 [00:12<00:00, 235.45ex/s]
preprocess train dataset: 100% 2620/2620 [00:13<00:00, 199.87ex/s]
preprocess train dataset: 100% 2939/2939 [00:12<00:00, 235.49ex/s]
100% 282/282 [00:00<00:00, 547.89ba/s]
100% 282/282 [00:01<00:00, 158.89ba/s]
100% 282/282 [00:01<00:00, 160.92ba/s]
100% 3/3 [00:00<00:00, 739.78ba/s]
100% 3/3 [00:00<00:00, 765.24ba/s]
100% 3/3 [00:00<00:00, 836.91ba/s]
100% 3/3 [00:00<00:00, 751.98ba/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
run_flax_speech_recognition_seq2seq.py:1052: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
wer_metric = load_metric("wer")
Feature extractor saved in ./preprocessor_config.json
tokenizer config file saved in ./tokenizer_config.json
Special tokens file saved in ./special_tokens_map.json
Configuration saved in ./config.json
loading feature extractor configuration file ./preprocessor_config.json
loading configuration file ./config.json
/home/sanchitgandhi/transformers/src/transformers/configuration_utils.py:368: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
warnings.warn(
Model config SpeechEncoderDecoderConfig {
"_name_or_path": "./",
"architectures": [
"SpeechEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.2,
"activation_function": "gelu",
"add_bias_logits": false,
"add_cross_attention": true,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": null,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.2,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": 0,
"forced_eos_token_id": 2,
"fuse_matmuls": false,
"gradient_checkpointing": true,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 1024,
"min_length": 0,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beam_groups": 1,
"num_beams": 4,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "float32",
"torchscript": false,
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": true,
"use_scan": true,
"vocab_size": 50265
},
"decoder_start_token_id": 0,
"encoder": {
"_name_or_path": "",
"activation_dropout": 0.2,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": true,
"add_cross_attention": false,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForPreTraining"
],
"attention_dropout": 0.1,
"bad_words_ids": null,
"bos_token_id": 1,
"chunk_size_feed_forward": 0,
"classifier_proj_size": 256,
"codevector_dim": 768,
"contrastive_logits_temperature": 0.1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"cross_attention_hidden_size": null,
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"decoder_start_token_id": null,
"diversity_loss_weight": 0.1,
"diversity_penalty": 0.0,
"do_sample": false,
"do_stable_layer_norm": true,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.2,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"fuse_matmuls": false,
"gradient_checkpointing": true,
"hidden_act": "gelu",
"hidden_dropout": 0.2,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"layerdrop": 0.0,
"length_penalty": 1.0,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.1,
"max_length": 20,
"min_length": 0,
"model_type": "wav2vec2",
"no_repeat_ngram_size": 0,
"num_adapter_layers": 3,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"num_negatives": 100,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_size": 1024,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 0,
"prefix": null,
"problem_type": null,
"proj_codevector_dim": 768,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"tdnn_dilation": [
1,
2,
3,
1,
1
],
"tdnn_dim": [
512,
512,
512,
512,
1500
],
"tdnn_kernel": [
5,
3,
3,
1,
1
],
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.21.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false,
"use_scan": true,
"use_weighted_layer_sum": false,
"vocab_size": 32,
"xvector_output_dim": 512
},
"eos_token_id": 2,
"is_encoder_decoder": true,
"max_length": 40,
"model_type": "speech-encoder-decoder",
"pad_token_id": 1,
"processor_class": "Wav2Vec2Processor",
"tie_word_embeddings": false,
"transformers_version": null,
"use_cache": false
}
loading feature extractor configuration file ./preprocessor_config.json
Feature extractor Wav2Vec2FeatureExtractor {
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"return_attention_mask": true,
"sampling_rate": 16000
}
Didn't find file ./added_tokens.json. We won't load it.
loading file ./vocab.json
loading file ./merges.txt
loading file ./tokenizer.json
loading file None
loading file ./special_tokens_map.json
loading file ./tokenizer_config.json
WARNING:__main__:Unable to display metrics through TensorBoard because the package is not installed: Please run `pip install tensorboard` to enable.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box. Make sure you pull the latest changes with `repo.git_pull()`.
WARNING:huggingface_hub.repository:/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/./ is already a clone of https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box. Make sure you pull the latest changes with `repo.git_pull()`.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:319: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
mu = jax.tree_map( # First moment
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
INFO:__main__:***** Running training *****
INFO:__main__: Num examples = 281186
INFO:__main__: Num Epochs = 12
INFO:__main__: Instantaneous batch size per device = 8
INFO:__main__: Num gradient accumulation steps = 1
INFO:__main__: Total train batch size (w. parallel & distributed) = 64
INFO:__main__: Total optimization steps = 50000
INFO:__main__: Gradient checkpointing: True
INFO:__main__: Use scan: True
INFO:__main__: Fuse matmuls: False
Epoch ... (1/12): 0% 0/12 [00:00<?, ?it/s]
return jax.tree_map(93 [00:00<?, ?it/s]
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:1266: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
grad = jax.tree_map(lambda g: g / total_samples, grad)
run_flax_speech_recognition_seq2seq.py:1267: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/linear_algebra.py:29: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
sum([jnp.sum(numerics.abs_sq(x)) for x in jax.tree_leaves(updates)]))
run_flax_speech_recognition_seq2seq.py:399: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:82: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:99: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:106: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda t: t / bias_correction.astype(t.dtype), moment)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:331: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
updates = jax.tree_map(
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:610: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
updates = jax.tree_map(
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/transform.py:647: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
updates = jax.tree_map(
/home/sanchitgandhi/hf/lib/python3.8/site-packages/optax/_src/update.py:42: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(
run_flax_speech_recognition_seq2seq.py:1277: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
run_flax_speech_recognition_seq2seq.py:1286: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
return jax.tree_map(lambda x: x[0], tree)6.02s/it]
Epoch ... (1/12): 0% 0/12 [05:07<?, ?it/s]
Training...: 42% 1825/4393 [2:33:45<3:55:59, 5.51s/it]
Step... (25 | Loss: 8.328937530517578, Learning Rate: 4.800000169780105e-06, Gradient Norm: 87.96355438232422)
Step... (50 | Loss: 6.823256969451904, Learning Rate: 9.800001862458885e-06, Gradient Norm: 37.718292236328125)
Step... (75 | Loss: 5.908570289611816, Learning Rate: 1.479999627918005e-05, Gradient Norm: 4.585142135620117)
Step... (100 | Loss: 5.766870975494385, Learning Rate: 1.979999797185883e-05, Gradient Norm: 32.34437942504883)
Step... (125 | Loss: 5.352975845336914, Learning Rate: 2.479999966453761e-05, Gradient Norm: 5.285919666290283)
Step... (150 | Loss: 5.732889652252197, Learning Rate: 2.9799994081258774e-05, Gradient Norm: 9.10770320892334)
Step... (175 | Loss: 5.380748271942139, Learning Rate: 3.480000304989517e-05, Gradient Norm: 2.8539175987243652)
Step... (200 | Loss: 5.2746663093566895, Learning Rate: 3.979999746661633e-05, Gradient Norm: 8.702874183654785)
Step... (225 | Loss: 5.10414457321167, Learning Rate: 4.479999915929511e-05, Gradient Norm: 3.008880853652954)
Step... (250 | Loss: 5.066227436065674, Learning Rate: 4.980000085197389e-05, Gradient Norm: 7.2485833168029785)
Step... (275 | Loss: 4.934179782867432, Learning Rate: 5.480000254465267e-05, Gradient Norm: 2.516850471496582)
Step... (300 | Loss: 4.820409297943115, Learning Rate: 5.980000423733145e-05, Gradient Norm: 6.128991603851318)
Step... (325 | Loss: 4.867349624633789, Learning Rate: 6.479999865405262e-05, Gradient Norm: 2.4680869579315186)
Step... (350 | Loss: 4.667332172393799, Learning Rate: 6.98000003467314e-05, Gradient Norm: 5.528798580169678)
Step... (375 | Loss: 4.901133060455322, Learning Rate: 7.480000203941017e-05, Gradient Norm: 2.432140350341797)
Step... (400 | Loss: 4.646675109863281, Learning Rate: 7.980000373208895e-05, Gradient Norm: 5.241644859313965)
Step... (425 | Loss: 4.873144149780273, Learning Rate: 8.480000542476773e-05, Gradient Norm: 2.560326099395752)
Step... (450 | Loss: 4.850857257843018, Learning Rate: 8.97999998414889e-05, Gradient Norm: 6.099117279052734)
Step... (475 | Loss: 4.680803298950195, Learning Rate: 9.480000881012529e-05, Gradient Norm: 2.29132342338562)
Step... (500 | Loss: 4.378286838531494, Learning Rate: 9.980000322684646e-05, Gradient Norm: 5.214158535003662)
Step... (525 | Loss: 4.8250627517700195, Learning Rate: 9.995151776820421e-05, Gradient Norm: 2.282822370529175)
Step... (550 | Loss: 4.181581020355225, Learning Rate: 9.990100807044655e-05, Gradient Norm: 5.4770355224609375)
Step... (575 | Loss: 4.76805305480957, Learning Rate: 9.98505056486465e-05, Gradient Norm: 2.273082971572876)
Step... (600 | Loss: 3.9572887420654297, Learning Rate: 9.980000322684646e-05, Gradient Norm: 4.990326404571533)
Step... (625 | Loss: 4.6408796310424805, Learning Rate: 9.97494935290888e-05, Gradient Norm: 2.1879968643188477)
Step... (650 | Loss: 3.7939791679382324, Learning Rate: 9.969899110728875e-05, Gradient Norm: 4.704588413238525)
Step... (675 | Loss: 4.548547744750977, Learning Rate: 9.96484886854887e-05, Gradient Norm: 2.147969961166382)
Step... (700 | Loss: 3.7140114307403564, Learning Rate: 9.959797898773104e-05, Gradient Norm: 4.6049981117248535)
Step... (725 | Loss: 4.475124359130859, Learning Rate: 9.954747656593099e-05, Gradient Norm: 2.3476402759552)
Step... (750 | Loss: 3.6484854221343994, Learning Rate: 9.949696686817333e-05, Gradient Norm: 4.747243881225586)
Step... (775 | Loss: 4.396439075469971, Learning Rate: 9.944646444637328e-05, Gradient Norm: 2.1694796085357666)
Step... (800 | Loss: 3.6524744033813477, Learning Rate: 9.939595474861562e-05, Gradient Norm: 5.665165424346924)
Step... (825 | Loss: 4.447864055633545, Learning Rate: 9.934545232681558e-05, Gradient Norm: 2.218409776687622)
Step... (850 | Loss: 3.3051741123199463, Learning Rate: 9.929494262905791e-05, Gradient Norm: 4.674701690673828)
Step... (875 | Loss: 4.333479881286621, Learning Rate: 9.924444020725787e-05, Gradient Norm: 2.1276533603668213)
Step... (900 | Loss: 3.417492389678955, Learning Rate: 9.91939305095002e-05, Gradient Norm: 4.565223217010498)
Step... (925 | Loss: 4.176679611206055, Learning Rate: 9.914342808770016e-05, Gradient Norm: 2.110710859298706)
Step... (950 | Loss: 3.143075704574585, Learning Rate: 9.909292566590011e-05, Gradient Norm: 4.653733730316162)
Step... (975 | Loss: 4.348679065704346, Learning Rate: 9.904241596814245e-05, Gradient Norm: 2.2036550045013428)
Step... (1000 | Loss: 2.7846860885620117, Learning Rate: 9.89919135463424e-05, Gradient Norm: 4.476178169250488)
Step... (1025 | Loss: 4.168126583099365, Learning Rate: 9.894141112454236e-05, Gradient Norm: 2.2347006797790527)
Step... (1050 | Loss: 2.9858577251434326, Learning Rate: 9.88909014267847e-05, Gradient Norm: 4.88425350189209)
Step... (1075 | Loss: 4.062693119049072, Learning Rate: 9.884039900498465e-05, Gradient Norm: 2.1905994415283203)
Step... (1100 | Loss: 2.5378994941711426, Learning Rate: 9.87898965831846e-05, Gradient Norm: 5.643308162689209)
Step... (1125 | Loss: 3.9322783946990967, Learning Rate: 9.873938688542694e-05, Gradient Norm: 2.507472276687622)
Step... (1150 | Loss: 2.2780089378356934, Learning Rate: 9.868888446362689e-05, Gradient Norm: 4.842682361602783)
Step... (1175 | Loss: 3.5336978435516357, Learning Rate: 9.863838204182684e-05, Gradient Norm: 2.5669169425964355)
Step... (1200 | Loss: 2.0677130222320557, Learning Rate: 9.858787234406918e-05, Gradient Norm: 8.548190116882324)
Step... (1225 | Loss: 3.1775474548339844, Learning Rate: 9.853736992226914e-05, Gradient Norm: 2.488553047180176)
Step... (1250 | Loss: 1.853069543838501, Learning Rate: 9.848686750046909e-05, Gradient Norm: 5.4602179527282715)
Step... (1275 | Loss: 2.4108481407165527, Learning Rate: 9.843635780271143e-05, Gradient Norm: 3.606497049331665)
Step... (1300 | Loss: 1.4020391702651978, Learning Rate: 9.838585538091138e-05, Gradient Norm: 4.786640644073486)
Step... (1325 | Loss: 1.7951806783676147, Learning Rate: 9.833535295911133e-05, Gradient Norm: 4.002208709716797)
Step... (1350 | Loss: 1.3696314096450806, Learning Rate: 9.828484326135367e-05, Gradient Norm: 5.143311500549316)
Step... (1375 | Loss: 1.1694923639297485, Learning Rate: 9.823434083955362e-05, Gradient Norm: 3.0316975116729736)
Step... (1400 | Loss: 1.3062351942062378, Learning Rate: 9.818383841775358e-05, Gradient Norm: 4.632654190063477)
Step... (1425 | Loss: 1.0138862133026123, Learning Rate: 9.813332871999592e-05, Gradient Norm: 2.421475648880005)
Step... (1450 | Loss: 1.0470666885375977, Learning Rate: 9.808282629819587e-05, Gradient Norm: 4.4383544921875)
Step... (1475 | Loss: 0.74889075756073, Learning Rate: 9.803232387639582e-05, Gradient Norm: 2.2315526008605957)
Step... (1500 | Loss: 0.7395172119140625, Learning Rate: 9.798181417863816e-05, Gradient Norm: 3.811149835586548)
Step... (1525 | Loss: 0.6412907838821411, Learning Rate: 9.793131175683811e-05, Gradient Norm: 1.9883421659469604)
Step... (1550 | Loss: 1.041448950767517, Learning Rate: 9.788080933503807e-05, Gradient Norm: 4.539083003997803)
Step... (1575 | Loss: 0.5415894985198975, Learning Rate: 9.78302996372804e-05, Gradient Norm: 1.5578198432922363)
Step... (1600 | Loss: 0.7398551106452942, Learning Rate: 9.777979721548036e-05, Gradient Norm: 4.611051082611084)
Step... (1625 | Loss: 0.4547487199306488, Learning Rate: 9.772929479368031e-05, Gradient Norm: 1.6650770902633667)
Step... (1650 | Loss: 0.49022048711776733, Learning Rate: 9.767878509592265e-05, Gradient Norm: 3.387274980545044)
Step... (1675 | Loss: 0.5778527855873108, Learning Rate: 9.76282826741226e-05, Gradient Norm: 1.6570004224777222)
Step... (1700 | Loss: 0.7402584552764893, Learning Rate: 9.757778025232255e-05, Gradient Norm: 4.378363132476807)
Step... (1725 | Loss: 0.4962593615055084, Learning Rate: 9.75272705545649e-05, Gradient Norm: 1.484971523284912)
Step... (1750 | Loss: 0.8252610564231873, Learning Rate: 9.747676813276485e-05, Gradient Norm: 3.8003933429718018)
Step... (1775 | Loss: 0.46149566769599915, Learning Rate: 9.74262657109648e-05, Gradient Norm: 1.3506320714950562)
Training...: 82% 3598/4393 [4:49:27<28:04, 2.12s/it]
Step... (1825 | Loss: 0.45814454555511475, Learning Rate: 9.732525359140709e-05, Gradient Norm: 1.927361011505127)
Step... (1850 | Loss: 0.6833104491233826, Learning Rate: 9.727474389364943e-05, Gradient Norm: 4.002225399017334)
Step... (1875 | Loss: 0.3850409984588623, Learning Rate: 9.722424147184938e-05, Gradient Norm: 1.310095191001892)
Step... (1900 | Loss: 0.6548380255699158, Learning Rate: 9.717373905004933e-05, Gradient Norm: 3.203636884689331)
Step... (1925 | Loss: 0.3904755711555481, Learning Rate: 9.712322935229167e-05, Gradient Norm: 1.244488000869751)
Step... (1950 | Loss: 0.6627995371818542, Learning Rate: 9.707272693049163e-05, Gradient Norm: 3.3569374084472656)
Step... (1975 | Loss: 0.29180118441581726, Learning Rate: 9.702222450869158e-05, Gradient Norm: 1.3752835988998413)
Step... (2000 | Loss: 0.5219539999961853, Learning Rate: 9.697171481093392e-05, Gradient Norm: 2.9133312702178955)
Step... (2025 | Loss: 0.3404647409915924, Learning Rate: 9.692121238913387e-05, Gradient Norm: 3.0273733139038086)
Step... (2050 | Loss: 0.7144688963890076, Learning Rate: 9.687070269137621e-05, Gradient Norm: 3.438058376312256)
Epoch ... (1/12): 8% 1/12 [5:50:57<64:20:33, 21057.59s/it]
Step... (2125 | Loss: 0.3445417582988739, Learning Rate: 9.671918815001845e-05, Gradient Norm: 1.2331606149673462)
Step... (2150 | Loss: 0.5462068915367126, Learning Rate: 9.666867845226079e-05, Gradient Norm: 2.630260467529297)
Step... (2175 | Loss: 0.26811814308166504, Learning Rate: 9.661817603046075e-05, Gradient Norm: 1.3439264297485352)
Step... (2200 | Loss: 0.44361090660095215, Learning Rate: 9.65676736086607e-05, Gradient Norm: 2.5680527687072754)
Step... (2225 | Loss: 0.2922044098377228, Learning Rate: 9.651716391090304e-05, Gradient Norm: 1.05724036693573)
Step... (2250 | Loss: 0.5575493574142456, Learning Rate: 9.646666148910299e-05, Gradient Norm: 3.3814682960510254)
Step... (2275 | Loss: 0.2683391869068146, Learning Rate: 9.641615906730294e-05, Gradient Norm: 0.9786010384559631)
Step... (2300 | Loss: 0.5252701044082642, Learning Rate: 9.636564936954528e-05, Gradient Norm: 2.9107272624969482)
Step... (2325 | Loss: 0.2841339111328125, Learning Rate: 9.631514694774523e-05, Gradient Norm: 1.2748677730560303)
Step... (2350 | Loss: 0.6419438719749451, Learning Rate: 9.626464452594519e-05, Gradient Norm: 2.8639562129974365)
Step... (2375 | Loss: 0.280693382024765, Learning Rate: 9.621413482818753e-05, Gradient Norm: 0.9972930550575256)
Step... (2400 | Loss: 0.4255363941192627, Learning Rate: 9.616363240638748e-05, Gradient Norm: 2.2353670597076416)
Step... (2425 | Loss: 0.2965167760848999, Learning Rate: 9.611312998458743e-05, Gradient Norm: 1.2774780988693237)
Step... (2450 | Loss: 0.6248699426651001, Learning Rate: 9.606262028682977e-05, Gradient Norm: 3.115103006362915)
Step... (2475 | Loss: 0.26660072803497314, Learning Rate: 9.601211786502972e-05, Gradient Norm: 0.9596872925758362)
Step... (2500 | Loss: 0.43441200256347656, Learning Rate: 9.596161544322968e-05, Gradient Norm: 2.2786362171173096)
Step... (2525 | Loss: 0.4871614873409271, Learning Rate: 9.591110574547201e-05, Gradient Norm: 3.725149154663086)
Step... (2550 | Loss: 0.7399157285690308, Learning Rate: 9.586060332367197e-05, Gradient Norm: 4.10433292388916)
Step... (2575 | Loss: 0.28419390320777893, Learning Rate: 9.581010090187192e-05, Gradient Norm: 1.104195475578308)
Step... (2600 | Loss: 0.3510257601737976, Learning Rate: 9.575959120411426e-05, Gradient Norm: 2.5480878353118896)
Step... (2625 | Loss: 0.2379504293203354, Learning Rate: 9.570908878231421e-05, Gradient Norm: 0.9049413204193115)
Step... (2650 | Loss: 0.4019342064857483, Learning Rate: 9.565858636051416e-05, Gradient Norm: 2.229780435562134)
Step... (2675 | Loss: 0.2592702805995941, Learning Rate: 9.56080766627565e-05, Gradient Norm: 1.0386632680892944)
Step... (2700 | Loss: 0.4528926610946655, Learning Rate: 9.555757424095646e-05, Gradient Norm: 2.613053560256958)
Step... (2725 | Loss: 0.19824065268039703, Learning Rate: 9.550707181915641e-05, Gradient Norm: 0.9244711399078369)
Step... (2750 | Loss: 0.32943522930145264, Learning Rate: 9.545656212139875e-05, Gradient Norm: 2.5082757472991943)
Step... (2775 | Loss: 0.2706708014011383, Learning Rate: 9.54060596995987e-05, Gradient Norm: 1.282505989074707)
Step... (2800 | Loss: 0.5671706795692444, Learning Rate: 9.535555727779865e-05, Gradient Norm: 3.151282787322998)
Step... (2825 | Loss: 0.19658076763153076, Learning Rate: 9.530504758004099e-05, Gradient Norm: 1.351545810699463)
Step... (2850 | Loss: 0.5204019546508789, Learning Rate: 9.525454515824094e-05, Gradient Norm: 3.0304579734802246)
Step... (2875 | Loss: 0.2571251094341278, Learning Rate: 9.52040427364409e-05, Gradient Norm: 1.3449103832244873)
Step... (2900 | Loss: 0.3958282172679901, Learning Rate: 9.515353303868324e-05, Gradient Norm: 2.664325475692749)
Step... (2925 | Loss: 0.3020494282245636, Learning Rate: 9.510303061688319e-05, Gradient Norm: 1.5144743919372559)
Step... (2950 | Loss: 0.3240774869918823, Learning Rate: 9.505252819508314e-05, Gradient Norm: 2.446608781814575)
Step... (2975 | Loss: 0.2148822844028473, Learning Rate: 9.500201849732548e-05, Gradient Norm: 1.4464852809906006)
Step... (3000 | Loss: 0.440028578042984, Learning Rate: 9.495151607552543e-05, Gradient Norm: 2.6232030391693115)
Step... (3025 | Loss: 0.21863850951194763, Learning Rate: 9.490100637776777e-05, Gradient Norm: 0.9206398725509644)
Step... (3050 | Loss: 0.36002713441848755, Learning Rate: 9.485050395596772e-05, Gradient Norm: 2.1922974586486816)
Step... (3075 | Loss: 0.17855629324913025, Learning Rate: 9.480000153416768e-05, Gradient Norm: 0.7571102380752563)
Step... (3100 | Loss: 0.5758547186851501, Learning Rate: 9.474949183641002e-05, Gradient Norm: 19.843076705932617)
Step... (3125 | Loss: 0.20221881568431854, Learning Rate: 9.469898941460997e-05, Gradient Norm: 0.7625271081924438)
Step... (3150 | Loss: 0.2861629128456116, Learning Rate: 9.464848699280992e-05, Gradient Norm: 1.6312204599380493)
Step... (3175 | Loss: 0.17547428607940674, Learning Rate: 9.459797729505226e-05, Gradient Norm: 0.9817888140678406)
Step... (3200 | Loss: 0.5097805857658386, Learning Rate: 9.454747487325221e-05, Gradient Norm: 2.885671854019165)
Step... (3225 | Loss: 0.18147142231464386, Learning Rate: 9.449697245145217e-05, Gradient Norm: 0.8246238231658936)
Step... (3250 | Loss: 0.5844413638114929, Learning Rate: 9.44464627536945e-05, Gradient Norm: 2.7525370121002197)
Step... (3275 | Loss: 0.1849716454744339, Learning Rate: 9.439596033189446e-05, Gradient Norm: 0.8131420016288757)
Step... (3300 | Loss: 0.4483461081981659, Learning Rate: 9.43454506341368e-05, Gradient Norm: 2.3861243724823)
Step... (3325 | Loss: 0.1761871725320816, Learning Rate: 9.429494821233675e-05, Gradient Norm: 0.7318664193153381)
Step... (3350 | Loss: 0.3498786389827728, Learning Rate: 9.424443851457909e-05, Gradient Norm: 3.470242500305176)
Step... (3375 | Loss: 0.18615788221359253, Learning Rate: 9.419393609277904e-05, Gradient Norm: 0.9352092742919922)
Step... (3400 | Loss: 0.27042198181152344, Learning Rate: 9.414342639502138e-05, Gradient Norm: 2.1422297954559326)
Step... (3425 | Loss: 0.1961713582277298, Learning Rate: 9.409292397322133e-05, Gradient Norm: 0.6911700367927551)
Step... (3450 | Loss: 0.2915394604206085, Learning Rate: 9.404242155142128e-05, Gradient Norm: 1.9624013900756836)
Step... (3475 | Loss: 0.21933865547180176, Learning Rate: 9.399191185366362e-05, Gradient Norm: 0.8972952365875244)
Step... (3500 | Loss: 0.5384829640388489, Learning Rate: 9.394140943186358e-05, Gradient Norm: 3.6333529949188232)
Step... (3525 | Loss: 0.1989884227514267, Learning Rate: 9.389090701006353e-05, Gradient Norm: 0.7745487093925476)
Step... (3550 | Loss: 0.5818794965744019, Learning Rate: 9.384039731230587e-05, Gradient Norm: 2.284414291381836)
Step... (3575 | Loss: 0.24821804463863373, Learning Rate: 9.378989489050582e-05, Gradient Norm: 10.061519622802734)
Step... (3600 | Loss: 0.2603316903114319, Learning Rate: 9.373939246870577e-05, Gradient Norm: 1.8977075815200806)
Step... (3625 | Loss: 0.2044754922389984, Learning Rate: 9.368888277094811e-05, Gradient Norm: 0.8418613076210022)
Step... (3650 | Loss: 0.4845062792301178, Learning Rate: 9.363838034914806e-05, Gradient Norm: 2.5268688201904297)
Step... (3675 | Loss: 0.16883544623851776, Learning Rate: 9.358787792734802e-05, Gradient Norm: 0.9651336669921875)
Step... (3700 | Loss: 0.5434038043022156, Learning Rate: 9.353736822959036e-05, Gradient Norm: 2.9204046726226807)
Step... (3725 | Loss: 0.27952197194099426, Learning Rate: 9.348686580779031e-05, Gradient Norm: 0.7691179513931274)
Step... (3750 | Loss: 0.3095962703227997, Learning Rate: 9.343635611003265e-05, Gradient Norm: 2.2058932781219482)
Step... (3775 | Loss: 0.1995597928762436, Learning Rate: 9.33858536882326e-05, Gradient Norm: 0.8971036672592163)
Step... (3800 | Loss: 0.34432289004325867, Learning Rate: 9.333535126643255e-05, Gradient Norm: 1.8248614072799683)
Step... (3825 | Loss: 0.161933034658432, Learning Rate: 9.328484156867489e-05, Gradient Norm: 0.8913214802742004)
Step... (3850 | Loss: 0.48422130942344666, Learning Rate: 9.323433914687485e-05, Gradient Norm: 2.36710786819458)
Step... (3875 | Loss: 0.1452736109495163, Learning Rate: 9.31838367250748e-05, Gradient Norm: 0.7015187740325928)
Step... (3900 | Loss: 0.2193298637866974, Learning Rate: 9.313332702731714e-05, Gradient Norm: 3.1408514976501465)
Step... (3925 | Loss: 0.18061470985412598, Learning Rate: 9.308282460551709e-05, Gradient Norm: 0.8163980841636658)
Step... (3950 | Loss: 0.3065214455127716, Learning Rate: 9.303232218371704e-05, Gradient Norm: 1.8904063701629639)
Step... (3975 | Loss: 0.19286735355854034, Learning Rate: 9.298181248595938e-05, Gradient Norm: 0.881371796131134)
Step... (4000 | Loss: 0.4714643061161041, Learning Rate: 9.293131006415933e-05, Gradient Norm: 2.6554923057556152)
Step... (4025 | Loss: 0.16019541025161743, Learning Rate: 9.288080764235929e-05, Gradient Norm: 0.6578169465065002)
Step... (4050 | Loss: 0.22082118690013885, Learning Rate: 9.283029794460163e-05, Gradient Norm: 1.89882230758667)
Step... (4075 | Loss: 0.1884138137102127, Learning Rate: 9.277979552280158e-05, Gradient Norm: 0.9139928221702576)
Step... (4100 | Loss: 0.42762622237205505, Learning Rate: 9.272929310100153e-05, Gradient Norm: 2.6736512184143066)
Step... (4125 | Loss: 0.17467239499092102, Learning Rate: 9.267878340324387e-05, Gradient Norm: 0.8163294792175293)
Step... (4150 | Loss: 0.3444820046424866, Learning Rate: 9.262828098144382e-05, Gradient Norm: 2.422739267349243)
Step... (4175 | Loss: 0.2216452807188034, Learning Rate: 9.257777855964378e-05, Gradient Norm: 0.821262001991272)
Step... (4200 | Loss: 0.4990865886211395, Learning Rate: 9.252726886188611e-05, Gradient Norm: 3.928582191467285)
Step... (4225 | Loss: 0.16187021136283875, Learning Rate: 9.247676644008607e-05, Gradient Norm: 0.8650521636009216)
Step... (4250 | Loss: 0.387790322303772, Learning Rate: 9.242626401828602e-05, Gradient Norm: 3.022831916809082)
Step... (4275 | Loss: 0.12846288084983826, Learning Rate: 9.237575432052836e-05, Gradient Norm: 0.6351085305213928)
Step... (4300 | Loss: 0.4357670843601227, Learning Rate: 9.232525189872831e-05, Gradient Norm: 2.2598319053649902)
Step... (4325 | Loss: 0.21033520996570587, Learning Rate: 9.227474947692826e-05, Gradient Norm: 0.8088125586509705)
Step... (4350 | Loss: 0.21965111792087555, Learning Rate: 9.22242397791706e-05, Gradient Norm: 1.4868179559707642)
Step... (4375 | Loss: 0.19302572309970856, Learning Rate: 9.217373735737056e-05, Gradient Norm: 1.0112311840057373)
Training...: 40% 1757/4393 [2:14:52<4:17:22, 5.86s/it]
Step... (4400 | Loss: 0.15398074686527252, Learning Rate: 9.212323493557051e-05, Gradient Norm: 0.7512454986572266)
Step... (4425 | Loss: 0.15918999910354614, Learning Rate: 9.207272523781285e-05, Gradient Norm: 0.7611452341079712)
Step... (4450 | Loss: 0.13304436206817627, Learning Rate: 9.20222228160128e-05, Gradient Norm: 0.6306468844413757)
Step... (4475 | Loss: 0.16709645092487335, Learning Rate: 9.197172039421275e-05, Gradient Norm: 0.8694853186607361)
Step... (4500 | Loss: 0.1157068982720375, Learning Rate: 9.192121069645509e-05, Gradient Norm: 0.5198965072631836)
Step... (4525 | Loss: 0.11202254891395569, Learning Rate: 9.187070827465504e-05, Gradient Norm: 0.776484489440918)
Step... (4550 | Loss: 0.14966146647930145, Learning Rate: 9.182019857689738e-05, Gradient Norm: 0.6543687582015991)
Step... (4575 | Loss: 0.13391637802124023, Learning Rate: 9.176969615509734e-05, Gradient Norm: 0.857903778553009)
Step... (4600 | Loss: 0.1459067016839981, Learning Rate: 9.171918645733967e-05, Gradient Norm: 0.604927659034729)
Step... (4625 | Loss: 0.12901170551776886, Learning Rate: 9.166868403553963e-05, Gradient Norm: 0.6328181028366089)
Step... (4650 | Loss: 0.1495911329984665, Learning Rate: 9.161817433778197e-05, Gradient Norm: 0.8638491034507751)
Step... (4675 | Loss: 0.13798420131206512, Learning Rate: 9.156767191598192e-05, Gradient Norm: 0.7175107002258301)
Step... (4700 | Loss: 0.1002354696393013, Learning Rate: 9.151716949418187e-05, Gradient Norm: 0.7489829659461975)
Step... (4725 | Loss: 0.19791680574417114, Learning Rate: 9.146665979642421e-05, Gradient Norm: 1.143479347229004)
Step... (4750 | Loss: 0.13573862612247467, Learning Rate: 9.141615737462416e-05, Gradient Norm: 1.257540225982666)
Step... (4775 | Loss: 0.12172726541757584, Learning Rate: 9.136565495282412e-05, Gradient Norm: 0.5807085633277893)
Step... (4800 | Loss: 0.10705988109111786, Learning Rate: 9.131514525506645e-05, Gradient Norm: 0.6314404606819153)
Step... (4825 | Loss: 0.13548623025417328, Learning Rate: 9.126464283326641e-05, Gradient Norm: 0.9735293388366699)
Step... (4850 | Loss: 0.12960033118724823, Learning Rate: 9.121414041146636e-05, Gradient Norm: 0.7907338738441467)
Step... (4875 | Loss: 0.07644310593605042, Learning Rate: 9.11636307137087e-05, Gradient Norm: 0.47419390082359314)
Step... (4900 | Loss: 0.09943026304244995, Learning Rate: 9.111312829190865e-05, Gradient Norm: 0.617705225944519)
Step... (4925 | Loss: 0.11053074896335602, Learning Rate: 9.10626258701086e-05, Gradient Norm: 0.6070181727409363)
Step... (4950 | Loss: 0.1151307001709938, Learning Rate: 9.101211617235094e-05, Gradient Norm: 1.1860915422439575)
Step... (4975 | Loss: 0.17635716497898102, Learning Rate: 9.09616137505509e-05, Gradient Norm: 1.0237845182418823)
Step... (5000 | Loss: 0.15253056585788727, Learning Rate: 9.091111132875085e-05, Gradient Norm: 0.6471536755561829)
Step... (5025 | Loss: 0.15132343769073486, Learning Rate: 9.086060163099319e-05, Gradient Norm: 0.7248958349227905)
Step... (5050 | Loss: 0.14301328361034393, Learning Rate: 9.081009920919314e-05, Gradient Norm: 0.64738929271698)
Step... (5075 | Loss: 0.1362609565258026, Learning Rate: 9.075959678739309e-05, Gradient Norm: 0.6385900378227234)
Step... (5100 | Loss: 0.1494288295507431, Learning Rate: 9.070908708963543e-05, Gradient Norm: 0.8640470504760742)
Step... (5125 | Loss: 0.13826008141040802, Learning Rate: 9.065858466783538e-05, Gradient Norm: 0.7304973006248474)
Step... (5150 | Loss: 0.1466275453567505, Learning Rate: 9.060808224603534e-05, Gradient Norm: 0.6393731832504272)
Step... (5175 | Loss: 0.14715760946273804, Learning Rate: 9.055757254827768e-05, Gradient Norm: 0.6790907382965088)
Step... (5200 | Loss: 0.13426756858825684, Learning Rate: 9.050707012647763e-05, Gradient Norm: 0.9185851812362671)
Step... (5225 | Loss: 0.12547633051872253, Learning Rate: 9.045656042871997e-05, Gradient Norm: 0.6549924612045288)
Step... (5250 | Loss: 0.1391286998987198, Learning Rate: 9.040605800691992e-05, Gradient Norm: 0.6910219192504883)
Step... (5275 | Loss: 0.10307434946298599, Learning Rate: 9.035555558511987e-05, Gradient Norm: 0.7248987555503845)
Step... (5300 | Loss: 0.15294413268566132, Learning Rate: 9.030504588736221e-05, Gradient Norm: 0.7696532607078552)
Step... (5325 | Loss: 0.1306358128786087, Learning Rate: 9.025454346556216e-05, Gradient Norm: 0.6660736203193665)
Step... (5350 | Loss: 0.17272095382213593, Learning Rate: 9.020404104376212e-05, Gradient Norm: 0.6890228390693665)
Step... (5375 | Loss: 0.13731196522712708, Learning Rate: 9.015353134600446e-05, Gradient Norm: 1.0845615863800049)
Step... (5400 | Loss: 0.10166216641664505, Learning Rate: 9.010302892420441e-05, Gradient Norm: 1.6075739860534668)
Step... (5425 | Loss: 0.13504694402217865, Learning Rate: 9.005252650240436e-05, Gradient Norm: 1.054123878479004)
Step... (5450 | Loss: 0.09731146693229675, Learning Rate: 9.00020168046467e-05, Gradient Norm: 0.4854496121406555)
Step... (5475 | Loss: 0.12537512183189392, Learning Rate: 8.995151438284665e-05, Gradient Norm: 0.9753578305244446)
Step... (5500 | Loss: 0.12291703373193741, Learning Rate: 8.99010119610466e-05, Gradient Norm: 0.6144852638244629)
Step... (5525 | Loss: 0.14224418997764587, Learning Rate: 8.985050226328894e-05, Gradient Norm: 0.7358378767967224)
Step... (5550 | Loss: 0.13874179124832153, Learning Rate: 8.97999998414889e-05, Gradient Norm: 0.7257264256477356)
Step... (5575 | Loss: 0.13608331978321075, Learning Rate: 8.974949741968885e-05, Gradient Norm: 3.697967290878296)
Step... (5600 | Loss: 0.13917957246303558, Learning Rate: 8.969898772193119e-05, Gradient Norm: 0.6503622531890869)
Step... (5625 | Loss: 0.135927215218544, Learning Rate: 8.964848530013114e-05, Gradient Norm: 0.8281283974647522)
Step... (5650 | Loss: 0.13524308800697327, Learning Rate: 8.95979828783311e-05, Gradient Norm: 0.594698429107666)
Step... (5675 | Loss: 0.13611824810504913, Learning Rate: 8.954747318057343e-05, Gradient Norm: 0.5706479549407959)
Step... (5700 | Loss: 0.13887955248355865, Learning Rate: 8.949697075877339e-05, Gradient Norm: 0.8225830793380737)
Step... (5725 | Loss: 0.12447820603847504, Learning Rate: 8.944646833697334e-05, Gradient Norm: 0.6664973497390747)
Step... (5750 | Loss: 0.17118078470230103, Learning Rate: 8.939595863921568e-05, Gradient Norm: 1.3190544843673706)
Step... (5775 | Loss: 0.15702074766159058, Learning Rate: 8.934545621741563e-05, Gradient Norm: 1.006943702697754)
Step... (5800 | Loss: 0.15159830451011658, Learning Rate: 8.929494651965797e-05, Gradient Norm: 0.8170256614685059)
Step... (5825 | Loss: 0.11707980930805206, Learning Rate: 8.924444409785792e-05, Gradient Norm: 1.4009065628051758)
Step... (5850 | Loss: 0.1458413302898407, Learning Rate: 8.919393440010026e-05, Gradient Norm: 0.9540699124336243)
Step... (5875 | Loss: 0.11078252643346786, Learning Rate: 8.914343197830021e-05, Gradient Norm: 0.6304067373275757)
Step... (5900 | Loss: 0.14289362728595734, Learning Rate: 8.909292955650017e-05, Gradient Norm: 0.6668269634246826)
Step... (5925 | Loss: 0.10349103808403015, Learning Rate: 8.90424198587425e-05, Gradient Norm: 0.8122596740722656)
Step... (5950 | Loss: 0.11653509736061096, Learning Rate: 8.899191743694246e-05, Gradient Norm: 0.6827570199966431)
Step... (5975 | Loss: 0.13345487415790558, Learning Rate: 8.89414077391848e-05, Gradient Norm: 0.6637983322143555)
Step... (6000 | Loss: 0.10303163528442383, Learning Rate: 8.889090531738475e-05, Gradient Norm: 0.6207571625709534)
Step... (6025 | Loss: 0.12258953601121902, Learning Rate: 8.884039561962709e-05, Gradient Norm: 0.5680171251296997)
Step... (6050 | Loss: 0.09622679650783539, Learning Rate: 8.878989319782704e-05, Gradient Norm: 0.5885440707206726)
Step... (6075 | Loss: 0.1243104338645935, Learning Rate: 8.8739390776027e-05, Gradient Norm: 0.6282617449760437)
Step... (6100 | Loss: 0.16120916604995728, Learning Rate: 8.868888107826933e-05, Gradient Norm: 0.6353745460510254)
Training...: 80% 3506/4393 [4:30:01<1:19:46, 5.40s/it]
Step... (6150 | Loss: 0.10419484227895737, Learning Rate: 8.858787623466924e-05, Gradient Norm: 0.6518651843070984)
Step... (6175 | Loss: 0.11813586205244064, Learning Rate: 8.853736653691158e-05, Gradient Norm: 0.6677852272987366)
Step... (6200 | Loss: 0.14711715281009674, Learning Rate: 8.848686411511153e-05, Gradient Norm: 0.5828253030776978)
Step... (6225 | Loss: 0.1347367763519287, Learning Rate: 8.843636169331148e-05, Gradient Norm: 0.569177508354187)
Step... (6250 | Loss: 0.1528175175189972, Learning Rate: 8.838585199555382e-05, Gradient Norm: 0.7070837616920471)
Step... (6275 | Loss: 0.09668218344449997, Learning Rate: 8.833534957375377e-05, Gradient Norm: 0.7941017150878906)
Step... (6300 | Loss: 0.12110703438520432, Learning Rate: 8.828484715195373e-05, Gradient Norm: 0.8058463931083679)
Step... (6325 | Loss: 0.10088901966810226, Learning Rate: 8.823433745419607e-05, Gradient Norm: 0.48208481073379517)
Step... (6350 | Loss: 0.12540197372436523, Learning Rate: 8.818383503239602e-05, Gradient Norm: 0.510494589805603)
Step... (6375 | Loss: 0.13623422384262085, Learning Rate: 8.813333261059597e-05, Gradient Norm: 0.8078494071960449)
Step... (6400 | Loss: 0.10657237470149994, Learning Rate: 8.808282291283831e-05, Gradient Norm: 0.5409206748008728)
Step... (6425 | Loss: 0.14446492493152618, Learning Rate: 8.803232049103826e-05, Gradient Norm: 0.579534113407135)
Step... (6450 | Loss: 0.21205583214759827, Learning Rate: 8.798181806923822e-05, Gradient Norm: 0.851882815361023)
Step... (6475 | Loss: 0.12547239661216736, Learning Rate: 8.793130837148055e-05, Gradient Norm: 0.5652810335159302)
Step... (6500 | Loss: 0.14671802520751953, Learning Rate: 8.788080594968051e-05, Gradient Norm: 0.909660816192627)
Step... (6525 | Loss: 0.15064595639705658, Learning Rate: 8.783030352788046e-05, Gradient Norm: 0.7164238095283508)
Step... (6550 | Loss: 0.14910094439983368, Learning Rate: 8.77797938301228e-05, Gradient Norm: 0.7907290458679199)
Step... (6575 | Loss: 0.10597706586122513, Learning Rate: 8.772929140832275e-05, Gradient Norm: 0.637769341468811)
Step... (6600 | Loss: 0.1302764117717743, Learning Rate: 8.76787889865227e-05, Gradient Norm: 0.7593493461608887)
Step... (6625 | Loss: 0.10654870420694351, Learning Rate: 8.762827928876504e-05, Gradient Norm: 0.6169654130935669)
Step... (6650 | Loss: 0.10767922550439835, Learning Rate: 8.7577776866965e-05, Gradient Norm: 0.5802223682403564)
Step... (6675 | Loss: 0.12482491880655289, Learning Rate: 8.752727444516495e-05, Gradient Norm: 0.5568225383758545)
Step... (6700 | Loss: 0.116365447640419, Learning Rate: 8.747676474740729e-05, Gradient Norm: 0.5982862710952759)
Step... (6725 | Loss: 0.07662271708250046, Learning Rate: 8.742626232560724e-05, Gradient Norm: 0.6089320778846741)
Step... (6750 | Loss: 0.15134650468826294, Learning Rate: 8.737575990380719e-05, Gradient Norm: 0.7963054776191711)
Step... (6775 | Loss: 0.14350146055221558, Learning Rate: 8.732525020604953e-05, Gradient Norm: 0.6256259083747864)
Step... (6800 | Loss: 0.14528611302375793, Learning Rate: 8.727474778424948e-05, Gradient Norm: 0.6886290907859802)
Step... (6825 | Loss: 0.12023483961820602, Learning Rate: 8.722424536244944e-05, Gradient Norm: 0.6041069030761719)
Step... (6850 | Loss: 0.16417433321475983, Learning Rate: 8.717373566469178e-05, Gradient Norm: 0.9486941695213318)
Step... (6875 | Loss: 0.13004860281944275, Learning Rate: 8.712323324289173e-05, Gradient Norm: 0.6948692798614502)
Step... (6900 | Loss: 0.15117710828781128, Learning Rate: 8.707273082109168e-05, Gradient Norm: 0.8188594579696655)
Step... (6925 | Loss: 0.10278861224651337, Learning Rate: 8.702222112333402e-05, Gradient Norm: 0.5591959357261658)
Step... (6950 | Loss: 0.1164846420288086, Learning Rate: 8.697171870153397e-05, Gradient Norm: 0.878705620765686)
Step... (6975 | Loss: 0.11927605420351028, Learning Rate: 8.692121627973393e-05, Gradient Norm: 0.5429350733757019)
Step... (7000 | Loss: 0.10895289480686188, Learning Rate: 8.687070658197626e-05, Gradient Norm: 0.6643750071525574)
Step... (7025 | Loss: 0.08085189759731293, Learning Rate: 8.682020416017622e-05, Gradient Norm: 0.7205637693405151)
Step... (7050 | Loss: 0.12590201199054718, Learning Rate: 8.676969446241856e-05, Gradient Norm: 0.6310210227966309)
Step... (7075 | Loss: 0.09471984952688217, Learning Rate: 8.671919204061851e-05, Gradient Norm: 0.6076512336730957)
Step... (7100 | Loss: 0.11090608686208725, Learning Rate: 8.666868234286085e-05, Gradient Norm: 0.5288007855415344)
Step... (7125 | Loss: 0.11144256591796875, Learning Rate: 8.66181799210608e-05, Gradient Norm: 0.5928792357444763)
Step... (7150 | Loss: 0.15605677664279938, Learning Rate: 8.656767749926075e-05, Gradient Norm: 0.7154860496520996)
Step... (7175 | Loss: 0.1493319422006607, Learning Rate: 8.651716780150309e-05, Gradient Norm: 0.6326310038566589)
Step... (7200 | Loss: 0.12305010110139847, Learning Rate: 8.646666537970304e-05, Gradient Norm: 0.6803693771362305)
Step... (7225 | Loss: 0.15390878915786743, Learning Rate: 8.641615568194538e-05, Gradient Norm: 0.6524022817611694)
Step... (7250 | Loss: 0.10154560953378677, Learning Rate: 8.636565326014534e-05, Gradient Norm: 0.5600380897521973)
Step... (7275 | Loss: 0.11497648060321808, Learning Rate: 8.631515083834529e-05, Gradient Norm: 0.6140745878219604)
Step... (7300 | Loss: 0.0939219743013382, Learning Rate: 8.626464114058763e-05, Gradient Norm: 0.8148360252380371)
Step... (7325 | Loss: 0.13607870042324066, Learning Rate: 8.621413871878758e-05, Gradient Norm: 1.3159329891204834)
Step... (7350 | Loss: 0.11052178591489792, Learning Rate: 8.616363629698753e-05, Gradient Norm: 0.5292452573776245)
Step... (7375 | Loss: 0.16697251796722412, Learning Rate: 8.611312659922987e-05, Gradient Norm: 0.8497327566146851)
Step... (7400 | Loss: 0.21488416194915771, Learning Rate: 8.606262417742983e-05, Gradient Norm: 0.9127271771430969)
Step... (7425 | Loss: 0.10327339172363281, Learning Rate: 8.601211447967216e-05, Gradient Norm: 2.964388370513916)
Step... (7450 | Loss: 0.13917475938796997, Learning Rate: 8.596161205787212e-05, Gradient Norm: 0.6416407823562622)
Step... (7475 | Loss: 0.14600548148155212, Learning Rate: 8.591110963607207e-05, Gradient Norm: 0.6200647354125977)
Step... (7500 | Loss: 0.07706329971551895, Learning Rate: 8.586059993831441e-05, Gradient Norm: 0.4497665464878082)
Step... (7525 | Loss: 0.10475296527147293, Learning Rate: 8.581009751651436e-05, Gradient Norm: 0.6419413089752197)
Step... (7550 | Loss: 0.13535363972187042, Learning Rate: 8.575959509471431e-05, Gradient Norm: 0.7294527292251587)
Step... (7575 | Loss: 0.11786855757236481, Learning Rate: 8.570908539695665e-05, Gradient Norm: 0.7611076235771179)
Step... (7600 | Loss: 0.12319649010896683, Learning Rate: 8.56585829751566e-05, Gradient Norm: 0.7079668045043945)
Step... (7625 | Loss: 0.17103444039821625, Learning Rate: 8.560808055335656e-05, Gradient Norm: 0.6459333300590515)
Step... (7650 | Loss: 0.1777440309524536, Learning Rate: 8.55575708555989e-05, Gradient Norm: 1.1776365041732788)
Step... (7675 | Loss: 0.10296334326267242, Learning Rate: 8.550706843379885e-05, Gradient Norm: 0.5230156779289246)
Step... (7700 | Loss: 0.13559919595718384, Learning Rate: 8.545655873604119e-05, Gradient Norm: 0.6376197338104248)
Step... (7725 | Loss: 0.16620127856731415, Learning Rate: 8.540605631424114e-05, Gradient Norm: 0.8585113286972046)
Step... (7750 | Loss: 0.10593363642692566, Learning Rate: 8.53555538924411e-05, Gradient Norm: 0.4990275204181671)
Step... (7775 | Loss: 0.10826694965362549, Learning Rate: 8.530504419468343e-05, Gradient Norm: 0.6208857297897339)
Step... (7800 | Loss: 0.10973000526428223, Learning Rate: 8.525454177288339e-05, Gradient Norm: 0.6479125022888184)
Step... (7825 | Loss: 0.14727739989757538, Learning Rate: 8.520403935108334e-05, Gradient Norm: 0.7955598831176758)
Step... (7850 | Loss: 0.12908253073692322, Learning Rate: 8.515352965332568e-05, Gradient Norm: 0.6523045301437378)
Epoch ... (1/12): 17% 2/12 [11:29:30<57:16:35, 20619.59s/it]
Step... (7900 | Loss: 0.13779808580875397, Learning Rate: 8.505252480972558e-05, Gradient Norm: 0.7253138422966003)
Step... (7925 | Loss: 0.1015205830335617, Learning Rate: 8.500201511196792e-05, Gradient Norm: 0.5427443385124207)
Step... (7950 | Loss: 0.116566501557827, Learning Rate: 8.495151269016787e-05, Gradient Norm: 0.654643714427948)
Step... (7975 | Loss: 0.09882434457540512, Learning Rate: 8.490101026836783e-05, Gradient Norm: 0.5266492962837219)
Step... (8000 | Loss: 0.13470995426177979, Learning Rate: 8.485050057061017e-05, Gradient Norm: 0.8038025498390198)
Step... (8025 | Loss: 0.15494322776794434, Learning Rate: 8.479999814881012e-05, Gradient Norm: 0.6313944458961487)
Step... (8050 | Loss: 0.09077829122543335, Learning Rate: 8.474949572701007e-05, Gradient Norm: 0.44213569164276123)
Step... (8075 | Loss: 0.13829872012138367, Learning Rate: 8.469898602925241e-05, Gradient Norm: 0.640563428401947)
Step... (8100 | Loss: 0.13687734305858612, Learning Rate: 8.464848360745236e-05, Gradient Norm: 0.7211358547210693)
Step... (8125 | Loss: 0.10291437059640884, Learning Rate: 8.459798118565232e-05, Gradient Norm: 1.2116658687591553)
Step... (8150 | Loss: 0.10889369994401932, Learning Rate: 8.454747148789465e-05, Gradient Norm: 0.5986737012863159)
Step... (8175 | Loss: 0.11770819872617722, Learning Rate: 8.449696906609461e-05, Gradient Norm: 0.5724157691001892)
Step... (8200 | Loss: 0.12205436080694199, Learning Rate: 8.444646664429456e-05, Gradient Norm: 0.7700273394584656)
Step... (8225 | Loss: 0.1144673302769661, Learning Rate: 8.43959569465369e-05, Gradient Norm: 1.5292632579803467)
Step... (8250 | Loss: 0.14595015347003937, Learning Rate: 8.434545452473685e-05, Gradient Norm: 1.5444319248199463)
Step... (8275 | Loss: 0.13142108917236328, Learning Rate: 8.42949521029368e-05, Gradient Norm: 0.5596189498901367)
Step... (8300 | Loss: 0.11982933431863785, Learning Rate: 8.424444240517914e-05, Gradient Norm: 0.48929890990257263)
Step... (8325 | Loss: 0.198286235332489, Learning Rate: 8.41939399833791e-05, Gradient Norm: 1.0410857200622559)
Step... (8350 | Loss: 0.11471763253211975, Learning Rate: 8.414343028562143e-05, Gradient Norm: 0.6537178158760071)
Step... (8375 | Loss: 0.09390740841627121, Learning Rate: 8.409292786382139e-05, Gradient Norm: 0.5476303100585938)
Step... (8400 | Loss: 0.07339417189359665, Learning Rate: 8.404242544202134e-05, Gradient Norm: 0.4375392198562622)
Step... (8425 | Loss: 0.08735157549381256, Learning Rate: 8.399191574426368e-05, Gradient Norm: 0.5681422352790833)
Step... (8450 | Loss: 0.16116836667060852, Learning Rate: 8.394141332246363e-05, Gradient Norm: 0.65557461977005)
Step... (8475 | Loss: 0.2733764350414276, Learning Rate: 8.389090362470597e-05, Gradient Norm: 1.9184274673461914)
Step... (8500 | Loss: 0.08071233332157135, Learning Rate: 8.384040120290592e-05, Gradient Norm: 0.4783284664154053)
Step... (8525 | Loss: 0.14461499452590942, Learning Rate: 8.378989150514826e-05, Gradient Norm: 0.659214437007904)
Step... (8550 | Loss: 0.1168070062994957, Learning Rate: 8.373938908334821e-05, Gradient Norm: 0.6005895137786865)
Step... (8575 | Loss: 0.0821458175778389, Learning Rate: 8.368888666154817e-05, Gradient Norm: 0.4110701382160187)
Step... (8600 | Loss: 0.13768349587917328, Learning Rate: 8.36383769637905e-05, Gradient Norm: 0.5835753083229065)
Step... (8625 | Loss: 0.07855778932571411, Learning Rate: 8.358787454199046e-05, Gradient Norm: 0.4884222745895386)
Step... (8650 | Loss: 0.14811816811561584, Learning Rate: 8.353737212019041e-05, Gradient Norm: 0.6351867914199829)
Step... (8675 | Loss: 0.1343218982219696, Learning Rate: 8.348686242243275e-05, Gradient Norm: 0.5733943581581116)
Step... (8700 | Loss: 0.16386865079402924, Learning Rate: 8.34363600006327e-05, Gradient Norm: 0.553588330745697)
Step... (8725 | Loss: 0.08179602771997452, Learning Rate: 8.338585757883266e-05, Gradient Norm: 1.0848875045776367)
Step... (8750 | Loss: 0.12045776844024658, Learning Rate: 8.3335347881075e-05, Gradient Norm: 0.5341562628746033)
Step... (8775 | Loss: 0.09336375445127487, Learning Rate: 8.328484545927495e-05, Gradient Norm: 0.5085514783859253)
Training...: 28% 1213/4393 [1:34:21<5:03:02, 5.72s/it]
Step... (8800 | Loss: 0.0734051838517189, Learning Rate: 8.32343430374749e-05, Gradient Norm: 0.5301979780197144)
Step... (8825 | Loss: 0.0855112299323082, Learning Rate: 8.318383333971724e-05, Gradient Norm: 0.5376077890396118)
Step... (8850 | Loss: 0.10068231076002121, Learning Rate: 8.313333091791719e-05, Gradient Norm: 1.0139813423156738)
Step... (8875 | Loss: 0.06252378225326538, Learning Rate: 8.308282849611714e-05, Gradient Norm: 0.48730936646461487)
Step... (8900 | Loss: 0.08460403978824615, Learning Rate: 8.303231879835948e-05, Gradient Norm: 0.4558766186237335)
Step... (8925 | Loss: 0.09871721267700195, Learning Rate: 8.298181637655944e-05, Gradient Norm: 0.6362630128860474)
Step... (8950 | Loss: 0.11740045249462128, Learning Rate: 8.293131395475939e-05, Gradient Norm: 0.6272518038749695)
Step... (8975 | Loss: 0.13512758910655975, Learning Rate: 8.288080425700173e-05, Gradient Norm: 0.5893998146057129)
Step... (9000 | Loss: 0.05581226944923401, Learning Rate: 8.283030183520168e-05, Gradient Norm: 0.48278042674064636)
Step... (9025 | Loss: 0.08583104610443115, Learning Rate: 8.277979941340163e-05, Gradient Norm: 0.5676160454750061)
Step... (9050 | Loss: 0.08980081230401993, Learning Rate: 8.272928971564397e-05, Gradient Norm: 0.5202922821044922)
Step... (9075 | Loss: 0.12204743176698685, Learning Rate: 8.267878729384392e-05, Gradient Norm: 0.6033841967582703)
Step... (9100 | Loss: 0.07934217900037766, Learning Rate: 8.262828487204388e-05, Gradient Norm: 0.5386624932289124)
Step... (9125 | Loss: 0.09263342618942261, Learning Rate: 8.257777517428622e-05, Gradient Norm: 0.49038422107696533)
Step... (9150 | Loss: 0.09809651970863342, Learning Rate: 8.252727275248617e-05, Gradient Norm: 0.5518418550491333)
Step... (9175 | Loss: 0.08282437175512314, Learning Rate: 8.247677033068612e-05, Gradient Norm: 0.6772879958152771)
Step... (9200 | Loss: 0.09388335794210434, Learning Rate: 8.242626063292846e-05, Gradient Norm: 0.8972507119178772)
Step... (9225 | Loss: 0.11405259370803833, Learning Rate: 8.237575821112841e-05, Gradient Norm: 0.5674875974655151)
Step... (9250 | Loss: 0.09774433821439743, Learning Rate: 8.232525578932837e-05, Gradient Norm: 0.9340572357177734)
Step... (9275 | Loss: 0.07160411030054092, Learning Rate: 8.22747460915707e-05, Gradient Norm: 0.7407450675964355)
Step... (9300 | Loss: 0.09505634754896164, Learning Rate: 8.222424366977066e-05, Gradient Norm: 0.6917828917503357)
Step... (9325 | Loss: 0.12817955017089844, Learning Rate: 8.217374124797061e-05, Gradient Norm: 1.1176568269729614)
Step... (9350 | Loss: 0.11201604455709457, Learning Rate: 8.212323155021295e-05, Gradient Norm: 0.725493848323822)
Step... (9375 | Loss: 0.16065749526023865, Learning Rate: 8.20727291284129e-05, Gradient Norm: 0.9583228230476379)
Step... (9400 | Loss: 0.076621875166893, Learning Rate: 8.202222670661286e-05, Gradient Norm: 0.5754097700119019)
Step... (9425 | Loss: 0.10213085263967514, Learning Rate: 8.19717170088552e-05, Gradient Norm: 0.740641713142395)
Step... (9450 | Loss: 0.09721554070711136, Learning Rate: 8.192121458705515e-05, Gradient Norm: 0.5298651456832886)
Step... (9475 | Loss: 0.08753546327352524, Learning Rate: 8.18707121652551e-05, Gradient Norm: 0.508285403251648)
Step... (9500 | Loss: 0.07296166568994522, Learning Rate: 8.182020246749744e-05, Gradient Norm: 0.8791430592536926)
Step... (9525 | Loss: 0.1059366762638092, Learning Rate: 8.176970004569739e-05, Gradient Norm: 0.9860773086547852)
Step... (9550 | Loss: 0.09692548960447311, Learning Rate: 8.171919034793973e-05, Gradient Norm: 0.4978175759315491)
Step... (9575 | Loss: 0.11797724664211273, Learning Rate: 8.166868792613968e-05, Gradient Norm: 0.6675213575363159)
Step... (9600 | Loss: 0.13216069340705872, Learning Rate: 8.161817822838202e-05, Gradient Norm: 19.723308563232422)
Step... (9625 | Loss: 0.08851467072963715, Learning Rate: 8.156767580658197e-05, Gradient Norm: 0.5197475552558899)
Step... (9650 | Loss: 0.08510927855968475, Learning Rate: 8.151717338478193e-05, Gradient Norm: 0.6016721725463867)
Step... (9675 | Loss: 0.07621399313211441, Learning Rate: 8.146666368702427e-05, Gradient Norm: 0.5809992551803589)
Step... (9700 | Loss: 0.10198916494846344, Learning Rate: 8.141616126522422e-05, Gradient Norm: 0.552069902420044)
Step... (9725 | Loss: 0.12253397703170776, Learning Rate: 8.136565156746656e-05, Gradient Norm: 0.5788819789886475)
Step... (9750 | Loss: 0.07940845936536789, Learning Rate: 8.131514914566651e-05, Gradient Norm: 0.5337548851966858)
Step... (9775 | Loss: 0.07598816603422165, Learning Rate: 8.126463944790885e-05, Gradient Norm: 0.5351502299308777)
Step... (9800 | Loss: 0.07974324375391006, Learning Rate: 8.12141370261088e-05, Gradient Norm: 0.4442839026451111)
Step... (9825 | Loss: 0.07073494046926498, Learning Rate: 8.116362732835114e-05, Gradient Norm: 0.517856776714325)
Step... (9850 | Loss: 0.08271614462137222, Learning Rate: 8.111312490655109e-05, Gradient Norm: 0.7158600687980652)
Step... (9875 | Loss: 0.12278185784816742, Learning Rate: 8.106262248475105e-05, Gradient Norm: 0.8880561590194702)
Step... (9900 | Loss: 0.10471830517053604, Learning Rate: 8.101211278699338e-05, Gradient Norm: 1.0475046634674072)
Step... (9925 | Loss: 0.07849471271038055, Learning Rate: 8.096161036519334e-05, Gradient Norm: 0.5359142422676086)
Step... (9950 | Loss: 0.0894748792052269, Learning Rate: 8.091110794339329e-05, Gradient Norm: 0.5352652072906494)
Step... (9975 | Loss: 0.1043805480003357, Learning Rate: 8.086059824563563e-05, Gradient Norm: 0.7022120952606201)
Training...: 28% 1213/4393 [1:34:27<5:03:02, 5.72s/it]
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:291: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:291: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
batch_sizes |= {t.shape[0] for t in jax.tree_leaves(a)}
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:312: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(pad, tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
run_flax_speech_recognition_seq2seq.py:1308: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
loss = jax.tree_map(lambda l: l / total_samples, loss)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:112: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
scopes, treedef = jax.tree_flatten(scope_tree)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/lift.py:729: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
lengths = set(jax.tree_leaves(lengths))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:134: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
in_avals, in_tree = jax.tree_flatten(input_avals)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/linen/transforms.py:249: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
jax.tree_leaves(tree)))
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/core/axes_scan.py:146: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1127: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
cache = jax.tree_map(
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1038: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(gather_fn, nested)
/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/models/modeling_flax_speech_encoder_decoder.py:1213: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/jax_utils.py:321: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return out if static_return else jax.tree_map(unpad, out)
Evaluating ...: 100% 85/85 [23:11<00:00, 16.37s/it]
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/training/common_utils.py:51: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
/home/sanchitgandhi/hf/lib/python3.8/site-packages/flax/training/common_utils.py:45: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(stack_args, *forest)
run_flax_speech_recognition_seq2seq.py:1392: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
run_flax_speech_recognition_seq2seq.py:336: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
Epoch ... (1/12): 17% 2/12 [13:27:12<57:16:35, 20619.59s/it]
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
Configuration saved in /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/config.json
tcmalloc: large alloc 1226489856 bytes == 0x35593c000 @ 0x7f7cba873680 0x7f7cba893bdd 0x7f7b690721ff 0x7f7b6908142c 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6907c164 0x7f7b6907c91e 0x505166 0x56bbfa 0x569dba 0x5f6eb3 0x56cc1f 0x569dba 0x5f6eb3 0x56cc1f 0x5f6cd6 0x56bacd 0x569dba 0x50bca0 0x56cc1f 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3
tcmalloc: large alloc 2586787840 bytes == 0x3b7a32000 @ 0x7f7cba873680 0x7f7cba893bdd 0x7f7b690721ff 0x7f7b6908142c 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6908241d 0x7f7b6907c164 0x7f7b6907c91e 0x505166 0x56bbfa 0x569dba 0x5f6eb3 0x56cc1f 0x569dba 0x5f6eb3 0x56cc1f 0x5f6cd6 0x56bacd 0x569dba 0x50bca0 0x56cc1f 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3
tcmalloc: large alloc 2353618944 bytes == 0x452526000 @ 0x7f7cba873680 0x7f7cba894824 0x5fb391 0x7f7b6907c209 0x7f7b6907c91e 0x505166 0x56bbfa 0x569dba 0x5f6eb3 0x56cc1f 0x569dba 0x5f6eb3 0x56cc1f 0x5f6cd6 0x56bacd 0x569dba 0x50bca0 0x56cc1f 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x6902a7 0x67f951 0x67f9cf 0x67fa71 0x681b97 0x6b9d32
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Model weights saved in /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/flax_model.msgpack
tokenizer config file saved in ./tokenizer_config.json
Special tokens file saved in ./special_tokens_map.json
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible