TinyViT: Fast Pretraining Distillation for Small Vision Transformers
Another interesting paper on training small models for computer vision, specifically vision transformers (ViT). The computer vision aspect of this paper isn’t as interesting as the training process they outline: An efficient knowledge distillation pipeline.
A primer on knowledge distillation
Smaller models while more hardware efficient usually don’t match the performance of their larger counterparts in most tasks, but knowledge distillation provides a way to close the gap.
It is a technique that adds an extra component to the loss term that pushes the smaller model to act similar to the larger model. Specifically for cross-entropy loss functions1 where we push the logits of the smaller to converge onto a class, we additionally try to push the same logits to be closer in distribution to the larger model’s logits (for the same inputs).
Here’s some pseudocode to help understand:
larger_model = LargerModel()
smaller_model = SmallerModel()
inputs, labels = Data()
smaller_model_logits = smaller_model(inputs)
larger_model_logits = larger_model(inputs)
cross_entropy_loss = CrossEntropy(smaller_model_logits, labels)
kl_div_loss = KLDivergence(smaller_model_logits, larger_model_logits)
total_loss = cross_entropy_loss + kl_div_loss
Why do we need the larger model logits at all, we already have the ground truth data?!!
Consider this
Model Inputs: I had a good time last night. The James Bond movie was <MASK>
Ground Truth: I had a good time last night. The James Bond movie was amazing
A model would be trained such that it should predict “amazing” for the mask token. Everything other than this token would be penalized. But in reality a lot of tokens can fill in for that spot like “good”, “great” or “thrilling”.
Models produce logits / probabilities for each possible token in the mask position. A larger model would likely produce the expected token from the ground truth
Small model logits:
amazing: 0.55
good: 0.03
great: 0.02
thrilling: 0.01
<others>: 0.25
Large model logits:
amazing: 0.75
good: 0.1
great: 0.1
thrilling: 0.03
<others>: 0.02
We can see that the large model not only has a higher probability for the correct token, but a better distribution over plausible alternatives too. The hypothesis is that the larger model’s ability to generalize is represented in these logit distributions, which can be distilled into a smaller model. This is where the KL penalty term comes in, which makes the logit distributions of the smaller model be similar to that of the larger model.
Training setup
In a typical setup of knowledge distillation training, both the larger model (aka teacher) and the smaller model (aka student) are both kept in memory.
This is pretty inefficient because the teacher model takes up more memory and time to generate the logits compared to the smaller model for every set of inputs. We also don’t backprop through the larger model thus exacerbating the inefficiency.
We’re only interested in optimizing the student model, so we really just need the large models logits for every corresponding input. The first optimization we could do is generate the teacher logits offline and save it to disk, removing the need keep both models in memory during the knowledge distillation. The downside to this is that with large training data coupled with large number of target classes, disk size grows. For vision specific training, data augmentation techniques adds another multiplier to the disk storage requirements as the teacher model would produce a different set of logits from the ground truth.
TinyViT
The second optimization (as suggested in TinyViT) is to storge only top-k logits and their indices from the teacher model instead of the full set. During student model KL-penalty, we reconstruct the full (approximate) teacher logits using the indices and logits we saved and setting the remaining logits to an appropriate value2. The benefit is clear in the amount of disk space stored, and the paper shows that utilizing the top-k logits has a modest impact on the student model quality.
This amortizes the costs of training multiple student models allowing for cheaper hyperparameter optimization or designing multiple student models for different hardware targets like mobile, servers or IoT devices.
Knowledge distillation can be applied for regression problems too, but not as common.
Setting the remaining logits is important so that we don’t change the relative importance of the saved logits to the ones we didn’t save