/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. warnings.warn( /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( Epoch ... (1/3): 0%| | 0/3 [00:00 train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) File "./run_mlm_flax_no_accum.py", line 255, in generate_batch_splits batch_idx = np.split(samples_idx, sections_split) File "<__array_function__ internals>", line 5, in split File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 874, in split return array_split(ary, indices_or_sections, axis) File "<__array_function__ internals>", line 5, in array_split File "/home/dat/pino/lib/python3.8/site-packages/numpy/lib/shape_base.py", line 790, in array_split sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0)) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5009, in _rewriting_take return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5028, in _gather y = lax.gather( File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 984, in gather return gather_p.bind( File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 264, in bind out = top_trace.process_primitive(self, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 603, in process_primitive return primitive.impl(*tracers, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 249, in apply_primitive return compiled_fun(*args) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive out_bufs = compiled.execute(input_bufs) RuntimeError: Resource exhausted: Attempting to allocate 17.0K. That was not possible. There are 48.0K free. Due to fragmentation, the largest contiguous region of free memory is 16.0K.; (0x0x0_HBM0)