Step 1 | loss:0.6448072195053101 lr:2e-05 tokens_per_second_per_gpu:42.45478579156796 grad_norm:DTensor(local_tensor=35.5, device_mesh=DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7]), placements=(_NormPartial(reduce_op='sum', norm_type=2.0),)) Step 2 | loss:0.4103721082210541 lr:2e-05 tokens_per_second_per_gpu:148.2067891058881 grad_norm:DTensor(local_tensor=30.125, device_mesh=DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7]), placements=(_NormPartial(reduce_op='sum', norm_type=2.0),)) Step 3 | loss:0.21746155619621277 lr:2e-05 tokens_per_second_per_gpu:169.78215214207822 grad_norm:DTensor(local_tensor=11.3125, device_mesh=DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7]), placements=(_NormPartial(reduce_op='sum', norm_type=2.0),)) Step 4 | loss:0.38765090703964233 lr:2e-05 tokens_per_second_per_gpu:179.82346623590314 grad_norm:DTensor(local_tensor=14.875, device_mesh=DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7]), placements=(_NormPartial(reduce_op='sum', norm_type=2.0),))