Edit model card

Warm-Starting Knowledge Distillation for Transformer-based Language Models

by GPT-4 & Crumb


Transformer models have become a popular choice for natural language processing (NLP) tasks due to their ability to handle long-range dependencies and their superior performance on various NLP benchmarks. The transformer model architecture was introduced in 2017 by Vaswani et al. and has since been used in many state-of-the-art models such as BERT and GPT. The decoder-only transformer model is a variant of the transformer model that has is commonly used for generative tasks in NLP. It uses masked self-attention to predict the next token in a sequence and has been shown to be powerful at predicting sequences of text.

Distillation [Bucila et al., 2006, Hinton et al., 2015] is a technique used in machine learning to compress a large model into a smaller one that can be used on devices with limited computational resources. In this technique, a smaller model is trained to mimic the behavior of a larger model by learning from its predictions. The smaller model is trained on a smaller dataset than the larger model, which makes it faster and more efficient. This technique has been used to compress models like BERT and GPT-2 into smaller models like DistilBERT and DistilGPT-2, respectively. In this project we apply the technique of knowledge distillation to the second smallest Pythia model on the Pile dataset.


We follow the work of Sanh et al. (2019) and Hinton et al. (2015) for a distillation loss over the soft target probabilities L_ce. We utilize the distillation loss in our loss function as a linear combination of the distillation loss L_ce with the supervised training loss L_clm. Our combined loss function is L_ce*(1-a) + L_clm*a where a is set to 0.5 and the Temperature parameter for the distillation loss is set to 2.

In an effort to maximize VRAM utilization, to reach a combined batch size of 4096 samples we use a device batch size of 2 with 2048 gradient accumulation steps and a context length of 2048 tokens with both the teacher and student model in bf16 precision. This allowed us to utilize around 98.94% of the 12 gigabytes of VRAM that the RTX3060 card has during training. It also means our training set totals to approximately 537 million training tokens, as our model trained for 64 steps. All training samples were taken from The Pile.

A learning rate of 1e-4 was used in this study, with no learning rate schedule.


Sanh et al. (2019) suggests a student around 40% of the size of it's teacher can achieve similar performance in encoder models when training from scratch with suprivision. We warm-start our model from a smaller checkpoint than the teacher that maintains a similar ratio with a student that is 43.75% the size of it's teacher.

model piqa acc winogrande acc lambada ppl lambada acc arc acc sciq acc wsc acc notes
pythia-70m (student base) 59.85 51.22 140.81 21.40 17.15 65.00 36.53
pythia-160m (teacher) 62.68 51.07 30.03 36.76 19.62 76.20 36.58
--- --- --- --- --- --- --- --- ---
distilpythia (student) 59.74 51.62 420.70 15.82 17.15 61.30 36.54 trained on padded/truncated examples
distilpythia-cl (student) 59.30 50.75 403.78 15.16 16.98 59.20 36.54 trained on a constant-length dataset
Table 1. The student before finetuning, teacher, and student after finetuning and their results on various benchmarks. Numbers in bold are where the student after finetuning matches or outperforms the student before finetuning.

The table provides a comparison of performance between the base student model (pythia-70m), the teacher model (pythia-160m), and the finetuned student model (distilpythia) across various benchmarks. The goal is to assess whether the distilpythia model can achieve similar or better performance than its base while being smaller in size.

From the table, we can observe the following:

  1. The pythia-160m (teacher) model outperforms pythia-70m (student base) in most benchmarks, except for Winogrande accuracy, where the student base has a slightly better performance (51.22% vs. 51.07%).

  2. The distilpythia (student) model, after finetuning, outperforms the pythia-70m (student base) on two benchmarks: Winogrande accuracy (51.62% vs. 51.22%) and WSC accuracy (36.54% vs. 36.53%). The improvements in these metrics indicate that the finetuning process may be effective in transferring knowledge from the teacher model to the student model.


it might have worked idk, maybe training from scratch or for longer would give more performance gains, also look at the lambada perplexity what happened there even

Downloads last month

Dataset used to train crumb/distilpythia-cl