Edit model card

Activation Beacon for Mistral

[Paper] [Github]

We apply activation beacon on Mistral-7B-Instruct-v0.2. It is remarkable for the following features:

  • Effective: strong performance on long-context tasks.
  • Efficient: significantly lower memory usage & inference latency compared with full-attention models (you can easily run 128K context on a single A100 device).
  • Compatible: a plug-in module to establish long-context capabilities for Mistral (we did not modify any parameters of the original Mistral model).
  • Low-Cost Training: train with 2B tokens where all training samples are less than 20K.

Compared with activation-beacon-llama2-7b-chat, there are three major differences:

  • Training Data: we increase data for pretraining (2B tokens with 16384 sequence length on slimpajama) and supervised finetuning (open-sourced long-context data as well as thousands of synthetic long-context QA data using GPT-4).
  • Sliding Window: the window size is increased to 2048.
  • Condensing Ratio: we train with condensing ratio of [2,4,8,16,32] during pretraining and [2,4,8] during finetuning. During both stages, we mix the condensing ratios with step-random strategy (see paper for detail).

Evaluation

You can easily reproduce the following results following instructions here.

Needle in a Haystack

We evaluate the model on the Needle-In-A-HayStack task using the official setting.

LongBench

We evaluate the model on LongBench using 32K context length.

Model Single Doc QA Multi Doc QA Summarization
Mistral-7B-Instruct-v0.2 32.70 25.87 27.42
Yarn-Mistral-128K 33.71 36.08 23.47
Activation-Beacon-Mistral-7B 39.14 43.27 29.52

InfiniteBench

We evaluate the model on InfiniteBench using 128K context length. The results of Yarn-Mistral-128K is copied from the paper.

Model LongBookQA Eng LongBookSum Eng
Yarn-Mistral-128K 9.55 9.09
Activation-Beacon-Mistral-7B 26.81 12.49

Topic Retrieval

We evaluate the model on Topic Retrieval task with [5,10,15,20,25,30,40,50,60,70] topics.

PG19 Perplexity

We evaluate the sliding window perplexity on PG19 test set with window size 100K and stride 32K. We also report the latency and the GPU memory usage. For full-attention models, we enable flash-attention-2 and tensor parallel. The evaluation is run on 8xA800 machine.

Model Perplexity Latency (s) Memory (GB)
Mistral-7B-Instruct-v0.2 8.83 14.02 525.6 (cannot run on a single GPU)
Yarn-Mistral-128K 7.66 14.56 525.6 (cannot run on a single GPU)
Activation-Beacon-Mistral-7B 8.16 3.06 27.4

Passkey Retrieval

We evaluate the model on Passkey Retrieval task using the official setting.

Environment

torch>=2.1.1
transformers==4.39.3

Usage

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "namespace-Pt/activation-beacon-mistral-7b"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)

model = model.cuda().eval()

with torch.no_grad():
  # short context
  messages = [{"role": "user", "content": "Tell me about yourself."}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, max_new_tokens=50)
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Output:       {tokenizer.decode(outputs[0], skip_special_tokens=True)}")

  # reset memory before new generation task
  model.memory.reset()

  # long context
  with open("data/infbench.json", encoding="utf-8") as f:
    example = json.load(f)
  messages = [{"role": "user", "content": example["context"]}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
  print("*"*20)
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Answers:      {example['answer']}")
  print(f"Prediction:   {tokenizer.decode(outputs[0], skip_special_tokens=True)}")

NOTE: It's okay to see warnings like This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (32768). Depending on the model, you may observe exceptions, performance degradation, or nothing at all. Just ignore it.

Downloads last month
292
Safetensors
Model size
8.58B params
Tensor type
BF16
·