--- license: apache-2.0 datasets: - EleutherAI/pile language: - en --- # Warm-Starting Knowledge Distillation for Transformer-based Language Models *by GPT-4 & Crumb* ### Introduction Transformer models have become a popular choice for natural language processing (NLP) tasks due to their ability to handle long-range dependencies and their superior performance on various NLP benchmarks. The transformer model architecture was introduced in 2017 by [Vaswani et al](https://arxiv.org/abs/1706.03762). and has since been used in many state-of-the-art models such as BERT and GPT. The decoder-only transformer model is a variant of the transformer model that has is commonly used for generative tasks in NLP. It uses masked self-attention to predict the next token in a sequence and has been shown to be powerful at predicting sequences of text. Distillation \[[Bucila et al., 2006](https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf), [Hinton et al., 2015](https://arxiv.org/abs/1503.02531)\] is a technique used in machine learning to compress a large model into a smaller one that can be used on devices with limited computational resources. In this technique, a smaller model is trained to mimic the behavior of a larger model by learning from its predictions. The smaller model is trained on a smaller dataset than the larger model, which makes it faster and more efficient. This technique has been used to compress models like BERT and GPT-2 into smaller models like DistilBERT and DistilGPT-2, respectively. In this project we apply the technique of knowledge distillation to the second smallest [Pythia](https://arxiv.org/pdf/2304.01373.pdf) model on the [Pile](https://arxiv.org/abs/2101.00027) dataset. ### Method We follow the work of [Sanh et al. (2019)](https://arxiv.org/abs/1910.01108) and [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531) for a distillation loss over the soft target probabilities `L_ce`. We utilize the distillation loss in our loss function as a linear combination of the distillation loss `L_ce` with the supervised training loss `L_clm`. Our combined loss function is `L_ce*(1-a) + L_clm*a` where `a` is set to 0.5 and the `T`emperature parameter for the distillation loss is set to 2. In an effort to maximize VRAM utilization, to reach a combined batch size of 4096 samples we use a device batch size of 2 with 2048 gradient accumulation steps and a context length of 2048 tokens with both the teacher and student model in bf16 precision. This allowed us to utilize around 98.94% of the 12 gigabytes of VRAM that the RTX3060 card has during training. It also means our training set totals to approximately 537 million training tokens, as our model trained for 64 steps. All training samples were taken from [The Pile](https://arxiv.org/abs/2101.00027). A learning rate of 1e-4 was used in this study, with no learning rate schedule. ### Evaluation [Sanh et al. (2019)](https://arxiv.org/abs/1910.01108) suggests a student around 40% of the size of it's teacher can achieve similar performance in encoder models when training from scratch with suprivision. We warm-start our model from a smaller checkpoint than the teacher that maintains a similar ratio with a student that is 43.75% the size of it's teacher. | model | piqa acc | winogrande acc | lambada ppl | lambada acc | arc acc | sciq acc | wsc acc | notes | | --- | --- | --- | --- | --- | --- | --- | --- | --- | | pythia-70m (student base) | 59.85 | 51.22 | 140.81 | 21.40 | 17.15 | 65.00 | 36.53 | | pythia-160m (teacher) | 62.68 | 51.07 | 30.03 | 36.76 | 19.62 | 76.20 | 36.58 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | | distilpythia (student) | 59.74 | **51.62** | 420.70 | 15.82 | **17.15** | 61.30 | **36.54** | trained on padded/truncated examples | distilpythia-cl (student) | 59.30 | 50.75 | 403.78 | 15.16 | 16.98 | 59.20 | **36.54** | trained on a constant-length dataset