gpt2
Edit model card

gpt2-medium-4096 is a 380M parameter transformer model based on GPT2. Trained from scratch on 13.5 billion tokens from a subset (80%) of SlimPajama-627B.

This model is meant to be a basis for further experiments, particularly fine-tuning on phi-style data and iterative (daily) training. It is possible to fine-tune this model on a recent nVidia GPU w/ 12 GB of RAM.

Parameters were chosen by:

  • Started w/ gpt2-medium's parameters
  • Extended context length to 4096, the largest size I could fit in VRAM on a 12 GiB GPU. Fitting this model in VRAM on a 12 GiB GPU requires batch_size=1 and 4-bit (AdamW) or 8-bit (Paged AdamW) optimizers.
  • Raised n_layer slightly to use remaining free VRAM.
gpt2-medium gpt2-medium-4096 gpt2-large
n_layer 24 26 36
n_head 16 16 20
n_embed 1024 1024 1280
n_ctx 1024 4096 1024

Trained on an RTX 3060 12 GB locally and on Portland State University's Coeus cluster on 2 x RTX A5000 using DDP. Training took five days on the cluster followed by YY days on the RTX 3060.

Training on the RTX 3060 was with batch_size=1, 4-bit AdamW. Training on Coeus (2 x RTX A5000) used batch_size=2, full AdamW optimizer. Trained in 'float16' rather than 'bfloat16'. Learning rate ramped up 6e-5 to 4e-4 over the first 3000 iterations (786M tokens) and stayed at 4e-4 for the next 11.7B tokens (w/ a very slight cooling, cosine falloff). Then LR was dropped to a constant 9e-5, for the next 1 B tokens. The first 12.5 B tokens were from a 50% subset of SlimPajama-627B, the next YY B were from a different 30% subset. The optimizer was switched to 8-bit AdamW as well.

train_val


Evaluations

gpt2-medium gpt2-medium-4096
hellaswag 0.3327 0.3095

Evaluation curve

Iters val loss hellaswag
5500 3.2508 0.2698
16100 2.7633 0.2856
19700 2.7520 0.2891
28200 2.7155 0.2917
29900 2.6846 0.2922
31000 2.6607 0.2949
36900 2.6366 0.2965
47900 2.6818 0.2992
49000 2.6967 0.3058
50800 2.4078 0.3079
51650 2.4898 0.3095

@misc{cerebras2023slimpajama,
  author = {Soboleva, Daria and Al-Khateeb, Faisal and Myers, Robert and Steeves, Jacob R and Hestness, Joel and Dey, Nolan},
  title = {{SlimPajama: A 627B token cleaned and deduplicated version of RedPajama}},
  month = June,
  year = 2023,
  howpublished = {\url{https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama}},
  url = {https://huggingface.co/datasets/cerebras/SlimPajama-627B},
}
@misc{li2023memory,
  title={Memory Efficient Optimizers with 4-bit States}, 
  author={Bingrui Li and Jianfei Chen and Jun Zhu},
  year={2023},
  eprint={2309.01507},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
  url = {https://github.com/thu-ml/low-bit-optimizers}
}
@article{dettmers2022optimizers,
  title={8-bit Optimizers via Block-wise Quantization},
  author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
  journal={9th International Conference on Learning Representations, ICLR},
  year={2022}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Unable to determine this model's library. Check the docs .

Dataset used to train venketh/ugpt2-medium-4096