PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training

Authors: Dawei Zhu, Nan Yang, Liang Wang, Yifan Song, Wenhao Wu, Furu Wei, Sujian Li

Abstract

Large Language Models (LLMs) are trained with a pre-defined context length, restricting their use in scenarios requiring long inputs. Previous efforts for adapting LLMs to a longer length usually requires fine-tuning with this target length (Full-length fine-tuning), suffering intensive training cost. To decouple train length from target length for efficient context window extension, we propose Positional Skip-wisE (PoSE) training that smartly simulates long inputs using a fixed context window. This is achieved by first dividing the original context window into several chunks, then designing distinct skipping bias terms to manipulate the position indices of each chunk. These bias terms and the lengths of each chunk are altered for every training example, allowing the model to adapt to all positions within target length. Experimental results show that PoSE greatly reduces memory and time overhead compared with Full-length fine-tuning, with minimal impact on performance. Leveraging this advantage, we have successfully extended the LLaMA model to 128k tokens using a 2k training context window. Furthermore, we empirically confirm that PoSE is compatible with all RoPE-based LLMs and position interpolation strategies. Notably, our method can potentially support infinite length, limited only by memory usage in inference. With ongoing progress for efficient inference, we believe PoSE can further scale the context window beyond 128k.

Released models

Context Extended Versions of LLaMA (originally support 2k context)

Model Context Interpolation Link
LLaMA-7B-PoSE-Linear-16k 16,384 Linear download link
LLaMA-7B-PoSE-NTK-16k 16,384 NTK download link
LLaMA-7B-PoSE-YaRN-16k 16,384 YaRN download link
LLaMA-7B-PoSE-Linear-96k 98,304 Linear download link
LLaMA-7B-PoSE-YaRN-96k 98,304 YaRN download link
LLaMA-7B-PoSE-YaRN-128k 131,072 YaRN download link

Context Extended Versions of LLaMA2 (originally support 4k context)

Model Context Interpolation Link
LLaMA2-7B-PoSE-Linear-16k 16,384 Linear download link
LLaMA2-7B-PoSE-NTK-16k 16,384 NTK download link
LLaMA2-7B-PoSE-YaRN-16k 16,384 YaRN download link

Context Extended Versions of Baichuan2 (originally support 4k context)

Model Context Interpolation Link
Baichuan2-7B-PoSE-Linear-16k 16,384 Linear download link
baichuan2-7B-PoSE-NTK-16k 16,384 NTK download link
baichuan2-7B-PoSE-YaRN-16k 16,384 YaRN download link

Notice

  • For YaRN interpolation, we use the revised version of YaRN in our experiments (see pose_modeling_llama.py), as supported by the issue inv_freq seems not calculated right.
  • In the configuration's max_position_embeddings parameter, we explicitly assign it to the scaled length. This differs slightly from the usage in the Hugging Face LLaMA document (huggingface.co). We've made this adjustment due to our positional skip-wise training, which utilizes position indices exceeding the input length. However, it's important to note that this modification does not negatively impact model performance.

Citation

If you find this project useful in your research, please consider citing:

@misc{zhu2023pose,
      title={PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training}, 
      author={Dawei Zhu and Nan Yang and Liang Wang and Yifan Song and Wenhao Wu and Furu Wei and Sujian Li},
      year={2023},
      eprint={2309.10400},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Acknowledgement

  • This work is built upon the LLaMA, GPT-J, Baichuan as the pre-trained models.
Downloads last month
16
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.