speed up flash-attn inference
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
16 |
|
17 |
try:
|
18 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
|
19 |
flash_attn_varlen_kvpacked_func,
|
20 |
flash_attn_varlen_qkvpacked_func,
|
21 |
)
|
@@ -146,7 +147,7 @@ def flashattn_forward(
|
|
146 |
else:
|
147 |
# turn off FA causal mask after first inference autoregressive iteration
|
148 |
# only on first autoregressive step q,k,v have same seqlen
|
149 |
-
is_causal =
|
150 |
|
151 |
if self.training and attention_mask.shape[0] == 1:
|
152 |
# special handling using sample packing
|
@@ -163,14 +164,20 @@ def flashattn_forward(
|
|
163 |
)
|
164 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
165 |
elif query_states.shape == key_states.shape:
|
|
|
|
|
|
|
166 |
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
167 |
-
query_states
|
168 |
-
key_states
|
169 |
-
value_states
|
170 |
qkvpacked=True,
|
171 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
172 |
# the attention_mask should be the same as the key_padding_mask
|
173 |
key_padding_mask=attention_mask,
|
|
|
|
|
|
|
174 |
)
|
175 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
176 |
qkv_unpad,
|
@@ -182,35 +189,48 @@ def flashattn_forward(
|
|
182 |
)
|
183 |
output = output_pad_fn(output_unpad)
|
184 |
else:
|
185 |
-
(
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
attn_output = output
|
216 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
|
16 |
|
17 |
try:
|
18 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
19 |
+
flash_attn_kvpacked_func,
|
20 |
flash_attn_varlen_kvpacked_func,
|
21 |
flash_attn_varlen_qkvpacked_func,
|
22 |
)
|
|
|
147 |
else:
|
148 |
# turn off FA causal mask after first inference autoregressive iteration
|
149 |
# only on first autoregressive step q,k,v have same seqlen
|
150 |
+
is_causal = past_key_value is not None
|
151 |
|
152 |
if self.training and attention_mask.shape[0] == 1:
|
153 |
# special handling using sample packing
|
|
|
164 |
)
|
165 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
166 |
elif query_states.shape == key_states.shape:
|
167 |
+
query_states = query_states.transpose(1, 2)
|
168 |
+
key_states = key_states.transpose(1, 2)
|
169 |
+
value_states = value_states.transpose(1, 2)
|
170 |
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
171 |
+
query_states,
|
172 |
+
key_states,
|
173 |
+
value_states,
|
174 |
qkvpacked=True,
|
175 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
176 |
# the attention_mask should be the same as the key_padding_mask
|
177 |
key_padding_mask=attention_mask,
|
178 |
+
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
179 |
+
if attention_mask is not None
|
180 |
+
else None,
|
181 |
)
|
182 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
183 |
qkv_unpad,
|
|
|
189 |
)
|
190 |
output = output_pad_fn(output_unpad)
|
191 |
else:
|
192 |
+
query_states = query_states.transpose(1, 2)
|
193 |
+
key_states = key_states.transpose(1, 2)
|
194 |
+
value_states = value_states.transpose(1, 2)
|
195 |
+
if attention_mask is None or attention_mask.all().item():
|
196 |
+
output = flash_attn_kvpacked_func(
|
197 |
+
query_states,
|
198 |
+
torch.stack([key_states, value_states], 2),
|
199 |
+
causal=is_causal,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
( # pylint: disable=unbalanced-tuple-unpacking
|
203 |
+
q_unpad,
|
204 |
+
kv_unpad,
|
205 |
+
cu_seqlens_q,
|
206 |
+
cu_seqlens_k,
|
207 |
+
max_seqlen_q,
|
208 |
+
max_seqlen_k,
|
209 |
+
_,
|
210 |
+
_,
|
211 |
+
output_pad_fn,
|
212 |
+
) = generate_qkv(
|
213 |
+
query_states,
|
214 |
+
key_states,
|
215 |
+
value_states,
|
216 |
+
kvpacked=True,
|
217 |
+
key_padding_mask=attention_mask,
|
218 |
+
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
219 |
+
if attention_mask is not None
|
220 |
+
else None,
|
221 |
+
)
|
222 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
223 |
+
q_unpad,
|
224 |
+
kv_unpad,
|
225 |
+
cu_seqlens_q,
|
226 |
+
cu_seqlens_k,
|
227 |
+
max_seqlen_q,
|
228 |
+
max_seqlen_k,
|
229 |
+
0.0,
|
230 |
+
softmax_scale=None,
|
231 |
+
causal=is_causal,
|
232 |
+
)
|
233 |
+
output = output_pad_fn(output_unpad)
|
234 |
|
235 |
attn_output = output
|
236 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|