Jackmin801 commited on
Commit
4f24e0f
1 Parent(s): 33026dc

old flash attn

Browse files
Files changed (2) hide show
  1. flash_attn_triton.py +424 -0
  2. modeling_bert.py +11 -0
flash_attn_triton.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton implementation of Flash Attention.
2
+
3
+ # Copyright (c) 2022, Tri Dao.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ *Experimental* implementation of FlashAttention in Triton.
18
+ We use the FlashAttention implementation from Phil Tillet a starting point.
19
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
20
+
21
+ Changes:
22
+ - Implement both causal and non-causal attention.
23
+ - Implement both self-attention and cross-attention.
24
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
25
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
26
+ - Support attention bias.
27
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
28
+ - Make the backward for d=128 much faster by reducing register spilling.
29
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
30
+ small batch size * nheads.
31
+
32
+ Caution:
33
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
34
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
35
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
36
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
37
+ that there are none left for other head dimensions.
38
+ Differences between this Triton version and the CUDA version:
39
+ - Triton version doesn't support dropout.
40
+ - Triton forward is generally faster than CUDA forward.
41
+ - Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
42
+ It is slightly slower when headdim=128 and batch * nheads is large.
43
+ - Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
44
+ """
45
+
46
+ import math
47
+
48
+ import torch
49
+ import triton # type: ignore (reportMissingImports)
50
+ import triton.language as tl # type: ignore (reportMissingImports)
51
+ from einops import repeat
52
+
53
+
54
+ @triton.autotune(
55
+ configs=[
56
+ triton.Config({
57
+ 'BLOCK_M': 128,
58
+ 'BLOCK_N': 128
59
+ },
60
+ num_warps=8,
61
+ num_stages=1),
62
+ # This config has a race condition when EVEN_M == False, disabling it for now.
63
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
64
+ ],
65
+ key=[
66
+ 'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
67
+ 'BLOCK_HEADDIM'
68
+ ])
69
+ @triton.heuristics({
70
+ 'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
71
+ 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
72
+ 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
73
+ })
74
+ @triton.jit
75
+ def _fwd_kernel(
76
+ Q,
77
+ K,
78
+ V,
79
+ Bias,
80
+ Out,
81
+ Lse,
82
+ TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
83
+ softmax_scale,
84
+ stride_qb,
85
+ stride_qh,
86
+ stride_qm,
87
+ stride_kb,
88
+ stride_kh,
89
+ stride_kn,
90
+ stride_vb,
91
+ stride_vh,
92
+ stride_vn,
93
+ stride_bb,
94
+ stride_bh,
95
+ stride_bm,
96
+ stride_ob,
97
+ stride_oh,
98
+ stride_om,
99
+ nheads,
100
+ seqlen_q,
101
+ seqlen_k,
102
+ seqlen_q_rounded,
103
+ headdim,
104
+ CACHE_KEY_SEQLEN_Q,
105
+ CACHE_KEY_SEQLEN_K,
106
+ BIAS_TYPE: tl.constexpr,
107
+ IS_CAUSAL: tl.constexpr,
108
+ BLOCK_HEADDIM: tl.constexpr,
109
+ EVEN_M: tl.constexpr,
110
+ EVEN_N: tl.constexpr,
111
+ EVEN_HEADDIM: tl.constexpr,
112
+ BLOCK_M: tl.constexpr,
113
+ BLOCK_N: tl.constexpr,
114
+ ):
115
+ start_m = tl.program_id(0)
116
+ off_hb = tl.program_id(1)
117
+ off_b = off_hb // nheads
118
+ off_h = off_hb % nheads
119
+ # off_b = tl.program_id(1)
120
+ # off_h = tl.program_id(2)
121
+ # off_hb = off_b * nheads + off_h
122
+ # initialize offsets
123
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
124
+ offs_n = tl.arange(0, BLOCK_N)
125
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
126
+ # Initialize pointers to Q, K, V
127
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
128
+ # https://github.com/openai/triton/issues/741
129
+ # I'm seeing a tiny bit of difference (5-7us)
130
+ q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
131
+ offs_m[:, None] * stride_qm + offs_d[None, :])
132
+ k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
133
+ offs_n[:, None] * stride_kn + offs_d[None, :])
134
+ v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
135
+ offs_n[:, None] * stride_vn + offs_d[None, :])
136
+ if BIAS_TYPE == 'vector':
137
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
138
+ elif BIAS_TYPE == 'matrix':
139
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
140
+ offs_m[:, None] * stride_bm + offs_n[None, :])
141
+ else:
142
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
143
+ # initialize pointer to m and l
144
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
145
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
146
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
147
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
148
+ # load q: it will stay in SRAM throughout
149
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
150
+ # tl.load(q_ptrs), we get the wrong output!
151
+ if EVEN_M & EVEN_N:
152
+ if EVEN_HEADDIM:
153
+ q = tl.load(q_ptrs)
154
+ else:
155
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
156
+ else:
157
+ if EVEN_HEADDIM:
158
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
159
+ else:
160
+ q = tl.load(q_ptrs,
161
+ mask=(offs_m[:, None] < seqlen_q) &
162
+ (offs_d[None, :] < headdim),
163
+ other=0.0)
164
+ # loop over k, v and update accumulator
165
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
166
+ (start_m + 1) * BLOCK_M, seqlen_k)
167
+ for start_n in range(0, end_n, BLOCK_N):
168
+ start_n = tl.multiple_of(start_n, BLOCK_N)
169
+ # -- compute qk ----
170
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
171
+ if EVEN_HEADDIM:
172
+ k = tl.load(k_ptrs + start_n * stride_kn)
173
+ else:
174
+ k = tl.load(k_ptrs + start_n * stride_kn,
175
+ mask=offs_d[None, :] < headdim,
176
+ other=0.0)
177
+ else:
178
+ if EVEN_HEADDIM:
179
+ k = tl.load(k_ptrs + start_n * stride_kn,
180
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
181
+ other=0.0)
182
+ else:
183
+ k = tl.load(k_ptrs + start_n * stride_kn,
184
+ mask=((start_n + offs_n)[:, None] < seqlen_k) &
185
+ (offs_d[None, :] < headdim),
186
+ other=0.0)
187
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
188
+ qk += tl.dot(q, k, trans_b=True)
189
+ # Trying to combine the two masks seem to make the result wrong
190
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
191
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
192
+ float('-inf'))
193
+ if IS_CAUSAL:
194
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0,
195
+ float('-inf'))
196
+ if BIAS_TYPE != 'none':
197
+ if BIAS_TYPE == 'vector':
198
+ if EVEN_N:
199
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
200
+ else:
201
+ bias = tl.load(b_ptrs + start_n,
202
+ mask=(start_n + offs_n) < seqlen_k,
203
+ other=0.0).to(tl.float32)
204
+ bias = bias[None, :]
205
+ elif BIAS_TYPE == 'matrix':
206
+ if EVEN_M & EVEN_N:
207
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
208
+ else:
209
+ bias = tl.load(b_ptrs + start_n,
210
+ mask=(offs_m[:, None] < seqlen_q) &
211
+ ((start_n + offs_n)[None, :] < seqlen_k),
212
+ other=0.0).to(tl.float32)
213
+ else:
214
+ raise ValueError(
215
+ "BIAS_TYPE must be one of {'vector', 'matrix'}")
216
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
217
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
218
+ # to multiply with softmax_scale here.
219
+ qk = qk * softmax_scale + bias
220
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
221
+ p = tl.exp(qk - m_ij[:, None])
222
+ else:
223
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
224
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
225
+ l_ij = tl.sum(p, 1)
226
+
227
+ # scale acc_o
228
+ acc_o_scale = tl.exp(m_i - m_ij)
229
+
230
+ # # -- update output accumulator --
231
+ # BUG: have to store and immediately load
232
+ tl.store(t_ptrs, acc_o_scale)
233
+ acc_o_scale = tl.load(t_ptrs)
234
+ acc_o = acc_o * acc_o_scale[:, None]
235
+ # update acc_o
236
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
237
+ if EVEN_HEADDIM:
238
+ v = tl.load(v_ptrs + start_n * stride_vn)
239
+ else:
240
+ v = tl.load(v_ptrs + start_n * stride_vn,
241
+ mask=offs_d[None, :] < headdim,
242
+ other=0.0)
243
+ else:
244
+ if EVEN_HEADDIM:
245
+ v = tl.load(v_ptrs + start_n * stride_vn,
246
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
247
+ other=0.0)
248
+ else:
249
+ v = tl.load(v_ptrs + start_n * stride_vn,
250
+ mask=((start_n + offs_n)[:, None] < seqlen_k) &
251
+ (offs_d[None, :] < headdim),
252
+ other=0.0)
253
+ p = p.to(v.dtype)
254
+ acc_o += tl.dot(p, v)
255
+
256
+ # -- update statistics
257
+ m_i = m_ij
258
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
259
+ lse_i = m_ij + tl.log(l_i_new)
260
+
261
+ o_scale = tl.exp(m_i - lse_i)
262
+ # BUG: have to store and immediately load
263
+ tl.store(t_ptrs, o_scale)
264
+ o_scale = tl.load(t_ptrs)
265
+ acc_o = acc_o * o_scale[:, None]
266
+ # rematerialize offsets to save registers
267
+ start_m = tl.program_id(0)
268
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
269
+ # write back l and m
270
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
271
+ tl.store(lse_ptrs, lse_i)
272
+ # initialize pointers to output
273
+ offs_n = tl.arange(0, BLOCK_HEADDIM)
274
+ out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (
275
+ offs_m[:, None] * stride_om + offs_n[None, :])
276
+ if EVEN_M:
277
+ if EVEN_HEADDIM:
278
+ tl.store(out_ptrs, acc_o)
279
+ else:
280
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
281
+ else:
282
+ if EVEN_HEADDIM:
283
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
284
+ else:
285
+ tl.store(out_ptrs,
286
+ acc_o,
287
+ mask=(offs_m[:, None] < seqlen_q) &
288
+ (offs_d[None, :] < headdim))
289
+
290
+ def init_to_zero(name):
291
+ return lambda nargs: nargs[name].zero_()
292
+
293
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
294
+ # shape constraints
295
+ batch, seqlen_q, nheads, d = q.shape
296
+ _, seqlen_k, _, _ = k.shape
297
+ assert k.shape == (batch, seqlen_k, nheads, d)
298
+ assert v.shape == (batch, seqlen_k, nheads, d)
299
+ assert d <= 128, 'FlashAttention only support head dimensions up to 128'
300
+ assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
301
+ assert q.dtype in [torch.float16,
302
+ torch.bfloat16], 'Only support fp16 and bf16'
303
+ assert q.is_cuda and k.is_cuda and v.is_cuda
304
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
305
+
306
+ has_bias = bias is not None
307
+ bias_type = 'none'
308
+ if has_bias:
309
+ assert bias.dtype in [q.dtype, torch.float]
310
+ assert bias.is_cuda
311
+ assert bias.dim() == 4
312
+ if bias.stride(-1) != 1:
313
+ bias = bias.contiguous()
314
+ if bias.shape[2:] == (1, seqlen_k):
315
+ bias_type = 'vector'
316
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
317
+ bias_type = 'matrix'
318
+ else:
319
+ print(q.shape)
320
+ print(k.shape)
321
+ print(seqlen_q)
322
+ print(seqlen_k)
323
+ print(bias.shape)
324
+ raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
325
+ ' or (seqlen_q, seqlen_k)')
326
+ if bias.shape[:2] == (1, nheads):
327
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
328
+ elif bias.shape[:2] == (batch, 1):
329
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
330
+ elif bias.shape[:2] == (1, 1):
331
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
332
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
333
+ assert bias.shape[:2] == (
334
+ batch, nheads
335
+ ), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
336
+ assert bias is not None # for type checking
337
+ bias_strides = (bias.stride(0), bias.stride(1),
338
+ bias.stride(2)) if has_bias else (0, 0, 0)
339
+
340
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
341
+ lse = torch.empty((batch, nheads, seqlen_q_rounded),
342
+ device=q.device,
343
+ dtype=torch.float32)
344
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded),
345
+ device=q.device,
346
+ dtype=torch.float32)
347
+ o = torch.empty_like(q)
348
+
349
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
350
+ # BLOCK = 128
351
+ # num_warps = 4 if d <= 64 else 8
352
+ grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
353
+ _fwd_kernel[grid]( # type: ignore
354
+ q,
355
+ k,
356
+ v,
357
+ bias,
358
+ o,
359
+ lse,
360
+ tmp,
361
+ softmax_scale,
362
+ q.stride(0),
363
+ q.stride(2),
364
+ q.stride(1),
365
+ k.stride(0),
366
+ k.stride(2),
367
+ k.stride(1),
368
+ v.stride(0),
369
+ v.stride(2),
370
+ v.stride(1),
371
+ *bias_strides,
372
+ o.stride(0),
373
+ o.stride(2),
374
+ o.stride(1),
375
+ nheads,
376
+ seqlen_q,
377
+ seqlen_k,
378
+ seqlen_q_rounded,
379
+ d,
380
+ seqlen_q // 32,
381
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
382
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
383
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
384
+ bias_type,
385
+ causal,
386
+ BLOCK_HEADDIM,
387
+ # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
388
+ # num_warps=num_warps,
389
+ # num_stages=1,
390
+ )
391
+ return o, lse, softmax_scale # softmax_scale could have been updated
392
+
393
+ class _FlashAttnFunc(torch.autograd.Function):
394
+
395
+ @staticmethod
396
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
397
+ """Forward pass for FlashAttention.
398
+
399
+ Args:
400
+ ctx: autograd context
401
+ q: (batch_size, seqlen_q, nheads, headdim)
402
+ k: (batch_size, seqlen_k, nheads, headdim)
403
+ v: (batch_size, seqlen_k, nheads, headdim)
404
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
405
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
406
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
407
+ causal (bool): whether to incorporate causal attention masking
408
+ softmax_scale (float, optional): scale factor for softmax
409
+ """
410
+ # Make sure that the last dimension is contiguous
411
+ q, k, v = [
412
+ x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]
413
+ ]
414
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
415
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
416
+ ctx.save_for_backward(q, k, v, o, lse, bias)
417
+ ctx.causal = causal
418
+ return o
419
+
420
+ @staticmethod
421
+ def backward(ctx, do):
422
+ raise NotImplementedError
423
+
424
+ flash_attn_func = _FlashAttnFunc.apply
modeling_bert.py CHANGED
@@ -55,6 +55,7 @@ from transformers.utils import (
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
 
58
 
59
  try:
60
  from tqdm.autonotebook import trange
@@ -333,6 +334,16 @@ class JinaBertSelfAttention(nn.Module):
333
  output_attentions: Optional[bool] = False,
334
  bias: Optional[torch.FloatTensor] = None,
335
  ) -> Tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
336
  mixed_query_layer = self.query(hidden_states)
337
 
338
  # If this is instantiated as a cross-attention module, the keys
 
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
58
+ from .flash_attn_triton import flash_attn_func
59
 
60
  try:
61
  from tqdm.autonotebook import trange
 
334
  output_attentions: Optional[bool] = False,
335
  bias: Optional[torch.FloatTensor] = None,
336
  ) -> Tuple[torch.Tensor]:
337
+ if False:
338
+ b, s, h = hidden_states.shape
339
+ q = self.query(hidden_states)
340
+ k = self.key(hidden_states)
341
+ v = self.value(hidden_states)
342
+ q = self.transpose_for_scores(q)
343
+ k = self.transpose_for_scores(k)
344
+ v = self.transpose_for_scores(v)
345
+ attn = flash_attn_func(q, k, v, bias)
346
+ return (attn.view(b, s, h),)
347
  mixed_query_layer = self.query(hidden_states)
348
 
349
  # If this is instantiated as a cross-attention module, the keys