Token Merging for fast LLM inference : Background and first trials with Mistral
Background
Decoder-only models have revolutionized natural language processing tasks by exhibiting remarkable generative abilities. The process used to generate a new token was brute force as it basically predicts a token, append it to the sequence and repeat until a maximum size is reached or a token is generated. Of course, at the core of this process lies the transformers architecture which requires more and more computational resources as the length increases to perform the self-attention mechanism.
How could we improve this process to make the inference faster without losing quality of predictions?
We could perform simpler operations in the network. That's quantization (8bits, 4bits, ternary...).
We could have a smaller model. The paper "The Unreasonable Ineffectiveness of the Deeper Layers" studies the impact of dropping layers from the model. The results show that dropping last attention modules of the model does not impact significatively the model performance while improving the latency.
The question that "naturally" emerges is : do we need the full sequence of tokens to predict the next token ? Or could we zip/merge it to feed a reduced sequence?
The objective is to twist the forward pass of a LLM to merge, one or several times, the sequence inside the call. This should be done without any needs of re-training or fine-tuning. A simple block optimizing the process.
To support my point, the "Diffusion world" has already a library to average redundant tokens in the forward pass to speed up latency.
Proposed technique
Hence, I am looking to average several tokens embeddings together. The simple average is not suitable as it tends to lose the magnitude of the averaged vectors. The work accomplished in the mergekit library introduces a relatively unknown averaging technique called SLERP (Spherical Linear Interpolation). It interpolates between two vectors by preserving the spherical aspect of the two. Unfortunatley, the SLERP approach only works for a pair of vectors, requiring me to preprocess the sequence.
The process of merging is as follows:
- Checking the sequence length
- If odd, prepend and append a "NULL" token (all values set to 0)
- If even, prepend, append and insert at the penultimate position a null token (every values set to 0)
- The sequence length is now always odd
- Reshape the sequence from (batch size, sequence length, dimension) to (batch size, sequence length/2, 2, dimension)
- Generate an array of temperatures for interpolation of shape (batch size, sequence length/2) with all values set at 0.5
- Change the temperature for pairs where a null token is present to 0 or 1 to fully preserve the non-null token
- Apply a pairwise SLERP
The new sequence obtained is almost two times smaller and is then feed to the other layers after the merge. The intuition behind the NULL token is the centrale role played by the token to share information (a sink token) and the need to attend to the last token fully for grammatical purpose.
Here is a scheme representing the process:
A very simple implementation where you merge the tokens before the language modelling head would be:
from transformers import AutoModelForCausalLM, AutoTokenizer
from forward_slerp import merge_tokens
mistral = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
sentence = "[INST] What is the biggest challenge in life ? [/INST]"
tokens = tokenizer(sentence, return_tensors = "pt")
hidden_state = mistral.model(**tokens)
merged = merge_tokens(hidden_state)
preds = mistral.lm_head(merged)
Benchmarking the merging
I created a merged-inference version of the Mistral Instruct 7B v0.2 model. The tests were done on a single H100 GPU.
The evaluation process of this merging process involves 3 parts: similarity of the next token predicted by the merged forward versus an unmerged one, estimated latency speed up (could be improved) and Alpaca benchmark.
On the wo first parts, I tested the merging process at different levels of the model layers, referred as "layers cutoff", and with sequences of varying length as input. The sequences are 5k texts randomly picked from the CNN dataset
In the Alpaca benchmark, I applied the merging process at the 20th layer and decoded until the 4096th token.
Next Token Prediction
Without surprises (or maybe with I don't know what you expected), the accuracy between the merged prediction and the base prediction grows linearly with the moment where you apply the merge. However, even at lower level, the merged model still achieves good results by quickly reaching 80% of agreement with the base model. No patterns seem to emerge from the sequence length.
Accuracy between the merged inference and unmerged inference
When checking at the top 3 and top 5 accuracy, the results are even better. Very short sequences tend to have worst results but this may be unsure as the texts are made of metadata at the beginning (eg. "LONDON, England (Reuters) --").
Top 3 accuracy between the merged inference and unmerged inference
Top 5 accuracy between the merged inference and unmerged inference
Estimated latency speed up
The latency speed up might be under-optimized as I did not work on an implementation managing the kv cache.
Latency speed-up (y axis : How many times faster is the merged inference)
The results seem fragile at this stage for me as I would expect a significant higher speed-up at first level layers. However, long sequences are the ones beneficiating the most from this approach.
AlpacaEval 2.0 Benchmark
To see how bad/good the model is when merged at inference (layer cutoff set at 20), I compared a merged version to a vanilla one on the AlpacaEval benchmark (805 prompts). The leaderboard and annotations are available on the repo in the folder "code/alpaca".
Two comments can be made:
- Its does not excessively impact the quality of the output as the merged model loses 4% of win rate (not length controlled) and 7% of win rate (length controlled). Even though it is not entirely sound, it demonstrates that the model is still "alive" and not too far from the base one.
- Merging tokens increase the verbosity of the model as the average length of the output is increased by 600 tokens. This comes from the impact of a shorter sequence fed to the layers, resulting in different positional encodings which tend to delay the occurrence of eos_token. I did not manage to implement a coherent positional encoding but will do in the future.
The model still outperforms gemma-7b-it, text_davinci_001 or nous-hermes-13b while averaging every pair of tokens. It ranks 88th on 145 tested models.
Conclusion and future works
As I mentioned, this work has been solely conducted on a Mistral 7B model. This merging idea could differ in terms of implementation depending on the model's underlying artchitecture. I also received questions whether this technique could scale to bigger models where the "needle in the haystack" effect tends to be less prone. I will repeat this work on differet models to see how well/bad this could go.
Also, the merging code might be under-optimized for a perfect super fast forward call as it does not handle the kv cache yet. A major topic is to handle the positional encoding to limit the "over generation" effect.
I intend to build a greater version of this technique to build eventually a wrapper class around any Causal Large Language Models in HuggingFace enabling faster inference (like accelerate).
In the end, this work outlines the need for a dual architecture world for LLMs: one for training and one for generating.
I would be pleased to hear your feedback, opinions and see if we can push this even more.
Links
Repo : https://github.com/samchaineau/llm_slerp_generation HF account : https://huggingface.co/samchain