Authors:
(1) David Raposo, Google DeepMind and with equal contribution;
(2) Sam Ritter, Google DeepMind;
(3) Blake Richards, Google DeepMind and McGill University & Mila;
(4) Timothy Lillicrap, Google DeepMind;
(5) Peter Conway Humphreys, Google DeepMind;
(6) Adam Santoro, Google DeepMind and with equal contribution.
Editor’s note: this is part 4 of 5 of a study detailing a way to make transformer-based language models more efficient by dynamically allocating computational resources. Read the rest below.
Table of Links
- Introduction
- Background
- Implementing Mixture-of-Depths Transformers
-
3.1. Defining a compute budget
-
3.2. Routing around transformer blocks
-
3.3. Routing schemes
-
3.4. Routing implementation
-
3.5. Sampling and 3.6. Training methods
-
- Results
- 4.1. Training, isoFLOP comparisons
- 4.2. Auto-regressive Evaluation and 4.3. Mixture-of-Depths-and-Experts (MoDE)
- Discussion and References
4. Results
4.1. Training, isoFLOP comparisons
We first trained models with a relatively small FLOP budget (6e18) to determine optimal hyperparameters (see figure 3). In general, we found that MoD transformers drag the baseline isoFLOP curve “down and to the right”. That is, the optimal MoD transformer achieves a lower loss than the optimal baseline, and also has more parameters. A fortunate consequence of this effect is that there exist smaller MoD models that, while they are not themselves isoFLOP optimal for their hyperparameter setting, are nevertheless as- or better-performing than the optimal baseline model while being faster to step. For example, a 220M parameter MoD (figure 3 model #3) variant slightly outperforms the isoFLOP optimal baseline (also 220M, figure 3 model #1), but is upwards of 60% faster to step during training. Crucially, when run on equivalent hardware these two model variants take take approximately the same amount of wall-clock time to train (figure 3).
We tested routing every block or every other block, using capacities from 12.5% to 95% of the total sequence. While routing every other block was crucial for strong performance, we found that aggressive capacity reduction was best (gradual improvements were observed when reducing the capacity down to 12.5% of the total sequence, corresponding to 87.5% of tokens routing around blocks, with performance degrading beyond this point). So, it seems the network is robust to significant capacity reductions as long as there is frequent opportunity for full capacity self-attention and MLP computations.
Learned routing is crucial, as MoD transformers that use stochastic routing (implemented using a top-𝑘 operation on router weights sampled from a Gaussian distribution) perform drastically worse than both the baseline and normal MoD transformer (figure 3).
Depicted in figure 4 is an isoFLOP analysis for 6e18, 2e19, and 1e20 total FLOPs. The trend that FLOP-optimal MoD transformers have more parameters than the baseline continues for these larger FLOP budgets. Notably, there exist MoD variants that are appreciably faster to step than the isoFLOP-optimal baseline (measured as steps per second when training on equivalent hardware) while also achieving a lower loss (in figure 4 we depict normalized FLOPs per forward pass rather than wall-clock step time per se, but from our experiments the two are tightly correlated. A similar plot can be produced showing relative wall-clock step times and the same basic trend is present).
Step-wise speed gains come from two sources. First, the FLOP-per-parameter ratio in MoD transformers is less than in the baselines because some proportion of tokens are routed around blocks. So, for a given model size, a transformer requires fewer FLOPs per forward pass. Second, since isoFLOP-optimal MoD transformers are both bigger and achieve a lower loss than the isoFLOP-optimal baseline, there exist smaller MoD variants that perform as well or better than the isoFLOP-optimal baseline, and these variants are faster to step because they are smaller. Altogether, then, there exist MoD transformers that perform as well as isoFLOP-optimal baselines and are faster to step, both because they use fewer FLOPs per parameter and because they use fewer parameters.
Figure 4 also reveals another important finding: the optimal MoD transformer is that which uses as many FLOPs per forward pass as the isoFLOP optimal baseline. This finding allows one to directly predict which sized MoD transformer will perform optimally for a given isoFLOP training budget: one just needs to tune the model size for a given MoD configuration (i.e., capacity and routing frequency) to produce a model that uses as many FLOPs per forward pass as the isoFLOP-optimal baseline, and they will have the optimally performing MoD variant for that configuration. Empirically, we find that it is better to add depth than to add width when adding FLOPs to the model.
Nevertheless, while the FLOPs per forward pass determines which model will be the isoFLOP optimal, it does not predict whether the optimal loss will improve upon the baseline (see figure 3. Namely, the optimal capacity appears to be empirically determinable. We found that it is best to use 12.5% capacity blocks, every other block.
We noticed that MoD transformers had memory savings relative to equivalently sized baseline models at larger sizes, with some variants requiring fewer total devices (i.e., a smaller TPU topology). We did not study this extensively, but we anticipate that as one scales to larger models, these savings could be an important consideration when choosing model variants to train, and could have significant positive effects in regards to the KV cache size during autoregressive sampling.
Figure 5 shows the routing decisions for an MoD transformer trained with interleaved routing blocks. Despite aggressive routing around the blocks, transformers are able to achieve performance improvements relative to baselines. We observe patterns that might warrant further study; namely, some tokens appear to engage each block along the transformer’s depth, while others decide to route around blocks whenever possible. Preliminary analyses suggest that the tokens that engage with blocks more frequently are correlated with output predictions that have higher entropy, which possibly corresponds to predictions that are more difficult to make.
4.2. Auto-regressive Evaluation
We evaluated MoD variants during auto-regressive sampling (see figure 6). Each model was tested on exactly the same held-out data comprising 256000 sequences (500M tokens). When switching from the top-𝑘 routing method to the predictor-based routing method we observed little performance degradation. As in the training setting, there exist MoD variants that are better performing than the isoFLOP-optimal baseline, while requiring fewer FLOPs per forward pass. These results suggest that the compute savings offered by MoD transformers should translate beyond the training setting.
4.3. Mixture-of-Depths-and-Experts (MoDE)
The MoD technique can be naturally integrated with MoE models (together comprising MoDE models) in addition to vanilla transformers. In figure 7 we present results showing that the performance improvments offered by MoD compound with those of MoE. We tried two variants: in staged MoDE, which routes tokens around or towards blocks prior to the self-attention step, and integrated MoDE, which implements MoD routing by integrating “no-op” experts among the conventional MLP experts. The former is advantageous because it allows for tokens to skip the self-attention step, while the latter is advantageous because it simplifies the routing machinery. We noticed that implementing MoDE in the integrated manner was distinctly better than simply reducing the capacity of experts in conventional MoE models, and relying on token dropping to implement residual routing. We believe this is because with the integrated MoDE machinery, tokens explicitly learn to choose the residual path around the experts, as opposed to preferring an expert but being dropped when implemented as a capacity reduction.