Authors:
(1) David Raposo, Google DeepMind and with equal contribution;
(2) Sam Ritter, Google DeepMind;
(3) Blake Richards, Google DeepMind and McGill University & Mila;
(4) Timothy Lillicrap, Google DeepMind;
(5) Peter Conway Humphreys, Google DeepMind;
(6) Adam Santoro, Google DeepMind and with equal contribution.
Editor’s note: this is part 3 of 5 of a study detailing a way to make transformer-based language models more efficient by dynamically allocating computational resources. Read the rest below.
Table of Links
- Introduction
- Background
- Implementing Mixture-of-Depths Transformers
-
3.1. Defining a compute budget
-
3.2. Routing around transformer blocks
-
3.3. Routing schemes
-
3.4. Routing implementation
-
3.5. Sampling and 3.6. Training methods
-
- Results
- 4.1. Training, isoFLOP comparisons
- 4.2. Auto-regressive Evaluation and 4.3. Mixture-of-Depths-and-Experts (MoDE)
- Discussion and References
3. Implementing Mixture-of-Depths Transformers
Our high-level strategy is as follows:
• Set a static compute budget that is less than that of an equivalent vanilla transformer by limiting the number of tokens in a sequence that can participate in a block’s computations (i.e., selfattention and subsequent MLP). For example, while a vanilla transformer might permit all the tokens in a sequence to participate in self-attention, we might limit the number to 50% of the tokens in a sequence. See section 3.1.
• Use a per-block router to emit a scalar weight for each token, which expresses the router’s preference for that token to participate in a block’s computations or to route around it. See section 3.2.
• Identify the top-𝑘 scalar weights (per sequence, per block) to select those tokens that will participate in a block’s computations. Since precisely 𝑘 tokens will participate in the block’s computations, the computation graph and tensor sizes remain static throughout training; it is merely the tokens’ participation that is dynamic and context-sensitive, as determined by the router. See section 3.3.
We then discuss some complications when sampling post-training in section 3.5.
3.1. Defining a compute budget
To enforce a total compute budget per forward pass we leverage the notion of capacity, which defines the total number of tokens that comprise the input to a given computation (e.g., the tokens participating in self-attention, a given expert in MoE transformers, etc). For example, the self-attention and MLP in each vanilla transformer block have a capacity of 𝑇—the total number of tokens across the sequence and batch. MoE transformers, on the other hand, use a capacity less than 𝑇 per expert MLP so as to more evenly divide the total compute across each expert. But, since they use multiple experts per block, their total capacity is approximately equal to that of a vanilla transformer.
Generally, it is the token capacity that determines the total FLOPs for transformers that use conditional computation, rather than the outcomes of any routing decisions. This is because staticgraph implementations account for the worst-case scenarios decisions; e.g., a computation’s inputs will be padded to its capacity amount even if relatively few tokens actually end up routing to it, and/or tokens will be dropped from the computation if the capacity is exceeded.
We can achieve our goal of using a smaller compute budget per forward pass compared to a vanilla transformer by lowering the capacity of the computations. However, using a smaller compute budget haphazardly will result in a performance degradation. We hypothesize that certain tokens might not require as much processing as others, and these tokens can be identified through learning. Therefore, if the network learns to choose the right tokens to fill up its capacities, then it may preserve its performance. In the following we describe routing schemes that can be used for this purpose.
3.2. Routing around transformer blocks
We consider the setting whereby we route tokens to one of two computational paths: (1) self-attention and MLP blocks, and (2) a residual connection. The latter is computationally cheap, and results in a block output that is entirely determined by the value of its input. The former path is computationally expensive.
Intuitively, the total FLOPs per forward pass decreases (and the time to complete a forward pass decreases) in proportion to how aggressively we shrink the blocks’ capacities. However, downstream performance will also be affected by how aggressively we shrink the blocks capacities, and by the routing algorithm we implement.
At one extreme, if we leave each block’s capacity at 𝑇 and route every token to (rather than around) each block, then we recover a vanilla transformer. At the other extreme, if we set each block’s capacity to 0 and route all tokens around each block, then we’re left with a very fast model that doesn’t engage with the vast majority of the transformer’s parameters, and undoubtedly has poor downstream performance. We hypothesize that somewhere between these two extremes is an optimal model that is faster than a vanilla Transformer and performs as well, if not better, all while being faster to step.
3.3. Routing schemes
Naively, one can leverage stochasticity to route tokens, akin to layer or block “dropout”. We present this routing scheme as a control, and will show that it significantly under-performs relative to vanilla transformers.
We hypothesize that learned routing is preferable. Intuitively, the network should be able to learn which tokens require more or less processing than others. If we are correct that Transformers often expend more compute than they need to make their predictions, then it is an empirical question as to how aggressively we can shrink each block’s capacity, and hence, how many tokens we can afford to route around each block.
There are two learned routing schemes we consider (see figure 2): token-choice and expert-choice. In token-choice routing, a router produces per-token probability distributions across computational paths (e.g., across expert identities in MoE Transformers). Tokens are then shuttled to the path they prefer—i.e., that with the highest probability—and auxiliary losses ensure that all tokens don’t converge to the same path. Token-choice routing can have load balancing problems since there isn’t a guarantee that tokens divide themselves appropriately between the possible paths. “Expert choice routing” flips this recipe on its head: rather than having tokens choose the path they prefer, each path instead chooses the top-𝑘 tokens based on the tokens’ preferences. This ensures a perfect load balance since 𝑘 tokens are guaranteed to be shuttled to each path. However, it could result in over- or under-processing of some tokens, since some tokens may be among the top-𝑘 for multiple paths, or for none of them.
We decided to leverage expert-choice routing for a few reasons. First, it obviates the need for an auxiliary balancing loss. Second, since the top-𝑘 operation depends on the magnitude of the router weights, this routing scheme allows for relative routing weights to help determine which tokens most need the block’s computations; routers can try to ensure that the most critical tokens are among the top-𝑘 by setting their weight appropriately, which is not possible with token-choice routing schemes. For our specific use-case, wherein one computational path is essentially a null operation, it might be critical that important tokens are routed away from the null operation. Third, because we only route through two paths, a single top-𝑘 operation can efficiently split the tokens into two mutually exclusive sets, one for each computational path, preventing the over- or under-processing problem mentioned above.
3.4. Routing implementation
As a reminder of the high-level intuition, each token is processed by a router to produce a scalar weight, and the top-𝑘 weights are then used to choose the token identities that will route through a transformer’s block, which comprises self-attention and the subsequent MLP.
Notably, we multiply the output of the function 𝑓 by the router weights. This puts the router weights along the “gradient path”, thus subjecting them to the forces of gradient descent through the course of the language modeling task (We experimented with versions where the router weights are also included along the computational path for those tokens that bypass the block’s computations, but it seems to be sufficient—and implementationally simpler—to only include the router weights along the computational path for those tokens that do not bypass the block’s computations).
3.5. Sampling
While expert-choice routing has a number of advantages, it has one distinct problem: the top-𝑘 operation is non-causal. This means that whether a given token’s routing weight is among the top-𝑘 for the sequence depends on the values of the routing weights for tokens that come after it, which we don’t have access to when autoregressively sampling.
We tested two methods to work around this problem. The first introduces a simple auxiliary loss that empirically affects the primary language modeling objective by approximately 0.2 − 0.3%, but allows us to sample from the model autoregressively. We use a binary cross-entropy loss wherein the router’s outputs provide the logits, and the top-𝑘 selections of these logits provide the targets (i.e. 1 if a token was among the top-𝑘, and 0 if not). Intuitively, this loss centers the sigmoid of the router’s outputs around 0.5; those tokens that are selected among the top-k are pressured to produce router outputs above 0.5, and those not among the top-k will be pressured to produce router outputs below 0.5. The second method introduces a small auxiliary MLP predictor (akin to a second router) that receives the same inputs as the router (with a stop gradient), but whose output is a prediction whether that token will be among the top-𝑘 or not in the sequence. This method does not affect the language modeling objective, and empirically does not significantly impact the step speed.
Equipped with these new methods, we can sample autoregressively by choosing to route tokens to or around a block based on the router’s output, which does not depend on any information from future tokens. We provide empirical evidence that this is a relatively easy auxiliary task that quickly achieves 99% accuracy.
3.6. Training methods
All models use the same basic hyperparameter configurations (e.g. cosine schedules equal to 1× the training steps, 128 batch size, 2048 sequence length) except for changes to the number of layers, heads, and embedding size to produce differently sized models during isoFLOP analyses.