Authors:
(1) Soham De, Google DeepMind and with Equal contributions;
(2) Samuel L. Smith, Google DeepMind and with Equal contributions;
(3) Anushan Fernando, Google DeepMind and with Equal contributions;
(4) Aleksandar Botev, Google DeepMind and with Equal contributions;
(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;
(6) Albert Gu, Work done while at Google DeepMind;
(7) Ruba Haroun, Google DeepMind;
(8) Leonard Berrada, Google DeepMind;
(9) Yutian Chen, Google DeepMind;
(10) Srivatsan Srinivasan, Google DeepMind;
(11) Guillaume Desjardins, Google DeepMind;
(12) Arnaud Doucet, Google DeepMind;
(13) David Budden, Google DeepMind;
(14) Yee Whye Teh, Google DeepMind;
(15) David Budden, Google DeepMind;
(16) Razvan Pascanu, Google DeepMind;
(17) Nando De Freitas, Google DeepMind;
(18) Caglar Gulcehre, Google DeepMind.
Table of Links
1 Introduction
2 Model Architecture
3 Recurrent Models Scale as Efficiently as Transformers
3.1. Scaling curves
3.2. Evaluation on downstream tasks
4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training
4.2. Efficient linear recurrences on device
4.3. Training speed on longer sequences
5. Inference Speed
5.1. A simple model of the decode step
5.2. Results
6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts
6.2. Copy and retrieval capabilities
7. Related Works
8. Conclusion, Acknowledgements, and References
A. RG-LRU Recurrence Gate
B. Complex-Gated Linear Recurrent Unit (CG-LRU)
C. Model Scale Hyper-Parameters
D. Efficient Linear Recurrences on Device
E. The Local Attention Window Size of Griffin
F. Inference Speeds
G. Improving Next Token Prediction with Longer Contexts: Additional Results
H. Additional Details of the Copy and Retrieval Tasks
5. Inference Speed
Inference in LLMs is composed of two stages. In the “prefill” stage, we receive and process the prompt. This step is effectively performing a forward pass of the model. Since the prompt can be processed in parallel across the sequence, most model operations are compute bound during this stage. We therefore expect the relative speeds of Transformers and recurrent models during the prefill stage to be similar to the relative speeds of the same models during training, which we discussed in Section 4
Prefillis followed by a “decode” stage, in which we sample tokens auto-regressively from themodel. As we show below, recurrent models have lower latency and higher throughput during the decoding stage, especially for longer sequence lengths where the key-value (KV) cache used in attention can get large.
There are two main metrics to consider when evaluating inference speed. The first is latency, which measures the time taken to generate a specified number of tokens at a certain batch size. The second is throughput, which measures the largest number of tokens per second that can be generated on a single device when sampling a specified number of tokens. Since throughput is given by tokens sampled times batch size divided by latency, one can improve throughput either by reducing the latency or by reducing memory usage to enable the use of larger batch sizes on device. Latency can be useful to consider for real-time applications that require a quick response time. Throughput is also useful to consider as it can tell us the maximum number of tokens we could sample from a particular model in a given time. This property is useful when considering other language applications such as Reinforcement Learning from Human Feedback (RLHF) or scoring language model outputs such as done in AlphaCode (Li et al., 2022) where being able to output a large number of tokens in a given time is an appealing feature.