Knowledge Distillation with Vision Transformers
We are going to learn about Knowledge Distillation, the method behind distilGPT and distilbert, two of the most downloaded models on the Hugging Face Hub!
Presumably, we’ve all had teachers who “teach” by simply providing us the correct answers and then testing us on questions we haven’t seen before, analogous to supervised learning of machine learning models where we provide a labeled dataset to train on. Instead of having a model train on labels, however, we can pursue Knowledge Distillation as an alternative to arriving at a much smaller model that can perform comparably to the larger model and much faster to boot.
Intuition Behind Knowledge Distillation
Imagine you were given this multiple-choice question:
If you had someone just tell you, “The answer is Draco Malfoy,” that doesn’t teach you a whole lot about each of the characters’ relative relationships with Harry Potter.
On the other hand, if someone tells you, “I am very confident it is not Ron Weasley, I am somewhat confident it is not Neville Longbottom, and I am very confident that it is Draco Malfoy”, this gives you some information about these characters’ relationships to Harry Potter! This is precisely the kind of information that gets passed down to our student model under the Knowledge Distillation paradigm.
Distilling the Knowledge in a Neural Network
In the paper Distilling the Knowledge in a Neural Network, Hinton et al. introduced the training methodology known as knowledge distillation, taking inspiration from insects, of all things. Just as insects transition from larval to adult forms that are optimized for different tasks, large-scale machine learning models can initially be cumbersome, like larvae, for extracting structure from data but can distill their knowledge into smaller, more efficient models for deployment.
The essence of Knowledge Distillation is using the predicted logits from a teacher network to pass information to a smaller, more efficient student model. We do this by re-writing the loss function to contain a distillation loss, which encourages the student model’s distribution over the output space to approximate the teacher’s.
The distillation loss is formulated as:
The KL loss refers to the Kullback-Leibler Divergence between the teacher and the student’s output distributions. The overall loss for the student model is then formulated as the sum of this distillation loss with the standard cross-entropy loss over the ground-truth labels.
To see this loss function implemented in Python and a fully worked out example in Python, let’s check out the notebook for this section.
< > Update on GitHub