Llama-3-CBHybridL-8B: Model Information
We are excited to release the Cerebras hybrid dense/sparse attention versions of Llama-3.1-8B-Instruct models optimized for long-context performance. This series includes two models: Llama3.1-CBHybridL-8B (model with 25 sparse attention layers out of 32) and Llama3.1-CBHybridM-8B (28 sparse attention layers out of 32).
This model β Cerebras Llama3.1-CBHybridL-8B β was built on top of Llama-3.1-8B-Instruct using sparse attention training features available in Cerebras Model Zoo Release 2.4. We created hybrid versions of Llama-3.1-8B-Instruct with most of the self-attention layers fine-tuned to perform sparse lambda-mask attention which reduces KV cache memory usage by 1.6-1.7x while largely maintaining long-context performance.
You can find more information about Cerebras hybrid Llama models at the following locations:
Results
Our hybrid models retain most of their performance in long-context despite requiring much less memory for KV cache:
LongBench suite | Llama-3.1-8B-Instruct | Llama-3-CBHybridM-8B | Llama-3-CBHybridL-8B |
---|---|---|---|
KV cache memory*, GB | 2.147 | 1.275 | 1.376 |
Single-doc QA | 54.197 | 54.507 | 56.187 |
Multi-doc QA | 41.455 | 41.022 | 43.082 |
Summarization | 26.1275 | 25.607 | 25.357 |
Few-shot learning | 63.4075 | 64.42 | 65.183 |
Synthetic | 97.29 | 96.75 | 98.0 |
Code completion | 59.745 | 66.865 | 66.49 |
Macro-mean (EN & ZH) | 57.037 | 58.195 | 59.05 |
Macro-mean (EN) | 58.606 | 60.485 | 60.937 |
HELMET suite (seq. len. 16K) | Llama-3.1-8B-Instruct | Llama-3-CBHybridM-8B | Llama-3-CBHybridL-8B |
---|---|---|---|
KV cache memory, GB | 2.147 | 1.275 | 1.376 |
Recall | 99.6875 | 87.5625 | 95.1875 |
Rerank | 52.6671 | 42.7879 | 45.5175 |
RAG | 69.0417 | 68.625 | 69.4583 |
LongdocQA | 32.061 | 34.419 | 35.2879 |
ICL | 76 | 81.6 | 82.2 |
Summarization | 26.278 | 22.4353 | 23.7324 |
Macro-mean | 59.2892 | 56.2382 | 58.564 |
* we include KV cache memory usage numbers at a representative sequence length of 16K, however note that samples across LongBench tasks have variable length, with ~14.5K being the 75th percentile of the sample length distribution.
Example Usage
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "cerebras/Llama-3-CBHybridL-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
messages = [
{"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
{"role": "user", "content": "Who are you?"},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
Adding memory tokens for enhanced long-context performance
We found that adding auxiliary memory tokens to input sequences at regular intervals improves long-context performance. These tokens can be inserted into the input sequence using a helper tokenizer.insert_memory_tokens()
method as shown below:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "cerebras/Llama-3-CBHybridL-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
messages = [
{"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
{"role": "user", "content": "Who are you?"},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Inserting 8 memory tokens per 256 tokens of original input:
input_ids = tokenizer.insert_memory_tokens(
input_ids,
episode_length=256,
num_memory_tokens_per_episode=8
)
outputs = model.generate(
input_ids,
max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
In our ablations, inserting 8 memory tokens after every 256 tokens of original input resulted in best accuracy. See out blog post for mode details.
License
Built with Llama3. Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright Β© Meta Platforms, Inc. All Rights Reserved.
Acknowledgements
Our models are fine-tuned versions of Meta-Llama-3.1-8B-Instruct. The sparse attention mechanism used in the Llama-3-CBHybrid
model series is from the LM-Infinite work of Han et al. See our blog post for the full list of references.
Citing this work
@misc{cerebras2025cb-hybrid-llama,
author = {Lazarevich, Ivan and Hassanpour, Mohammad and Venkatesh, Ganesh},
title = {Compressing KV cache memory by half with sparse attention},
month = {March},
year = {2025},
howpublished = {\url{https://www.cerebras.ai/blog/compressing-kv-cache-memory-by-half-with-sparse-attention}}
}ββββ
- Downloads last month
- 69