/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. warnings.warn( /home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( Epoch ... (1/5): 0%| | 0/5 [00:00= 20.82G: reserved 530.00M program 20.30G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 20.30G: global 660.0K scoped 125.0K HLO temp 20.30G (63.5% utilization: Unpadded (12.44G) Padded (19.60G), 3.5% fragmentation (717.54M)) Largest program allocations in hbm: 1. Size: 1.54G Operator: op_type="dot_general" op_name="pmap(train_step)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/pino/lib/python3.8/site-packages/flax/linen/linear.py" source_line=175 Shape: bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} Unpadded size: 1.54G Extra memory due to padding: 64.0K (1.0x expansion) XLA label: %fusion.1304.remat4 = bf16[4,4096,50358]{1,2,0:T(8,128)(2,1)} fusion(bf16[50358,768]{1,0:T(8,128)(2,1)} %copy.16213, f32[768]{0:T(1024)} %fusion.8859, f32[768]{0:T(1024)} %fusion.8860, f32[4,4096]{1,0:T(4,128)} %get-tuple-element.16597, f32[4,4096]{1,0:T(4... Allocation type: HLO temp ========================== 2. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.135 = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.485, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5710, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} %get-tuple-element.16812, f32[4,12,60,64,192]{3,4,2,1,0... Allocation type: HLO temp ========================== 3. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.144.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.494, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5719, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 4. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.143.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.493, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5718, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 5. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.142.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.492, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5717, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 6. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.141.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.491, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5716, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 7. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.134.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.134.remat_compressed) Allocation type: HLO temp ========================== 8. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.140.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.490, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5715, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 9. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.139.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.489, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5714, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 10. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.138.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.488, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5713, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 11. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.137.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.487, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5712, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 12. Size: 360.00M Operator: op_type="div" op_name="pmap(train_step)/div" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=619 Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.136.remat = (bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}, bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)}) fusion(f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.486, f32[4,12,60,64]{3,2,1,0:T(8,128)} %fusion.5711, f32[4,12,60,64,64]{3,4,2,1,0:T(8,128)} ... Allocation type: HLO temp ========================== 13. Size: 360.00M Shape: bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} Unpadded size: 180.00M Extra memory due to padding: 180.00M (2.0x expansion) XLA label: %fusion.133.remat_uncompressed = bf16[4,12,60,64,512]{3,4,2,1,0:T(8,128)(2,1)} copy(bf16[4,12,60,64,512]{4,3,2,1,0:T(8,128)(2,1)} %fusion.133.remat_compressed) Allocation type: HLO temp ========================== 14. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.378.remat5 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.17038, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14428, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.655), kind=kOut... Allocation type: HLO temp ========================== 15. Size: 270.00M Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.310.remat_uncompressed = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} copy(f32[4,12,60,64,192]{4,3,2,1,0:T(8,128)} %fusion.310.remat_compressed) Allocation type: HLO temp ========================== 16. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.386.remat6 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.17038, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.13900, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.639), kind=kOut... Allocation type: HLO temp ========================== 17. Size: 270.00M Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.326.remat_uncompressed.remat2 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} copy(f32[4,12,60,64,192]{4,3,2,1,0:T(8,128)} %fusion.326.remat_compressed) Allocation type: HLO temp ========================== 18. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=591 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.10361 = (f32[4,12,60,64]{3,2,1,0:T(8,128)}, f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)}) fusion(s32[4,12,62,64,192]{3,4,2,1,0:T(8,128)} %get-tuple-element.18295, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14494, bf16[4,12,60,192,64]{3,2,1,0,4:T... Allocation type: HLO temp ========================== 19. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.380.remat5 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.17038, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14296, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.651), kind=kOut... Allocation type: HLO temp ========================== 20. Size: 270.00M Operator: op_type="dot_general" op_name="pmap(train_step)/jit(jvp(_einsum))/dot_general[ dimension_numbers=(((4,), (4,)), ((0, 1, 2), (0, 1, 2)))\n precision=None\n preferred_element_type=None ]" source_file="/home/dat/transformers/src/transformers/models/big_bird/modeling_flax_big_bird.py" source_line=584 Shape: f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} Unpadded size: 135.00M Extra memory due to padding: 135.00M (2.0x expansion) XLA label: %fusion.379.remat3 = f32[4,12,60,64,192]{3,4,2,1,0:T(8,128)} fusion(f32[4,60,64,192]{2,3,1,0:T(8,128)} %get-tuple-element.17038, bf16[4,12,64,64,64]{4,3,2,1,0:T(8,128)(2,1)} %copy.14362, bf16[4,12,60,192,64]{3,2,4,1,0:T(8,128)(2,1)} %fusion.653), kind=kOut... Allocation type: HLO temp ========================== During handling of the above exception, another exception occurred: Traceback (most recent call last): File "./run_mlm_flax.py", line 709, in state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) KeyboardInterrupt