Achieving a 4x pretraining speedup over a strongly tuned T5-XXL baseline (Google’s top transformer).
Transformers have been achieving a huge amount of success in machine learning, specifically NLP (and now image processing too). They are one of the hottest topics currently and it makes sense that Google would try to improve them.
A few days ago Google released a new huge paper that proposes a new method to significantly boost the number of parameters while maintaining the number of Floating-point operations per second (the ML computational cost standard metric).
It’s well known that increasing the number of parameters increases the model’s complexity and its ability to learn (up to a certain point of course). And as expected the model gains 4x improvement over T5-XXL and 7x improvement over T5-Base and T5-Large.
Their paper is much bigger than most papers (around 30 pages) and so I will be highlighting the most significant details in order to make this article concise and to the point.
Pre-requisites
- Mixture of Experts (MoE) algorithm
Mixture of experts refers to a machine learning technique where multiple experts (learners) are used to divide the problem space into homogeneous regions. … If the output is conditioned on multiple levels of probabilistic gating functions, the mixture is called a hierarchical mixture of experts.
Source: Wikipedia
One of the most interesting things about MoE is that while typical Deep learning networks use the same parameters for all inputs, MoE adjusts the network to use different parameters for different inputs. This makes the network more versatile.
And if you think about it, this means that the model will have a huge (sparse) number of parameters, however, not all of them will always be used at the same time, which is the essence of this paper.
MoE has been introduced much before this paper, however, it has had several training instabilities and computational cost issues that have been tackled by the paper.
2. Distillation
Distillation is essentially “teaching” in ML. It is when a smaller network mimics a larger and better network to learn from it without going through the long training process. This is a bit different than transfer learning and offers significant performance boosts.
In fact, Facebook’s recent state-of-art DeIt paper that classifies images with 0 convolutions and less than 1% of the state-of-art dataset is reliant on distillation. If you are interested in finding more about the distillation token trick, check out my article here:
3. Model and data sharding (parallelism)
When the models and data are extremely huge, you have to start splitting them across multiple cores. This is often challenging but because the models are mainly sparse here (not all of the parameters are always used), this is easy. Actually, the main selling point is simple and efficient sparsity!
Paper highlights (how does it work briefly)
One of the most enjoyable things about this paper is their use of an “engineering mindset”. When working with a large amount of computational power and model parameters, you have to be smart. This is why the first highlight is how the tokens are routed to the correct expert (MoE) after attention.
The model starts off with a classic self-attention block (the essence of transformers). The attention part sort of aggregates information and relates the individual items of the sequence to each other and transforms them into a new sequence where tokens can gather information from every other token. [2]
After that, there is a feed-forward network (FFN) where every token is isolated. The FFN’s job is to determine the best representation of each token for the next group of layers. So basically the attention is kind of relating tokens to each other and the feedforward layers are relating layers to each other [2]. You can think of the FFN as a middle man translating between 2 entities. On one side there is a token that needs to be processed and on the other side, there is a group of experts.
Now here is the interesting bit, all of what I have described above is the MoE transformer, we still haven’t got to the Switch Transformers!
The routing trick
The Switch Transformers introduce a “Switch” layer before the FFN layer that essentially pushes the FFNs to become the experts and that Switch matches each token to the correct FFN / expert. It essentially routes the correct token to the most suitable “expert”.
For example, an FFN layer might be an expert in processing nouns, another one would specialize in processing verbs, punctuation, etc… [2].
They call this concept Switch Routing and it’s essentially an upgrade on the Mixture of Experts. Previous authors of the MoE hypothesized that this wouldn’t be possible because of the computational cost demand. Google introduces a quite novel workaround in this area.
We instead use a simplified strategy where we route to only a single expert. We show this simplification preserves model quality, reduces routing computation and performs better. This k = 1 routing strategy is later referred to as a Switch layer. The benefits for the Switch layer are three-fold:
(1) The router computation is reduced as we are only routing a token to a single expert.
(2) The batch size (expert capacity) of each expert can be at least halved since each token is only being routed to a single expert.
(3) The routing implementation is simplified and communication costs are reduced.
Source: Switch Transformers paper
We all know that ML relies on floating-point operations. And if you were to deploy a large distributed model, you must be sending over a lot of float numbers. Float numbers come in mainly 2 sizes 16bit and 32bit, if you send only send 16bit numbers you will not be able to do standard ML calculations and you cant send 32bits (which is required for Switch routing) because of computational constraints.
So what did they do..? They introduce a selective precision technique where they send 16bit Floats to the models and they selectively upscale the required ones to 32bit floats to do the required operations and then they downscale those floats to 16bits again. A simple solution to a difficult problem!
They also optimize those processes through the notion of a capacity factor where each expert only processes a certain number of tokens.
Furthermore, to alleviate some of the deployment issues (since these models are huge), they use distillation techniques based on BERT models.
And the result is a 3–30% improvement in performance without an increase in the number of parameters. This proves the magic of distillation!
Final thoughts
There seems to be a lot of work that has been done in this paper, and I haven’t covered it all! There is more about distributing the models and the data in the paper, but my aim was to demonstrate the highlights. It’s always great to see innovation and I think the best parts are where they use engineering and computer systems techniques to solve ML problems (such as routing and selective precision). This shows that ML is not only Maths and Statistics, but is also Computer Science….
If you are interested in reading more about other novel papers, check out my articles here:
References:
[1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. William Fedus and Barret Zoph and Noam Shazeer. 2021. In arxiv
[2] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Yannic Kilcher on youtube