Papers
arxiv:2308.04623

Accelerating LLM Inference with Staged Speculative Decoding

Published on Aug 8, 2023
ยท Featured in Daily Papers on Aug 10, 2023
Authors:

Abstract

Recent advances with large language models (LLM) illustrate their diverse capabilities. We propose a novel algorithm, staged speculative decoding, to accelerate LLM inference in small-batch, on-device scenarios. We address the low arithmetic intensity of small-batch inference by improving upon previous work in speculative decoding. First, we restructure the speculative batch as a tree, which reduces generation costs and increases the expected tokens per batch. Second, we add a second stage of speculative decoding. Taken together, we reduce single-batch decoding latency by 3.16x with a 762M parameter GPT-2-L model while perfectly preserving output quality.

Community

Interesting

Interesting work! I have a couple of questions about 3.1 -

  1. what does it mean by moving the "compute from end of very long sequences to the beginning"?
  2. In the same paragraph, does it mean like a beam search of likely second and third token? For ex: with a beam of size 3, with tree-structured there are 3 possible sequences.
    Thanks

By moving compute from end of very long sequences to the beginning, I mean that for a fixed batch size, you'd rather use that compute on more probable completions than less probable ones. Let's say I have a batch size of 16. Standard speculative decoding would structure the tree as a single path:

#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#

Whereas I prefer to structure the tree as having many branches, since the shorter branches have higher cumulative probability than getting to the end of a single long branch

#+#+#-#-#-#-#
| -#-#
+#-#-#-#
|
-#-#
|
-#

In the above diagram, we've taken 4 branches at the first choice, and then 2 branches from the most likely node for the third token. (First token is always already known as the last output of the last batch.)

It is indeed very much like a beam search -- you actually get a free beam search when you do tree-structured speculative decoding. For the paper I ignore that since I focus more on the latency and memory bandwidth improvements and just match distribution, but it is also nice that you get a free improvement in quality :)

Hope this helps!

I am also very interested about this work! But I still have some confusion about the construction of the tree structure. Could you please give me some more detailed illustration or some reference? Thanks!

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2308.04623 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2308.04623 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2308.04623 in a Space README.md to link it from this page.

Collections including this paper 6