/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lax._check_user_dtype_supported(dtype, "zeros") /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. warnings.warn( /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( Epoch ... (1/5): 0%| | 0/5 [00:00 state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1647, in f_pmapped out = pxla.xla_pmap( File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind return call_bind(self, fun, *args, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process return trace.process_map(self, fun, tracers, params) File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size, File "/home/dat/pino/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 899, in parallel_callable compiled = xla.backend_compile(backend, built, compile_options) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 20.61G of 15.48G hbm. Exceeded hbm capacity by 5.13G. Total hbm usage >= 21.13G: reserved 530.00M program 20.61G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 20.61G: global 900.0K scoped 924.0K HLO temp 20.61G (63.0% utilization: Unpadded (12.43G) Padded (19.71G), 4.4% fragmentation (918.84M)) Largest program allocations in hbm: 1. Size: 1.54G Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175 Shape: bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} Unpadded size: 1.54G Extra memory due to padding: 64.0K (1.0x expansion) XLA label: %fusion.3615.remat4 = bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} fusion(bf16[50358,768]{1,0:T(8,128)(2,1)} %get-tuple-element.22628, f32[768]{0:T(1024)} %fusion.10158, f32[768]{0:T(1024)} %fusion.10159, f32[4,4096]{1,0:T(4,128)} %get-tuple-element.20129, f32[... Allocation type: HLO temp ========================== 2. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2444.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.2444.remat_compressed) Allocation type: HLO temp ========================== 3. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2454.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2804, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7916, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 4. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2453.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2803, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7915, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 5. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2452.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2802, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7914, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 6. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2451.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2801, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7913, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 7. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2445 = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2795, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7907, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} %get-tuple-element.20342, f32[4,12,60,64,192]{3,4,2,1... Allocation type: HLO temp ========================== 8. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2443.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.2443.remat_compressed) Allocation type: HLO temp ========================== 9. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2450.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2800, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7912, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 10. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2449.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2799, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7911, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 11. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2448.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2798, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7910, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 12. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2447.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2797, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7909, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 13. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2446.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2796, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7908, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 14. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2689.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14362, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2964), kind=kO... Allocation type: HLO temp ========================== 15. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2690.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14296, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2962), kind=kO... Allocation type: HLO temp ========================== 16. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2688.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14428, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2966), kind=kO... Allocation type: HLO temp ========================== 17. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2691.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14230, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2960), kind=kO... Allocation type: HLO temp ========================== 18. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2692.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14164, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2958), kind=kO... Allocation type: HLO temp ========================== 19. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2693.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14098, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2956), kind=kO... Allocation type: HLO temp ========================== 20. Size: 270.00M Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2616.remat_uncompressed = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} copy(f32[4,12,60,64,192]{4,3,2,1,0:T(8,128)} %fusion.2616.remat_compressed) Allocation type: HLO temp ========================== The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "./run_mlm_flax.py", line 709, in state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/xla.py", line 360, in backend_compile return backend.compile(built_c, compile_options=options) RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 20.61G of 15.48G hbm. Exceeded hbm capacity by 5.13G. Total hbm usage >= 21.13G: reserved 530.00M program 20.61G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 20.61G: global 900.0K scoped 924.0K HLO temp 20.61G (63.0% utilization: Unpadded (12.43G) Padded (19.71G), 4.4% fragmentation (918.84M)) Largest program allocations in hbm: 1. Size: 1.54G Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175 Shape: bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} Unpadded size: 1.54G Extra memory due to padding: 64.0K (1.0x expansion) XLA label: %fusion.3615.remat4 = bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} fusion(bf16[50358,768]{1,0:T(8,128)(2,1)} %get-tuple-element.22628, f32[768]{0:T(1024)} %fusion.10158, f32[768]{0:T(1024)} %fusion.10159, f32[4,4096]{1,0:T(4,128)} %get-tuple-element.20129, f32[... Allocation type: HLO temp ========================== 2. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2444.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.2444.remat_compressed) Allocation type: HLO temp ========================== 3. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2454.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2804, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7916, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 4. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2453.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2803, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7915, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 5. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2452.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2802, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7914, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 6. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2451.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2801, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7913, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 7. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2445 = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2795, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7907, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} %get-tuple-element.20342, f32[4,12,60,64,192]{3,4,2,1... Allocation type: HLO temp ========================== 8. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2443.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.2443.remat_compressed) Allocation type: HLO temp ========================== 9. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2450.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2800, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7912, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 10. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2449.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2799, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7911, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 11. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2448.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2798, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7910, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 12. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2447.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2797, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7909, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 13. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.2446.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.2796, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.7908, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)... Allocation type: HLO temp ========================== 14. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2689.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14362, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2964), kind=kO... Allocation type: HLO temp ========================== 15. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2690.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14296, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2962), kind=kO... Allocation type: HLO temp ========================== 16. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2688.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14428, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2966), kind=kO... Allocation type: HLO temp ========================== 17. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2691.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14230, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2960), kind=kO... Allocation type: HLO temp ========================== 18. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2692.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14164, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2958), kind=kO... Allocation type: HLO temp ========================== 19. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2693.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.20556, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14098, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.2956), kind=kO... Allocation type: HLO temp ========================== 20. Size: 270.00M Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.2616.remat_uncompressed = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} copy(f32[4,12,60,64,192]{4,3,2,1,0:T(8,128)} %fusion.2616.remat_compressed) Allocation type: HLO temp ==========================