Spaces:
Runtime error
Runtime error
lamhieu
commited on
Commit
•
7a58a7d
0
Parent(s):
chore: initialize the app
Browse files- .DS_Store +0 -0
- .gitattributes +35 -0
- LICENSE.txt +1 -0
- README.md +45 -0
- SelfExtend.py +199 -0
- app.py +388 -0
- requirements.txt +8 -0
- self_extend_patch/Llama.py +482 -0
- self_extend_patch/__init__.py +1 -0
- self_extend_patch/selfextend_flash_attn.py +199 -0
- self_extend_patch/selfextend_flash_attn_triton.py +278 -0
- self_extend_patch/triton_selfextend_flash_attn.py +250 -0
- style.css +24 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
~
|
README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ghost 8B Beta (128k)
|
3 |
+
emoji: 👻 / 📚
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
suggested_hardware: a10g-small
|
11 |
+
language:
|
12 |
+
- en
|
13 |
+
- vi
|
14 |
+
- es
|
15 |
+
- pt
|
16 |
+
- de
|
17 |
+
- it
|
18 |
+
- fr
|
19 |
+
- ko
|
20 |
+
- zh
|
21 |
+
license: other
|
22 |
+
license_name: ghost-llms
|
23 |
+
license_link: https://ghost-x.org/ghost-llms-license
|
24 |
+
tags:
|
25 |
+
- ghost
|
26 |
+
---
|
27 |
+
|
28 |
+
# ~
|
29 |
+
|
30 |
+
### Notes
|
31 |
+
|
32 |
+
The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning".
|
33 |
+
|
34 |
+
See source code details [here](https://github.com/datamllab/LongLM).
|
35 |
+
|
36 |
+
```
|
37 |
+
@misc{jin2024llm,
|
38 |
+
title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
|
39 |
+
author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
|
40 |
+
year={2024},
|
41 |
+
eprint={2401.01325},
|
42 |
+
archivePrefix={arXiv},
|
43 |
+
primaryClass={cs.CL}
|
44 |
+
}
|
45 |
+
```
|
SelfExtend.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import MethodType
|
2 |
+
from functools import partial
|
3 |
+
import self_extend_patch as SE
|
4 |
+
|
5 |
+
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
|
6 |
+
"""
|
7 |
+
This function modifies the method of an instance of a model class.
|
8 |
+
It's part from chat-GPT.
|
9 |
+
It will replace the method with the new method.
|
10 |
+
Currently, we only use this function to modify the attention method of a model. Do not test it further.
|
11 |
+
|
12 |
+
instance:
|
13 |
+
instance of a model to modify.
|
14 |
+
target_class_name:
|
15 |
+
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
|
16 |
+
new_method: new method to replace the original method. E.g. 'self_extend_forward'.
|
17 |
+
It should include a parameter 'self' to be binded to the instance.
|
18 |
+
"""
|
19 |
+
target_found = False
|
20 |
+
if visited_instances is None:
|
21 |
+
visited_instances = set()
|
22 |
+
# Unique identifier for the instance (using id() since object's id is unique)
|
23 |
+
instance_id = id(instance)
|
24 |
+
if instance_id in visited_instances:
|
25 |
+
target_found = False
|
26 |
+
return target_found
|
27 |
+
# Add the instance to the already_visited set
|
28 |
+
visited_instances.add(instance_id)
|
29 |
+
|
30 |
+
# Check if this instance is of the target class
|
31 |
+
if instance.__class__.__name__ == target_class_name:
|
32 |
+
bond_method = MethodType(new_method, instance)
|
33 |
+
setattr(instance, target_method_name, bond_method)
|
34 |
+
target_found = True
|
35 |
+
return target_found
|
36 |
+
elif hasattr(instance, '__dict__'):
|
37 |
+
for attr_name, attr_value in instance.__dict__.items():
|
38 |
+
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
|
39 |
+
_found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
|
40 |
+
if _found:
|
41 |
+
target_found = True
|
42 |
+
elif isinstance(attr_value, (list, tuple)):
|
43 |
+
for item in attr_value:
|
44 |
+
if isinstance(item, object):
|
45 |
+
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
|
46 |
+
if _found:
|
47 |
+
target_found = True
|
48 |
+
# If attribute value is a dictionary, iterate over its values and recurse
|
49 |
+
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
|
50 |
+
elif isinstance(attr_value, dict):
|
51 |
+
for key, value in attr_value.items():
|
52 |
+
if isinstance(value, object):
|
53 |
+
_found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
|
54 |
+
if _found:
|
55 |
+
target_found = True
|
56 |
+
# If attribute value is a set, iterate and recurse
|
57 |
+
elif isinstance(attr_value, set):
|
58 |
+
for item in attr_value:
|
59 |
+
if isinstance(item, object):
|
60 |
+
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
|
61 |
+
if _found:
|
62 |
+
target_found = True
|
63 |
+
|
64 |
+
return target_found
|
65 |
+
|
66 |
+
|
67 |
+
def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
|
68 |
+
'''
|
69 |
+
loaded_model:
|
70 |
+
model to apply the self-attention extension.
|
71 |
+
group_size:
|
72 |
+
group size for the self-attention extension.
|
73 |
+
window_size:
|
74 |
+
window size for the self-attention extension.
|
75 |
+
scale_base:
|
76 |
+
base for the scale, equal to pretraining length.
|
77 |
+
e.g. 4096 for Llama, 8192 for Gemma
|
78 |
+
|
79 |
+
Two recommended scale factor:
|
80 |
+
yarn: https://arxiv.org/abs/2309.00071
|
81 |
+
log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
|
82 |
+
This is helpful while retrieving a long sequence (e.g a long passkey).
|
83 |
+
But on real-world data, the impact is minor. (e.g. on LongBench, LEval).
|
84 |
+
|
85 |
+
The reported results in our paper does not use this scale except for long passkey retrieval.
|
86 |
+
'''
|
87 |
+
arch_name = loaded_model.__class__.__name__
|
88 |
+
if 'Llama' in arch_name:
|
89 |
+
if enable_flash_attention:
|
90 |
+
if flash_attention_impl == "flash_attn":
|
91 |
+
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
|
92 |
+
group_size_1=group_size,
|
93 |
+
group_size_2=window_size,
|
94 |
+
scale_base=scale_base)
|
95 |
+
modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
96 |
+
modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
|
97 |
+
print("Using flash_attn flash self_extend!!")
|
98 |
+
if (not modifed_1) or (not modifed_2):
|
99 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
100 |
+
|
101 |
+
elif flash_attention_impl == "triton":
|
102 |
+
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
|
103 |
+
group_size_1=group_size,
|
104 |
+
group_size_2=window_size,
|
105 |
+
scale_base=scale_base)
|
106 |
+
modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
|
107 |
+
print("Using triton flash self_extend!!")
|
108 |
+
if (not modifed):
|
109 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
110 |
+
else:
|
111 |
+
raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")
|
112 |
+
|
113 |
+
|
114 |
+
else:
|
115 |
+
self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
|
116 |
+
group_size_1=group_size,
|
117 |
+
group_size_2=window_size,
|
118 |
+
scale_base=scale_base)
|
119 |
+
# after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
|
120 |
+
# print("loaded_model", loaded_model)
|
121 |
+
modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
|
122 |
+
if not modifed_2:
|
123 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
124 |
+
elif 'Mistral' in arch_name:
|
125 |
+
# Mistral shares the same architecture with Llama, so the implementation should be exchangable.
|
126 |
+
if enable_flash_attention:
|
127 |
+
self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
|
128 |
+
group_size_1=group_size,
|
129 |
+
group_size_2=window_size,
|
130 |
+
scale_base=scale_base)
|
131 |
+
modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
132 |
+
modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
|
133 |
+
if (not modifed_1) or (not modifed_2):
|
134 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
135 |
+
else:
|
136 |
+
self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
|
137 |
+
group_size_1=group_size,
|
138 |
+
group_size_2=window_size,
|
139 |
+
scale_base=scale_base)
|
140 |
+
modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
|
141 |
+
if not modifed_2:
|
142 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
143 |
+
elif 'Gemma' in arch_name:
|
144 |
+
if enable_flash_attention:
|
145 |
+
self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
|
146 |
+
group_size_1=group_size,
|
147 |
+
group_size_2=window_size,
|
148 |
+
scale_base=scale_base)
|
149 |
+
modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
150 |
+
modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
|
151 |
+
if (not modifed_1) or (not modifed_2):
|
152 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
153 |
+
else:
|
154 |
+
self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
|
155 |
+
group_size_1=group_size,
|
156 |
+
group_size_2=window_size,
|
157 |
+
scale_base=scale_base)
|
158 |
+
modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
|
159 |
+
if not modifed_2:
|
160 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
161 |
+
elif 'Qwen2' in arch_name:
|
162 |
+
if enable_flash_attention:
|
163 |
+
self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
|
164 |
+
group_size_1=group_size,
|
165 |
+
group_size_2=window_size,
|
166 |
+
scale_base=scale_base)
|
167 |
+
modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
168 |
+
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
|
169 |
+
if (not modifed_1) or (not modifed_2):
|
170 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
171 |
+
else:
|
172 |
+
self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
|
173 |
+
group_size_1=group_size,
|
174 |
+
group_size_2=window_size,
|
175 |
+
scale_base=scale_base)
|
176 |
+
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
|
177 |
+
if not modifed_2:
|
178 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
179 |
+
elif 'Phi' in arch_name:
|
180 |
+
if enable_flash_attention:
|
181 |
+
self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
|
182 |
+
group_size_1=group_size,
|
183 |
+
group_size_2=window_size,
|
184 |
+
scale_base=scale_base)
|
185 |
+
modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
|
186 |
+
modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
|
187 |
+
if (not modifed_1) or (not modifed_2):
|
188 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
189 |
+
else:
|
190 |
+
self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
|
191 |
+
group_size_1=group_size,
|
192 |
+
group_size_2=window_size,
|
193 |
+
scale_base=scale_base)
|
194 |
+
modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
|
195 |
+
if not modifed_2:
|
196 |
+
raise Exception(f"Failed to modify the attention method of {arch_name}")
|
197 |
+
else:
|
198 |
+
raise NotImplementedError
|
199 |
+
|
app.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
subprocess.run(
|
6 |
+
f"pip install flash-attn --no-build-isolation",
|
7 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
8 |
+
shell=True,
|
9 |
+
)
|
10 |
+
|
11 |
+
import os
|
12 |
+
from threading import Thread
|
13 |
+
from typing import Iterator
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
import spaces
|
17 |
+
import torch
|
18 |
+
import SelfExtend
|
19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
20 |
+
|
21 |
+
|
22 |
+
MAX_MAX_NEW_TOKENS = 4096
|
23 |
+
DEFAULT_MAX_NEW_TOKENS = 1536
|
24 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
|
25 |
+
|
26 |
+
DESCRIPTION = """\
|
27 |
+
# Playground with Ghost 8B Beta (p)
|
28 |
+
|
29 |
+
**Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, 8k and 128k, along with multilingual function tools support by default.
|
30 |
+
|
31 |
+
The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
|
32 |
+
|
33 |
+
📋 Note: current model version is "disl-0x5" (10 Jul 2024), context length 128k (123392 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
PLACEHOLDER = """
|
38 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
39 |
+
<h1 style="font-size: 26px; margin-bottom: 2px; opacity: 0.20;">👻 Ghost 8B Beta</h1>
|
40 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.10;">Ask and share whatever you want ~</p>
|
41 |
+
</div>
|
42 |
+
"""
|
43 |
+
|
44 |
+
LICENSE = """
|
45 |
+
<p/>
|
46 |
+
|
47 |
+
---
|
48 |
+
Ghost 8B Beta may give inaccurate information, including information about people, so please verify Ghost 8B Beta's answers. [Ghost 8B Beta](https://ghost-x.org/docs/models/ghost-8b-beta/) by [Ghost X](https://ghost-x.org).
|
49 |
+
"""
|
50 |
+
|
51 |
+
EXAMPLES = [
|
52 |
+
[
|
53 |
+
"What is the significance of the Higgs boson in the Standard Model of particle physics?"
|
54 |
+
],
|
55 |
+
[
|
56 |
+
"Qu'est-ce que l'effet fondateur et comment influence-t-il la diversité génétique d'une population?"
|
57 |
+
],
|
58 |
+
["Qual è il principio di Le Chatelier e come si applica agli equilibri chimici?"],
|
59 |
+
[
|
60 |
+
"¿Qué es una supernova y cuál es su importancia en la formación de elementos pesados en el universo?"
|
61 |
+
],
|
62 |
+
[
|
63 |
+
"Qual é a definição formal de uma integral de linha e como é utilizada em física?"
|
64 |
+
],
|
65 |
+
[
|
66 |
+
"Was versteht man unter dem Moho-Diskontinuität und welche Bedeutung hat sie für das Verständnis der Erdkruste?"
|
67 |
+
],
|
68 |
+
[
|
69 |
+
"Hiện tượng nhà kính là gì và nó ảnh hưởng như thế nào đến biến đổi khí hậu toàn cầu?"
|
70 |
+
],
|
71 |
+
[
|
72 |
+
"알고리즘의 시간 복잡도가 중요한 이유는 무엇이며, 시간 복잡도를 어떻게 분석하나요?"
|
73 |
+
],
|
74 |
+
["什么是CRISPR-Cas9基因编辑技术,它在现代生物学研究中的作用是什么?"],
|
75 |
+
[
|
76 |
+
"Create a Python function that takes a list of integers and returns the list sorted in ascending order without using the built-in sort or sorted functions."
|
77 |
+
],
|
78 |
+
[
|
79 |
+
"Écrivez une fonction en C++ qui trouve le plus long sous-tableau contigu avec une somme égale à zéro."
|
80 |
+
],
|
81 |
+
[
|
82 |
+
"Scrivi una funzione in Java che calcola il fattoriale di un numero utilizzando la ricorsione."
|
83 |
+
],
|
84 |
+
[
|
85 |
+
"Desarrolla una función en JavaScript que determine si una cadena de texto es un palíndromo, ignorando espacios y signos de puntuación."
|
86 |
+
],
|
87 |
+
["Implemente uma função em C# que verifique se uma matriz quadrada é simétrica."],
|
88 |
+
[
|
89 |
+
"Schreiben Sie eine Funktion in Swift, die eine gegebene Zeichenfolge in umgekehrter Reihenfolge zurückgibt, ohne integrierte Funktionen zu verwenden."
|
90 |
+
],
|
91 |
+
[
|
92 |
+
"Viết một hàm trong PHP để tìm tất cả các số nguyên tố trong một khoảng cho trước."
|
93 |
+
],
|
94 |
+
[
|
95 |
+
"파이썬을 사용하여 주어진 이진 트리가 이진 탐색 트리인지 확인하는 함수를 작성하십시오."
|
96 |
+
],
|
97 |
+
[
|
98 |
+
"用 Go 语言编写一个函数,计算给定字符串中每个字符出现的次数,并返回一个包含字符及其出现次数的映射。"
|
99 |
+
],
|
100 |
+
[
|
101 |
+
"Can you help me design a detailed project plan for developing a machine learning model for predicting stock prices?"
|
102 |
+
],
|
103 |
+
[
|
104 |
+
"Pouvez-vous m'aider à organiser un emploi du temps hebdomadaire pour maximiser la productivité de mon équipe de développement logiciel?"
|
105 |
+
],
|
106 |
+
[
|
107 |
+
"Puoi aiutarmi a creare un piano di sviluppo per un'applicazione mobile che gestisce le prenotazioni di ristoranti?"
|
108 |
+
],
|
109 |
+
[
|
110 |
+
"¿Podrías ayudarme a elaborar un plan detallado para la implementación de un sistema de gestión de contenido (CMS) en una empresa mediana?"
|
111 |
+
],
|
112 |
+
[
|
113 |
+
"Você pode me ajudar a planejar uma estratégia de desenvolvimento para um sistema de comércio eletrônico escalável?"
|
114 |
+
],
|
115 |
+
[
|
116 |
+
"Können Sie mir helfen, einen detaillierten Zeitplan für die Implementierung eines neuen ERP-Systems in unserem Unternehmen zu erstellen?"
|
117 |
+
],
|
118 |
+
[
|
119 |
+
"Bạn có thể giúp tôi xây dựng một kế hoạch phát triển chi tiết cho dự án xây dựng hệ thống quản lý chuỗi cung ứng không?"
|
120 |
+
],
|
121 |
+
[
|
122 |
+
"신경망 기반 이미지 인식 모델 개발을 위한 세부 프로젝트 계획을 세우는 데 도움을 줄 수 있나요?"
|
123 |
+
],
|
124 |
+
["你能帮我制定一个详细的开发计划,用于创建一个基于区块链的分布式账本系统吗?"],
|
125 |
+
[
|
126 |
+
"Prove that the sum of the squares of any two sides of a right triangle is equal to the square of the hypotenuse."
|
127 |
+
],
|
128 |
+
[
|
129 |
+
"Calculez la force gravitationnelle entre deux masses de 10 kg chacune séparées par une distance de 1 mètre."
|
130 |
+
],
|
131 |
+
[
|
132 |
+
"Determina la formula molecolare di un composto che contiene il 40% di carbonio, il 6.67% di idrogeno e il 53.33% di ossigeno in massa."
|
133 |
+
],
|
134 |
+
[
|
135 |
+
"Explica la teoría del ciclo económico de Schumpeter y cómo se aplica a la economía moderna."
|
136 |
+
],
|
137 |
+
[
|
138 |
+
"Calcule a energia potencial gravitacional de um objeto de 5 kg a uma altura de 10 metros acima do solo (g = 9,8 m/s²)."
|
139 |
+
],
|
140 |
+
[
|
141 |
+
"Beweisen Sie, dass jede Primzahl der Form 4k+1 als Summe zweier Quadrate geschrieben werden kann."
|
142 |
+
],
|
143 |
+
[
|
144 |
+
"Tính nồng độ mol của dung dịch H₂SO₄ khi hoà tan 98 gam H₂SO₄ vào nước để được 1 lít dung dịch."
|
145 |
+
],
|
146 |
+
["케인스 경제학의 핵심 개념과 그것이 현대 경제 정책에 미치는 영향을 설명하십시오."],
|
147 |
+
["计算一个质量为2 kg的物体在3米高处的重力势能(g = 9.8 m/s²)。"],
|
148 |
+
[
|
149 |
+
'Identify the author of a novel that features a dystopian society where "Big Brother" watches over its citizens and the protagonist works for the Ministry of Truth.'
|
150 |
+
],
|
151 |
+
[
|
152 |
+
"Quel est le seul mammifère capable de voler activement, souvent associé à la nuit et capable d'écholocalisation?"
|
153 |
+
],
|
154 |
+
[
|
155 |
+
"Qual è l'opera letteraria italiana che narra il viaggio immaginario di un poeta attraverso Inferno, Purgatorio e Paradiso, guidato da Virgilio e Beatrice?"
|
156 |
+
],
|
157 |
+
[
|
158 |
+
"¿Qué insecto es conocido por su organización social compleja, su capacidad para producir miel y su comunicación mediante la danza?"
|
159 |
+
],
|
160 |
+
[
|
161 |
+
"Qual é o fenômeno atmosférico que ocorre quando uma massa de ar quente se encontra com uma massa de ar frio, resultando em uma violenta tempestade giratória?"
|
162 |
+
],
|
163 |
+
[
|
164 |
+
"Welches literarische Werk beschreibt die Geschichte eines jungen Mädchens, das durch einen Kaninchenbau in eine fantastische Welt voller skurriler Charaktere fällt?"
|
165 |
+
],
|
166 |
+
[
|
167 |
+
"Động vật nào có thể tái sinh toàn bộ cơ thể từ một mảnh nhỏ của chính nó, thường sống dưới nước và có thể có nhiều xúc tu?"
|
168 |
+
],
|
169 |
+
[
|
170 |
+
"어떤 자연 현상은 태양빛이 대기 중의 물방울에 반사되고 굴절되어 발생하며, 하늘에 나타나는 여러 색깔의 아치 형태를 띠나요?"
|
171 |
+
],
|
172 |
+
["这部文学作品讲述了一位绅士和他的侍从的冒险故事,他们在"],
|
173 |
+
[
|
174 |
+
"Can you derive the Euler-Lagrange equation from the principle of stationary action in classical mechanics?"
|
175 |
+
],
|
176 |
+
[
|
177 |
+
"Expliquez la notion de « différence ontologique » chez Martin Heidegger et son importance pour la phénoménologie."
|
178 |
+
],
|
179 |
+
[
|
180 |
+
"Qual è il significato simbolico del colore blu nei dipinti di Giotto di Bondone durante il Rinascimento?"
|
181 |
+
],
|
182 |
+
[
|
183 |
+
"¿Cómo afecta el cambio de código a la estructura gramatical en comunidades bilingües de habla español-inglés?"
|
184 |
+
],
|
185 |
+
[
|
186 |
+
"Qual é o impacto da política monetária não convencional no controle da inflação durante uma crise econômica?"
|
187 |
+
],
|
188 |
+
[
|
189 |
+
"Erklären Sie den Unterschied zwischen deterministischen und nicht-deterministischen endlichen Automaten und ihre Anwendungsbereiche."
|
190 |
+
],
|
191 |
+
[
|
192 |
+
"Giải thích cơ chế của quá trình phiên mã ngược (reverse transcription) và tầm quan trọng của nó trong nghiên cứu HIV/AIDS."
|
193 |
+
],
|
194 |
+
["조선시대 성리학이 한국 사회와 문화에 미친 영향을 설명하세요."],
|
195 |
+
["如何解释量子纠缠现象,以及它在量子计算中的潜在应用?"],
|
196 |
+
[
|
197 |
+
"How can you design a daily schedule that maximizes productivity for a remote worker who has multiple meetings and project deadlines?"
|
198 |
+
],
|
199 |
+
[
|
200 |
+
"Quels sont les meilleures stratégies pour gérer les conflits au sein d'une équipe multiculturelle travaillant sur un projet commun?"
|
201 |
+
],
|
202 |
+
[
|
203 |
+
"Quali sono i migliori consigli per mantenere un equilibrio tra vita professionale e vita privata in un ambiente lavorativo stressante?"
|
204 |
+
],
|
205 |
+
[
|
206 |
+
"¿Cómo se puede elaborar un plan financiero personal efectivo que incluya ahorro para la jubilación, inversión y manejo de deudas?"
|
207 |
+
],
|
208 |
+
[
|
209 |
+
"Quais são as melhores práticas para implementar metodologias ágeis em uma equipe de desenvolvimento de software?"
|
210 |
+
],
|
211 |
+
[
|
212 |
+
"Welche Strategien können verwendet werden, um ein starkes berufliches Netzwerk aufzubauen und zu pflegen, insbesondere in der Tech-Branche?"
|
213 |
+
],
|
214 |
+
[
|
215 |
+
"Những bước nào cần thiết để xây dựng một lộ trình phát triển sự nghiệp bền vững trong lĩnh vực công nghệ thông tin?"
|
216 |
+
],
|
217 |
+
["프로젝트의 범위 변동을 효과적으로 관리하기 위한 최고의 방법은 무엇인가요?"],
|
218 |
+
["在快速变化的职场环境中,如何有效地实现工作与生活的平衡?"],
|
219 |
+
[
|
220 |
+
"Write an argumentative essay discussing the pros and cons of artificial intelligence in the workplace, including potential ethical concerns."
|
221 |
+
],
|
222 |
+
[
|
223 |
+
"Analysez les impacts sociaux et économiques de la digitalisation sur les petites entreprises en France."
|
224 |
+
],
|
225 |
+
[
|
226 |
+
"Scrivi un'email formale al direttore di una rivista per proporre un articolo sulla sostenibilità ambientale nelle città italiane."
|
227 |
+
],
|
228 |
+
[
|
229 |
+
"Elabora un informe detallado sobre los efectos del cambio climático en la biodiversidad de la región amazónica."
|
230 |
+
],
|
231 |
+
[
|
232 |
+
"Analise criticamente os principais pontos abordados no relatório anual do Banco Mundial sobre a pobreza global."
|
233 |
+
],
|
234 |
+
[
|
235 |
+
"Erstellen Sie eine technische Dokumentation für die Implementierung eines neuen Software-Features in einer bestehenden Anwendung."
|
236 |
+
],
|
237 |
+
[
|
238 |
+
"Viết một bài luận phân tích về tác động của cuộc cách mạng công nghiệp 4.0 đối với thị trường lao động Việt Nam."
|
239 |
+
],
|
240 |
+
[
|
241 |
+
"인공지능의 윤리적 문제에 대한 연구 논문을 작성하고, 다양한 사례를 통해 그 영향을 분석하세요."
|
242 |
+
],
|
243 |
+
["分析鲁迅的小说《阿Q正传》中反映的中国社会问题和作者的批判态度。"],
|
244 |
+
]
|
245 |
+
|
246 |
+
if not torch.cuda.is_available():
|
247 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
248 |
+
|
249 |
+
|
250 |
+
if torch.cuda.is_available():
|
251 |
+
model_id = "lamhieu/ghost-8b-beta-disl-0x5-8k"
|
252 |
+
model_tk = os.getenv("HF_TOKEN", None)
|
253 |
+
model = AutoModelForCausalLM.from_pretrained(
|
254 |
+
model_id,
|
255 |
+
device_map="auto",
|
256 |
+
torch_dtype=torch.bfloat16,
|
257 |
+
attn_implementation="flash_attention_2",
|
258 |
+
trust_remote_code=True,
|
259 |
+
token=model_tk,
|
260 |
+
)
|
261 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
262 |
+
model_id,
|
263 |
+
trust_remote_code=True,
|
264 |
+
token=model_tk,
|
265 |
+
)
|
266 |
+
SelfExtend.apply(
|
267 |
+
model,
|
268 |
+
group_size=16,
|
269 |
+
window_size=512,
|
270 |
+
enable_flash_attention=True,
|
271 |
+
flash_attention_impl="flash_attn",
|
272 |
+
)
|
273 |
+
model.generation_config.max_length = 123392
|
274 |
+
|
275 |
+
|
276 |
+
@spaces.GPU(duration=120)
|
277 |
+
def generate(
|
278 |
+
message: str,
|
279 |
+
chat_history: list[tuple[str, str]],
|
280 |
+
system_prompt: str,
|
281 |
+
max_new_tokens: int = 1536,
|
282 |
+
temperature: float = 0.4,
|
283 |
+
top_p: float = 0.95,
|
284 |
+
top_k: int = 50,
|
285 |
+
repetition_penalty: float = 1.0,
|
286 |
+
) -> Iterator[str]:
|
287 |
+
conversation = []
|
288 |
+
if system_prompt:
|
289 |
+
conversation.append({"role": "system", "content": system_prompt})
|
290 |
+
for user, assistant in chat_history:
|
291 |
+
conversation.extend(
|
292 |
+
[
|
293 |
+
{"role": "user", "content": user},
|
294 |
+
{"role": "assistant", "content": assistant},
|
295 |
+
]
|
296 |
+
)
|
297 |
+
conversation.append({"role": "user", "content": message})
|
298 |
+
|
299 |
+
input_ids = tokenizer.apply_chat_template(
|
300 |
+
conversation, add_generation_prompt=True, return_tensors="pt"
|
301 |
+
)
|
302 |
+
input_ids = input_ids.to(model.device)
|
303 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
304 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
305 |
+
gr.Warning(
|
306 |
+
f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
|
307 |
+
)
|
308 |
+
|
309 |
+
streamer = TextIteratorStreamer(
|
310 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
311 |
+
)
|
312 |
+
generate_kwargs = dict(
|
313 |
+
input_ids=input_ids,
|
314 |
+
streamer=streamer,
|
315 |
+
max_new_tokens=max_new_tokens,
|
316 |
+
do_sample=True,
|
317 |
+
top_p=top_p,
|
318 |
+
top_k=top_k,
|
319 |
+
temperature=temperature,
|
320 |
+
repetition_penalty=repetition_penalty,
|
321 |
+
)
|
322 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
323 |
+
t.start()
|
324 |
+
|
325 |
+
outputs = []
|
326 |
+
for text in streamer:
|
327 |
+
outputs.append(text)
|
328 |
+
yield "".join(outputs)
|
329 |
+
|
330 |
+
|
331 |
+
chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
|
332 |
+
|
333 |
+
chat_interface = gr.ChatInterface(
|
334 |
+
fn=generate,
|
335 |
+
chatbot=chatbot,
|
336 |
+
fill_height=True,
|
337 |
+
additional_inputs=[
|
338 |
+
gr.Textbox(label="System prompt", lines=6),
|
339 |
+
gr.Slider(
|
340 |
+
label="Max new tokens",
|
341 |
+
minimum=1,
|
342 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
343 |
+
step=1,
|
344 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
345 |
+
),
|
346 |
+
gr.Slider(
|
347 |
+
label="Temperature",
|
348 |
+
minimum=0.1,
|
349 |
+
maximum=2.0,
|
350 |
+
step=0.1,
|
351 |
+
value=0.4,
|
352 |
+
),
|
353 |
+
gr.Slider(
|
354 |
+
label="Top-p (nucleus sampling)",
|
355 |
+
minimum=0.05,
|
356 |
+
maximum=1.0,
|
357 |
+
step=0.05,
|
358 |
+
value=0.95,
|
359 |
+
),
|
360 |
+
gr.Slider(
|
361 |
+
label="Top-k",
|
362 |
+
minimum=1,
|
363 |
+
maximum=100,
|
364 |
+
step=1,
|
365 |
+
value=50,
|
366 |
+
),
|
367 |
+
gr.Slider(
|
368 |
+
label="Repetition penalty",
|
369 |
+
minimum=1.0,
|
370 |
+
maximum=2.0,
|
371 |
+
step=0.05,
|
372 |
+
value=1.0,
|
373 |
+
),
|
374 |
+
],
|
375 |
+
stop_btn="Stop",
|
376 |
+
cache_examples=False,
|
377 |
+
examples=EXAMPLES,
|
378 |
+
examples_per_page=9,
|
379 |
+
)
|
380 |
+
|
381 |
+
with gr.Blocks(fill_height=True, css="style.css") as demo:
|
382 |
+
gr.Markdown(DESCRIPTION)
|
383 |
+
chat_interface.render()
|
384 |
+
gr.Markdown(LICENSE)
|
385 |
+
|
386 |
+
if __name__ == "__main__":
|
387 |
+
demo.queue(max_size=20).launch(share=True)
|
388 |
+
# demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.30.1
|
2 |
+
bitsandbytes==0.43.1
|
3 |
+
gradio==4.37.2
|
4 |
+
scipy==1.13.0
|
5 |
+
sentencepiece==0.2.0
|
6 |
+
spaces==0.28.3
|
7 |
+
torch==2.0.0
|
8 |
+
transformers==4.41.0
|
self_extend_patch/Llama.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers.cache_utils import Cache
|
8 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
9 |
+
from .selfextend_flash_attn import self_extend_flash_forward
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
16 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
17 |
+
"""
|
18 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
19 |
+
if n_rep == 1:
|
20 |
+
return hidden_states
|
21 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
22 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
23 |
+
|
24 |
+
def rotate_half(x):
|
25 |
+
"""Rotates half the hidden dims of the input."""
|
26 |
+
x1 = x[..., : x.shape[-1] // 2]
|
27 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
28 |
+
return torch.cat((-x2, x1), dim=-1)
|
29 |
+
|
30 |
+
|
31 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
32 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
q (`torch.Tensor`): The query tensor.
|
36 |
+
k (`torch.Tensor`): The key tensor.
|
37 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
38 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
39 |
+
position_ids (`torch.Tensor`, *optional*):
|
40 |
+
Deprecated and unused.
|
41 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
42 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
43 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
44 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
45 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
46 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
47 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
48 |
+
Returns:
|
49 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
50 |
+
"""
|
51 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
52 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
53 |
+
q_embed = (q * cos) + (rotate_half(q) * sin) if not q is None else None
|
54 |
+
k_embed = (k * cos) + (rotate_half(k) * sin) if not k is None else None
|
55 |
+
return q_embed, k_embed
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
def self_extend_forward(
|
61 |
+
self,
|
62 |
+
hidden_states: torch.Tensor,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
65 |
+
past_key_value: Optional[Cache] = None,
|
66 |
+
output_attentions: bool = False,
|
67 |
+
use_cache: bool = False,
|
68 |
+
cache_position: Optional[torch.LongTensor] = None,
|
69 |
+
group_size_1: Optional[float] = 8,
|
70 |
+
group_size_2: Optional[float] = 1024,
|
71 |
+
scale_base: Optional[int] = -1,
|
72 |
+
**kwargs,
|
73 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
74 |
+
if "padding_mask" in kwargs:
|
75 |
+
warnings.warn(
|
76 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
77 |
+
)
|
78 |
+
|
79 |
+
bsz, q_len, _ = hidden_states.size()
|
80 |
+
|
81 |
+
|
82 |
+
if self.config.pretraining_tp > 1:
|
83 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
84 |
+
query_slices = self.q_proj.weight.split(
|
85 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
86 |
+
)
|
87 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
88 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
89 |
+
|
90 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
91 |
+
query_states = torch.cat(query_states, dim=-1)
|
92 |
+
|
93 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
94 |
+
key_states = torch.cat(key_states, dim=-1)
|
95 |
+
|
96 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
97 |
+
value_states = torch.cat(value_states, dim=-1)
|
98 |
+
|
99 |
+
else:
|
100 |
+
query_states = self.q_proj(hidden_states)
|
101 |
+
key_states = self.k_proj(hidden_states)
|
102 |
+
value_states = self.v_proj(hidden_states)
|
103 |
+
|
104 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
105 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
106 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
107 |
+
|
108 |
+
if scale_base > 0:
|
109 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
110 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
111 |
+
else:
|
112 |
+
scaled_query = query_states
|
113 |
+
|
114 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
115 |
+
if past_key_value is not None:
|
116 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
117 |
+
cache_kwargs = {"cache_position": cache_position}
|
118 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
119 |
+
kv_seq_len = key_states.shape[-2]
|
120 |
+
|
121 |
+
query_position = position_ids
|
122 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
127 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
128 |
+
|
129 |
+
|
130 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
131 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 // group_size_1
|
132 |
+
group_key_position = key_position // group_size_1
|
133 |
+
|
134 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
135 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
140 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
141 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
142 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
|
147 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
|
148 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
153 |
+
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
154 |
+
|
155 |
+
|
156 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
157 |
+
if cache_position is not None:
|
158 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
159 |
+
else:
|
160 |
+
causal_mask = attention_mask
|
161 |
+
group_attn_weights = group_attn_weights + causal_mask
|
162 |
+
neighbor_attn_weights = neighbor_attn_weights + causal_mask
|
163 |
+
|
164 |
+
if q_len == 1:
|
165 |
+
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
166 |
+
neighbor_attention_mask[:, -group_size_2:] = 1
|
167 |
+
elif q_len == kv_seq_len:
|
168 |
+
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
169 |
+
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
|
170 |
+
if q_len-group_size_2 > 0:
|
171 |
+
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
|
172 |
+
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
|
173 |
+
else:
|
174 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
175 |
+
|
176 |
+
neighbor_attention_mask = neighbor_attention_mask.bool()
|
177 |
+
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
|
178 |
+
|
179 |
+
# upcast attention to fp32
|
180 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
181 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
182 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
183 |
+
|
184 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
185 |
+
raise ValueError(
|
186 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
187 |
+
f" {attn_output.size()}"
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
192 |
+
|
193 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
194 |
+
|
195 |
+
if self.config.pretraining_tp > 1:
|
196 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
197 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
198 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
199 |
+
else:
|
200 |
+
attn_output = self.o_proj(attn_output)
|
201 |
+
|
202 |
+
if not output_attentions:
|
203 |
+
attn_weights = None
|
204 |
+
|
205 |
+
return attn_output, attn_weights, past_key_value
|
206 |
+
|
207 |
+
def flash_self_extend_forward(
|
208 |
+
self,
|
209 |
+
hidden_states: torch.Tensor,
|
210 |
+
attention_mask: Optional[torch.Tensor] = None,
|
211 |
+
position_ids: Optional[torch.LongTensor] = None,
|
212 |
+
past_key_value: Optional[Cache] = None,
|
213 |
+
output_attentions: bool = False,
|
214 |
+
use_cache: bool = False,
|
215 |
+
group_size_1: Optional[float] = 8,
|
216 |
+
group_size_2: Optional[float] = 1024,
|
217 |
+
scale_base: Optional[int] = -1,
|
218 |
+
cache_position: Optional[torch.LongTensor] = None,
|
219 |
+
**kwargs,
|
220 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
221 |
+
"""
|
222 |
+
Require updating tansformers to >= 4.38.2, flash_attn >= 2.5.6
|
223 |
+
a. Only support causal mask.
|
224 |
+
b. Don't support atttention_mask.
|
225 |
+
c. Never test it with batch size > 1.
|
226 |
+
d. Only support q_len = 1 or q_len = seq_len.
|
227 |
+
"""
|
228 |
+
if "padding_mask" in kwargs:
|
229 |
+
warnings.warn(
|
230 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
231 |
+
)
|
232 |
+
attention_mask = kwargs.pop("padding_mask")
|
233 |
+
|
234 |
+
bsz, q_len, _ = hidden_states.size()
|
235 |
+
|
236 |
+
query_states = self.q_proj(hidden_states)
|
237 |
+
key_states = self.k_proj(hidden_states)
|
238 |
+
value_states = self.v_proj(hidden_states)
|
239 |
+
|
240 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
241 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
242 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
243 |
+
|
244 |
+
if scale_base > 0:
|
245 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
246 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
247 |
+
else:
|
248 |
+
scaled_query = query_states
|
249 |
+
|
250 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
251 |
+
if past_key_value is not None:
|
252 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
253 |
+
cache_kwargs = {"cache_position": cache_position}
|
254 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
255 |
+
kv_seq_len = key_states.shape[-2]
|
256 |
+
|
257 |
+
query_position = position_ids
|
258 |
+
# only consider bsz=1 for now.
|
259 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len)
|
260 |
+
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
261 |
+
if q_len == 1:
|
262 |
+
# We implement the case q_len == 1 separately, by manipulating positions.
|
263 |
+
# for our flash implementation doesnot work for decoding stage at the releasing time.
|
264 |
+
|
265 |
+
neighbor_key_position = position_ids[:, -1] - key_position
|
266 |
+
_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
|
267 |
+
group_key_position = position_ids[:, -1]//group_size_1 - key_position//group_size_1 + (_re_group_size_2 - _re_group_size_2//group_size_1)
|
268 |
+
decode_key_position = torch.cat([group_key_position[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
|
269 |
+
|
270 |
+
decode_k_cos, decode_k_sin = self.rotary_emb(value_states, decode_key_position)#, seq_len=None)
|
271 |
+
#import pdb; pdb.set_trace()
|
272 |
+
#neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position_ids)
|
273 |
+
decode_query_states = scaled_query.transpose(1,2).contiguous() # position 0: cos 0 = 1, sin 0 = 0
|
274 |
+
_, decode_key_states = apply_rotary_pos_emb(None, key_states, decode_k_cos, -decode_k_sin, decode_key_position)
|
275 |
+
|
276 |
+
decode_key_states = repeat_kv(decode_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
277 |
+
decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
278 |
+
|
279 |
+
attn_output = flash_attn_func(decode_query_states,
|
280 |
+
decode_key_states,
|
281 |
+
decode_value_states,
|
282 |
+
attn_dropout,
|
283 |
+
softmax_scale=None,
|
284 |
+
causal=True)
|
285 |
+
elif q_len == kv_seq_len:
|
286 |
+
# set correct position_ids & apply RoPE.
|
287 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
288 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
289 |
+
|
290 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
291 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
|
292 |
+
group_key_position = key_position // group_size_1
|
293 |
+
|
294 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
295 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
296 |
+
|
297 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
298 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
299 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
300 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
301 |
+
|
302 |
+
|
303 |
+
neighbor_query_states = neighbor_query_states.transpose(1, 2).contiguous()
|
304 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
305 |
+
group_query_states = group_query_states.transpose(1, 2).contiguous()
|
306 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
307 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
|
308 |
+
|
309 |
+
attn_output = self_extend_flash_forward(self,
|
310 |
+
query_position,
|
311 |
+
group_size_2,
|
312 |
+
neighbor_query_states,
|
313 |
+
neighbor_key_states,
|
314 |
+
group_query_states,
|
315 |
+
group_key_states,
|
316 |
+
value_states,
|
317 |
+
attention_mask,
|
318 |
+
bsz,
|
319 |
+
q_len,
|
320 |
+
kv_seq_len,
|
321 |
+
attn_dropout,
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
325 |
+
|
326 |
+
attn_output = attn_output.contiguous()
|
327 |
+
attn_output = attn_output.view(bsz, q_len, -1).contiguous()
|
328 |
+
attn_output = self.o_proj(attn_output)
|
329 |
+
|
330 |
+
if not output_attentions:
|
331 |
+
attn_weights = None
|
332 |
+
return attn_output, attn_weights, past_key_value
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
def lm_infinite_forward(
|
337 |
+
self,
|
338 |
+
hidden_states: torch.Tensor,
|
339 |
+
attention_mask: Optional[torch.Tensor] = None,
|
340 |
+
position_ids: Optional[torch.LongTensor] = None,
|
341 |
+
past_key_value: Optional[Cache] = None,
|
342 |
+
output_attentions: bool = False,
|
343 |
+
use_cache: bool = False,
|
344 |
+
cache_position: Optional[torch.LongTensor] = None,
|
345 |
+
group_size_1: Optional[float] = 8,
|
346 |
+
group_size_2: Optional[float] = 1024,
|
347 |
+
initial_num: Optional[int] = 1,
|
348 |
+
scale_base: Optional[int] = -1,
|
349 |
+
**kwargs,
|
350 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
351 |
+
if "padding_mask" in kwargs:
|
352 |
+
warnings.warn(
|
353 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
354 |
+
)
|
355 |
+
|
356 |
+
bsz, q_len, _ = hidden_states.size()
|
357 |
+
|
358 |
+
|
359 |
+
if self.config.pretraining_tp > 1:
|
360 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
361 |
+
query_slices = self.q_proj.weight.split(
|
362 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
363 |
+
)
|
364 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
365 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
366 |
+
|
367 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
368 |
+
query_states = torch.cat(query_states, dim=-1)
|
369 |
+
|
370 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
371 |
+
key_states = torch.cat(key_states, dim=-1)
|
372 |
+
|
373 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
374 |
+
value_states = torch.cat(value_states, dim=-1)
|
375 |
+
|
376 |
+
else:
|
377 |
+
query_states = self.q_proj(hidden_states)
|
378 |
+
key_states = self.k_proj(hidden_states)
|
379 |
+
value_states = self.v_proj(hidden_states)
|
380 |
+
|
381 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
382 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
383 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
384 |
+
|
385 |
+
if scale_base > 0:
|
386 |
+
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
|
387 |
+
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
|
388 |
+
else:
|
389 |
+
scaled_query = query_states
|
390 |
+
|
391 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
392 |
+
if past_key_value is not None:
|
393 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
394 |
+
cache_kwargs = {"cache_position": cache_position}
|
395 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
396 |
+
kv_seq_len = key_states.shape[-2]
|
397 |
+
|
398 |
+
query_position = position_ids
|
399 |
+
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only consider bsz=1 for now.
|
400 |
+
|
401 |
+
|
402 |
+
|
403 |
+
neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
|
404 |
+
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)
|
405 |
+
|
406 |
+
|
407 |
+
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
|
408 |
+
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
|
409 |
+
group_key_position = key_position // group_size_1
|
410 |
+
|
411 |
+
group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
|
412 |
+
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)
|
413 |
+
|
414 |
+
|
415 |
+
|
416 |
+
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
|
417 |
+
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
|
418 |
+
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
|
419 |
+
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
|
424 |
+
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
|
425 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
430 |
+
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
431 |
+
|
432 |
+
|
433 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
434 |
+
if cache_position is not None:
|
435 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
436 |
+
else:
|
437 |
+
causal_mask = attention_mask
|
438 |
+
group_attn_weights = group_attn_weights + causal_mask
|
439 |
+
neighbor_attn_weights = neighbor_attn_weights + causal_mask
|
440 |
+
|
441 |
+
if q_len == 1:
|
442 |
+
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
443 |
+
neighbor_attention_mask[:, -group_size_2:] = 1
|
444 |
+
elif q_len == kv_seq_len:
|
445 |
+
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
|
446 |
+
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
|
447 |
+
if q_len-group_size_2 > 0:
|
448 |
+
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
|
449 |
+
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
|
450 |
+
else:
|
451 |
+
raise ValueError("q_len should be 1 or seq_len.")
|
452 |
+
|
453 |
+
neighbor_attention_mask = neighbor_attention_mask.bool()
|
454 |
+
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
|
455 |
+
|
456 |
+
# upcast attention to fp32
|
457 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
458 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
459 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
460 |
+
|
461 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
462 |
+
raise ValueError(
|
463 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
464 |
+
f" {attn_output.size()}"
|
465 |
+
)
|
466 |
+
|
467 |
+
|
468 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
469 |
+
|
470 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
471 |
+
|
472 |
+
if self.config.pretraining_tp > 1:
|
473 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
474 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
475 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
476 |
+
else:
|
477 |
+
attn_output = self.o_proj(attn_output)
|
478 |
+
|
479 |
+
if not output_attentions:
|
480 |
+
attn_weights = None
|
481 |
+
|
482 |
+
return attn_output, attn_weights, past_key_value
|
self_extend_patch/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import Llama
|
self_extend_patch/selfextend_flash_attn.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# must replace orginal flash forward method with the following one first, to enbale the window feature.
|
5 |
+
def flash_attention2_forward_with_window_size(
|
6 |
+
self,
|
7 |
+
query_states,
|
8 |
+
key_states,
|
9 |
+
value_states,
|
10 |
+
attention_mask,
|
11 |
+
query_length,
|
12 |
+
dropout=0.0,
|
13 |
+
softmax_scale=None,
|
14 |
+
window_size=[-1, -1],
|
15 |
+
return_attn_probs=False,
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
19 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
query_states (`torch.Tensor`):
|
23 |
+
Input query states to be passed to Flash Attention API
|
24 |
+
key_states (`torch.Tensor`):
|
25 |
+
Input key states to be passed to Flash Attention API
|
26 |
+
value_states (`torch.Tensor`):
|
27 |
+
Input value states to be passed to Flash Attention API
|
28 |
+
attention_mask (`torch.Tensor`):
|
29 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
30 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
31 |
+
dropout (`int`, *optional*):
|
32 |
+
Attention dropout
|
33 |
+
softmax_scale (`float`, *optional*):
|
34 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
35 |
+
window_size ([Int, Int])
|
36 |
+
The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used.
|
37 |
+
return_attn_probs (`bool`, *optional*):
|
38 |
+
Whether to return the attention softmax logssumexp and probabilities. Default to False.
|
39 |
+
"""
|
40 |
+
if not self._flash_attn_uses_top_left_mask:
|
41 |
+
causal = self.is_causal
|
42 |
+
else:
|
43 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
44 |
+
causal = self.is_causal and query_length != 1
|
45 |
+
|
46 |
+
# Contains at least one padding token in the sequence
|
47 |
+
if attention_mask is not None:
|
48 |
+
batch_size = query_states.shape[0]
|
49 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
50 |
+
query_states, key_states, value_states, attention_mask, query_length
|
51 |
+
)
|
52 |
+
|
53 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
54 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
55 |
+
attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func(
|
56 |
+
query_states,
|
57 |
+
key_states,
|
58 |
+
value_states,
|
59 |
+
cu_seqlens_q=cu_seqlens_q,
|
60 |
+
cu_seqlens_k=cu_seqlens_k,
|
61 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
62 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
63 |
+
dropout_p=dropout,
|
64 |
+
softmax_scale=softmax_scale,
|
65 |
+
causal=causal,
|
66 |
+
window_size=window_size,
|
67 |
+
return_attn_probs=True,
|
68 |
+
)
|
69 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
70 |
+
else:
|
71 |
+
attn_output, softmax_lse, S_dmask = flash_attn_func(
|
72 |
+
query_states,
|
73 |
+
key_states,
|
74 |
+
value_states,
|
75 |
+
dropout,
|
76 |
+
softmax_scale=softmax_scale,
|
77 |
+
causal=causal,
|
78 |
+
window_size=window_size,
|
79 |
+
return_attn_probs=True,
|
80 |
+
)
|
81 |
+
|
82 |
+
if return_attn_probs:
|
83 |
+
return attn_output, softmax_lse, S_dmask
|
84 |
+
else:
|
85 |
+
return attn_output
|
86 |
+
|
87 |
+
def self_extend_flash_forward(
|
88 |
+
model_self,
|
89 |
+
query_position,
|
90 |
+
group_size_2,
|
91 |
+
neighbor_query_states,
|
92 |
+
neighbor_key_states,
|
93 |
+
group_query_states,
|
94 |
+
group_key_states,
|
95 |
+
value_states,
|
96 |
+
attention_mask,
|
97 |
+
bsz,
|
98 |
+
q_len,
|
99 |
+
kv_seq_len,
|
100 |
+
attn_dropout,
|
101 |
+
):
|
102 |
+
|
103 |
+
if query_position.max() >= group_size_2:
|
104 |
+
neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward(
|
105 |
+
neighbor_query_states,
|
106 |
+
neighbor_key_states,
|
107 |
+
value_states,
|
108 |
+
attention_mask,
|
109 |
+
q_len,
|
110 |
+
dropout=attn_dropout,
|
111 |
+
window_size=[group_size_2 - 1, 0],
|
112 |
+
# right dim here does not matter and can be -1, or > 0 due to causal mask
|
113 |
+
return_attn_probs=True,
|
114 |
+
)
|
115 |
+
|
116 |
+
group_attention_len = (
|
117 |
+
kv_seq_len - group_size_2
|
118 |
+
) # here we should use kv_seq_len rather than max_kv_len since we have paddings in qkv and attention_mask
|
119 |
+
|
120 |
+
group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None
|
121 |
+
group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward(
|
122 |
+
group_query_states[:, -group_attention_len:, :, :],
|
123 |
+
group_key_states[:, :group_attention_len, :, :],
|
124 |
+
value_states[:, :group_attention_len, :, :],
|
125 |
+
group_attention_mask,
|
126 |
+
group_query_states[:, -group_attention_len:, :, :].shape[1],
|
127 |
+
dropout=attn_dropout,
|
128 |
+
window_size=[-1, -1],
|
129 |
+
return_attn_probs=True,
|
130 |
+
) # note that kv and q's indexing are different! also query size could be different from kv length and very small during generation compared to prefilling
|
131 |
+
|
132 |
+
|
133 |
+
# normalize lse first
|
134 |
+
neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) # [batch_size, 1]
|
135 |
+
group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) # [batch_size, 1]
|
136 |
+
|
137 |
+
# convert align left to align right and convert exp(0) to 0
|
138 |
+
neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded)
|
139 |
+
group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded)
|
140 |
+
for idx in range(bsz):
|
141 |
+
if neighbor_seq_length[idx] > 0:
|
142 |
+
neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[
|
143 |
+
idx, :, : neighbor_seq_length[idx]
|
144 |
+
]
|
145 |
+
if group_seq_length[idx] > 0:
|
146 |
+
group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[
|
147 |
+
idx, :, : group_seq_length[idx]
|
148 |
+
]
|
149 |
+
|
150 |
+
# attn_output size is [batch_size, max_seq_len (not the true one), query_length, dim]
|
151 |
+
true_neighbor_seq_max_length = neighbor_softmax_lse.shape[
|
152 |
+
-1
|
153 |
+
] # it could be smaller than query_length due to the attention_mask
|
154 |
+
true_group_seq_max_length = group_softmax_lse.shape[
|
155 |
+
-1
|
156 |
+
] # it could be smaller than group_query_layer[:, -group_attention_len:, :, :].shape[1] due to the attention_mask[:, :group_attention_len]
|
157 |
+
|
158 |
+
neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze(
|
159 |
+
-1
|
160 |
+
) # [batch_size, true_neighbor_seq_max_length, self.num_heads, 1]
|
161 |
+
group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze(
|
162 |
+
-1
|
163 |
+
) # [batch_size, true_group_seq_max_length, self.num_heads, 1]
|
164 |
+
|
165 |
+
lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :]
|
166 |
+
#if torch.isinf(neighbor_softmax_lse).any() or torch.isnan(neighbor_softmax_lse).any():
|
167 |
+
# import pdb; pdb.set_trace()
|
168 |
+
|
169 |
+
neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap))
|
170 |
+
neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1.
|
171 |
+
group_softmax_lse = 1 / (1 + torch.exp(-lse_gap))
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = (
|
176 |
+
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse
|
177 |
+
)
|
178 |
+
group_attn_output[:, -true_group_seq_max_length:, ...] = (
|
179 |
+
group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse
|
180 |
+
)
|
181 |
+
attn_output = torch.empty_like(neighbor_attn_output).copy_(
|
182 |
+
neighbor_attn_output
|
183 |
+
) # might be slightly faster than clone
|
184 |
+
#attn_output[:, group_size_2:, ...] += group_attn_output
|
185 |
+
attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output
|
186 |
+
attn_output = torch.nan_to_num(attn_output, nan=0)
|
187 |
+
|
188 |
+
else:
|
189 |
+
attn_output = model_self._flash_attention_forward(
|
190 |
+
neighbor_query_states,
|
191 |
+
neighbor_key_states,
|
192 |
+
value_states,
|
193 |
+
attention_mask,
|
194 |
+
q_len,
|
195 |
+
dropout=attn_dropout,
|
196 |
+
window_size=[-1, -1],
|
197 |
+
)
|
198 |
+
|
199 |
+
return attn_output
|
self_extend_patch/selfextend_flash_attn_triton.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
|
8 |
+
def self_extend_flash_forward_triton(
|
9 |
+
model_self,
|
10 |
+
query_position,
|
11 |
+
group_size_2,
|
12 |
+
neighbor_query_states,
|
13 |
+
neighbor_key_states,
|
14 |
+
group_query_states,
|
15 |
+
group_key_states,
|
16 |
+
value_states,
|
17 |
+
attention_mask,
|
18 |
+
bsz,
|
19 |
+
q_len,
|
20 |
+
kv_seq_len,
|
21 |
+
attn_dropout,
|
22 |
+
):
|
23 |
+
|
24 |
+
o = _self_extend_flash_forward_triton(q=neighbor_query_states,
|
25 |
+
k=neighbor_key_states,
|
26 |
+
q1=group_query_states,
|
27 |
+
k1=group_key_states,
|
28 |
+
v=value_states,
|
29 |
+
causal=(q_len == kv_seq_len),
|
30 |
+
sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]),
|
31 |
+
window=group_size_2)
|
32 |
+
o = o.transpose(1, 2).contiguous()
|
33 |
+
# print("o", o.shape)
|
34 |
+
return o
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window):
|
42 |
+
# shape constraints
|
43 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
44 |
+
assert Lq == Lk and Lk == Lv
|
45 |
+
assert Lk in {16, 32, 64, 128}
|
46 |
+
|
47 |
+
device = torch.cuda.device_of(q)
|
48 |
+
with torch.cuda.device(device):
|
49 |
+
o = torch.empty_like(q)
|
50 |
+
BLOCK_M = 128
|
51 |
+
BLOCK_N = 32
|
52 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
|
53 |
+
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
54 |
+
_fwd_kernel[grid](
|
55 |
+
q,
|
56 |
+
k,
|
57 |
+
q1,
|
58 |
+
k1,
|
59 |
+
v,
|
60 |
+
sm_scale,
|
61 |
+
L,
|
62 |
+
o,
|
63 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
64 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
65 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
66 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
67 |
+
q.shape[0],
|
68 |
+
q.shape[1],
|
69 |
+
q.shape[2],
|
70 |
+
k.shape[2],
|
71 |
+
BLOCK_M=BLOCK_M,
|
72 |
+
BLOCK_N=BLOCK_N,
|
73 |
+
BLOCK_DMODEL=Lk,
|
74 |
+
IS_CAUSAL=causal,
|
75 |
+
WINDOW=window,
|
76 |
+
num_warps=8,
|
77 |
+
num_stages=2)
|
78 |
+
|
79 |
+
return o
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
@triton.heuristics(
|
85 |
+
{
|
86 |
+
"EVEN_M": lambda args: args["Q_CTX"] % args["BLOCK_M"] == 0,
|
87 |
+
"EVEN_N": lambda args: args["KV_CTX"] % args["BLOCK_N"] == 0,
|
88 |
+
}
|
89 |
+
)
|
90 |
+
@triton.jit
|
91 |
+
def _fwd_kernel(
|
92 |
+
Q,
|
93 |
+
K,
|
94 |
+
Q1,
|
95 |
+
K1,
|
96 |
+
V,
|
97 |
+
sm_scale,
|
98 |
+
L,
|
99 |
+
Out,
|
100 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
101 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
102 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
103 |
+
stride_oz, stride_oh, stride_om, stride_on,
|
104 |
+
Z,
|
105 |
+
H,
|
106 |
+
Q_CTX,
|
107 |
+
KV_CTX,
|
108 |
+
BLOCK_M: tl.constexpr,
|
109 |
+
BLOCK_DMODEL: tl.constexpr,
|
110 |
+
BLOCK_N: tl.constexpr,
|
111 |
+
IS_CAUSAL: tl.constexpr,
|
112 |
+
WINDOW: tl.constexpr,
|
113 |
+
EVEN_M: tl.constexpr,
|
114 |
+
EVEN_N: tl.constexpr
|
115 |
+
|
116 |
+
):
|
117 |
+
start_m = tl.program_id(0)
|
118 |
+
off_hz = tl.program_id(1)
|
119 |
+
# qvk_offset = off_hz * stride_qh
|
120 |
+
q_offset = off_hz * stride_qh
|
121 |
+
vk_offset = off_hz * stride_kh
|
122 |
+
# vk_offset = q_offset
|
123 |
+
|
124 |
+
Q_block_ptr = tl.make_block_ptr(
|
125 |
+
base=Q + q_offset,
|
126 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
127 |
+
strides=(stride_qm, stride_qk),
|
128 |
+
offsets=(start_m * BLOCK_M, 0),
|
129 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
130 |
+
order=(1, 0)
|
131 |
+
)
|
132 |
+
K_block_ptr = tl.make_block_ptr(
|
133 |
+
base=K + vk_offset,
|
134 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
135 |
+
strides=(stride_kn, stride_kk),
|
136 |
+
offsets=(0, 0),
|
137 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
138 |
+
order=(1, 0)
|
139 |
+
)
|
140 |
+
Q1_block_ptr = tl.make_block_ptr(
|
141 |
+
base=Q1 + q_offset,
|
142 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
143 |
+
strides=(stride_qm, stride_qk),
|
144 |
+
offsets=(start_m * BLOCK_M, 0),
|
145 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
146 |
+
order=(1, 0)
|
147 |
+
)
|
148 |
+
K1_block_ptr = tl.make_block_ptr(
|
149 |
+
base=K1 + vk_offset,
|
150 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
151 |
+
strides=(stride_kn, stride_kk),
|
152 |
+
offsets=(0, 0),
|
153 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
154 |
+
order=(1, 0)
|
155 |
+
)
|
156 |
+
V_block_ptr = tl.make_block_ptr(
|
157 |
+
base=V + vk_offset,
|
158 |
+
shape=(KV_CTX, BLOCK_DMODEL),
|
159 |
+
strides=(stride_vn, stride_vk),
|
160 |
+
offsets=(0, 0),
|
161 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
162 |
+
order=(1, 0)
|
163 |
+
)
|
164 |
+
|
165 |
+
# initialize offsets
|
166 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
167 |
+
offs_n = tl.arange(0, BLOCK_N)
|
168 |
+
|
169 |
+
# initialize pointer to m and l
|
170 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
171 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
172 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
173 |
+
|
174 |
+
# scale sm_scale by log_2(e) and use
|
175 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
176 |
+
# don't work as expected with `exp` in the loop
|
177 |
+
qk_scale = sm_scale * 1.4426950408889634
|
178 |
+
|
179 |
+
# load q: it will stay in SRAM throughout
|
180 |
+
if EVEN_M:
|
181 |
+
q = tl.load(Q_block_ptr)
|
182 |
+
q1 = tl.load(Q1_block_ptr)
|
183 |
+
else:
|
184 |
+
q = tl.load(Q_block_ptr, boundary_check=(1,0))
|
185 |
+
q1 = tl.load(Q1_block_ptr, boundary_check=(1,0))
|
186 |
+
|
187 |
+
q = (q * qk_scale).to(tl.bfloat16)
|
188 |
+
q1 = (q1 * qk_scale).to(tl.bfloat16)
|
189 |
+
|
190 |
+
|
191 |
+
# Dot I trick: it converts q1, q2 into mma layout and saves shared memory
|
192 |
+
# better way to generate a eye matrix. avoid casting from bool
|
193 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
194 |
+
I = tl.where(offs_k[:, None] == offs_k,
|
195 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16),
|
196 |
+
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16))
|
197 |
+
q = tl.dot(q, I).to(tl.bfloat16)
|
198 |
+
q1 = tl.dot(q1, I).to(tl.bfloat16)
|
199 |
+
|
200 |
+
|
201 |
+
# loop over k, v and update accumulator
|
202 |
+
lo = 0
|
203 |
+
if IS_CAUSAL:
|
204 |
+
hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M)
|
205 |
+
else:
|
206 |
+
hi = KV_CTX
|
207 |
+
|
208 |
+
for start_n in range(lo, hi, BLOCK_N):
|
209 |
+
# -- load k, v --
|
210 |
+
if EVEN_N:
|
211 |
+
k = tl.load(K_block_ptr)
|
212 |
+
k1 = tl.load(K1_block_ptr)
|
213 |
+
v = tl.load(V_block_ptr)
|
214 |
+
else:
|
215 |
+
k = tl.load(K_block_ptr, boundary_check=(1,0))
|
216 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,0))
|
217 |
+
v = tl.load(V_block_ptr, boundary_check=(1,0))
|
218 |
+
|
219 |
+
# -- compute qk ---
|
220 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
221 |
+
|
222 |
+
# Window masking
|
223 |
+
mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW)
|
224 |
+
qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k)))
|
225 |
+
|
226 |
+
# if not EVEN_N:
|
227 |
+
# mask = (start_n + offs_n) < KV_CTX
|
228 |
+
# qk = tl.where(mask, qk, float("-inf"))
|
229 |
+
|
230 |
+
if IS_CAUSAL:
|
231 |
+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
232 |
+
qk = tl.where(mask, qk, float("-inf"))
|
233 |
+
# qk += tl.dot(q, k)
|
234 |
+
|
235 |
+
# -- compute scaling constant ---
|
236 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
237 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
238 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
239 |
+
|
240 |
+
# -- scale and update acc --
|
241 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
242 |
+
acc *= acc_scale[:, None]
|
243 |
+
acc += tl.dot(p.to(tl.bfloat16), v)
|
244 |
+
|
245 |
+
# -- update m_i and l_i --
|
246 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
247 |
+
m_i = m_i_new
|
248 |
+
|
249 |
+
# update pointers
|
250 |
+
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0))
|
251 |
+
K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0))
|
252 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
253 |
+
|
254 |
+
|
255 |
+
# write back l and m
|
256 |
+
acc = acc * (1.0 / l_i[:, None])
|
257 |
+
l_ptrs = L + off_hz * Q_CTX + offs_m
|
258 |
+
|
259 |
+
mask_m = offs_m < Q_CTX
|
260 |
+
l_i = m_i + tl.math.log2(l_i)
|
261 |
+
if EVEN_M:
|
262 |
+
tl.store(l_ptrs, l_i)
|
263 |
+
else:
|
264 |
+
tl.store(l_ptrs, l_i, mask=mask_m)
|
265 |
+
|
266 |
+
# write back O
|
267 |
+
O_block_ptr = tl.make_block_ptr(
|
268 |
+
base=Out + q_offset,
|
269 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
270 |
+
strides=(stride_om, stride_on),
|
271 |
+
offsets=(start_m * BLOCK_M, 0),
|
272 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
273 |
+
order=(1, 0)
|
274 |
+
)
|
275 |
+
if EVEN_M:
|
276 |
+
tl.store(O_block_ptr, acc.to(tl.bfloat16))
|
277 |
+
else:
|
278 |
+
tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0))
|
self_extend_patch/triton_selfextend_flash_attn.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
|
10 |
+
# the code below and commenting out the equivalent parameters is convenient for
|
11 |
+
# re-tuning.
|
12 |
+
#@triton.autotune(
|
13 |
+
# configs=[
|
14 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
|
15 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
16 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
|
17 |
+
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
18 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
19 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
|
20 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
|
21 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
|
22 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
23 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
|
24 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
|
25 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
|
26 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
|
27 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
|
28 |
+
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
|
29 |
+
# ],
|
30 |
+
# key=['N_CTX'],
|
31 |
+
#)
|
32 |
+
@triton.jit
|
33 |
+
def _attn_fwd_prefill(Q1, K1, Q2, K2, V, sm_scale, M, Out, #
|
34 |
+
stride_qz, stride_qh, stride_qm, stride_qk, #
|
35 |
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
36 |
+
stride_vz, stride_vh, stride_vk, stride_vn, #
|
37 |
+
stride_oz, stride_oh, stride_om, stride_on, #
|
38 |
+
Z, H, #
|
39 |
+
Q_CTX: tl.constexpr, #
|
40 |
+
N_CTX: tl.constexpr, #
|
41 |
+
WINDOW: tl.constexpr, #
|
42 |
+
BLOCK_M: tl.constexpr, #
|
43 |
+
BLOCK_DMODEL: tl.constexpr, #
|
44 |
+
BLOCK_N: tl.constexpr, #
|
45 |
+
):
|
46 |
+
start_m = tl.program_id(0)
|
47 |
+
off_hz = tl.program_id(1)
|
48 |
+
off_z = off_hz // H
|
49 |
+
off_h = off_hz % H
|
50 |
+
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
|
51 |
+
|
52 |
+
# block pointers
|
53 |
+
Q1_block_ptr = tl.make_block_ptr(
|
54 |
+
base=Q1 + qvk_offset,
|
55 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
56 |
+
strides=(stride_qm, stride_qk),
|
57 |
+
offsets=(start_m * BLOCK_M, 0),
|
58 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
59 |
+
order=(1, 0),
|
60 |
+
)
|
61 |
+
Q2_block_ptr = tl.make_block_ptr(
|
62 |
+
base=Q2 + qvk_offset,
|
63 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
64 |
+
strides=(stride_qm, stride_qk),
|
65 |
+
offsets=(start_m * BLOCK_M, 0),
|
66 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
67 |
+
order=(1, 0),
|
68 |
+
)
|
69 |
+
V_block_ptr = tl.make_block_ptr(
|
70 |
+
base=V + qvk_offset,
|
71 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
72 |
+
strides=(stride_vk, stride_vn),
|
73 |
+
offsets=(0, 0),
|
74 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
75 |
+
order=(1, 0),
|
76 |
+
)
|
77 |
+
K1_block_ptr = tl.make_block_ptr(
|
78 |
+
base=K1 + qvk_offset,
|
79 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
80 |
+
strides=(stride_kk, stride_kn),
|
81 |
+
offsets=(0, 0),
|
82 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
83 |
+
order=(0, 1),
|
84 |
+
)
|
85 |
+
K2_block_ptr = tl.make_block_ptr(
|
86 |
+
base=K2 + qvk_offset,
|
87 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
88 |
+
strides=(stride_kk, stride_kn),
|
89 |
+
offsets=(0, 0),
|
90 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
91 |
+
order=(0, 1),
|
92 |
+
)
|
93 |
+
O_block_ptr = tl.make_block_ptr(
|
94 |
+
base=Out + qvk_offset,
|
95 |
+
shape=(Q_CTX, BLOCK_DMODEL),
|
96 |
+
strides=(stride_om, stride_on),
|
97 |
+
offsets=(start_m * BLOCK_M, 0),
|
98 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
99 |
+
order=(1, 0),
|
100 |
+
)
|
101 |
+
# initialize offsets
|
102 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
103 |
+
offs_n = tl.arange(0, BLOCK_N)
|
104 |
+
# initialize pointer to m and l
|
105 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
106 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
107 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
108 |
+
# load scales
|
109 |
+
qk_scale = sm_scale
|
110 |
+
qk_scale *= 1.442695040888963#1.44269504 # 1/log(2)
|
111 |
+
# load q: it will stay in SRAM throughout
|
112 |
+
#q = tl.load(Q_block_ptr)
|
113 |
+
if start_m * BLOCK_M + BLOCK_M > Q_CTX:
|
114 |
+
q1 = tl.load(Q1_block_ptr, boundary_check=(0,), padding_option='zero')
|
115 |
+
q2 = tl.load(Q2_block_ptr, boundary_check=(0,), padding_option='zero')
|
116 |
+
else:
|
117 |
+
q1 = tl.load(Q1_block_ptr)
|
118 |
+
q2 = tl.load(Q2_block_ptr)
|
119 |
+
#q1 = (q1 * qk_scale).to(tl.float16)
|
120 |
+
#q2 = (q2 * qk_scale).to(tl.float16)
|
121 |
+
|
122 |
+
lo = 0
|
123 |
+
hi = (start_m + 1) * BLOCK_M
|
124 |
+
# loop over k, v and update accumulator
|
125 |
+
for start_n in range(lo, hi, BLOCK_N):
|
126 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
127 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) #?
|
128 |
+
#qk = qk.to(tl.float16)
|
129 |
+
# if use condition, qk has to be float32, then convert to float16...
|
130 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
131 |
+
if start_n + BLOCK_N - 1 > start_m * BLOCK_M - 1:
|
132 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, -1.0e6)#float("-inf"))
|
133 |
+
|
134 |
+
#qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
135 |
+
# -- compute qk ----
|
136 |
+
#k = tl.load(K_block_ptr)
|
137 |
+
# case 1: only need group attention: q2, k2
|
138 |
+
if BLOCK_N + start_n <= (start_m * BLOCK_M - WINDOW + 1):
|
139 |
+
if BLOCK_N + start_n >= N_CTX:
|
140 |
+
k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
|
141 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
142 |
+
else:
|
143 |
+
k2 = tl.load(K2_block_ptr)
|
144 |
+
v = tl.load(V_block_ptr)
|
145 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
146 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
147 |
+
qk += tl.dot(q2, k2)#, out_dtype=tl.float16)
|
148 |
+
else:
|
149 |
+
#case 2: only need neighbor attention: q1, k1
|
150 |
+
if start_n >= (start_m+1) * BLOCK_M - WINDOW:
|
151 |
+
if BLOCK_N + start_n >= N_CTX:
|
152 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
|
153 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
154 |
+
else:
|
155 |
+
k1 = tl.load(K1_block_ptr)
|
156 |
+
v = tl.load(V_block_ptr)
|
157 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
158 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
159 |
+
qk += tl.dot(q1, k1)#, out_dtype=tl.float16)
|
160 |
+
else:
|
161 |
+
#case 3: need both q1, k1 and q2, k2
|
162 |
+
if BLOCK_N + start_n >= N_CTX:
|
163 |
+
k1 = tl.load(K1_block_ptr, boundary_check=(1,), padding_option='zero')
|
164 |
+
k2 = tl.load(K2_block_ptr, boundary_check=(1,), padding_option='zero')
|
165 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
166 |
+
else:
|
167 |
+
k1 = tl.load(K1_block_ptr)
|
168 |
+
k2 = tl.load(K2_block_ptr)
|
169 |
+
v = tl.load(V_block_ptr)
|
170 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
171 |
+
#qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
172 |
+
qk1 = tl.dot(q1, k1)#, out_dtype=tl.float16)
|
173 |
+
qk2 = tl.dot(q2, k2)#, out_dtype=tl.float16)
|
174 |
+
#merge_mask = tl.abs((offs_m[:, None] - (start_n + offs_n[None, :]))) >= WINDOW
|
175 |
+
#qk += tl.where(merge_mask, qk2, qk1)
|
176 |
+
qk += tl.where(tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW, qk1, qk2)
|
177 |
+
|
178 |
+
qk *= qk_scale
|
179 |
+
|
180 |
+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
181 |
+
qk = qk - m_ij[:, None]
|
182 |
+
p = tl.math.exp2(qk)
|
183 |
+
l_ij = tl.sum(p, 1)
|
184 |
+
# -- update m_i and l_i
|
185 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
186 |
+
l_i = l_i * alpha + l_ij
|
187 |
+
# -- update output accumulator --
|
188 |
+
acc = acc * alpha[:, None]
|
189 |
+
# update acc
|
190 |
+
#v = tl.load(V_block_ptr)
|
191 |
+
#v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
|
192 |
+
acc += tl.dot(p.to(tl.float16), v)
|
193 |
+
# update m_i and l_i
|
194 |
+
m_i = m_ij
|
195 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
196 |
+
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
|
197 |
+
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
|
198 |
+
|
199 |
+
# epilogue
|
200 |
+
m_i += tl.math.log2(l_i)
|
201 |
+
acc = acc / l_i[:, None]
|
202 |
+
m_ptrs = M + off_hz * Q_CTX + offs_m
|
203 |
+
if start_m * BLOCK_M + BLOCK_M >= Q_CTX:
|
204 |
+
tl.store(m_ptrs, m_i, mask=offs_m < Q_CTX)
|
205 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
|
206 |
+
else:
|
207 |
+
tl.store(m_ptrs, m_i)
|
208 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
209 |
+
|
210 |
+
|
211 |
+
def prefill_flash_forward(q1, k1, q2, k2, v, q_len, seq_len, window, sm_scale=None):
|
212 |
+
# shape constraints
|
213 |
+
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
|
214 |
+
assert Lq == Lk and Lk == Lv
|
215 |
+
assert Lk in {16, 32, 64, 128}
|
216 |
+
assert q_len == seq_len or q_len == 1
|
217 |
+
if sm_scale is None:
|
218 |
+
sm_scale = 1.0 / math.sqrt(Lq) # the default scale factor.
|
219 |
+
o = torch.empty_like(q1, device=q1.device)
|
220 |
+
block_m = 128
|
221 |
+
block_n = 64 # if Lk <= 64 else 32
|
222 |
+
num_stages = 4 if Lk <= 64 else 3
|
223 |
+
num_warps = 4
|
224 |
+
# Tuning for H100
|
225 |
+
if torch.cuda.get_device_capability()[0] == 9:
|
226 |
+
num_warps = 8
|
227 |
+
num_stages = 7 if Lk >= 64 else 3
|
228 |
+
grid = (triton.cdiv(q1.shape[2], block_m), q1.shape[0] * q1.shape[1], 1)
|
229 |
+
M = torch.empty((q1.shape[0], q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32)
|
230 |
+
with torch.cuda.device(v.device.index):
|
231 |
+
# https://github.com/Dao-AILab/flash-attention/commit/9795159082f6e6c847db2bf4284fd17326c31fbd
|
232 |
+
# to avoid the device issue .
|
233 |
+
_attn_fwd_prefill[grid](
|
234 |
+
q1, k1, q2, k2, v, sm_scale, M, o, #
|
235 |
+
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), #
|
236 |
+
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), #
|
237 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
238 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
239 |
+
q1.shape[0], q1.shape[1], #
|
240 |
+
Q_CTX=q_len,
|
241 |
+
N_CTX=seq_len, #
|
242 |
+
BLOCK_M=block_m, #
|
243 |
+
BLOCK_N=block_n, #
|
244 |
+
WINDOW=window,
|
245 |
+
BLOCK_DMODEL=Lk, #
|
246 |
+
num_warps=num_warps, #
|
247 |
+
num_stages=num_stages #
|
248 |
+
)
|
249 |
+
|
250 |
+
return o
|
style.css
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
display: block;
|
4 |
+
}
|
5 |
+
|
6 |
+
#duplicate-button {
|
7 |
+
margin: auto;
|
8 |
+
color: white;
|
9 |
+
background: #1565c0;
|
10 |
+
border-radius: 100vh;
|
11 |
+
}
|
12 |
+
|
13 |
+
.contain {
|
14 |
+
max-width: 900px;
|
15 |
+
margin: auto;
|
16 |
+
padding-top: 1.5rem;
|
17 |
+
}
|
18 |
+
|
19 |
+
.s-pad {
|
20 |
+
display: block;
|
21 |
+
padding-top: 2rem;
|
22 |
+
height: 1px;
|
23 |
+
width: 100%;
|
24 |
+
}
|