
A few weeks ago, Facebook released a new ML model (Data-Efficient Image Transformer, DeIt) that achieves state-of-art image classification performance using only the ImageNet dataset (1.2 million images). State of the art visual transformers can only reach this performance using hundreds of millions of images [1]. And how Facebook achieved this is the most interesting bit since they didn’t use any convolutions or a large dataset.
There are a lot of great machine learning papers that are being released almost daily. The reason that I chose this one to review is that it’s using some interesting techniques.
One of those techniques is attention and transformers which I don’t want to cover thoroughly since there are tons of other articles about them. I am however going to give a quick overview just so we can explore DeIt properly.
Visual Transformers
Transformers and attention have been dominating the machine learning space for the last few years. They started in NLP and now they are moving to images.
Visual transformers use Multi-head Self Attention layers. Those layers are based on the attention mechanism that utilizes queries, keys, and vectors to “pay attention” to information from different representations at different positions.
A classic transformer block for images starts with a normal Feed Forward Network followed by a Multi-head Self Attention layer. One interesting bit is that the feed-forward network used an activation function called Gaussian Error Linear Unit which aims to regularize the model by randomly multiplying a few activations by 0.
The visual transformer has some issues that were solved by this paper, such as:
- It was trained on 300 million images (JFT-300M [1])
- Those 300 million images are a private dataset
- It couldn’t generalize well.
Okay now that we have covered the basics, let’s start taking a look at what is special about this paper.
The newly introduced trick: the distillation token. What is distillation?
Knowledge distillation refers to the idea of model compression by teaching a smaller network, step by step, exactly what to do using a bigger already trained network. The ‘soft labels’ refer to the output feature maps by the bigger network after every convolution layer. The smaller network is then trained to learn the exact behavior of the bigger network by trying to replicate it’s outputs at every level (not just the final loss).
Source: Prakhar Ganesh
This is quite fascinating, just like in the real world we have teachers, in ML we have bigger smaller networks mimicking larger networks to learn from them.
Typical visual transformers use the concept of a trainable vector called the class token. This token attempts to replace conventional pooling layers that can be found in Convolutional Neural Networks. It boosts the model’s performance and spreads out the information from image patches.
Facebook adds a distillation token that interacts with this class token and other initial embeddings at the start to boost the self-attention mechanism of the model. This token is a trainable vector that is being learned during training.
Its objective is to minimize the Kullback-Leibler (KL) divergence between the softmax of the teacher and the softmax of the student model, (this is called soft distillation). All you need to know about the KL divergence is that it measures the difference between 2 distributions.
So essentially, this distillation token tries to minimize the difference in the information of the student network and the teacher network. This is quite an impressive and novel strategy!
They have also verified [1] that the usefulness of this new token by attempting to add a class token (instead of the distillation token). The result was worse performance.
Note that the teacher network here is a Convolutional Neural Network.
Results
One of the best things about this paper is that Facebook has released the full code, dataset, paper, and pretty much everything. They released 3 different models of different sizes. And as you can see from the graph, they all perform quite well even compared to one of the best and most recent networks, EfficientNet.
In summary, I think these are the 3 main tricks to Facebook’s achievement:
- The power of visual transformers and attention
- Replacing word embeddings with patch embeddings through a distillation token
- Not relying on convolutions
Final thoughts:
There is no such thing as a perfect model, I am sure this model has a few flaws. However, it’s quite interesting to see what the top AI researchers are doing. I hope you got the intuition behind the distillation token trick so that you can be inventing your own tricks in your ML projects!
I didn’t want to dive into mathematics (although I love math) so that the article would suit a larger audience. If you are interested in that and in checking out more of their results, I suggest taking a look at the paper.
References:
[1] Training data-efficient image transformers & distillation through attention. Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou. 2021 In arxiv