Transformers
Safetensors
Inference Endpoints
mustafaaljadery commited on
Commit
ece19aa
1 Parent(s): e70c9c5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -3
README.md CHANGED
@@ -1,3 +1,61 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Gemma 2B - 10M Context
5
+
6
+ Gemma 2B with recurrent local attention with context length of up to 10M. Our implemenation uses **<32GB** of memory!
7
+
8
+ ![Graphic of our implementation context](./images/graphic.png)
9
+
10
+ **Features:**
11
+
12
+ - 10M sequence length on Gemma 2B.
13
+ - Runs on less then 32GB of memory.
14
+ - Native inference on Apple Silicon using MLX.
15
+ - Highly performing retrieval - needle in hay stack.
16
+
17
+ ## Quick Start
18
+
19
+ > **Note:** This is a very early checkpoint of the model. Only 200 steps. We plan on training for a lot more tokens!
20
+
21
+ Install the model from huggingface - [Huggingface Model](https://huggingface.co/mustafaaljadery/gemma-10M-safetensor).
22
+
23
+ ```bash
24
+ python main.py
25
+ ```
26
+
27
+ Change the `main.py` inference code to the specific prompt you desire.
28
+
29
+ ```python
30
+ model_path = "./models/gemma-2b-10m"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
32
+ model = GemmaForCausalLM.from_pretrained(
33
+ model_path,
34
+ torch_dtype=torch.bfloat16
35
+ )
36
+
37
+ prompt_text = "Summarize this harry potter book..."
38
+
39
+ with torch.no_grad():
40
+ generated_text = generate(
41
+ model, tokenizer, prompt_text, max_length=512, temperature=0.8
42
+ )
43
+
44
+ print(generated_text)
45
+ ```
46
+
47
+ ## How does this work?
48
+
49
+ The largest bottleneck (in terms of memory) for LLMs is the KV cache. It grows quadratically in vanilla multi-head attention, thus limiting the size of your sequence length.
50
+
51
+ Our approach splits the attention in local attention blocks as outlined by [InfiniAttention](https://arxiv.org/abs/2404.07143). We take those local attention blocks and apply recurrance to the local attention blocks for the final result of 10M context global atention.
52
+
53
+ A lot of the inspiration for our ideas comes from the [Transformer-XL](https://arxiv.org/abs/1901.02860) paper.
54
+
55
+ ## Credits
56
+
57
+ This was built by:
58
+
59
+ - [Mustafa Aljadery](https://www.maxaljadery.com/)
60
+ - [Siddharth Sharma](https://stanford.edu/~sidshr/)
61
+ - [Aksh Garg](https://www.linkedin.com/in/aksh-garg/)