FlexPrefill
This repository provides the code for the paper "FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference".
TL;DR
FlexPrefill is a dynamic and context-aware sparse attention mechanism that optimizes computational efficiency during long-sequence inference for large language models (LLMs). It achieves this by dynamically adjusting sparse attention patterns and computational budgets in real-time based on input demands and attention head requirements.
Requirements
To use FlexPrefill, you will need the following packages:
torch==2.4.0
triton==3.0.0
transformers==4.44.0
flash_attn
(optional)vllm==0.5.4
(optional)
Quick Start
Example Test
To run tests using a specific model, you can use the test script located in tests/test_llm.py
:
python tests/test_llm.py --model meta-llama/Llama-3.1-8B-Instruct --pattern flex_prefill --engine hf
FlexPrefill Sparse Attention Function
import torch
from flex_prefill import flex_prefill_attention
B, N, H, D = 1, 64000, 32, 64
gamma = 0.9
tau = 0.1
q = torch.randn(B, N, H, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, N, H // 4, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, N, H // 4, D, device="cuda", dtype=torch.bfloat16)
flex_prefill_output, computational_ratio = flex_prefill_attention(
q,
k,
v,
gamma,
tau,
min_budget=1024,
max_budget=None,
gqa_interleave=False,
block_size=128,
return_computational_ratio=True,
)
Faster Hugging Face Transformers Model Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
from flex_prefill import patch_model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
).cuda()
flex_prefill_config = {
"block_size": 128,
"flex_prefill_gamma": 0.9,
"flex_prefill_tau": 0.1,
"flex_prefill_min_budget": 1024,
"flex_prefill_max_budget": None,
}
patch_model(model, "flex_prefill", flex_prefill_config)
input_ids = tokenizer(prompt, return_tensors="pt", return_attention_mask=False).input_ids.cuda()
output_ids = model.generate(input_ids, max_new_tokens=64)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
vLLM Model Inference
from vllm import LLM, SamplingParams
from flex_prefill import patch_model
model = LLM("meta-llama/Llama-3.1-8B-Instruct", enable_chunked_prefill=False, max_num_seqs=1)
sampling_params = SamplingParams(temperature=0, max_tokens=64)
flex_prefill_config = {
"block_size": 128,
"flex_prefill_gamma": 0.9,
"flex_prefill_tau": 0.1,
"flex_prefill_min_budget": 1024,
"flex_prefill_max_budget": None,
}
patch_model(model, "flex_prefill", flex_prefill_config)
model.generate(prompts=[prompt], sampling_params=sampling_params)
output = outputs[0].outputs[0].text
Experiments
Experiment scripts are provided in the experiments
folder. First, you need to install dependencies, and download the necessary data and models:
bash install.sh
bash experiments/download_data.sh
bash experiments/download_model.sh
Then, you can run benchmark experiments:
bash experiments/scripts/flex_prefill/ruler.sh
bash experiments/scripts/flex_prefill/infinitebench.sh
The results will be saved in the experiments/result
directory.
Supported Models
Currently, flex_prefill.patch_model
only supports the following models:
- LLaMA: meta-llama/Meta-Llama-3.1-8B-Instruct
- Qwen2: Qwen/Qwen2-7B-Instruct
- ChatGLM4: THUDM/glm-4-9b-chat-1m
- Yi: 01-ai/Yi-9B-200K
flex_prefill
can be used with both Hugging Face Transformers models and VLLM models, but note that the batch size must be equal to 1.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Citation
If you use this code in your research, please cite the following paper:
@article{FlexPrefill2024,
title={FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference},
author={Your Name and Collaborators},
journal={ArXiv Preprint },
year={2024}
}
Acknowledgments
We acknowledge the support from our collaborators and the community. Thank you for your contributions and feedback.
Contact
For any questions or comments about the paper or the code, please contact laixunhao@pku.edu.cn.
Enjoy using FlexPrefill, and feel free to contribute to the project by opening issues or submitting pull requests!