/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3132: UserWarning: Explicitly requested dtype requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "zeros") /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. warnings.warn( /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( Epoch ... (1/5): 0%| | 0/5 [00:00 state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1669, in f_pmapped out = pxla.xla_pmap( File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind return call_bind(self, fun, *args, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process return trace.process_map(self, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl return compiled_fun(*args) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated out_bufs = compiled.execute_sharded_on_local_devices(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 5.86G free, 0B reserved, and 5.65G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well). The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "./run_mlm_flax.py", line 804, in state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated out_bufs = compiled.execute_sharded_on_local_devices(input_bufs) RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 5.86G free, 0B reserved, and 5.65G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).