lamhieu commited on
Commit
7a58a7d
0 Parent(s):

chore: initialize the app

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ~
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ghost 8B Beta (128k)
3
+ emoji: 👻 / 📚
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ suggested_hardware: a10g-small
11
+ language:
12
+ - en
13
+ - vi
14
+ - es
15
+ - pt
16
+ - de
17
+ - it
18
+ - fr
19
+ - ko
20
+ - zh
21
+ license: other
22
+ license_name: ghost-llms
23
+ license_link: https://ghost-x.org/ghost-llms-license
24
+ tags:
25
+ - ghost
26
+ ---
27
+
28
+ # ~
29
+
30
+ ### Notes
31
+
32
+ The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning".
33
+
34
+ See source code details [here](https://github.com/datamllab/LongLM).
35
+
36
+ ```
37
+ @misc{jin2024llm,
38
+ title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
39
+ author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
40
+ year={2024},
41
+ eprint={2401.01325},
42
+ archivePrefix={arXiv},
43
+ primaryClass={cs.CL}
44
+ }
45
+ ```
SelfExtend.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import MethodType
2
+ from functools import partial
3
+ import self_extend_patch as SE
4
+
5
+ def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
6
+ """
7
+ This function modifies the method of an instance of a model class.
8
+ It's part from chat-GPT.
9
+ It will replace the method with the new method.
10
+ Currently, we only use this function to modify the attention method of a model. Do not test it further.
11
+
12
+ instance:
13
+ instance of a model to modify.
14
+ target_class_name:
15
+ name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
16
+ new_method: new method to replace the original method. E.g. 'self_extend_forward'.
17
+ It should include a parameter 'self' to be binded to the instance.
18
+ """
19
+ target_found = False
20
+ if visited_instances is None:
21
+ visited_instances = set()
22
+ # Unique identifier for the instance (using id() since object's id is unique)
23
+ instance_id = id(instance)
24
+ if instance_id in visited_instances:
25
+ target_found = False
26
+ return target_found
27
+ # Add the instance to the already_visited set
28
+ visited_instances.add(instance_id)
29
+
30
+ # Check if this instance is of the target class
31
+ if instance.__class__.__name__ == target_class_name:
32
+ bond_method = MethodType(new_method, instance)
33
+ setattr(instance, target_method_name, bond_method)
34
+ target_found = True
35
+ return target_found
36
+ elif hasattr(instance, '__dict__'):
37
+ for attr_name, attr_value in instance.__dict__.items():
38
+ if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
39
+ _found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
40
+ if _found:
41
+ target_found = True
42
+ elif isinstance(attr_value, (list, tuple)):
43
+ for item in attr_value:
44
+ if isinstance(item, object):
45
+ _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
46
+ if _found:
47
+ target_found = True
48
+ # If attribute value is a dictionary, iterate over its values and recurse
49
+ # E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
50
+ elif isinstance(attr_value, dict):
51
+ for key, value in attr_value.items():
52
+ if isinstance(value, object):
53
+ _found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
54
+ if _found:
55
+ target_found = True
56
+ # If attribute value is a set, iterate and recurse
57
+ elif isinstance(attr_value, set):
58
+ for item in attr_value:
59
+ if isinstance(item, object):
60
+ _found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
61
+ if _found:
62
+ target_found = True
63
+
64
+ return target_found
65
+
66
+
67
+ def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
68
+ '''
69
+ loaded_model:
70
+ model to apply the self-attention extension.
71
+ group_size:
72
+ group size for the self-attention extension.
73
+ window_size:
74
+ window size for the self-attention extension.
75
+ scale_base:
76
+ base for the scale, equal to pretraining length.
77
+ e.g. 4096 for Llama, 8192 for Gemma
78
+
79
+ Two recommended scale factor:
80
+ yarn: https://arxiv.org/abs/2309.00071
81
+ log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
82
+ This is helpful while retrieving a long sequence (e.g a long passkey).
83
+ But on real-world data, the impact is minor. (e.g. on LongBench, LEval).
84
+
85
+ The reported results in our paper does not use this scale except for long passkey retrieval.
86
+ '''
87
+ arch_name = loaded_model.__class__.__name__
88
+ if 'Llama' in arch_name:
89
+ if enable_flash_attention:
90
+ if flash_attention_impl == "flash_attn":
91
+ self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
92
+ group_size_1=group_size,
93
+ group_size_2=window_size,
94
+ scale_base=scale_base)
95
+ modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
96
+ modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
97
+ print("Using flash_attn flash self_extend!!")
98
+ if (not modifed_1) or (not modifed_2):
99
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
100
+
101
+ elif flash_attention_impl == "triton":
102
+ self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
103
+ group_size_1=group_size,
104
+ group_size_2=window_size,
105
+ scale_base=scale_base)
106
+ modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
107
+ print("Using triton flash self_extend!!")
108
+ if (not modifed):
109
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
110
+ else:
111
+ raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")
112
+
113
+
114
+ else:
115
+ self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
116
+ group_size_1=group_size,
117
+ group_size_2=window_size,
118
+ scale_base=scale_base)
119
+ # after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
120
+ # print("loaded_model", loaded_model)
121
+ modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
122
+ if not modifed_2:
123
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
124
+ elif 'Mistral' in arch_name:
125
+ # Mistral shares the same architecture with Llama, so the implementation should be exchangable.
126
+ if enable_flash_attention:
127
+ self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
128
+ group_size_1=group_size,
129
+ group_size_2=window_size,
130
+ scale_base=scale_base)
131
+ modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
132
+ modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
133
+ if (not modifed_1) or (not modifed_2):
134
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
135
+ else:
136
+ self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
137
+ group_size_1=group_size,
138
+ group_size_2=window_size,
139
+ scale_base=scale_base)
140
+ modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
141
+ if not modifed_2:
142
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
143
+ elif 'Gemma' in arch_name:
144
+ if enable_flash_attention:
145
+ self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
146
+ group_size_1=group_size,
147
+ group_size_2=window_size,
148
+ scale_base=scale_base)
149
+ modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
150
+ modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
151
+ if (not modifed_1) or (not modifed_2):
152
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
153
+ else:
154
+ self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
155
+ group_size_1=group_size,
156
+ group_size_2=window_size,
157
+ scale_base=scale_base)
158
+ modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
159
+ if not modifed_2:
160
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
161
+ elif 'Qwen2' in arch_name:
162
+ if enable_flash_attention:
163
+ self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
164
+ group_size_1=group_size,
165
+ group_size_2=window_size,
166
+ scale_base=scale_base)
167
+ modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
168
+ modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
169
+ if (not modifed_1) or (not modifed_2):
170
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
171
+ else:
172
+ self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
173
+ group_size_1=group_size,
174
+ group_size_2=window_size,
175
+ scale_base=scale_base)
176
+ modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
177
+ if not modifed_2:
178
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
179
+ elif 'Phi' in arch_name:
180
+ if enable_flash_attention:
181
+ self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
182
+ group_size_1=group_size,
183
+ group_size_2=window_size,
184
+ scale_base=scale_base)
185
+ modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
186
+ modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
187
+ if (not modifed_1) or (not modifed_2):
188
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
189
+ else:
190
+ self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
191
+ group_size_1=group_size,
192
+ group_size_2=window_size,
193
+ scale_base=scale_base)
194
+ modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
195
+ if not modifed_2:
196
+ raise Exception(f"Failed to modify the attention method of {arch_name}")
197
+ else:
198
+ raise NotImplementedError
199
+
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+
3
+ import subprocess
4
+
5
+ subprocess.run(
6
+ f"pip install flash-attn --no-build-isolation",
7
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
8
+ shell=True,
9
+ )
10
+
11
+ import os
12
+ from threading import Thread
13
+ from typing import Iterator
14
+
15
+ import gradio as gr
16
+ import spaces
17
+ import torch
18
+ import SelfExtend
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
20
+
21
+
22
+ MAX_MAX_NEW_TOKENS = 4096
23
+ DEFAULT_MAX_NEW_TOKENS = 1536
24
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
25
+
26
+ DESCRIPTION = """\
27
+ # Playground with Ghost 8B Beta (p)
28
+
29
+ **Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, 8k and 128k, along with multilingual function tools support by default.
30
+
31
+ The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
32
+
33
+ 📋 Note: current model version is "disl-0x5" (10 Jul 2024), context length 128k (123392 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
34
+ """
35
+
36
+
37
+ PLACEHOLDER = """
38
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
39
+ <h1 style="font-size: 26px; margin-bottom: 2px; opacity: 0.20;">👻 Ghost 8B Beta</h1>
40
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.10;">Ask and share whatever you want ~</p>
41
+ </div>
42
+ """
43
+
44
+ LICENSE = """
45
+ <p/>
46
+
47
+ ---
48
+ Ghost 8B Beta may give inaccurate information, including information about people, so please verify Ghost 8B Beta's answers. [Ghost 8B Beta](https://ghost-x.org/docs/models/ghost-8b-beta/) by [Ghost X](https://ghost-x.org).
49
+ """
50
+
51
+ EXAMPLES = [
52
+ [
53
+ "What is the significance of the Higgs boson in the Standard Model of particle physics?"
54
+ ],
55
+ [
56
+ "Qu'est-ce que l'effet fondateur et comment influence-t-il la diversité génétique d'une population?"
57
+ ],
58
+ ["Qual è il principio di Le Chatelier e come si applica agli equilibri chimici?"],
59
+ [
60
+ "¿Qué es una supernova y cuál es su importancia en la formación de elementos pesados en el universo?"
61
+ ],
62
+ [
63
+ "Qual é a definição formal de uma integral de linha e como é utilizada em física?"
64
+ ],
65
+ [
66
+ "Was versteht man unter dem Moho-Diskontinuität und welche Bedeutung hat sie für das Verständnis der Erdkruste?"
67
+ ],
68
+ [
69
+ "Hiện tượng nhà kính là gì và nó ảnh hưởng như thế nào đến biến đổi khí hậu toàn cầu?"
70
+ ],
71
+ [
72
+ "알고리즘의 시간 복잡도가 중요한 이유는 무엇이며, 시간 복잡도를 어떻게 분석하나요?"
73
+ ],
74
+ ["什么是CRISPR-Cas9基因编辑技术,它在现代生物学研究中的作用是什么?"],
75
+ [
76
+ "Create a Python function that takes a list of integers and returns the list sorted in ascending order without using the built-in sort or sorted functions."
77
+ ],
78
+ [
79
+ "Écrivez une fonction en C++ qui trouve le plus long sous-tableau contigu avec une somme égale à zéro."
80
+ ],
81
+ [
82
+ "Scrivi una funzione in Java che calcola il fattoriale di un numero utilizzando la ricorsione."
83
+ ],
84
+ [
85
+ "Desarrolla una función en JavaScript que determine si una cadena de texto es un palíndromo, ignorando espacios y signos de puntuación."
86
+ ],
87
+ ["Implemente uma função em C# que verifique se uma matriz quadrada é simétrica."],
88
+ [
89
+ "Schreiben Sie eine Funktion in Swift, die eine gegebene Zeichenfolge in umgekehrter Reihenfolge zurückgibt, ohne integrierte Funktionen zu verwenden."
90
+ ],
91
+ [
92
+ "Viết một hàm trong PHP để tìm tất cả các số nguyên tố trong một khoảng cho trước."
93
+ ],
94
+ [
95
+ "파이썬을 사용하여 주어진 이진 트리가 이진 탐색 트리인지 확인하는 함수를 작성하십시오."
96
+ ],
97
+ [
98
+ "用 Go 语言编写一个函数,计算给定字符串中每个字符出现的次数,并返回一个包含字符及其出现次数的映射。"
99
+ ],
100
+ [
101
+ "Can you help me design a detailed project plan for developing a machine learning model for predicting stock prices?"
102
+ ],
103
+ [
104
+ "Pouvez-vous m'aider à organiser un emploi du temps hebdomadaire pour maximiser la productivité de mon équipe de développement logiciel?"
105
+ ],
106
+ [
107
+ "Puoi aiutarmi a creare un piano di sviluppo per un'applicazione mobile che gestisce le prenotazioni di ristoranti?"
108
+ ],
109
+ [
110
+ "¿Podrías ayudarme a elaborar un plan detallado para la implementación de un sistema de gestión de contenido (CMS) en una empresa mediana?"
111
+ ],
112
+ [
113
+ "Você pode me ajudar a planejar uma estratégia de desenvolvimento para um sistema de comércio eletrônico escalável?"
114
+ ],
115
+ [
116
+ "Können Sie mir helfen, einen detaillierten Zeitplan für die Implementierung eines neuen ERP-Systems in unserem Unternehmen zu erstellen?"
117
+ ],
118
+ [
119
+ "Bạn có thể giúp tôi xây dựng một kế hoạch phát triển chi tiết cho dự án xây dựng hệ thống quản lý chuỗi cung ứng không?"
120
+ ],
121
+ [
122
+ "신경망 기반 이미지 인식 모델 개발을 위한 세부 프로젝트 계획을 세우는 데 도움을 줄 수 있나요?"
123
+ ],
124
+ ["你能帮我制定一个详细的开发计划,用于创建一个基于区块链的分布式账本系统吗?"],
125
+ [
126
+ "Prove that the sum of the squares of any two sides of a right triangle is equal to the square of the hypotenuse."
127
+ ],
128
+ [
129
+ "Calculez la force gravitationnelle entre deux masses de 10 kg chacune séparées par une distance de 1 mètre."
130
+ ],
131
+ [
132
+ "Determina la formula molecolare di un composto che contiene il 40% di carbonio, il 6.67% di idrogeno e il 53.33% di ossigeno in massa."
133
+ ],
134
+ [
135
+ "Explica la teoría del ciclo económico de Schumpeter y cómo se aplica a la economía moderna."
136
+ ],
137
+ [
138
+ "Calcule a energia potencial gravitacional de um objeto de 5 kg a uma altura de 10 metros acima do solo (g = 9,8 m/s²)."
139
+ ],
140
+ [
141
+ "Beweisen Sie, dass jede Primzahl der Form 4k+1 als Summe zweier Quadrate geschrieben werden kann."
142
+ ],
143
+ [
144
+ "Tính nồng độ mol của dung dịch H₂SO₄ khi hoà tan 98 gam H₂SO₄ vào nước để được 1 lít dung dịch."
145
+ ],
146
+ ["케인스 경제학의 핵심 개념과 그것이 현대 경제 정책에 미치는 영향을 설명하십시오."],
147
+ ["计算一个质量为2 kg的物体在3米高处的重力势能(g = 9.8 m/s²)。"],
148
+ [
149
+ 'Identify the author of a novel that features a dystopian society where "Big Brother" watches over its citizens and the protagonist works for the Ministry of Truth.'
150
+ ],
151
+ [
152
+ "Quel est le seul mammifère capable de voler activement, souvent associé à la nuit et capable d'écholocalisation?"
153
+ ],
154
+ [
155
+ "Qual è l'opera letteraria italiana che narra il viaggio immaginario di un poeta attraverso Inferno, Purgatorio e Paradiso, guidato da Virgilio e Beatrice?"
156
+ ],
157
+ [
158
+ "¿Qué insecto es conocido por su organización social compleja, su capacidad para producir miel y su comunicación mediante la danza?"
159
+ ],
160
+ [
161
+ "Qual é o fenômeno atmosférico que ocorre quando uma massa de ar quente se encontra com uma massa de ar frio, resultando em uma violenta tempestade giratória?"
162
+ ],
163
+ [
164
+ "Welches literarische Werk beschreibt die Geschichte eines jungen Mädchens, das durch einen Kaninchenbau in eine fantastische Welt voller skurriler Charaktere fällt?"
165
+ ],
166
+ [
167
+ "Động vật nào có thể tái sinh toàn bộ cơ thể từ một mảnh nhỏ của chính nó, thường sống dưới nước và có thể có nhiều xúc tu?"
168
+ ],
169
+ [
170
+ "어떤 자연 현상은 태양빛이 대기 중의 물방울에 반사되고 굴절되어 발생하며, 하늘에 나타나는 여러 색깔의 아치 형태를 띠나요?"
171
+ ],
172
+ ["这部文学作品讲述了一位绅士和他的侍从的冒险故事,他们在"],
173
+ [
174
+ "Can you derive the Euler-Lagrange equation from the principle of stationary action in classical mechanics?"
175
+ ],
176
+ [
177
+ "Expliquez la notion de « différence ontologique » chez Martin Heidegger et son importance pour la phénoménologie."
178
+ ],
179
+ [
180
+ "Qual è il significato simbolico del colore blu nei dipinti di Giotto di Bondone durante il Rinascimento?"
181
+ ],
182
+ [
183
+ "¿Cómo afecta el cambio de código a la estructura gramatical en comunidades bilingües de habla español-inglés?"
184
+ ],
185
+ [
186
+ "Qual é o impacto da política monetária não convencional no controle da inflação durante uma crise econômica?"
187
+ ],
188
+ [
189
+ "Erklären Sie den Unterschied zwischen deterministischen und nicht-deterministischen endlichen Automaten und ihre Anwendungsbereiche."
190
+ ],
191
+ [
192
+ "Giải thích cơ chế của quá trình phiên mã ngược (reverse transcription) và tầm quan trọng của nó trong nghiên cứu HIV/AIDS."
193
+ ],
194
+ ["조선시대 성리학이 한국 사회와 문화에 미친 영향을 설명하세요."],
195
+ ["如何解释量子纠缠现象,以及它在量子计算中的潜在应用?"],
196
+ [
197
+ "How can you design a daily schedule that maximizes productivity for a remote worker who has multiple meetings and project deadlines?"
198
+ ],
199
+ [
200
+ "Quels sont les meilleures stratégies pour gérer les conflits au sein d'une équipe multiculturelle travaillant sur un projet commun?"
201
+ ],
202
+ [
203
+ "Quali sono i migliori consigli per mantenere un equilibrio tra vita professionale e vita privata in un ambiente lavorativo stressante?"
204
+ ],
205
+ [
206
+ "¿Cómo se puede elaborar un plan financiero personal efectivo que incluya ahorro para la jubilación, inversión y manejo de deudas?"
207
+ ],
208
+ [
209
+ "Quais são as melhores práticas para implementar metodologias ágeis em uma equipe de desenvolvimento de software?"
210
+ ],
211
+ [
212
+ "Welche Strategien können verwendet werden, um ein starkes berufliches Netzwerk aufzubauen und zu pflegen, insbesondere in der Tech-Branche?"
213
+ ],
214
+ [
215
+ "Những bước nào cần thiết để xây dựng một lộ trình phát triển sự nghiệp bền vững trong lĩnh vực công nghệ thông tin?"
216
+ ],
217
+ ["프로젝트의 범위 변동을 효과적으로 관리하기 위한 최고의 방법은 무엇인가요?"],
218
+ ["在快速变化的职场环境中,如何有效地实现工作与生活的平衡?"],
219
+ [
220
+ "Write an argumentative essay discussing the pros and cons of artificial intelligence in the workplace, including potential ethical concerns."
221
+ ],
222
+ [
223
+ "Analysez les impacts sociaux et économiques de la digitalisation sur les petites entreprises en France."
224
+ ],
225
+ [
226
+ "Scrivi un'email formale al direttore di una rivista per proporre un articolo sulla sostenibilità ambientale nelle città italiane."
227
+ ],
228
+ [
229
+ "Elabora un informe detallado sobre los efectos del cambio climático en la biodiversidad de la región amazónica."
230
+ ],
231
+ [
232
+ "Analise criticamente os principais pontos abordados no relatório anual do Banco Mundial sobre a pobreza global."
233
+ ],
234
+ [
235
+ "Erstellen Sie eine technische Dokumentation für die Implementierung eines neuen Software-Features in einer bestehenden Anwendung."
236
+ ],
237
+ [
238
+ "Viết một bài luận phân tích về tác động của cuộc cách mạng công nghiệp 4.0 đối với thị trường lao động Việt Nam."
239
+ ],
240
+ [
241
+ "인공지능의 윤리적 문제에 대한 연구 논문을 작성하고, 다양한 사례를 통해 그 영향을 분석하세요."
242
+ ],
243
+ ["分析鲁迅的小说《阿Q正传》中反映的中国社会问题和作者的批判态度。"],
244
+ ]
245
+
246
+ if not torch.cuda.is_available():
247
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
248
+
249
+
250
+ if torch.cuda.is_available():
251
+ model_id = "lamhieu/ghost-8b-beta-disl-0x5-8k"
252
+ model_tk = os.getenv("HF_TOKEN", None)
253
+ model = AutoModelForCausalLM.from_pretrained(
254
+ model_id,
255
+ device_map="auto",
256
+ torch_dtype=torch.bfloat16,
257
+ attn_implementation="flash_attention_2",
258
+ trust_remote_code=True,
259
+ token=model_tk,
260
+ )
261
+ tokenizer = AutoTokenizer.from_pretrained(
262
+ model_id,
263
+ trust_remote_code=True,
264
+ token=model_tk,
265
+ )
266
+ SelfExtend.apply(
267
+ model,
268
+ group_size=16,
269
+ window_size=512,
270
+ enable_flash_attention=True,
271
+ flash_attention_impl="flash_attn",
272
+ )
273
+ model.generation_config.max_length = 123392
274
+
275
+
276
+ @spaces.GPU(duration=120)
277
+ def generate(
278
+ message: str,
279
+ chat_history: list[tuple[str, str]],
280
+ system_prompt: str,
281
+ max_new_tokens: int = 1536,
282
+ temperature: float = 0.4,
283
+ top_p: float = 0.95,
284
+ top_k: int = 50,
285
+ repetition_penalty: float = 1.0,
286
+ ) -> Iterator[str]:
287
+ conversation = []
288
+ if system_prompt:
289
+ conversation.append({"role": "system", "content": system_prompt})
290
+ for user, assistant in chat_history:
291
+ conversation.extend(
292
+ [
293
+ {"role": "user", "content": user},
294
+ {"role": "assistant", "content": assistant},
295
+ ]
296
+ )
297
+ conversation.append({"role": "user", "content": message})
298
+
299
+ input_ids = tokenizer.apply_chat_template(
300
+ conversation, add_generation_prompt=True, return_tensors="pt"
301
+ )
302
+ input_ids = input_ids.to(model.device)
303
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
304
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
305
+ gr.Warning(
306
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
307
+ )
308
+
309
+ streamer = TextIteratorStreamer(
310
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
311
+ )
312
+ generate_kwargs = dict(
313
+ input_ids=input_ids,
314
+ streamer=streamer,
315
+ max_new_tokens=max_new_tokens,
316
+ do_sample=True,
317
+ top_p=top_p,
318
+ top_k=top_k,
319
+ temperature=temperature,
320
+ repetition_penalty=repetition_penalty,
321
+ )
322
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
323
+ t.start()
324
+
325
+ outputs = []
326
+ for text in streamer:
327
+ outputs.append(text)
328
+ yield "".join(outputs)
329
+
330
+
331
+ chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
332
+
333
+ chat_interface = gr.ChatInterface(
334
+ fn=generate,
335
+ chatbot=chatbot,
336
+ fill_height=True,
337
+ additional_inputs=[
338
+ gr.Textbox(label="System prompt", lines=6),
339
+ gr.Slider(
340
+ label="Max new tokens",
341
+ minimum=1,
342
+ maximum=MAX_MAX_NEW_TOKENS,
343
+ step=1,
344
+ value=DEFAULT_MAX_NEW_TOKENS,
345
+ ),
346
+ gr.Slider(
347
+ label="Temperature",
348
+ minimum=0.1,
349
+ maximum=2.0,
350
+ step=0.1,
351
+ value=0.4,
352
+ ),
353
+ gr.Slider(
354
+ label="Top-p (nucleus sampling)",
355
+ minimum=0.05,
356
+ maximum=1.0,
357
+ step=0.05,
358
+ value=0.95,
359
+ ),
360
+ gr.Slider(
361
+ label="Top-k",
362
+ minimum=1,
363
+ maximum=100,
364
+ step=1,
365
+ value=50,
366
+ ),
367
+ gr.Slider(
368
+ label="Repetition penalty",
369
+ minimum=1.0,
370
+ maximum=2.0,
371
+ step=0.05,
372
+ value=1.0,
373
+ ),
374
+ ],
375
+ stop_btn="Stop",
376
+ cache_examples=False,
377
+ examples=EXAMPLES,
378
+ examples_per_page=9,
379
+ )
380
+
381
+ with gr.Blocks(fill_height=True, css="style.css") as demo:
382
+ gr.Markdown(DESCRIPTION)
383
+ chat_interface.render()
384
+ gr.Markdown(LICENSE)
385
+
386
+ if __name__ == "__main__":
387
+ demo.queue(max_size=20).launch(share=True)
388
+ # demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ bitsandbytes==0.43.1
3
+ gradio==4.37.2
4
+ scipy==1.13.0
5
+ sentencepiece==0.2.0
6
+ spaces==0.28.3
7
+ torch==2.0.0
8
+ transformers==4.41.0
self_extend_patch/Llama.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+ from typing import Optional, Tuple
6
+ import torch.nn.functional as F
7
+ from transformers.cache_utils import Cache
8
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
9
+ from .selfextend_flash_attn import self_extend_flash_forward
10
+
11
+
12
+
13
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
14
+ """
15
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
16
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
17
+ """
18
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
19
+ if n_rep == 1:
20
+ return hidden_states
21
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
22
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
23
+
24
+ def rotate_half(x):
25
+ """Rotates half the hidden dims of the input."""
26
+ x1 = x[..., : x.shape[-1] // 2]
27
+ x2 = x[..., x.shape[-1] // 2 :]
28
+ return torch.cat((-x2, x1), dim=-1)
29
+
30
+
31
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
32
+ """Applies Rotary Position Embedding to the query and key tensors.
33
+
34
+ Args:
35
+ q (`torch.Tensor`): The query tensor.
36
+ k (`torch.Tensor`): The key tensor.
37
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
38
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
39
+ position_ids (`torch.Tensor`, *optional*):
40
+ Deprecated and unused.
41
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
42
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
43
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
44
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
45
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
46
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
47
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
48
+ Returns:
49
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
50
+ """
51
+ cos = cos.unsqueeze(unsqueeze_dim)
52
+ sin = sin.unsqueeze(unsqueeze_dim)
53
+ q_embed = (q * cos) + (rotate_half(q) * sin) if not q is None else None
54
+ k_embed = (k * cos) + (rotate_half(k) * sin) if not k is None else None
55
+ return q_embed, k_embed
56
+
57
+
58
+
59
+
60
+ def self_extend_forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_value: Optional[Cache] = None,
66
+ output_attentions: bool = False,
67
+ use_cache: bool = False,
68
+ cache_position: Optional[torch.LongTensor] = None,
69
+ group_size_1: Optional[float] = 8,
70
+ group_size_2: Optional[float] = 1024,
71
+ scale_base: Optional[int] = -1,
72
+ **kwargs,
73
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
74
+ if "padding_mask" in kwargs:
75
+ warnings.warn(
76
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
77
+ )
78
+
79
+ bsz, q_len, _ = hidden_states.size()
80
+
81
+
82
+ if self.config.pretraining_tp > 1:
83
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
84
+ query_slices = self.q_proj.weight.split(
85
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
86
+ )
87
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
88
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
89
+
90
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
91
+ query_states = torch.cat(query_states, dim=-1)
92
+
93
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
94
+ key_states = torch.cat(key_states, dim=-1)
95
+
96
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
97
+ value_states = torch.cat(value_states, dim=-1)
98
+
99
+ else:
100
+ query_states = self.q_proj(hidden_states)
101
+ key_states = self.k_proj(hidden_states)
102
+ value_states = self.v_proj(hidden_states)
103
+
104
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
105
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
106
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
107
+
108
+ if scale_base > 0:
109
+ scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
110
+ #scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
111
+ else:
112
+ scaled_query = query_states
113
+
114
+ past_key_value = getattr(self, "past_key_value", past_key_value)
115
+ if past_key_value is not None:
116
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
117
+ cache_kwargs = {"cache_position": cache_position}
118
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
119
+ kv_seq_len = key_states.shape[-2]
120
+
121
+ query_position = position_ids
122
+ key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
123
+
124
+
125
+
126
+ neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
127
+ neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
128
+
129
+
130
+ _re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
131
+ group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 // group_size_1
132
+ group_key_position = key_position // group_size_1
133
+
134
+ group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
135
+ group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
136
+
137
+
138
+
139
+ neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
140
+ _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
141
+ group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
142
+ _, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
143
+
144
+
145
+
146
+ neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
147
+ group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
148
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
149
+
150
+
151
+
152
+ neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
153
+ group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
154
+
155
+
156
+ if attention_mask is not None: # no matter the length, we just slice it
157
+ if cache_position is not None:
158
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
159
+ else:
160
+ causal_mask = attention_mask
161
+ group_attn_weights = group_attn_weights + causal_mask
162
+ neighbor_attn_weights = neighbor_attn_weights + causal_mask
163
+
164
+ if q_len == 1:
165
+ neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
166
+ neighbor_attention_mask[:, -group_size_2:] = 1
167
+ elif q_len == kv_seq_len:
168
+ neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
169
+ neighbor_attention_mask = torch.tril(neighbor_attention_mask)
170
+ if q_len-group_size_2 > 0:
171
+ group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
172
+ neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
173
+ else:
174
+ raise ValueError("q_len should be 1 or seq_len.")
175
+
176
+ neighbor_attention_mask = neighbor_attention_mask.bool()
177
+ attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
178
+
179
+ # upcast attention to fp32
180
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
181
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
182
+ attn_output = torch.matmul(attn_weights, value_states)
183
+
184
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
185
+ raise ValueError(
186
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
187
+ f" {attn_output.size()}"
188
+ )
189
+
190
+
191
+ attn_output = attn_output.transpose(1, 2).contiguous()
192
+
193
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
194
+
195
+ if self.config.pretraining_tp > 1:
196
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
197
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
198
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
199
+ else:
200
+ attn_output = self.o_proj(attn_output)
201
+
202
+ if not output_attentions:
203
+ attn_weights = None
204
+
205
+ return attn_output, attn_weights, past_key_value
206
+
207
+ def flash_self_extend_forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ position_ids: Optional[torch.LongTensor] = None,
212
+ past_key_value: Optional[Cache] = None,
213
+ output_attentions: bool = False,
214
+ use_cache: bool = False,
215
+ group_size_1: Optional[float] = 8,
216
+ group_size_2: Optional[float] = 1024,
217
+ scale_base: Optional[int] = -1,
218
+ cache_position: Optional[torch.LongTensor] = None,
219
+ **kwargs,
220
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
221
+ """
222
+ Require updating tansformers to >= 4.38.2, flash_attn >= 2.5.6
223
+ a. Only support causal mask.
224
+ b. Don't support atttention_mask.
225
+ c. Never test it with batch size > 1.
226
+ d. Only support q_len = 1 or q_len = seq_len.
227
+ """
228
+ if "padding_mask" in kwargs:
229
+ warnings.warn(
230
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
231
+ )
232
+ attention_mask = kwargs.pop("padding_mask")
233
+
234
+ bsz, q_len, _ = hidden_states.size()
235
+
236
+ query_states = self.q_proj(hidden_states)
237
+ key_states = self.k_proj(hidden_states)
238
+ value_states = self.v_proj(hidden_states)
239
+
240
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
241
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
242
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
243
+
244
+ if scale_base > 0:
245
+ scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
246
+ #scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
247
+ else:
248
+ scaled_query = query_states
249
+
250
+ past_key_value = getattr(self, "past_key_value", past_key_value)
251
+ if past_key_value is not None:
252
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
253
+ cache_kwargs = {"cache_position": cache_position}
254
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
255
+ kv_seq_len = key_states.shape[-2]
256
+
257
+ query_position = position_ids
258
+ # only consider bsz=1 for now.
259
+ key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len)
260
+ attn_dropout = self.config.attention_dropout if self.training else 0.0
261
+ if q_len == 1:
262
+ # We implement the case q_len == 1 separately, by manipulating positions.
263
+ # for our flash implementation doesnot work for decoding stage at the releasing time.
264
+
265
+ neighbor_key_position = position_ids[:, -1] - key_position
266
+ _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
267
+ group_key_position = position_ids[:, -1]//group_size_1 - key_position//group_size_1 + (_re_group_size_2 - _re_group_size_2//group_size_1)
268
+ decode_key_position = torch.cat([group_key_position[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
269
+
270
+ decode_k_cos, decode_k_sin = self.rotary_emb(value_states, decode_key_position)#, seq_len=None)
271
+ #import pdb; pdb.set_trace()
272
+ #neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position_ids)
273
+ decode_query_states = scaled_query.transpose(1,2).contiguous() # position 0: cos 0 = 1, sin 0 = 0
274
+ _, decode_key_states = apply_rotary_pos_emb(None, key_states, decode_k_cos, -decode_k_sin, decode_key_position)
275
+
276
+ decode_key_states = repeat_kv(decode_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
277
+ decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
278
+
279
+ attn_output = flash_attn_func(decode_query_states,
280
+ decode_key_states,
281
+ decode_value_states,
282
+ attn_dropout,
283
+ softmax_scale=None,
284
+ causal=True)
285
+ elif q_len == kv_seq_len:
286
+ # set correct position_ids & apply RoPE.
287
+ neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
288
+ neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
289
+
290
+ _re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
291
+ group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
292
+ group_key_position = key_position // group_size_1
293
+
294
+ group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
295
+ group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
296
+
297
+ neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
298
+ _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
299
+ group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
300
+ _, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
301
+
302
+
303
+ neighbor_query_states = neighbor_query_states.transpose(1, 2).contiguous()
304
+ neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
305
+ group_query_states = group_query_states.transpose(1, 2).contiguous()
306
+ group_key_states = repeat_kv(group_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
307
+ value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
308
+
309
+ attn_output = self_extend_flash_forward(self,
310
+ query_position,
311
+ group_size_2,
312
+ neighbor_query_states,
313
+ neighbor_key_states,
314
+ group_query_states,
315
+ group_key_states,
316
+ value_states,
317
+ attention_mask,
318
+ bsz,
319
+ q_len,
320
+ kv_seq_len,
321
+ attn_dropout,
322
+ )
323
+ else:
324
+ raise ValueError("q_len should be 1 or seq_len.")
325
+
326
+ attn_output = attn_output.contiguous()
327
+ attn_output = attn_output.view(bsz, q_len, -1).contiguous()
328
+ attn_output = self.o_proj(attn_output)
329
+
330
+ if not output_attentions:
331
+ attn_weights = None
332
+ return attn_output, attn_weights, past_key_value
333
+
334
+
335
+
336
+ def lm_infinite_forward(
337
+ self,
338
+ hidden_states: torch.Tensor,
339
+ attention_mask: Optional[torch.Tensor] = None,
340
+ position_ids: Optional[torch.LongTensor] = None,
341
+ past_key_value: Optional[Cache] = None,
342
+ output_attentions: bool = False,
343
+ use_cache: bool = False,
344
+ cache_position: Optional[torch.LongTensor] = None,
345
+ group_size_1: Optional[float] = 8,
346
+ group_size_2: Optional[float] = 1024,
347
+ initial_num: Optional[int] = 1,
348
+ scale_base: Optional[int] = -1,
349
+ **kwargs,
350
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
351
+ if "padding_mask" in kwargs:
352
+ warnings.warn(
353
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
354
+ )
355
+
356
+ bsz, q_len, _ = hidden_states.size()
357
+
358
+
359
+ if self.config.pretraining_tp > 1:
360
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
361
+ query_slices = self.q_proj.weight.split(
362
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
363
+ )
364
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
365
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
366
+
367
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
368
+ query_states = torch.cat(query_states, dim=-1)
369
+
370
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
371
+ key_states = torch.cat(key_states, dim=-1)
372
+
373
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
374
+ value_states = torch.cat(value_states, dim=-1)
375
+
376
+ else:
377
+ query_states = self.q_proj(hidden_states)
378
+ key_states = self.k_proj(hidden_states)
379
+ value_states = self.v_proj(hidden_states)
380
+
381
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
382
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
383
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
384
+
385
+ if scale_base > 0:
386
+ scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
387
+ #scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
388
+ else:
389
+ scaled_query = query_states
390
+
391
+ past_key_value = getattr(self, "past_key_value", past_key_value)
392
+ if past_key_value is not None:
393
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
394
+ cache_kwargs = {"cache_position": cache_position}
395
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
396
+ kv_seq_len = key_states.shape[-2]
397
+
398
+ query_position = position_ids
399
+ key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
400
+
401
+
402
+
403
+ neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
404
+ neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
405
+
406
+
407
+ _re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
408
+ group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
409
+ group_key_position = key_position // group_size_1
410
+
411
+ group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
412
+ group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
413
+
414
+
415
+
416
+ neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
417
+ _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
418
+ group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
419
+ _, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
420
+
421
+
422
+
423
+ neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
424
+ group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
425
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
426
+
427
+
428
+
429
+ neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
430
+ group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
431
+
432
+
433
+ if attention_mask is not None: # no matter the length, we just slice it
434
+ if cache_position is not None:
435
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
436
+ else:
437
+ causal_mask = attention_mask
438
+ group_attn_weights = group_attn_weights + causal_mask
439
+ neighbor_attn_weights = neighbor_attn_weights + causal_mask
440
+
441
+ if q_len == 1:
442
+ neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
443
+ neighbor_attention_mask[:, -group_size_2:] = 1
444
+ elif q_len == kv_seq_len:
445
+ neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
446
+ neighbor_attention_mask = torch.tril(neighbor_attention_mask)
447
+ if q_len-group_size_2 > 0:
448
+ group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
449
+ neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
450
+ else:
451
+ raise ValueError("q_len should be 1 or seq_len.")
452
+
453
+ neighbor_attention_mask = neighbor_attention_mask.bool()
454
+ attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
455
+
456
+ # upcast attention to fp32
457
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
458
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
459
+ attn_output = torch.matmul(attn_weights, value_states)
460
+
461
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
462
+ raise ValueError(
463
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
464
+ f" {attn_output.size()}"
465
+ )
466
+
467
+
468
+ attn_output = attn_output.transpose(1, 2).contiguous()
469
+
470
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
471
+
472
+ if self.config.pretraining_tp > 1:
473
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
474
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
475
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
476
+ else:
477
+ attn_output = self.o_proj(attn_output)
478
+
479
+ if not output_attentions:
480
+ attn_weights = None
481
+
482
+ return attn_output, attn_weights, past_key_value
self_extend_patch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import Llama
self_extend_patch/selfextend_flash_attn.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
2
+ import torch
3
+
4
+ # must replace orginal flash forward method with the following one first, to enbale the window feature.
5
+ def flash_attention2_forward_with_window_size(
6
+ self,
7
+ query_states,
8
+ key_states,
9
+ value_states,
10
+ attention_mask,
11
+ query_length,
12
+ dropout=0.0,
13
+ softmax_scale=None,
14
+ window_size=[-1, -1],
15
+ return_attn_probs=False,
16
+ ):
17
+ """
18
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
19
+ first unpad the input, then computes the attention scores and pad the final attention scores.
20
+
21
+ Args:
22
+ query_states (`torch.Tensor`):
23
+ Input query states to be passed to Flash Attention API
24
+ key_states (`torch.Tensor`):
25
+ Input key states to be passed to Flash Attention API
26
+ value_states (`torch.Tensor`):
27
+ Input value states to be passed to Flash Attention API
28
+ attention_mask (`torch.Tensor`):
29
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
30
+ position of padding tokens and 1 for the position of non-padding tokens.
31
+ dropout (`int`, *optional*):
32
+ Attention dropout
33
+ softmax_scale (`float`, *optional*):
34
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
35
+ window_size ([Int, Int])
36
+ The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used.
37
+ return_attn_probs (`bool`, *optional*):
38
+ Whether to return the attention softmax logssumexp and probabilities. Default to False.
39
+ """
40
+ if not self._flash_attn_uses_top_left_mask:
41
+ causal = self.is_causal
42
+ else:
43
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
44
+ causal = self.is_causal and query_length != 1
45
+
46
+ # Contains at least one padding token in the sequence
47
+ if attention_mask is not None:
48
+ batch_size = query_states.shape[0]
49
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
50
+ query_states, key_states, value_states, attention_mask, query_length
51
+ )
52
+
53
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
54
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
55
+ attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func(
56
+ query_states,
57
+ key_states,
58
+ value_states,
59
+ cu_seqlens_q=cu_seqlens_q,
60
+ cu_seqlens_k=cu_seqlens_k,
61
+ max_seqlen_q=max_seqlen_in_batch_q,
62
+ max_seqlen_k=max_seqlen_in_batch_k,
63
+ dropout_p=dropout,
64
+ softmax_scale=softmax_scale,
65
+ causal=causal,
66
+ window_size=window_size,
67
+ return_attn_probs=True,
68
+ )
69
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
70
+ else:
71
+ attn_output, softmax_lse, S_dmask = flash_attn_func(
72
+ query_states,
73
+ key_states,
74
+ value_states,
75
+ dropout,
76
+ softmax_scale=softmax_scale,
77
+ causal=causal,
78
+ window_size=window_size,
79
+ return_attn_probs=True,
80
+ )
81
+
82
+ if return_attn_probs:
83
+ return attn_output, softmax_lse, S_dmask
84
+ else:
85
+ return attn_output
86
+
87
+ def self_extend_flash_forward(
88
+ model_self,
89
+ query_position,
90
+ group_size_2,
91
+ neighbor_query_states,
92
+ neighbor_key_states,
93
+ group_query_states,
94
+ group_key_states,
95
+ value_states,
96
+ attention_mask,
97
+ bsz,
98
+ q_len,
99
+ kv_seq_len,
100
+ attn_dropout,
101
+ ):
102
+
103
+ if query_position.max() >= group_size_2:
104
+ neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward(
105
+ neighbor_query_states,
106
+ neighbor_key_states,
107
+ value_states,
108
+ attention_mask,
109
+ q_len,
110
+ dropout=attn_dropout,
111
+ window_size=[group_size_2 - 1, 0],
112
+ # right dim here does not matter and can be -1, or > 0 due to causal mask
113
+ return_attn_probs=True,
114
+ )
115
+
116
+ group_attention_len = (
117
+ kv_seq_len - group_size_2
118
+ ) # here we should use kv_seq_len rather than max_kv_len since we have paddings in qkv and attention_mask
119
+
120
+ group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None
121
+ group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward(
122
+ group_query_states[:, -group_attention_len:, :, :],
123
+ group_key_states[:, :group_attention_len, :, :],
124
+ value_states[:, :group_attention_len, :, :],
125
+ group_attention_mask,
126
+ group_query_states[:, -group_attention_len:, :, :].shape[1],
127
+ dropout=attn_dropout,
128
+ window_size=[-1, -1],
129
+ return_attn_probs=True,
130
+ ) # note that kv and q's indexing are different! also query size could be different from kv length and very small during generation compared to prefilling
131
+
132
+
133
+ # normalize lse first
134
+ neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) # [batch_size, 1]
135
+ group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) # [batch_size, 1]
136
+
137
+ # convert align left to align right and convert exp(0) to 0
138
+ neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded)
139
+ group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded)
140
+ for idx in range(bsz):
141
+ if neighbor_seq_length[idx] > 0:
142
+ neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[
143
+ idx, :, : neighbor_seq_length[idx]
144
+ ]
145
+ if group_seq_length[idx] > 0:
146
+ group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[
147
+ idx, :, : group_seq_length[idx]
148
+ ]
149
+
150
+ # attn_output size is [batch_size, max_seq_len (not the true one), query_length, dim]
151
+ true_neighbor_seq_max_length = neighbor_softmax_lse.shape[
152
+ -1
153
+ ] # it could be smaller than query_length due to the attention_mask
154
+ true_group_seq_max_length = group_softmax_lse.shape[
155
+ -1
156
+ ] # it could be smaller than group_query_layer[:, -group_attention_len:, :, :].shape[1] due to the attention_mask[:, :group_attention_len]
157
+
158
+ neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze(
159
+ -1
160
+ ) # [batch_size, true_neighbor_seq_max_length, self.num_heads, 1]
161
+ group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze(
162
+ -1
163
+ ) # [batch_size, true_group_seq_max_length, self.num_heads, 1]
164
+
165
+ lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :]
166
+ #if torch.isinf(neighbor_softmax_lse).any() or torch.isnan(neighbor_softmax_lse).any():
167
+ # import pdb; pdb.set_trace()
168
+
169
+ neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap))
170
+ neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1.
171
+ group_softmax_lse = 1 / (1 + torch.exp(-lse_gap))
172
+
173
+
174
+
175
+ neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = (
176
+ neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse
177
+ )
178
+ group_attn_output[:, -true_group_seq_max_length:, ...] = (
179
+ group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse
180
+ )
181
+ attn_output = torch.empty_like(neighbor_attn_output).copy_(
182
+ neighbor_attn_output
183
+ ) # might be slightly faster than clone
184
+ #attn_output[:, group_size_2:, ...] += group_attn_output
185
+ attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output
186
+ attn_output = torch.nan_to_num(attn_output, nan=0)
187
+
188
+ else:
189
+ attn_output = model_self._flash_attention_forward(
190
+ neighbor_query_states,
191
+ neighbor_key_states,
192
+ value_states,
193
+ attention_mask,
194
+ q_len,
195
+ dropout=attn_dropout,
196
+ window_size=[-1, -1],
197
+ )
198
+
199
+ return attn_output
self_extend_patch/selfextend_flash_attn_triton.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ def self_extend_flash_forward_triton(
9
+ model_self,
10
+ query_position,
11
+ group_size_2,
12
+ neighbor_query_states,
13
+ neighbor_key_states,
14
+ group_query_states,
15
+ group_key_states,
16
+ value_states,
17
+ attention_mask,
18
+ bsz,
19
+ q_len,
20
+ kv_seq_len,
21
+ attn_dropout,
22
+ ):
23
+
24
+ o = _self_extend_flash_forward_triton(q=neighbor_query_states,
25
+ k=neighbor_key_states,
26
+ q1=group_query_states,
27
+ k1=group_key_states,
28
+ v=value_states,
29
+ causal=(q_len == kv_seq_len),
30
+ sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]),
31
+ window=group_size_2)
32
+ o = o.transpose(1, 2).contiguous()
33
+ # print("o", o.shape)
34
+ return o
35
+
36
+
37
+
38
+
39
+
40
+
41
+ def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window):
42
+ # shape constraints
43
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
44
+ assert Lq == Lk and Lk == Lv
45
+ assert Lk in {16, 32, 64, 128}
46
+
47
+ device = torch.cuda.device_of(q)
48
+ with torch.cuda.device(device):
49
+ o = torch.empty_like(q)
50
+ BLOCK_M = 128
51
+ BLOCK_N = 32
52
+ grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
53
+ L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
54
+ _fwd_kernel[grid](
55
+ q,
56
+ k,
57
+ q1,
58
+ k1,
59
+ v,
60
+ sm_scale,
61
+ L,
62
+ o,
63
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
64
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
65
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
66
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
67
+ q.shape[0],
68
+ q.shape[1],
69
+ q.shape[2],
70
+ k.shape[2],
71
+ BLOCK_M=BLOCK_M,
72
+ BLOCK_N=BLOCK_N,
73
+ BLOCK_DMODEL=Lk,
74
+ IS_CAUSAL=causal,
75
+ WINDOW=window,
76
+ num_warps=8,
77
+ num_stages=2)
78
+
79
+ return o
80
+
81
+
82
+
83
+
84
+ @triton.heuristics(
85
+ {
86
+ "EVEN_M": lambda args: args["Q_CTX"] % args["BLOCK_M"] == 0,
87
+ "EVEN_N": lambda args: args["KV_CTX"] % args["BLOCK_N"] == 0,
88
+ }
89
+ )
90
+ @triton.jit
91
+ def _fwd_kernel(
92
+ Q,
93
+ K,
94
+ Q1,
95
+ K1,
96
+ V,
97
+ sm_scale,
98
+ L,
99
+ Out,
100
+ stride_qz, stride_qh, stride_qm, stride_qk,
101
+ stride_kz, stride_kh, stride_kn, stride_kk,
102
+ stride_vz, stride_vh, stride_vn, stride_vk,
103
+ stride_oz, stride_oh, stride_om, stride_on,
104
+ Z,
105
+ H,
106
+ Q_CTX,
107
+ KV_CTX,
108
+ BLOCK_M: tl.constexpr,
109
+ BLOCK_DMODEL: tl.constexpr,
110
+ BLOCK_N: tl.constexpr,
111
+ IS_CAUSAL: tl.constexpr,
112
+ WINDOW: tl.constexpr,
113
+ EVEN_M: tl.constexpr,
114
+ EVEN_N: tl.constexpr
115
+
116
+ ):
117
+ start_m = tl.program_id(0)
118
+ off_hz = tl.program_id(1)
119
+ # qvk_offset = off_hz * stride_qh
120
+ q_offset = off_hz * stride_qh
121
+ vk_offset = off_hz * stride_kh
122
+ # vk_offset = q_offset
123
+
124
+ Q_block_ptr = tl.make_block_ptr(
125
+ base=Q + q_offset,
126
+ shape=(Q_CTX, BLOCK_DMODEL),
127
+ strides=(stride_qm, stride_qk),
128
+ offsets=(start_m * BLOCK_M, 0),
129
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
130
+ order=(1, 0)
131
+ )
132
+ K_block_ptr = tl.make_block_ptr(
133
+ base=K + vk_offset,
134
+ shape=(KV_CTX, BLOCK_DMODEL),
135
+ strides=(stride_kn, stride_kk),
136
+ offsets=(0, 0),
137
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
138
+ order=(1, 0)
139
+ )
140
+ Q1_block_ptr = tl.make_block_ptr(
141
+ base=Q1 + q_offset,
142
+ shape=(Q_CTX, BLOCK_DMODEL),
143
+ strides=(stride_qm, stride_qk),
144
+ offsets=(start_m * BLOCK_M, 0),
145
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
146
+ order=(1, 0)
147
+ )
148
+ K1_block_ptr = tl.make_block_ptr(
149
+ base=K1 + vk_offset,
150
+ shape=(KV_CTX, BLOCK_DMODEL),
151
+ strides=(stride_kn, stride_kk),
152
+ offsets=(0, 0),
153
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
154
+ order=(1, 0)
155
+ )
156
+ V_block_ptr = tl.make_block_ptr(
157
+ base=V + vk_offset,
158
+ shape=(KV_CTX, BLOCK_DMODEL),
159
+ strides=(stride_vn, stride_vk),
160
+ offsets=(0, 0),
161
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
162
+ order=(1, 0)
163
+ )
164
+
165
+ # initialize offsets
166
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
167
+ offs_n = tl.arange(0, BLOCK_N)
168
+
169
+ # initialize pointer to m and l
170
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
171
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
172
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
173
+
174
+ # scale sm_scale by log_2(e) and use
175
+ # 2^x instead of exp in the loop because CSE and LICM
176
+ # don't work as expected with `exp` in the loop
177
+ qk_scale = sm_scale * 1.4426950408889634
178
+
179
+ # load q: it will stay in SRAM throughout
180
+ if EVEN_M:
181
+ q = tl.load(Q_block_ptr)
182
+ q1 = tl.load(Q1_block_ptr)
183
+ else:
184
+ q = tl.load(Q_block_ptr, boundary_check=(1,0))
185
+ q1 = tl.load(Q1_block_ptr, boundary_check=(1,0))
186
+
187
+ q = (q * qk_scale).to(tl.bfloat16)
188
+ q1 = (q1 * qk_scale).to(tl.bfloat16)
189
+
190
+
191
+ # Dot I trick: it converts q1, q2 into mma layout and saves shared memory
192
+ # better way to generate a eye matrix. avoid casting from bool
193
+ offs_k = tl.arange(0, BLOCK_DMODEL)
194
+ I = tl.where(offs_k[:, None] == offs_k,
195
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16),
196
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16))
197
+ q = tl.dot(q, I).to(tl.bfloat16)
198
+ q1 = tl.dot(q1, I).to(tl.bfloat16)
199
+
200
+
201
+ # loop over k, v and update accumulator
202
+ lo = 0
203
+ if IS_CAUSAL:
204
+ hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M)
205
+ else:
206
+ hi = KV_CTX
207
+
208
+ for start_n in range(lo, hi, BLOCK_N):
209
+ # -- load k, v --
210
+ if EVEN_N:
211
+ k = tl.load(K_block_ptr)
212
+ k1 = tl.load(K1_block_ptr)
213
+ v = tl.load(V_block_ptr)
214
+ else:
215
+ k = tl.load(K_block_ptr, boundary_check=(1,0))
216
+ k1 = tl.load(K1_block_ptr, boundary_check=(1,0))
217
+ v = tl.load(V_block_ptr, boundary_check=(1,0))
218
+
219
+ # -- compute qk ---
220
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
221
+
222
+ # Window masking
223
+ mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW)
224
+ qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k)))
225
+
226
+ # if not EVEN_N:
227
+ # mask = (start_n + offs_n) < KV_CTX
228
+ # qk = tl.where(mask, qk, float("-inf"))
229
+
230
+ if IS_CAUSAL:
231
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
232
+ qk = tl.where(mask, qk, float("-inf"))
233
+ # qk += tl.dot(q, k)
234
+
235
+ # -- compute scaling constant ---
236
+ m_i_new = tl.maximum(m_i, tl.max(qk, 1))
237
+ alpha = tl.math.exp2(m_i - m_i_new)
238
+ p = tl.math.exp2(qk - m_i_new[:, None])
239
+
240
+ # -- scale and update acc --
241
+ acc_scale = l_i * 0 + alpha # workaround some compiler bug
242
+ acc *= acc_scale[:, None]
243
+ acc += tl.dot(p.to(tl.bfloat16), v)
244
+
245
+ # -- update m_i and l_i --
246
+ l_i = l_i * alpha + tl.sum(p, 1)
247
+ m_i = m_i_new
248
+
249
+ # update pointers
250
+ K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0))
251
+ K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0))
252
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
253
+
254
+
255
+ # write back l and m
256
+ acc = acc * (1.0 / l_i[:, None])
257
+ l_ptrs = L + off_hz * Q_CTX + offs_m
258
+
259
+ mask_m = offs_m < Q_CTX
260
+ l_i = m_i + tl.math.log2(l_i)
261
+ if EVEN_M:
262
+ tl.store(l_ptrs, l_i)
263
+ else:
264
+ tl.store(l_ptrs, l_i, mask=mask_m)
265
+
266
+ # write back O
267
+ O_block_ptr = tl.make_block_ptr(
268
+ base=Out + q_offset,
269
+ shape=(Q_CTX, BLOCK_DMODEL),
270
+ strides=(stride_om, stride_on),
271
+ offsets=(start_m * BLOCK_M, 0),
272
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
273
+ order=(1, 0)
274
+ )
275
+ if EVEN_M:
276
+ tl.store(O_block_ptr, acc.to(tl.bfloat16))
277
+ else:
278
+ tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0))
self_extend_patch/triton_selfextend_flash_attn.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import triton
4
+ import triton.language as tl
5
+
6
+
7
+
8
+
9
+ # We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
10
+ # the code below and commenting out the equivalent parameters is convenient for
11
+ # re-tuning.
12
+ #@triton.autotune(
13
+ # configs=[
14
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
15
+ # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
16
+ # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
17
+ # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
18
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
19
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
20
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
21
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
22
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
23
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
24
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
25
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
26
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
27
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
28
+ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
29
+ # ],
30
+ # key=['N_CTX'],
31
+ #)
32
+ @triton.jit
33
+ def _attn_fwd_prefill(Q1, K1, Q2, K2, V, sm_scale, M, Out, #
34
+ stride_qz, stride_qh, stride_qm, stride_qk, #
35
+ stride_kz, stride_kh, stride_kn, stride_kk, #
36
+ stride_vz, stride_vh, stride_vk, stride_vn, #
37
+ stride_oz, stride_oh, stride_om, stride_on, #
38
+ Z, H, #
39
+ Q_CTX: tl.constexpr, #
40
+ N_CTX: tl.constexpr, #
41
+ WINDOW: tl.constexpr, #
42
+ BLOCK_M: tl.constexpr, #
43
+ BLOCK_DMODEL: tl.constexpr, #
44
+ BLOCK_N: tl.constexpr, #
45
+ ):
46
+ start_m = tl.program_id(0)
47
+ off_hz = tl.program_id(1)
48
+ off_z = off_hz // H
49
+ off_h = off_hz % H
50
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
51
+
52
+ # block pointers
53
+ Q1_block_ptr = tl.make_block_ptr(
54
+ base=Q1 + qvk_offset,
55
+ shape=(Q_CTX, BLOCK_DMODEL),
56
+ strides=(stride_qm, stride_qk),
57
+ offsets=(start_m * BLOCK_M, 0),
58
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
59
+ order=(1, 0),
60
+ )
61
+ Q2_block_ptr = tl.make_block_ptr(
62
+ base=Q2 + qvk_offset,
63
+ shape=(Q_CTX, BLOCK_DMODEL),
64
+ strides=(stride_qm, stride_qk),
65
+ offsets=(start_m * BLOCK_M, 0),
66
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
67
+ order=(1, 0),
68
+ )
69
+ V_block_ptr = tl.make_block_ptr(
70
+ base=V + qvk_offset,
71
+ shape=(N_CTX, BLOCK_DMODEL),
72
+ strides=(stride_vk, stride_vn),
73
+ offsets=(0, 0),
74
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
75
+ order=(1, 0),
76
+ )
77
+ K1_block_ptr = tl.make_block_ptr(
78
+ base=K1 + qvk_offset,
79
+ shape=(BLOCK_DMODEL, N_CTX),
80
+ strides=(stride_kk, stride_kn),
81
+ offsets=(0, 0),
82
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
83
+ order=(0, 1),
84
+ )
85
+ K2_block_ptr = tl.make_block_ptr(
86
+ base=K2 + qvk_offset,
87
+ shape=(BLOCK_DMODEL, N_CTX),
88
+ strides=(stride_kk, stride_kn),
89
+ offsets=(0, 0),
90
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
91
+ order=(0, 1),
92
+ )
93
+ O_block_ptr = tl.make_block_ptr(
94
+ base=Out + qvk_offset,
95
+ shape=(Q_CTX, BLOCK_DMODEL),
96
+ strides=(stride_om, stride_on),
97
+ offsets=(start_m * BLOCK_M, 0),
98
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
99
+ order=(1, 0),
100
+ )
101
+ # initialize offsets
102
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
103
+ offs_n = tl.arange(0, BLOCK_N)
104
+ # initialize pointer to m and l
105
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
106
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
107
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
108
+ # load scales
109
+ qk_scale = sm_scale
110
+ qk_scale *= 1.442695040888963#1.44269504 # 1/log(2)
111
+ # load q: it will stay in SRAM throughout
112
+ #q = tl.load(Q_block_ptr)
113
+ if start_m * BLOCK_M + BLOCK_M > Q_CTX:
114
+ q1 = tl.load(Q1_block_ptr, boundary_check=(0,), padding_option='zero')
115
+ q2 = tl.load(Q2_block_ptr, boundary_check=(0,), padding_option='zero')
116
+ else:
117
+ q1 = tl.load(Q1_block_ptr)
118
+ q2 = tl.load(Q2_block_ptr)
119
+ #q1 = (q1 * qk_scale).to(tl.float16)
120
+ #q2 = (q2 * qk_scale).to(tl.float16)
121
+
122
+ lo = 0
123
+ hi = (start_m + 1) * BLOCK_M
124
+ # loop over k, v and update accumulator
125
+ for start_n in range(lo, hi, BLOCK_N):
126
+ start_n = tl.multiple_of(start_n, BLOCK_N)
127
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) #?
128
+ #qk = qk.to(tl.float16)
129
+ # if use condition, qk has to be float32, then convert to float16...
130
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
131
+ if start_n + BLOCK_N - 1 > start_m * BLOCK_M - 1:
132
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, -1.0e6)#float("-inf"))
133
+
134
+ #qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
135
+ # -- compute qk ----
136
+ #k = tl.load(K_block_ptr)
137
+ # case 1: only need group attention: q2, k2
138
+ if BLOCK_N + start_n <= (start_m * BLOCK_M - WINDOW + 1):
139
+ if BLOCK_N + start_n >= N_CTX:
140
+ k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
141
+ v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
142
+ else:
143
+ k2 = tl.load(K2_block_ptr)
144
+ v = tl.load(V_block_ptr)
145
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
146
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
147
+ qk += tl.dot(q2, k2)#, out_dtype=tl.float16)
148
+ else:
149
+ #case 2: only need neighbor attention: q1, k1
150
+ if start_n >= (start_m+1) * BLOCK_M - WINDOW:
151
+ if BLOCK_N + start_n >= N_CTX:
152
+ k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
153
+ v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
154
+ else:
155
+ k1 = tl.load(K1_block_ptr)
156
+ v = tl.load(V_block_ptr)
157
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
158
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
159
+ qk += tl.dot(q1, k1)#, out_dtype=tl.float16)
160
+ else:
161
+ #case 3: need both q1, k1 and q2, k2
162
+ if BLOCK_N + start_n >= N_CTX:
163
+ k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
164
+ k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
165
+ v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
166
+ else:
167
+ k1 = tl.load(K1_block_ptr)
168
+ k2 = tl.load(K2_block_ptr)
169
+ v = tl.load(V_block_ptr)
170
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
171
+ #qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
172
+ qk1 = tl.dot(q1, k1)#, out_dtype=tl.float16)
173
+ qk2 = tl.dot(q2, k2)#, out_dtype=tl.float16)
174
+ #merge_mask = tl.abs((offs_m[:, None] - (start_n + offs_n[None, :]))) >= WINDOW
175
+ #qk += tl.where(merge_mask, qk2, qk1)
176
+ qk += tl.where(tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW, qk1, qk2)
177
+
178
+ qk *= qk_scale
179
+
180
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
181
+ qk = qk - m_ij[:, None]
182
+ p = tl.math.exp2(qk)
183
+ l_ij = tl.sum(p, 1)
184
+ # -- update m_i and l_i
185
+ alpha = tl.math.exp2(m_i - m_ij)
186
+ l_i = l_i * alpha + l_ij
187
+ # -- update output accumulator --
188
+ acc = acc * alpha[:, None]
189
+ # update acc
190
+ #v = tl.load(V_block_ptr)
191
+ #v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
192
+ acc += tl.dot(p.to(tl.float16), v)
193
+ # update m_i and l_i
194
+ m_i = m_ij
195
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
196
+ K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
197
+ K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
198
+
199
+ # epilogue
200
+ m_i += tl.math.log2(l_i)
201
+ acc = acc / l_i[:, None]
202
+ m_ptrs = M + off_hz * Q_CTX + offs_m
203
+ if start_m * BLOCK_M + BLOCK_M >= Q_CTX:
204
+ tl.store(m_ptrs, m_i, mask=offs_m < Q_CTX)
205
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
206
+ else:
207
+ tl.store(m_ptrs, m_i)
208
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
209
+
210
+
211
+ def prefill_flash_forward(q1, k1, q2, k2, v, q_len, seq_len, window, sm_scale=None):
212
+ # shape constraints
213
+ Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
214
+ assert Lq == Lk and Lk == Lv
215
+ assert Lk in {16, 32, 64, 128}
216
+ assert q_len == seq_len or q_len == 1
217
+ if sm_scale is None:
218
+ sm_scale = 1.0 / math.sqrt(Lq) # the default scale factor.
219
+ o = torch.empty_like(q1, device=q1.device)
220
+ block_m = 128
221
+ block_n = 64 # if Lk <= 64 else 32
222
+ num_stages = 4 if Lk <= 64 else 3
223
+ num_warps = 4
224
+ # Tuning for H100
225
+ if torch.cuda.get_device_capability()[0] == 9:
226
+ num_warps = 8
227
+ num_stages = 7 if Lk >= 64 else 3
228
+ grid = (triton.cdiv(q1.shape[2], block_m), q1.shape[0] * q1.shape[1], 1)
229
+ M = torch.empty((q1.shape[0], q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32)
230
+ with torch.cuda.device(v.device.index):
231
+ # https://github.com/Dao-AILab/flash-attention/commit/9795159082f6e6c847db2bf4284fd17326c31fbd
232
+ # to avoid the device issue .
233
+ _attn_fwd_prefill[grid](
234
+ q1, k1, q2, k2, v, sm_scale, M, o, #
235
+ q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), #
236
+ k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), #
237
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
238
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
239
+ q1.shape[0], q1.shape[1], #
240
+ Q_CTX=q_len,
241
+ N_CTX=seq_len, #
242
+ BLOCK_M=block_m, #
243
+ BLOCK_N=block_n, #
244
+ WINDOW=window,
245
+ BLOCK_DMODEL=Lk, #
246
+ num_warps=num_warps, #
247
+ num_stages=num_stages #
248
+ )
249
+
250
+ return o
style.css ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: white;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
12
+
13
+ .contain {
14
+ max-width: 900px;
15
+ margin: auto;
16
+ padding-top: 1.5rem;
17
+ }
18
+
19
+ .s-pad {
20
+ display: block;
21
+ padding-top: 2rem;
22
+ height: 1px;
23
+ width: 100%;
24
+ }