versae's picture
Saving weights and logs of step 1000
9c3de9e
2022-01-14 22:15:40.254500: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-01-14 22:15:40.254546: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
Epoch ... (1/3): 0%| | 0/3 [00:00<?, ?it/s]
Training...: 0%| | 0/39919 [02:25<?, ?it/s]
Epoch ... (1/3): 0%| | 0/3 [03:13<?, ?it/s]
Traceback (most recent call last):
File "run_mlm_flax.py", line 815, in <module>
main()
File "run_mlm_flax.py", line 723, in main
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/data/flax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2058, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 1934, in f_pmapped
out = pxla.xla_pmap(
File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1727, in bind
return call_bind(self, fun, *args, **params)
File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1652, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 1730, in process
return trace.process_map(self, fun, tracers, params)
File "/data/flax/lib/python3.8/site-packages/jax/core.py", line 633, in process_call
return primitive.impl(f, *tracers, **params)
File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 778, in xla_pmap_impl
return compiled_fun(*args)
File "/data/flax/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_mlm_flax.py", line 815, in <module>
main()
File "run_mlm_flax.py", line 723, in main
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
File "/data/flax/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1502, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 12.83G at the bottom of memory. That was not possible. There are 13.18G free, 0B reserved, and 12.71G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).