winglian commited on
Commit
b2edaae
1 Parent(s): b88f515

fix for flash attn w mistral w/o sammple packing (#648)

Browse files
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py CHANGED
@@ -2,13 +2,17 @@
2
  # pylint: disable=duplicate-code
3
 
4
  import logging
5
- import math
6
  from typing import List, Optional, Tuple, Union
7
 
8
  import torch
9
  import transformers
10
  from einops import rearrange
11
- from torch import nn
 
 
 
 
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast
13
  from transformers.models.mistral.modeling_mistral import (
14
  MistralDecoderLayer as OriginalMistralDecoderLayer,
@@ -17,16 +21,6 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r
17
 
18
  from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
19
 
20
- try:
21
- from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
22
- flash_attn_varlen_qkvpacked_func,
23
- )
24
- except ImportError:
25
- from flash_attn.flash_attn_interface import (
26
- flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
27
- )
28
-
29
-
30
  LOG = logging.getLogger("axolotl.monkeypatch.mistral")
31
 
32
 
@@ -108,6 +102,15 @@ def flashattn_forward(
108
  key_states = repeat_kv(key_states, self.num_key_value_groups)
109
  value_states = repeat_kv(value_states, self.num_key_value_groups)
110
 
 
 
 
 
 
 
 
 
 
111
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
112
  # special handling using sample packing
113
  qkv = torch.stack(
@@ -120,46 +123,84 @@ def flashattn_forward(
120
  qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
121
  )
122
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
123
- attn_output = output
124
- if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
125
- raise ValueError(
126
- f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
127
- f" {attn_output.size()}"
128
- )
129
- attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
130
- attn_weights = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  else:
132
- attn_weights = torch.matmul(
133
- query_states, key_states.transpose(2, 3)
134
- ) / math.sqrt(self.head_dim)
135
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
136
- raise ValueError(
137
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
138
- f" {attn_weights.size()}"
 
139
  )
140
-
141
- if attention_mask is not None:
142
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
143
- raise ValueError(
144
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
145
- )
146
-
147
- attn_weights = attn_weights + attention_mask
148
-
149
- # upcast attention to fp32
150
- attn_weights = nn.functional.softmax(
151
- attn_weights, dim=-1, dtype=torch.float32
152
- ).to(query_states.dtype)
153
- attn_output = torch.matmul(attn_weights, value_states)
154
-
155
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
156
- raise ValueError(
157
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
158
- f" {attn_output.size()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
 
160
 
161
- attn_output = attn_output.transpose(1, 2).contiguous()
162
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
 
 
 
163
 
164
  attn_output = self.o_proj(attn_output)
165
 
@@ -169,6 +210,105 @@ def flashattn_forward(
169
  return attn_output, attn_weights, past_key_value
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def mistral_model_forward(
173
  self,
174
  input_ids: torch.LongTensor = None,
 
2
  # pylint: disable=duplicate-code
3
 
4
  import logging
 
5
  from typing import List, Optional, Tuple, Union
6
 
7
  import torch
8
  import transformers
9
  from einops import rearrange
10
+ from flash_attn.bert_padding import pad_input, unpad_input
11
+ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
12
+ flash_attn_kvpacked_func,
13
+ flash_attn_varlen_kvpacked_func,
14
+ flash_attn_varlen_qkvpacked_func,
15
+ )
16
  from transformers.modeling_outputs import BaseModelOutputWithPast
17
  from transformers.models.mistral.modeling_mistral import (
18
  MistralDecoderLayer as OriginalMistralDecoderLayer,
 
21
 
22
  from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
23
 
 
 
 
 
 
 
 
 
 
 
24
  LOG = logging.getLogger("axolotl.monkeypatch.mistral")
25
 
26
 
 
102
  key_states = repeat_kv(key_states, self.num_key_value_groups)
103
  value_states = repeat_kv(value_states, self.num_key_value_groups)
104
 
105
+ if self.training:
106
+ # during training q,k,v always have same seqlen
107
+ assert key_states.shape == query_states.shape
108
+ is_causal = True
109
+ else:
110
+ # turn off FA causal mask after first inference autoregressive iteration
111
+ # only on first autoregressive step q,k,v have same seqlen
112
+ is_causal = key_states.shape == query_states.shape
113
+
114
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
115
  # special handling using sample packing
116
  qkv = torch.stack(
 
123
  qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
124
  )
125
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
126
+ elif query_states.shape == key_states.shape:
127
+ query_states = query_states.transpose(1, 2)
128
+ key_states = key_states.transpose(1, 2)
129
+ value_states = value_states.transpose(1, 2)
130
+ qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
131
+ query_states,
132
+ key_states,
133
+ value_states,
134
+ qkvpacked=True,
135
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
136
+ # the attention_mask should be the same as the key_padding_mask
137
+ key_padding_mask=attention_mask,
138
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
139
+ if attention_mask is not None
140
+ else None,
141
+ )
142
+ output_unpad = flash_attn_varlen_qkvpacked_func(
143
+ qkv_unpad,
144
+ cu_seqlens_q,
145
+ max_seqlen_q,
146
+ 0.0,
147
+ softmax_scale=None,
148
+ causal=is_causal,
149
+ )
150
+ output = output_pad_fn(output_unpad)
151
  else:
152
+ query_states = query_states.transpose(1, 2)
153
+ key_states = key_states.transpose(1, 2)
154
+ value_states = value_states.transpose(1, 2)
155
+ if attention_mask is None or attention_mask.all().item():
156
+ output = flash_attn_kvpacked_func(
157
+ query_states,
158
+ torch.stack([key_states, value_states], 2),
159
+ causal=is_causal,
160
  )
161
+ else:
162
+ ( # pylint: disable=unbalanced-tuple-unpacking
163
+ q_unpad,
164
+ kv_unpad,
165
+ cu_seqlens_q,
166
+ cu_seqlens_k,
167
+ max_seqlen_q,
168
+ max_seqlen_k,
169
+ _,
170
+ _,
171
+ output_pad_fn,
172
+ ) = generate_qkv(
173
+ query_states,
174
+ key_states,
175
+ value_states,
176
+ kvpacked=True,
177
+ key_padding_mask=attention_mask,
178
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
179
+ if attention_mask is not None
180
+ else None,
181
+ )
182
+ if q_unpad.dtype != kv_unpad.dtype:
183
+ kv_unpad = kv_unpad.to(q_unpad.dtype)
184
+ output_unpad = flash_attn_varlen_kvpacked_func(
185
+ q_unpad,
186
+ kv_unpad,
187
+ cu_seqlens_q,
188
+ cu_seqlens_k,
189
+ max_seqlen_q,
190
+ max_seqlen_k,
191
+ 0.0,
192
+ softmax_scale=None,
193
+ causal=is_causal,
194
  )
195
+ output = output_pad_fn(output_unpad)
196
 
197
+ attn_output = output
198
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
199
+ raise ValueError(
200
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
201
+ f" {attn_output.size()}"
202
+ )
203
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
204
 
205
  attn_output = self.o_proj(attn_output)
206
 
 
210
  return attn_output, attn_weights, past_key_value
211
 
212
 
213
+ # based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
214
+ def generate_qkv(
215
+ q,
216
+ k,
217
+ v,
218
+ query_padding_mask=None,
219
+ key_padding_mask=None,
220
+ kvpacked=False,
221
+ qkvpacked=False,
222
+ ): # pylint: disable=invalid-name,unnecessary-lambda-assignment
223
+ """
224
+ Arguments:
225
+ q: (batch_size, seqlen_q, nheads, d)
226
+ k: (batch_size, seqlen_k, nheads_k, d)
227
+ v: (batch_size, seqlen_k, nheads_k, d)
228
+ query_padding_mask: (batch_size, seqlen), bool
229
+ key_padding_mask: (batch_size, seqlen), bool
230
+ """
231
+ assert not (kvpacked and qkvpacked)
232
+ batch_size, seqlen_q, nheads, d = q.shape
233
+ _, seqlen_k, nheads_k, _ = k.shape
234
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
235
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d)
236
+
237
+ if query_padding_mask is not None:
238
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
239
+ q, query_padding_mask
240
+ )
241
+
242
+ output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
243
+ output_unpad, indices_q, batch_size, seqlen_q
244
+ )
245
+
246
+ else:
247
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
248
+ cu_seqlens_q = torch.arange(
249
+ 0,
250
+ (batch_size + 1) * seqlen_q,
251
+ step=seqlen_q,
252
+ dtype=torch.int32,
253
+ device=q_unpad.device,
254
+ )
255
+ max_seqlen_q = seqlen_q
256
+
257
+ output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
258
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
259
+ )
260
+
261
+ if key_padding_mask is not None:
262
+ k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
263
+ v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
264
+ else:
265
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
266
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
267
+ cu_seqlens_k = torch.arange(
268
+ 0,
269
+ (batch_size + 1) * seqlen_k,
270
+ step=seqlen_k,
271
+ dtype=torch.int32,
272
+ device=k_unpad.device,
273
+ )
274
+ max_seqlen_k = seqlen_k
275
+
276
+ if qkvpacked:
277
+ assert nheads == nheads_k
278
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
279
+ qkv = torch.stack([q, k, v], dim=2)
280
+ return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
281
+
282
+ if kvpacked:
283
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
284
+ kv = torch.stack([k, v], dim=2)
285
+ return (
286
+ q_unpad,
287
+ kv_unpad,
288
+ cu_seqlens_q,
289
+ cu_seqlens_k,
290
+ max_seqlen_q,
291
+ max_seqlen_k,
292
+ q,
293
+ kv,
294
+ output_pad_fn,
295
+ )
296
+
297
+ return (
298
+ q_unpad,
299
+ k_unpad,
300
+ v_unpad,
301
+ cu_seqlens_q,
302
+ cu_seqlens_k,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ q,
306
+ k,
307
+ v,
308
+ output_pad_fn,
309
+ )
310
+
311
+
312
  def mistral_model_forward(
313
  self,
314
  input_ids: torch.LongTensor = None,