gpt2-medium-4096 is a 380M parameter transformer model based on GPT2. Trained from scratch on 13.5 billion tokens from a subset (80%) of SlimPajama-627B.
This model is meant to be a basis for further experiments, particularly fine-tuning on phi-style data and iterative (daily) training. It is possible to fine-tune this model on a recent nVidia GPU w/ 12 GB of RAM.
Parameters were chosen by:
- Started w/ gpt2-medium's parameters
- Extended context length to 4096, the largest size I could fit in VRAM on a 12 GiB GPU. Fitting this model in VRAM on a 12 GiB GPU requires batch_size=1 and 4-bit (AdamW) or 8-bit (Paged AdamW) optimizers.
- Raised n_layer slightly to use remaining free VRAM.
gpt2-medium | gpt2-medium-4096 | gpt2-large | |
---|---|---|---|
n_layer | 24 | 26 | 36 |
n_head | 16 | 16 | 20 |
n_embed | 1024 | 1024 | 1280 |
n_ctx | 1024 | 4096 | 1024 |
Trained on an RTX 3060 12 GB locally and on Portland State University's Coeus cluster on 2 x RTX A5000 using DDP. Training took five days on the cluster followed by YY days on the RTX 3060.
Training on the RTX 3060 was with batch_size=1, 4-bit AdamW. Training on Coeus (2 x RTX A5000) used batch_size=2, full AdamW optimizer. Trained in 'float16' rather than 'bfloat16'. Learning rate ramped up 6e-5 to 4e-4 over the first 3000 iterations (786M tokens) and stayed at 4e-4 for the next 11.7B tokens (w/ a very slight cooling, cosine falloff). Then LR was dropped to a constant 9e-5, for the next 1 B tokens. The first 12.5 B tokens were from a 50% subset of SlimPajama-627B, the next YY B were from a different 30% subset. The optimizer was switched to 8-bit AdamW as well.
Evaluations
gpt2-medium | gpt2-medium-4096 | |
---|---|---|
hellaswag | 0.3327 | 0.3095 |
Evaluation curve
Iters | val loss | hellaswag |
---|---|---|
5500 | 3.2508 | 0.2698 |
16100 | 2.7633 | 0.2856 |
19700 | 2.7520 | 0.2891 |
28200 | 2.7155 | 0.2917 |
29900 | 2.6846 | 0.2922 |
31000 | 2.6607 | 0.2949 |
36900 | 2.6366 | 0.2965 |
47900 | 2.6818 | 0.2992 |
49000 | 2.6967 | 0.3058 |
50800 | 2.4078 | 0.3079 |
51650 | 2.4898 | 0.3095 |
@misc{cerebras2023slimpajama,
author = {Soboleva, Daria and Al-Khateeb, Faisal and Myers, Robert and Steeves, Jacob R and Hestness, Joel and Dey, Nolan},
title = {{SlimPajama: A 627B token cleaned and deduplicated version of RedPajama}},
month = June,
year = 2023,
howpublished = {\url{https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama}},
url = {https://huggingface.co/datasets/cerebras/SlimPajama-627B},
}
@misc{li2023memory,
title={Memory Efficient Optimizers with 4-bit States},
author={Bingrui Li and Jianfei Chen and Jun Zhu},
year={2023},
eprint={2309.01507},
archivePrefix={arXiv},
primaryClass={cs.LG}
url = {https://github.com/thu-ml/low-bit-optimizers}
}
@article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
journal={9th International Conference on Learning Representations, ICLR},
year={2022}
}