We apply activation beacon on Mistral-7B-Instruct-v0.2. It is remarkable for the following features:
- Effective: strong performance on long-context tasks.
- Efficient: significantly lower memory usage & inference latency compared with full-attention models (you can easily run 128K context on a single A100 device).
- Compatible: a plug-in module to establish long-context capabilities for Mistral (we did not modify any parameters of the original Mistral model).
- Low-Cost Training: train with 2B tokens where all training samples are less than 20K.
Compared with activation-beacon-llama2-7b-chat, there are three major differences:
- Training Data: we increase data for pretraining (2B tokens with 16384 sequence length on slimpajama) and supervised finetuning (open-sourced long-context data as well as thousands of synthetic long-context QA data using GPT-4).
- Sliding Window: the window size is increased to 2048.
- Condensing Ratio: we train with condensing ratio of
[2,4,8,16,32]
during pretraining and[2,4,8]
during finetuning. During both stages, we mix the condensing ratios with step-random strategy (see paper for detail).
Evaluation
You can easily reproduce the following results following instructions here.
Needle in a Haystack
We evaluate the model on the Needle-In-A-HayStack task using the official setting.
LongBench
We evaluate the model on LongBench using 32K context length.
Model | Single Doc QA | Multi Doc QA | Summarization |
---|---|---|---|
Mistral-7B-Instruct-v0.2 | 32.70 | 25.87 | 27.42 |
Yarn-Mistral-128K | 33.71 | 36.08 | 23.47 |
Activation-Beacon-Mistral-7B | 39.14 | 43.27 | 29.52 |
InfiniteBench
We evaluate the model on InfiniteBench using 128K context length. The results of Yarn-Mistral-128K is copied from the paper.
Model | LongBookQA Eng | LongBookSum Eng |
---|---|---|
Yarn-Mistral-128K | 9.55 | 9.09 |
Activation-Beacon-Mistral-7B | 26.81 | 12.49 |
Topic Retrieval
We evaluate the model on Topic Retrieval task with [5,10,15,20,25,30,40,50,60,70]
topics.
PG19 Perplexity
We evaluate the sliding window perplexity on PG19 test set with window size 100K and stride 32K. We also report the latency and the GPU memory usage. For full-attention models, we enable flash-attention-2 and tensor parallel. The evaluation is run on 8xA800 machine.
Model | Perplexity | Latency (s) | Memory (GB) |
---|---|---|---|
Mistral-7B-Instruct-v0.2 | 8.83 | 14.02 | 525.6 (cannot run on a single GPU) |
Yarn-Mistral-128K | 7.66 | 14.56 | 525.6 (cannot run on a single GPU) |
Activation-Beacon-Mistral-7B | 8.16 | 3.06 | 27.4 |
Passkey Retrieval
We evaluate the model on Passkey Retrieval task using the official setting.
Environment
torch>=2.1.1
transformers==4.39.3
Usage
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "namespace-Pt/activation-beacon-mistral-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.cuda().eval()
with torch.no_grad():
# short context
messages = [{"role": "user", "content": "Tell me about yourself."}]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=50)
print(f"Input Length: {inputs['input_ids'].shape[1]}")
print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
# reset memory before new generation task
model.memory.reset()
# long context
with open("data/infbench.json", encoding="utf-8") as f:
example = json.load(f)
messages = [{"role": "user", "content": example["context"]}]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
print("*"*20)
print(f"Input Length: {inputs['input_ids'].shape[1]}")
print(f"Answers: {example['answer']}")
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
NOTE: It's okay to see warnings like This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (32768). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
Just ignore it.
- Downloads last month
- 292