python3 -c "import jax; print(jax.devices())" ./run_mlm_flax_stream.py \ --output_dir="${MODEL_DIR}" \ --model_type="roberta" \ --config_name="${MODEL_DIR}" \ --tokenizer_name="${MODEL_DIR}" \ --dataset_name="mc4" \ --dataset_config_name="hi" \ --max_seq_length="256" \ --per_device_train_batch_size="128" \ --per_device_eval_batch_size="128" \ --learning_rate="3e-4" \ --warmup_steps="10000" \ --overwrite_output_dir \ --adam_beta1="0.9" \ --adam_beta2="0.98" \ --num_train_steps="10000" \ --num_eval_samples="5000" \ --preprocessing_num_workers="90" \ --logging_steps="250" \ --eval_steps="1000"