File size: 3,395 Bytes
f291f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
Epoch ... (1/3):   0%|                                                                                                        | 0/3 [00:00<?, ?it/s]
Training...:   0%|                                                                                                       | 0/503952 [01:39<?, ?it/s]
Epoch ... (1/3):   0%|                                                                                                        | 0/3 [09:25<?, ?it/s]
Traceback (most recent call last):
  File "./run_mlm_flax_no_accum.py", line 684, in <module>
    state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1669, in f_pmapped
    out = pxla.xla_pmap(
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process
    return trace.process_map(self, fun, tracers, params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl
    return compiled_fun(*args)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Attempting to reserve 11.03G at the bottom of memory. That was not possible. There are 7.51G free, 0B reserved, and 7.45G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "./run_mlm_flax_no_accum.py", line 684, in <module>
    state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
  File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1159, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Resource exhausted: Attempting to reserve 11.03G at the bottom of memory. That was not possible. There are 7.51G free, 0B reserved, and 7.45G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).