Papers
arxiv:2404.02258

Mixture-of-Depths: Dynamically allocating compute in transformer-based language models

Published on Apr 2
· Featured in Daily Papers on Apr 4
Authors:
,
,
,
,

Abstract

Transformer-based language models spread FLOPs uniformly across input sequences. In this work we demonstrate that transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence, optimising the allocation along the sequence for different layers across the model depth. Our method enforces a total compute budget by capping the number of tokens (k) that can participate in the self-attention and MLP computations at a given layer. The tokens to be processed are determined by the network using a top-k routing mechanism. Since k is defined a priori, this simple procedure uses a static computation graph with known tensor sizes, unlike other conditional computation techniques. Nevertheless, since the identities of the k tokens are fluid, this method can expend FLOPs non-uniformly across the time and model depth dimensions. Thus, compute expenditure is entirely predictable in sum total, but dynamic and context-sensitive at the token-level. Not only do models trained in this way learn to dynamically allocate compute, they do so efficiently. These models match baseline performance for equivalent FLOPS and wall-clock times to train, but require a fraction of the FLOPs per forward pass, and can be upwards of 50\% faster to step during post-training sampling.

Community

MoD TL;DR;

  • Reduces the number of tokens (k) that can participate in the self-attention and MLP computations at a given layer.
  • Adds a top-k Router that learns which tokens should be processed at each layer

Results:

  • MoD matches baseline performance with 66% faster training up to 1B scale
  • Can be combined with MoEs into MoDE
  • Inference improvements might be limited to batch_size=1

Do you expect MoDs to scale up to >=13B parameters?

Am I correct that all these Mo* approaches are not for the GPU poor (MoE on a single GPU doesn't give much benefit, MoD on a 13B parameter model would strongly reduce latency, but you still have a hard time executing a 13B parameter model on a single smaller GPU)?

My Blog post featuring this paper : https://rb.gy/zqqrm0

Awesome work! Here is a blog that was also nice to skim/read: https://huggingface.co/blog/joey00072/mixture-of-depth-is-vibe

Awesome work! Here is a blog that was also nice to skim/read: https://huggingface.co/blog/joey00072/mixture-of-depth-is-vibe

Here were some interesting insights from the blog:

  • With mod, you can't do batch inference since each token, since each token in a different batch can get routed around the block. You can if you use a mask but at that point, it's the same as having inference on a normal model with the overhead of a router.
  • Putting the whole seq will not have a lot of speed up, the same problem above some token will go through blocks some will not, and at inference time we don't want fixed capacity routing.
  • Existing speedup techniques of speculative decoding will not work or not be useful as they are in normal models.

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2404.02258 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2404.02258 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2404.02258 in a Space README.md to link it from this page.

Collections including this paper 44