Authors:
(1) David Raposo, Google DeepMind and with equal contribution;
(2) Sam Ritter, Google DeepMind;
(3) Blake Richards, Google DeepMind and McGill University & Mila;
(4) Timothy Lillicrap, Google DeepMind;
(5) Peter Conway Humphreys, Google DeepMind;
(6) Adam Santoro, Google DeepMind and with equal contribution.
Editor’s note: this is part 5 of 5 of a study detailing a way to make transformer-based language models more efficient by dynamically allocating computational resources. Read the rest below.
Table of Links
- Introduction
- Background
- Implementing Mixture-of-Depths Transformers
-
3.1. Defining a compute budget
-
3.2. Routing around transformer blocks
-
3.3. Routing schemes
-
3.4. Routing implementation
-
3.5. Sampling and 3.6. Training methods
-
- Results
- 4.1. Training, isoFLOP comparisons
- 4.2. Auto-regressive Evaluation and 4.3. Mixture-of-Depths-and-Experts (MoDE)
- Discussion and References
5. Discussion
Mixture-of-Depths transformers empirically demonstrate that one can improve on isoFLOP-optimal baseline performance with models that use fewer FLOPs per forward pass. This means that—for a given training FLOP budget—we can train models that are both faster and better performing than their baseline counterparts. Previously, to train models that are both faster and as- or better-performing than isoFLOP-optimal models, one would have to use surplus compute to overtrain smaller models (notably, this overtraining technique is still possible with MoD transformers, and speed gains should compound).
While MoD transformers require fewer FLOPs per forward pass, one cannot forego FLOPs indiscriminately. Rather, it is crucial to use learned routing decisions—much like in Mixture-of-Experts transformers—to determine whether a token should participate in self-attention and the subsequent MLP (requiring FLOPs), or not (saving FLOPs).We can then use any saved FLOPs by, for example, making the model bigger or training it for longer. Our results show that indeed FLOPs may be inefficiently used in vanilla transformer models, and that there may be more efficient ways for them to be expended.
Learned routing mechanisms are sometimes non-causal; that is, information about the future is used to determine a given token’s routing decision. This is generally true for top-k routing mechanisms, which are useful because they forego the need for auxiliary balancing losses. However, top-k routing mechanisms present difficulties in post-training autoregressive sampling, where it is impossible to use information about future token identities to determine routing decisions. In this work we show that one can successfully use a top-k routing scheme during training, but not require it during later autoregressive sampling. Eiher a simple auxiliary classifier, or auxiliary loss on the router, is sufficient to learn the top-𝑘 routing decisions such that it can mimic the top-𝑘 decisions during autoregressive
sampling, with minimal to no performance degradation.
Intuitively, a token might learn to route around blocks because the prediction being made at that step is easier, and hence, does not require as much compute. However, this strategy is undoubtedly not all that the network learns. If a token does not participate in self-attention at a certain block, then later tokens will also not be able to attend to it. Thus, whether tokens decide to route or not impacts both the current step’s prediction and future predictions via causal self-attention, and how the network balances these effects is guided by their influence on the overall language modeling objective.
This insight opens the door to MoD variants that decouple the routing for queries, keys and values. For example, perhaps a token would prefer to be among the queries, but not the keys, for a given self-attention computation. One can imagine extending this idea even further into the domain of “long-term memory”: perhaps there are tokens that would be extremely valuable as keys, regardless of whether it is useful for them to also be among the queries at the step of their occurrence. Learned routing could be a powerful mechanism for deciding which tokens these might be, perhaps funnelling them into a long-term memory buffer that is available during future self-attention. One advantage of such an approach to long-term memory is that tokens decide once, at the moment of “memory encoding”, whether they should be retrieved in the future. This is more computationally efficient than performing a full content-based lookup across an entire memory buffer for each step in the future, and could be one step towards drastically increasing the context-length available for making a prediction.
Unlike MoE transformers that route between effectively the same computation (usually MLPs), MoD transformers demonstrate the value of routing among different types of computations. In this work the types were either the conventional transformer block, or a null computation (functionally equivalent to multiplying by zero). However, one can imagine extending this idea further by routing between even more types of computation. For example, perhaps some tokens are routed to “memory lookup” functions, and others are routed to “tool use” functions. In general, the routing machinery we deployed provides a knob for adjusting the types of computations available to the network and their relative cost (in total FLOPs); if one wants to introduce an expensive computation, then this can be offset by setting its capacity to some small amount, and hence, by routing only a small number of tokens to it.
Altogether, MoD transformers are another tool one can use to tune a model’s compute per forward pass (and hence inference time). The machinery used to implement MoD is also generic, and opens the doors to many extensions and integration with other techniques, such as MoE.
References
J. Ainslie, T. Lei, M. de Jong, S. Ontañón, S. Brahma, Y. Zemlyanskiy, D. Uthus, M. Guo, J. LeeThorp, Y. Tay, Y.-H. Sung, and S. Sanghai. Colt5: Faster long-range transformers with conditional computation, 2023.
A. Bapna, N. Arivazhagan, and O. Firat. Controlling computation versus quality for neural sequence models. CoRR, abs/2002.07106, 2020. URL https://arxiv.org/abs/2002.07106.
E. Bengio, P.-L. Bacon, J. Pineau, and D. Precup. Conditional computation in neural networks for faster models, 2016.
Y. Bengio. Deep learning of representations: Looking forward, 2013.
Y. Bengio, N. Léonard, and A. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation, 2013.
D. Bolya, C.-Y. Fu, X. Dai, P. Zhang, C. Feichtenhofer, and J. Hoffman. Token merging: Your vit but faster, 2023.
K. Cho and Y. Bengio. Exponentially increasing the capacity-to-computation ratio for conditional computation in deep learning, 2014.
M. Dehghani, S. Gouws, O. Vinyals, J. Uszkoreit, and Ł. Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
M. Elbayad, J. Gu, E. Grave, and M. Auli. Depth-adaptive transformer. CoRR, abs/1910.10073, 2019. URL http://arxiv.org/abs/1910.10073.
W. Fedus, B. Zoph, and N. Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity, 2022.
A. Graves. Adaptive computation time for recurrent neural networks. CoRR, abs/1603.08983, 2016. URL http://arxiv.org/abs/1603.08983.
M. Guo, J. Ainslie, D. Uthus, S. Ontanon, J. Ni, Y.-H. Sung, and Y. Yang. Longt5: Efficient text-to-text transformer for long sequences, 2022.
M. Gupta and P. Agrawal. Compression of deep learning models for text: A survey, 2021.
J. He, C. Zhou, X. Ma, T. Berg-Kirkpatrick, and G. Neubig. Towards a unified view of parameter-efficient transfer learning. arXiv preprint arXiv:2110.04366, 2021.
Y. Jernite, E. Grave, A. Joulin, and T. Mikolov. Variable computation in recurrent neural networks, 2017.
T. Lei, J. Bai, S. Brahma, J. Ainslie, K. Lee, Y. Zhou, N. Du, V. Y. Zhao, Y. Wu, B. Li, Y. Zhang, and M.-W. Chang. Conditional adapters: Parameter-efficient transfer learning with fast inference, 2023.
D. Lepikhin, H. Lee, Y. Xu, D. Chen, O. Firat, Y. Huang, M. Krikun, N. Shazeer, and Z. Chen. Gshard: Scaling giant models with conditional computation and automatic sharding. arXiv preprint arXiv:2006.16668, 2020.
Z. Liu, Z. Xu, H.-J. Wang, T. Darrell, and E. Shelhamer. Anytime dense prediction with confidence adaptivity. arXiv preprint arXiv:2104.00749, 2021.
T. Schuster, A. Fisch, J. Gupta, M. Dehghani, D. Bahri, V. Q. Tran, Y. Tay, and D. Metzler. Confident adaptive language modeling, 2022.
N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. Le, G. Hinton, and J. Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538, 2017.
A. Simoulin and B. Crabbé. How many layers and why? An analysis of the model depth in transformers. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing: Student Research Workshop, pages 221–228, Online, Aug. 2021. Association for Computational Linguistics. doi: 10.18653/v1/ 2021.acl-srw.23. URL https://aclanthology.org/2021.acl-srw.23.
Y. Tay, M. Dehghani, D. Bahri, and D. Metzler. Efficient transformers: A survey. CoRR, abs/2009.06732, 2020. URL https://arxiv.org/abs/2009.06732.
X. Wang, F. Yu, Z. Dou, and J. E. Gonzalez. Skipnet: Learning dynamic routing in convolutional networks. CoRR, abs/1711.09485, 2017. URL http://arxiv.org/abs/1711.09485.
B. Zoph, I. Bello, S. Kumar, N. Du, Y. Huang, J. Dean, N. Shazeer, and W. Fedus. St-moe: Designing stable and transferable sparse expert models, 2022.