dat
Saving weights and logs at step 1252
f291f93
/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "zeros")
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
Training...: 1%|▏ | 502/92767 [07:26<137:02:44, 5.35s/it]
Training...: 1%|▍ | 1006/92767 [13:17<34:59:10, 1.37s/it]
Training...: 2%|▋ | 1500/92767 [18:44<17:46:50, 1.43it/s]
Training...: 2%|▊ | 2000/92767 [24:35<17:40:47, 1.43it/s]
Training...: 3%|█ | 2502/92767 [30:50<133:42:23, 5.33s/it]
Training...: 3%|█▎ | 3000/92767 [36:17<17:29:26, 1.43it/s]
Epoch ... (1/5): 0%| | 0/5 [42:04<?, ?it/s]
Traceback (most recent call last):
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1647, in f_pmapped
out = pxla.xla_pmap(
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl
return compiled_fun(*args)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1152, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
KeyboardInterrupt
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "./run_mlm_flax.py", line 712, in <module>
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)