Writing in the Margins: Better Inference Pattern for Long Context Retrieval
Abstract
In this paper, we introduce Writing in the Margins (WiM), a new inference pattern for Large Language Models designed to optimize the handling of long input sequences in retrieval-oriented tasks. This approach leverages the chunked prefill of the key-value cache to perform segment-wise inference, which enables efficient processing of extensive contexts along with the generation and classification of intermediate information ("margins") that guide the model towards specific tasks. This method increases computational overhead marginally while significantly enhancing the performance of off-the-shelf models without the need for fine-tuning. Specifically, we observe that WiM provides an average enhancement of 7.5% in accuracy for reasoning skills (HotpotQA, MultiHop-RAG) and more than a 30.0% increase in the F1-score for aggregation tasks (CWE). Additionally, we show how the proposed pattern fits into an interactive retrieval design that provides end-users with ongoing updates about the progress of context processing, and pinpoints the integration of relevant information into the final response. We release our implementation of WiM using Hugging Face Transformers library at https://github.com/writer/writing-in-the-margins.
Community
link to the code : https://github.com/writer/writing-in-the-margins
Congrats on the paper🔥Amazing work!
Amazing work, congratulations!
This is an automated message from the Librarian Bot. I found the following papers similar to this paper.
The following papers were recommended by the Semantic Scholar API
- Finch: Prompt-guided Key-Value Cache Compression (2024)
- Characterizing Prompt Compression Methods for Long Context Inference (2024)
- Retrieval Augmented Generation or Long-Context LLMs? A Comprehensive Study and Hybrid Approach (2024)
- CompAct: Compressing Retrieved Documents Actively for Question Answering (2024)
- ChatQA 2: Bridging the Gap to Proprietary LLMs in Long Context and RAG Capabilities (2024)
Please give a thumbs up to this comment if you found it helpful!
If you want recommendations for any Paper on Hugging Face checkout this Space
You can directly ask Librarian Bot for paper recommendations by tagging it in a comment:
@librarian-bot
recommend
Nicely done! Great paper!
This is really cool!!
However I have got myself very confused. Can someone explain why this is true, I think I am missing something simple.
By splitting a prompt of length L into N chunks, each
of size K, where N = L/K, the overall memory complexity
of prefilling is reduced from O(L^2) to O(LK).
In the chunked case, am I right to assume our first chunk would have memory cost K * K (?) As each of the K tokens in our chunk attends to the others in said chunk.
For the second chunk, this is now a(2K * K) as our K tokens in this chunk now attend to the prior K tokens also. I think this would continue for ...
Total Cost = Cost Chunk 1 + ... + Cost Chunk N
= K^2 + 2K^2 + ... + NK^2
= K^2(1+...+N)
= K^2 (N)(N+1)/2
= K^2 (L/K)(L/K+1)/2 * by substitution of N = L/K
= K^2 (L^2/K^2 + L/K)/2
= (L^2 + LK)/2
And I am confused why this is O(LK), so I think there must be some fundamental flaw in my understanding
Since chunked prefill is done "step by step" (each step is a forward pass through the model), the worst case memory complexity is allocated during the last step in which the last chunk is prefilled. For the last step, the memory complexity is the number of tokens in a single chunk (K) multiplied by the total number of tokens in the sequence (L).
Memory is always allocated and subsequently released after each forward pass (except for the KV-Cache and the model's parameters), that's why we don't "accumulate" the total cost over all chunks.
A video summary is now available here - https://youtu.be/JODc9ku5djA
I love it!
We actually recently did an independent implementation of this paper in our open-source optimizing llm proxy optillm - https://github.com/codelion/optillm/blob/main/optillm/plugins/memory_plugin.py
We were able to use it as a basis for the memory plugin in optillm that gives LLMs short term memory. It helps improve accuracy on long context retrieval and even enables LLMs to have unbounded context if needed.
We were able to match SOTA on a recent benchmark from Google Frames benchmark (https://huggingface.co/datasets/google/frames-benchmark) with only gpt-4o-mini v/s Gemini 1.5 Flash which has a context length that is 10x more.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper