tmm1 commited on
Commit
13f7efa
1 Parent(s): d773384

speed up flash-attn inference

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
16
 
17
  try:
18
  from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
 
19
  flash_attn_varlen_kvpacked_func,
20
  flash_attn_varlen_qkvpacked_func,
21
  )
@@ -146,7 +147,7 @@ def flashattn_forward(
146
  else:
147
  # turn off FA causal mask after first inference autoregressive iteration
148
  # only on first autoregressive step q,k,v have same seqlen
149
- is_causal = key_states.shape == query_states.shape
150
 
151
  if self.training and attention_mask.shape[0] == 1:
152
  # special handling using sample packing
@@ -163,14 +164,20 @@ def flashattn_forward(
163
  )
164
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
165
  elif query_states.shape == key_states.shape:
 
 
 
166
  qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
167
- query_states.transpose(1, 2),
168
- key_states.transpose(1, 2),
169
- value_states.transpose(1, 2),
170
  qkvpacked=True,
171
  # We have disabled _prepare_decoder_attention_mask in LlamaModel
172
  # the attention_mask should be the same as the key_padding_mask
173
  key_padding_mask=attention_mask,
 
 
 
174
  )
175
  output_unpad = flash_attn_varlen_qkvpacked_func(
176
  qkv_unpad,
@@ -182,35 +189,48 @@ def flashattn_forward(
182
  )
183
  output = output_pad_fn(output_unpad)
184
  else:
185
- ( # pylint: disable=unbalanced-tuple-unpacking
186
- q_unpad,
187
- kv_unpad,
188
- cu_seqlens_q,
189
- cu_seqlens_k,
190
- max_seqlen_q,
191
- max_seqlen_k,
192
- _,
193
- _,
194
- output_pad_fn,
195
- ) = generate_qkv(
196
- query_states.transpose(1, 2),
197
- key_states.transpose(1, 2),
198
- value_states.transpose(1, 2),
199
- kvpacked=True,
200
- key_padding_mask=attention_mask,
201
- )
202
- output_unpad = flash_attn_varlen_kvpacked_func(
203
- q_unpad,
204
- kv_unpad,
205
- cu_seqlens_q,
206
- cu_seqlens_k,
207
- max_seqlen_q,
208
- max_seqlen_k,
209
- 0.0,
210
- softmax_scale=None,
211
- causal=is_causal,
212
- )
213
- output = output_pad_fn(output_unpad)
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  attn_output = output
216
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
 
16
 
17
  try:
18
  from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
19
+ flash_attn_kvpacked_func,
20
  flash_attn_varlen_kvpacked_func,
21
  flash_attn_varlen_qkvpacked_func,
22
  )
 
147
  else:
148
  # turn off FA causal mask after first inference autoregressive iteration
149
  # only on first autoregressive step q,k,v have same seqlen
150
+ is_causal = past_key_value is not None
151
 
152
  if self.training and attention_mask.shape[0] == 1:
153
  # special handling using sample packing
 
164
  )
165
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
166
  elif query_states.shape == key_states.shape:
167
+ query_states = query_states.transpose(1, 2)
168
+ key_states = key_states.transpose(1, 2)
169
+ value_states = value_states.transpose(1, 2)
170
  qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
  qkvpacked=True,
175
  # We have disabled _prepare_decoder_attention_mask in LlamaModel
176
  # the attention_mask should be the same as the key_padding_mask
177
  key_padding_mask=attention_mask,
178
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
179
+ if attention_mask is not None
180
+ else None,
181
  )
182
  output_unpad = flash_attn_varlen_qkvpacked_func(
183
  qkv_unpad,
 
189
  )
190
  output = output_pad_fn(output_unpad)
191
  else:
192
+ query_states = query_states.transpose(1, 2)
193
+ key_states = key_states.transpose(1, 2)
194
+ value_states = value_states.transpose(1, 2)
195
+ if attention_mask is None or attention_mask.all().item():
196
+ output = flash_attn_kvpacked_func(
197
+ query_states,
198
+ torch.stack([key_states, value_states], 2),
199
+ causal=is_causal,
200
+ )
201
+ else:
202
+ ( # pylint: disable=unbalanced-tuple-unpacking
203
+ q_unpad,
204
+ kv_unpad,
205
+ cu_seqlens_q,
206
+ cu_seqlens_k,
207
+ max_seqlen_q,
208
+ max_seqlen_k,
209
+ _,
210
+ _,
211
+ output_pad_fn,
212
+ ) = generate_qkv(
213
+ query_states,
214
+ key_states,
215
+ value_states,
216
+ kvpacked=True,
217
+ key_padding_mask=attention_mask,
218
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
219
+ if attention_mask is not None
220
+ else None,
221
+ )
222
+ output_unpad = flash_attn_varlen_kvpacked_func(
223
+ q_unpad,
224
+ kv_unpad,
225
+ cu_seqlens_q,
226
+ cu_seqlens_k,
227
+ max_seqlen_q,
228
+ max_seqlen_k,
229
+ 0.0,
230
+ softmax_scale=None,
231
+ causal=is_causal,
232
+ )
233
+ output = output_pad_fn(output_unpad)
234
 
235
  attn_output = output
236
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):