dat
update all
731244f
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]2021-07-14 23:26:04.701487: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0)
Epoch ... (1/5): 0%| | 0/5 [14:02<?, ?it/s]
Traceback (most recent call last):
File "./run_mlm_flax.py", line 806, in <module>
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
File "./run_mlm_flax.py", line 263, in generate_batch_splits
batch_idx = np.split(samples_idx, sections_split)
File "<__array_function__ internals>", line 5, in split
File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 874, in split
return array_split(ary, indices_or_sections, axis)
File "<__array_function__ internals>", line 5, in array_split
File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 790, in array_split
sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0))
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5009, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5028, in _gather
y = lax.gather(
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 984, in gather
return gather_p.bind(
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 249, in apply_primitive
return compiled_fun(*args)
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
RuntimeError: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0)