Text Generation
Transformers
PyTorch
Safetensors
longllama
text-generation-inference
custom_code
Szymon Tworkowski commited on
Commit
2137eb1
1 Parent(s): 65ce2de

fix md/typos

Browse files
Files changed (1) hide show
  1. README.md +13 -14
README.md CHANGED
@@ -6,12 +6,11 @@ pipeline_tag: text-generation
6
  tags:
7
  - text-generation-inference
8
  ---
9
- <p align="center" width="100%"><img src="https://raw.githubusercontent.com/CStanKonrad/long_llama/main/assets/longllama.png" alt="LongLLaMA" style="width: 50%; display: block; margin: auto;"></p>
10
-
11
  # LongLLaMA: Focused Transformer Training for Context Scaling
 
12
 
13
 
14
- [Colab](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_colab.ipynb) | [TLDR](#TLDR) | [Overview](#Overview) | [Usage](#Usage) | [LongLLaMA performance](#LongLLaMA-performance) | [Authors](#Authors) | [Citation](#Citation) | [License](License) | [Acknowledgments](#Acknowledgments)
15
 
16
  ## TLDR
17
  This repository contains the research preview of **LongLLaMA, a large language model capable of handling long contexts of 256k tokens or even more**.
@@ -24,7 +23,7 @@ LongLLaMA is built upon the foundation of [OpenLLaMA](https://github.com/openlm-
24
 
25
 
26
  **LongLLaMA** is an [OpenLLaMA](https://github.com/openlm-research/open_llama) model finetuned with the FoT method,
27
- with three layers used for context extension. Crucially, LongLLama is able to extrapolate much beyond the context length seen in training: $8k$. E.g., in the key retrieval task, it can handle inputs of length $256k$.
28
 
29
  <center>
30
 
@@ -66,7 +65,7 @@ prompt = "My name is Julien and I like to"
66
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
67
  outputs = model(input_ids=input_ids)
68
  ```
69
- During the model call, one can provide the parameter `last_context_length` (default $1024$), which specifies the number of tokens left in the last context window. Tuning this parameter can improve generation as the first layers do not have access to memory. See details in [How LongLLaMA handles long inputs](#How-LongLLaMA-handles-long-inputs).
70
 
71
  ```python
72
  generation_output = model.generate(
@@ -85,7 +84,7 @@ LongLLaMA has several other parameters:
85
  * `mem_layers` specifies layers endowed with memory (should be either an empty list or a list of all memory layers specified in the description of the checkpoint).
86
  * `mem_dtype` allows changing the type of memory cache
87
  * `mem_attention_grouping` can trade off speed for reduced memory usage.
88
- When equal to `(4, 2048)`, the memory layers will process at most $4*2048$ queries at once ($4$ heads and $2048$ queries for each head).
89
 
90
  ```python
91
  import torch
@@ -103,7 +102,7 @@ model = AutoModelForCausalLM.from_pretrained(
103
 
104
 
105
  ### Drop-in use with LLaMA code
106
- LongLLaMA checkpoints can also be used as a drop-in replacement for LLaMA checkpoints in [Hugging Face implementation of LLaMA](https://huggingface.co/docs/transformers/main/model_doc/llama), but in this case, they will be limited to the original context length of $2048$.
107
 
108
  ```python
109
  from transformers import LlamaTokenizer, LlamaForCausalLM
@@ -115,11 +114,11 @@ model = LlamaForCausalLM.from_pretrained("syzymon/long_llama_3b", torch_dtype=to
115
 
116
 
117
  ### How LongLLaMA handles long inputs
118
- Inputs over $2048$ tokens are automatically split into windows $w_1, \ldots, w_m$. The first $m-2$ windows contain $2048$ tokens each, $w_{m-1}$ has no more than $2048$ tokens, and $w_m$ contains the number of tokens specified by `last_context_length`. The model processes the windows one by one extending the memory cache after each. If `use_cache` is `True`, the last window will not be loaded to the memory cache but to the local (generation) cache.
119
 
120
- The memory cache stores $(key, value)$ pairs for each head of the specified memory layers `mem_layers`. In addition to this, it stores attention masks.
121
 
122
- If `use_cache=True` (which is the case in generation), LongLLaMA will use two caches: the memory cache for the specified layers and the local (generation) cache for all layers. When the local cache exceeds $2048$ elements, its content is moved to the memory cache for the memory layers.
123
 
124
  For simplicity, context extension is realized with a memory cache and full attention in this repo. Replacing this simple mechanism with a KNN search over an external database is possible with systems like [Faiss](https://github.com/facebookresearch/faiss). This potentially would enable further context length scaling. We leave this as a future work.
125
 
@@ -139,10 +138,10 @@ Our LongLLaMA 3B model also shows improvements when using long context on two do
139
 
140
  | Context/Dataset | TREC | WebQS |
141
  | --- | --- | --- |
142
- | $2K$ | 67.0 | 21.2 |
143
- | $4K$ | 71.6 | 21.4 |
144
- | $6K$ | 72.9 | 22.2 |
145
- | $8K$ | **73.3** | **22.4** |
146
 
147
  </center>
148
 
 
6
  tags:
7
  - text-generation-inference
8
  ---
 
 
9
  # LongLLaMA: Focused Transformer Training for Context Scaling
10
+ [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_colab.ipynb)
11
 
12
 
13
+ [TLDR](#TLDR) | [Overview](#Overview) | [Usage](#Usage) | [LongLLaMA performance](#LongLLaMA-performance) | [Authors](#Authors) | [Citation](#Citation) | [License](License) | [Acknowledgments](#Acknowledgments)
14
 
15
  ## TLDR
16
  This repository contains the research preview of **LongLLaMA, a large language model capable of handling long contexts of 256k tokens or even more**.
 
23
 
24
 
25
  **LongLLaMA** is an [OpenLLaMA](https://github.com/openlm-research/open_llama) model finetuned with the FoT method,
26
+ with three layers used for context extension. **Crucially, LongLLama is able to extrapolate much beyond the context length seen in training: 8k. E.g., in the key retrieval task, it can handle inputs of length 256k**.
27
 
28
  <center>
29
 
 
65
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
66
  outputs = model(input_ids=input_ids)
67
  ```
68
+ During the model call, one can provide the parameter `last_context_length` (default 1024), which specifies the number of tokens left in the last context window. Tuning this parameter can improve generation as the first layers do not have access to memory. See details in [How LongLLaMA handles long inputs](#How-LongLLaMA-handles-long-inputs).
69
 
70
  ```python
71
  generation_output = model.generate(
 
84
  * `mem_layers` specifies layers endowed with memory (should be either an empty list or a list of all memory layers specified in the description of the checkpoint).
85
  * `mem_dtype` allows changing the type of memory cache
86
  * `mem_attention_grouping` can trade off speed for reduced memory usage.
87
+ When equal to `(4, 2048)`, the memory layers will process at most 4*2048 queries at once (4 heads and 2048 queries for each head).
88
 
89
  ```python
90
  import torch
 
102
 
103
 
104
  ### Drop-in use with LLaMA code
105
+ LongLLaMA checkpoints can also be used as a drop-in replacement for LLaMA checkpoints in [Hugging Face implementation of LLaMA](https://huggingface.co/docs/transformers/main/model_doc/llama), but in this case, they will be limited to the original context length of 2048.
106
 
107
  ```python
108
  from transformers import LlamaTokenizer, LlamaForCausalLM
 
114
 
115
 
116
  ### How LongLLaMA handles long inputs
117
+ Inputs over 2048 tokens are automatically split into windows w_1, \ldots, w_m. The first m-2 windows contain 2048 tokens each, w_{m-1} has no more than 2048 tokens, and w_m contains the number of tokens specified by `last_context_length`. The model processes the windows one by one extending the memory cache after each. If `use_cache` is `True`, the last window will not be loaded to the memory cache but to the local (generation) cache.
118
 
119
+ The memory cache stores (key, value) pairs for each head of the specified memory layers `mem_layers`. In addition to this, it stores attention masks.
120
 
121
+ If `use_cache=True` (which is the case in generation), LongLLaMA will use two caches: the memory cache for the specified layers and the local (generation) cache for all layers. When the local cache exceeds 2048 elements, its content is moved to the memory cache for the memory layers.
122
 
123
  For simplicity, context extension is realized with a memory cache and full attention in this repo. Replacing this simple mechanism with a KNN search over an external database is possible with systems like [Faiss](https://github.com/facebookresearch/faiss). This potentially would enable further context length scaling. We leave this as a future work.
124
 
 
138
 
139
  | Context/Dataset | TREC | WebQS |
140
  | --- | --- | --- |
141
+ | 2K | 67.0 | 21.2 |
142
+ | 4K | 71.6 | 21.4 |
143
+ | 6K | 72.9 | 22.2 |
144
+ | 8K | **73.3** | **22.4** |
145
 
146
  </center>
147