CLEX: Continuous Length Extrapolation for Large Language Models

This repo stores the checkpoint of CLEX-Mixtral-8x7B-Chat-32K.

Features and Highlights of CLEX

CLEX_diagram

  • Simple and Clear: MINIMAL code and architecture changes. Only one up-and-down projection layer introduced, NO recurrent memory caching or sparse attention required.
  • Train Short, Test Long: NO performance drop on the sequences 4x~8x longer than the training ones (see here).
  • Continuous Length Extrapolation: Explicitly modeling the continuous dynamics of context window size during length extrapolation.

If you have any questions, feel free to contact us. (Emails: guanzzh.chen@gmail.com, lixin4ever@gmail.com)

Model Zoo

Model Name Model Type Starting Point Train Data Train Length MAX Test Length HF Repo
CLEX-LLaMA-2-7B-16K base LLaMA-2-7B Redpajama-Book 16K 64K link
CLEX-LLaMA-2-7B-Chat-16K chat CLEX-7B-16K UltraChat 16K 64K link
CLEX-LLaMA-2-7B-64K base LLaMA-2-7B Redpajama-Book 64k 256K link
CLEX-Phi-2-32K base Phi-2-2.7B LongCorpus-2.5B 32k 128K link
CLEX-Mixtral-8x7B-32K base Mixtral-8x7B-v0.1 LongCorpus-2.5B 32k >128K link
CLEX-Mixtral-8x7B-Chat-32k (this checkpoint) chat CLEX-Mixtral-8x7B-32K Ultrachat 200k 32k >128K link

Usage

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K", torch_dtype=torch.bfloat16, trust_remote_code=True)
inputs = tokenizer("What is CLEX?", return_tensors="pt")
sample = model.generate(**inputs, max_length=128)
print(tokenizer.decode(sample[0]))

Evaluation

InfiniteBench

We also evaluate CLEX-Mixtral-8x7B-Chat-32k on InfiniteBench, which is a 128k-length benchmark covering various tasks. We compare our CLEX-Mixtral-8x7B-Chat-32k with GPT-4, Claude, KimiChat, and vanilla Mixtral-8x7B.

Task Name GPT-4 YaRN-Mistral-7B Kimi-Chat Claude 2 CLEX-Mixtral-8x7B-Chat-32k Mixtral-8x7B-Instruct-v0.1
Retrieve.PassKey 100% 92.71% 98.14% 97.80% 99.72% 96.78%
Retrieve.Number 100% 56.61% 95.42% 98.14% 76.10% 76.61%
Retrieve.KV 89.00% < 5% 53.60% 65.40% <5% <5%
En.Sum 14.73% 9.09% 17.93% 14.45% 15.48% 14.3%
En.QA 22.22% 9.55% 16.52% 11.97% 15.52% 16.81%
En.MC 67.25% 27.95% 72.49% 62.88% 58.96% 56.77%
En.Dia 8.50% 7.50% 11.50% 46.50% 9% <5%
Code.Debug 39.59% < 5% 18.02% < 5% 21.32% <5%
Code.Run 23.25% < 5% < 5% < 5% < 5% <5%
Math.Calc < 5% < 5% < 5% < 5% < 5% <5%
Math.Find 60.00% 17.14% 12.57% 32.29% 28% 26.57%

Citation

If you find our project useful, hope you can star our repo and cite our paper as follows:

@article{damonlpsg2023clex,
  author = {Chen, Guanzheng and Li, Xin and Meng, Zaiqiao and Liang, Shangsong and Bing, Lidong},
  title = {CLEX: Continuous Length Extrapolation for Large Language Models},
  year = 2023,
  journal = {arXiv preprint arXiv:2310.16450},
  url = {https://arxiv.org/abs/2310.16450}
}
Downloads last month
27
Safetensors
Model size
46.7B params
Tensor type
F32
·
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K