Llama-3-CBHybridM-8B: Model Information

We are excited to release the Cerebras hybrid dense/sparse attention versions of Llama-3.1-8B-Instruct models optimized for long-context performance. This series includes two models: Llama3.1-CBHybridL-8B (model with 25 sparse attention layers out of 32) and Llama3.1-CBHybridM-8B (28 sparse attention layers out of 32).

This model – Cerebras Llama3.1-CBHybridM-8B – was built on top of Llama-3.1-8B-Instruct using sparse attention training features available in Cerebras Model Zoo Release 2.4. We created hybrid versions of Llama-3.1-8B-Instruct with most of the self-attention layers fine-tuned to perform sparse lambda-mask attention which reduces KV cache memory usage by 1.6-1.7x while largely maintaining long-context performance.

You can find more information about Cerebras hybrid Llama models at the following locations:

Results

Our hybrid models retain most of their performance in long-context despite requiring much less memory for KV cache:

HELMET result

LongBench suite Llama-3.1-8B-Instruct Llama-3-CBHybridM-8B Llama-3-CBHybridL-8B
KV cache memory*, GB 2.147 1.275 1.376
Single-doc QA 54.197 54.507 56.187
Multi-doc QA 41.455 41.022 43.082
Summarization 26.1275 25.607 25.357
Few-shot learning 63.4075 64.42 65.183
Synthetic 97.29 96.75 98.0
Code completion 59.745 66.865 66.49
Macro-mean (EN & ZH) 57.037 58.195 59.05
Macro-mean (EN) 58.606 60.485 60.937
HELMET suite (seq. len. 16K) Llama-3.1-8B-Instruct Llama-3-CBHybridM-8B Llama-3-CBHybridL-8B
KV cache memory, GB 2.147 1.275 1.376
Recall 99.6875 87.5625 95.1875
Rerank 52.6671 42.7879 45.5175
RAG 69.0417 68.625 69.4583
LongdocQA 32.061 34.419 35.2879
ICL 76 81.6 82.2
Summarization 26.278 22.4353 23.7324
Macro-mean 59.2892 56.2382 58.564

* we include KV cache memory usage numbers at a representative sequence length of 16K, however note that samples across LongBench tasks have variable length, with ~14.5K being the 75th percentile of the sample length distribution.

Example Usage

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "cerebras/Llama-3-CBHybridM-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

messages = [
    {"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

Adding memory tokens for enhanced long-context performance

We found that adding auxiliary memory tokens to input sequences at regular intervals improves long-context performance. These tokens can be inserted into the input sequence using a helper tokenizer.insert_memory_tokens() method as shown below:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "cerebras/Llama-3-CBHybridM-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

messages = [
    {"role": "system", "content": "You are a wafer-scale chatbot who always responds in wafer speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

# Inserting 8 memory tokens per 256 tokens of original input:
input_ids = tokenizer.insert_memory_tokens(
    input_ids,
    episode_length=256,
    num_memory_tokens_per_episode=8
)

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

In our ablations, inserting 8 memory tokens after every 256 tokens of original input resulted in best accuracy. See out blog post for mode details.

License

Built with Llama3. Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright Β© Meta Platforms, Inc. All Rights Reserved.

Llama3.1 Community License

Acceptable Use Policy

Acknowledgements

Our models are fine-tuned versions of Meta-Llama-3.1-8B-Instruct. The sparse attention mechanism used in the Llama-3-CBHybrid model series is from the LM-Infinite work of Han et al. See our blog post for the full list of references.

Citing this work

@misc{cerebras2025cb-hybrid-llama,
  author       = {Lazarevich, Ivan and Hassanpour, Mohammad and Venkatesh, Ganesh},
  title        = {Compressing KV cache memory by half with sparse attention},
  month        = {March},
  year         = {2025},
  howpublished = {\url{https://www.cerebras.ai/blog/compressing-kv-cache-memory-by-half-with-sparse-attention}}
}​​​​
Downloads last month
79
Safetensors
Model size
8.03B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support