Authors:
(1) Soham De, Google DeepMind and with Equal contributions;
(2) Samuel L. Smith, Google DeepMind and with Equal contributions;
(3) Anushan Fernando, Google DeepMind and with Equal contributions;
(4) Aleksandar Botev, Google DeepMind and with Equal contributions;
(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;
(6) Albert Gu, Work done while at Google DeepMind;
(7) Ruba Haroun, Google DeepMind;
(8) Leonard Berrada, Google DeepMind;
(9) Yutian Chen, Google DeepMind;
(10) Srivatsan Srinivasan, Google DeepMind;
(11) Guillaume Desjardins, Google DeepMind;
(12) Arnaud Doucet, Google DeepMind;
(13) David Budden, Google DeepMind;
(14) Yee Whye Teh, Google DeepMind;
(15) David Budden, Google DeepMind;
(16) Razvan Pascanu, Google DeepMind;
(17) Nando De Freitas, Google DeepMind;
(18) Caglar Gulcehre, Google DeepMind.
Table of Links
1 Introduction
2 Model Architecture
3 Recurrent Models Scale as Efficiently as Transformers
3.1. Scaling curves
3.2. Evaluation on downstream tasks
4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training
4.2. Efficient linear recurrences on device
4.3. Training speed on longer sequences
5. Inference Speed
5.1. A simple model of the decode step
5.2. Results
6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts
6.2. Copy and retrieval capabilities
7. Related Works
8. Conclusion, Acknowledgements, and References
A. RG-LRU Recurrence Gate
B. Complex-Gated Linear Recurrent Unit (CG-LRU)
C. Model Scale Hyper-Parameters
D. Efficient Linear Recurrences on Device
E. The Local Attention Window Size of Griffin
F. Inference Speeds
G. Improving Next Token Prediction with Longer Contexts: Additional Results
H. Additional Details of the Copy and Retrieval Tasks
5.2. Results
Here, we look at inference results for models of size 1B parameters. For our baseline, we compare against a MQA Transformer, which is significantly faster during inference than the standard MHA Transformer often used in the literature. The models that we compare are: i) MQA Transformer, ii) Hawk, and iii) Griffin. For comparing different models we report both latency and throughput.
Latency We compare the latency for models with a batch size of 16 with an empty prefill as well as a prefill of 4096 tokens as seen in Figure 4. Hawk and Griffin achieve faster sampling latency than MQA Transformers for long sequences. This is particularly noticeable as the sequence length and the prefill length (which affect the size of the KV cache) are increased. Griffin achieves similar latency to Hawk, demonstrating the excellent compatibility of linear recurrences and local attention.
Throughput We compare the maximum throughput (tokens/s) for the same models when sampling 512, 1024, 2048 and 4196 tokens following an empty prompt in Figure 1(b). We see that both Griffin and Hawk achieve significantly higher throughput than the MQA Transformer baseline. This is partially due to recurrent models having lower latency but also primarily occurs because Griffin and Hawk can fit larger batch sizes than the MQA Transformer on a single device, since their cache size is smaller. Hawk achieves higher throughputs than Griffin, since the size of the local attention cache eventually becomes comparable to the size of the parameters when the batch size is large.