|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. |
|
warnings.warn( |
|
Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s] |
|
Training...: 0%| | 0/503952 [01:49<?, ?it/s] |
|
Epoch ... (1/3): 0%| | 0/3 [09:38<?, ?it/s] |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax_no_accum.py", line 673, in <module> |
|
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback |
|
return fun(*args, **kwargs) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1669, in f_pmapped |
|
out = pxla.xla_pmap( |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind |
|
return call_bind(self, fun, *args, **params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind |
|
outs = primitive.process(top_trace, fun, tracers, params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process |
|
return trace.process_map(self, fun, tracers, params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call |
|
return primitive.impl(f, *tracers, **params) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl |
|
return compiled_fun(*args) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated |
|
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs) |
|
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Attempting to reserve 13.42G at the bottom of memory. That was not possible. There are 7.27G free, 0B reserved, and 7.22G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well). |
|
The stack trace below excludes JAX-internal frames. |
|
The preceding is the original exception that occurred, unmodified. |
|
-------------------- |
|
The above exception was the direct cause of the following exception: |
|
Traceback (most recent call last): |
|
File "./run_mlm_flax_no_accum.py", line 673, in <module> |
|
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) |
|
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated |
|
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs) |
|
RuntimeError: Resource exhausted: Attempting to reserve 13.42G at the bottom of memory. That was not possible. There are 7.27G free, 0B reserved, and 7.22G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well). |