|
/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3132: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the ///google/ |
|
lax._check_user_dtype_supported(dtype, "zeros") |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id ; please update your code. |
|
warnings.warn( |
|
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Training...: 0%|▏ | 50/25197 [01:48<3:11:36, 2.19it/s] |
|
|
|
|
|
|
|
|
|
|
|
|
|
Epoch ... (1/5): 0%| | 0/5 [02:45<?, ?it/s] |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax.py", line 810, in <module> |
|
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] |
|
|