[21:35:52] - INFO - absl - A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`. /home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3132: UserWarning: Explicitly requested dtype requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "zeros") [21:35:52] - INFO - __main__ - RESTORING CHECKPOINT FROM ./... tcmalloc: large alloc 1530273792 bytes == 0xd8c8e000 @ 0x7f41b5469680 0x7f41b548a824 0x5f7b11 0x648631 0x5c38e6 0x4f30e6 0x64ee88 0x505653 0x56acb6 0x568d9a 0x5f5b33 0x56aadf 0x568d9a 0x68cdc7 0x67e161 0x67e1df 0x67e281 0x67e627 0x6b6e62 0x6b71ed 0x7f41b527e0b3 0x5f96de restoring state of multisteps optimizer [21:35:55] - INFO - __main__ - checkpoint restored Traceback (most recent call last): File "./run_mlm_flax.py", line 712, in state = restore_model_checkpoint(training_args.resume_from_checkpoint, state) File "./run_mlm_flax.py", line 314, in restore_model_checkpoint inner_opt_state = reinstantiate_states(opt_state.inner_opt_state) File "./run_mlm_flax.py", line 294, in reinstantiate_states cls = getattr(optax, type(state).__name__) AttributeError: module 'optax' has no attribute 'list'