TPUv4-32

#1
by versae - opened

Hi,

Congrats on the development of the Finnish models! We were looking into how to use EasyLM to pre-train 7B Llama2 models on TPUv4-32 and came across this repo with a modified version of EasyLM. Is the pretrain script ready to be launched on TPU VM pods? Did you try to run anything bigger than a 3B? We are not sure how the sharding should go for a TPUv4-32.

Cheers.

Finnish-NLP org

Hi @versae and thanks! This model is being trained on TPUv4-32 so you can just use the script to launch the training :) The argument --jax_distributed.initialize_jax_distributed=True in the script enables training on pods. There is also 7B model being pretrained, to get it running you need to change the --mesh_dim='1,-1,1' to --mesh_dim='1,-1,4' but let me know if you find some alternative better configuration to get 7B scale model training on v4-32.

Thanks @aapot ! Really appreciate the tips. We'll test it right away and report back :) One last question, to save model checkpoints during training, did you just create a shared NFS disk or something similar?

Finnish-NLP org

You can set address of your GCS bucket with --logger.output_dir and EasyLM will save checkpoints there

Thanks! I'll try it out :)

Well, the code works like charm! Thank you so much for sharing it.

One last question. I need to do exactly 2 epochs for an experiment, but EasyLM only allows to specify train steps (I think). In order to calculate the number of steps I need to know the number of tokens in my dataset (20B), the seq len (2048), the batch size (128), and the TPU configuration. Since I'm training on a v4-32, would that mean that 1 epoch = 20B / (2048 * 128 * 32) = 2384 steps? However, in your code I can see that after 5000 steps only around 1.3B tokens are consumed.

{'accuracy': array(0.2635248, dtype=float32),
 'dataset_example_index': 1144022,
 'dataset_total_tokens': 1310720000,
 'epoch': 0,
 'gradient_norm': array(0.92961913, dtype=float32),
 'learning_rate': array(8.333329e-06, dtype=float32),
 'loss': array(4.5614524, dtype=float32),
 'param_norm': array(808.38196, dtype=float32),
 'step': 5000}

Is it the case that batch size is the global batch size? Hence, needing 20B / (2048 * 128) = 76293 steps per epoch? That would make (1310720000 / 5000) * 76293 ~ 20B.

Do you have any estimate on how many tokens you can fit in a batch?

Finnish-NLP org
edited Mar 14

Good question, you could confirm that from the original authors at https://github.com/young-geng/EasyLM. I interpreted the batch size to be global but I remember wondering about the same thing as you a while back, so it would be good to clarify it or figure it out ourselves if original authors won't reply. The training code should be quite easily modifiable so it should be possible to also make it stop at 2 epochs instead of given number of steps. With the TPUv4-32, 7B llama model and --mesh_dim='1,-1,4' I was able to use batch size 64 with seq length 2048.

Got confirmation from the author:
image.png

I'll test '1,-1,4', but also '1, 32, 1' for smaller models as suggested by the creator of EasyDeL @erfanzar .

Thanks again!

Finnish-NLP org

Cool, I have also had in mind to try EasyDel library but haven't had time yet

aapot changed discussion status to closed

Sign up or log in to comment