File size: 24,686 Bytes
4f41cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
# Copyright (c) 2023, Tri Dao.

from typing import Optional, Union

import torch
import torch.nn as nn

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda

# isort: on

torch.library.define("fa2::fwd", "(Tensor q, Tensor k, Tensor v, Tensor out, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, Tensor attn_bias, bool return_softmax, Tensor gen_) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)")

@torch.library.impl("fa2::fwd", "default")
def cuda_fa2_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    out: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    attn_bias: torch.Tensor,
    return_softmax: bool,
    gen_: torch.Tensor,
):

    out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, attn_bias, return_softmax, None)
    return  out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state

@torch.library.impl_abstract("fa2::fwd", cuda_fa2_fwd)
def meta_fa2_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    out: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    attn_bias: torch.Tensor,
    return_softmax: bool,
    gen_: torch.Tensor
):

    round_multiple = lambda x, m: (x + m - 1) // m * m
    batch_size = q.shape[0]
    seqlen_q = q.shape[1]
    seqlen_k = k.shape[1]
    num_heads = q.shape[2]
    head_dim_og = q.shape[3]
    seqlen_q_rounded = round_multiple(seqlen_q, 128)
    seqlen_k_rounded = round_multiple(seqlen_k, 128)
    seqlen_q_rounded_8 = round_multiple(seqlen_q, 8)
    seqlen_k_rounded_8 = round_multiple(seqlen_k, 8)
    head_dim = round_multiple(head_dim_og, 8)

    if attn_bias is not None:
        batch_size_bias = attn_bias.shape[0]
        num_heads_bias = attn_bias.shape[1]

    return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
                (head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype), # out
        q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # q_padded
        k.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # k_padded
        v.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # v_padded
        q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # out_padded
        q.new_empty((batch_size_bias, num_heads_bias, seqlen_q_rounded_8, seqlen_k_rounded_8)) if attn_bias is not None else None, # attn_bias
        q.new_empty((batch_size, num_heads, seqlen_q)), # softmax_lse
        q.new_empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded)) if return_softmax and (dropout_p > 0) else None, # p
        torch.empty((2), dtype=torch.int64, device=q.device) # rng_state
        )

torch.library.define("fa2::bwd", "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor dq, Tensor dk, Tensor dv, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, bool deterministic, Tensor attn_bias, bool attn_bias_require_grad, Tensor ds, int seqlen_k_orig, Tensor gen_, Tensor rng_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")

@torch.library.impl("fa2::bwd", "default")
def cuda_fa2_bwd(
    dout: torch.Tensor,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    out: torch.Tensor,
    softmax_lse: torch.Tensor,
    dq: torch.Tensor,
    dk: torch.Tensor,
    dv: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    deterministic: bool,
    attn_bias: torch.Tensor,
    attn_bias_require_grad: bool,
    ds: torch.Tensor,
    seqlen_k_orig: int,
    gen_: torch.Tensor,
    rng_sate: torch.Tensor
):
    dq, dk, dv, ds, s = flash_attn_cuda.bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, deterministic, attn_bias, attn_bias_require_grad, ds, None, rng_sate)
    return dq, dk, dv, ds, s

@torch.library.impl_abstract("fa2::bwd", cuda_fa2_bwd)
def meta_fa2_bwd(
    dout: torch.Tensor,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    out: torch.Tensor,
    softmax_lse: torch.Tensor,
    dq: torch.Tensor,
    dk: torch.Tensor,
    dv: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    deterministic: bool,
    attn_bias: torch.Tensor,
    attn_bias_require_grad: bool,
    ds: torch.Tensor,
    seqlen_k_orig: int,
    gen_: torch.Tensor,
    rng_sate: torch.Tensor
):

    round_multiple = lambda x, m: (x + m - 1) // m * m
    batch_size = dout.shape[0]
    seqlen_q = dout.shape[1]
    seqlen_k = k.shape[1]
    seqlen_q_rounded = round_multiple(seqlen_q, 128)
    num_heads = dout.shape[2]
    head_dim_og = dout.shape[3]
    head_dim = round_multiple(head_dim_og, 8)
    seqlen_q_round8 = round_multiple(seqlen_q, 8)
    seqlen_k_round8 = round_multiple(seqlen_k_orig, 8)

    if attn_bias is not None:
        batch_size_bias = attn_bias.shape[0]
        num_heads_bias = attn_bias.shape[1]

    return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
                (head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype),
        torch.empty_strided((batch_size, seqlen_k_orig, num_heads, head_dim_og),
                (head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=k.device, dtype=k.dtype),
        torch.empty_strided((batch_size, seqlen_k, num_heads, head_dim_og),
                (head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=v.device, dtype=v.dtype),
        torch.empty_strided((batch_size_bias, num_heads_bias, seqlen_q, seqlen_k_orig),
                (num_heads_bias*seqlen_q_round8*seqlen_k_round8, seqlen_q_round8*seqlen_k_round8, seqlen_q_round8, 1), device=v.device, dtype=v.dtype)
                if attn_bias_require_grad else None,
        q.new_empty((batch_size, num_heads, seqlen_q_rounded))
        )


class FlashAttnQKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_softmax,
        return_ds
    ):
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)

        out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
            qkv[:, :, 0],
            qkv[:, :, 1],
            qkv[:, :, 2],
            None,
            alibi_slopes,
            dropout_p,
            softmax_scale,
            causal,
            window_size_left,
            window_size_right,
            attn_bias,
            return_softmax and dropout_p > 0,
            None
        )

        ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
        ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
        ## TODO: make the padding here instead
        ctx.save_for_backward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
        #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size_left = window_size_left
        ctx.window_size_right = window_size_right
        ctx.deterministic = deterministic
        ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
        ctx.seqlen_k_orig = qkv.shape[1]

        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors

        dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            None,
            None,
            None,
            alibi_slopes,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            ctx.window_size_left,
            ctx.window_size_right,
            ctx.deterministic,
            attn_bias,
            ctx.bias_requires_grad,
            None,
            ctx.seqlen_k_orig,
            None,
            rng_state
        )
        dqkv = torch.stack([dq, dk, dv], dim=2)
        return dqkv, None, None, None, None, None, None, None, ds, None, None

class FlashAttnKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_softmax,
        return_ds
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
            q,
            kv[:, :, 0],
            kv[:, :, 1],
            None,
            alibi_slopes,
            dropout_p,
            softmax_scale,
            causal,
            window_size_left,
            window_size_right,
            attn_bias,
            return_softmax and dropout_p > 0,
            None
        )

        ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
        ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
        ## TODO: make the padding here instead
        ctx.save_for_backward(q, kv[:, :, 0], kv[:, :, 1], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
        #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size_left = window_size_left
        ctx.window_size_right = window_size_right
        ctx.deterministic = deterministic
        ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
        ctx.seqlen_k_orig = kv.shape[1]
        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors

        dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            None,
            None,
            None,
            alibi_slopes,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            ctx.window_size_left,
            ctx.window_size_right,
            ctx.deterministic,
            attn_bias,
            ctx.bias_requires_grad,
            None,
            ctx.seqlen_k_orig,
            None,
            rng_state
        )
        dkv = torch.stack([dk, dv], dim=2)

        return dq, dkv, None, None, None, None, None, None, None, ds, None, None

class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_softmax,
        return_ds
    ):

        batch_size, seqlen_q = q.shape[:2]
        seqlen_k = k.shape[1]

        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        if attn_bias is not None:
            attn_bias = attn_bias.to(q.dtype)

        out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
            q,
            k,
            v,
            None,
            alibi_slopes,
            dropout_p,
            softmax_scale,
            causal,
            window_size_left,
            window_size_right,
            attn_bias,
            return_softmax and dropout_p > 0,
            None
        )

        ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
        ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
        ## TODO: make the padding here instead
        ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes)
        #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)

        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        ctx.window_size_left = window_size_left
        ctx.window_size_right = window_size_right
        ctx.deterministic = deterministic
        ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
        ctx.seqlen_k_orig = k.shape[1]

        return out if not return_softmax else (out, softmax_lse, S_dmask)

    @staticmethod
    def backward(ctx, dout, *args):
        q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors

        dout = dout.contiguous()
        dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
            dout,
            q,
            k,
            v,
            out,
            softmax_lse,
            None,
            None,
            None,
            alibi_slopes,
            ctx.dropout_p,
            ctx.softmax_scale,
            ctx.causal,
            ctx.window_size_left,
            ctx.window_size_right,
            ctx.deterministic,
            attn_bias,
            ctx.bias_requires_grad,
            None,
            ctx.seqlen_k_orig,
            None,
            rng_state
        )

        return dq, dk, dv, None, None, None, None, None, None, None, ds, None, None


def flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size_left=-1,
    window_size_right=-1,  # -1 means infinite context window
    alibi_slopes=None,
    deterministic=False,
    attn_bias=None,
    return_attn_probs=False,
    return_ds=False
):
    """dropout_p should be set to 0.0 during evaluation
    If Q, K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of Q, K, V.
    For multi-query and grouped-query attention (MQA/GQA), please see
    flash_attn_kvpacked_func and flash_attn_func.

    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.

    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
            the attention score of query i and key j.
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnQKVPackedFunc.apply(
        qkv,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_attn_probs,
        return_ds
    )


def flash_attn_kvpacked_func(
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size_left=-1,
    window_size_right=-1,  # -1 means infinite context window
    alibi_slopes=None,
    deterministic=False,
    attn_bias=None,
    return_attn_probs=False,
    return_ds=False
):
    """dropout_p should be set to 0.0 during evaluation
    If K, V are already stacked into 1 tensor, this function will be faster than
    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
    of the gradients of K, V.
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        kv: (batch_size, seqlen, 2, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnKVPackedFunc.apply(
        q,
        kv,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_attn_probs,
        return_ds
    )


def flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size_left=-1,
    window_size_right=-1,  # -1 means infinite context window
    alibi_slopes=None,
    deterministic=False,
    attn_bias=None,
    return_attn_probs=False,
    return_ds=False
):
    """dropout_p should be set to 0.0 during evaluation
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1
    If the row of the mask is all zero, the output will be zero.

    If window_size != (-1, -1), implements sliding window local attention. Query at position i
    will only attend to keys between
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

    Arguments:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float. Dropout probability.
        softmax_scale: float. The scaling of QK^T before applying softmax.
            Default to 1 / sqrt(headdim).
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        window_size: (left, right). If not (-1, -1), implements sliding window local attention.
        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            is added to the attention score of query i and key j.
        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
            which is slightly slower and uses more memory. The forward pass is always deterministic.
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
    Return:
        out: (batch_size, seqlen, nheads, headdim).
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            normalization factor).
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
            The output of softmax (possibly with different scaling). It also encodes the dropout
            pattern (negative means that location was dropped, nonnegative means it was kept).
    """
    return FlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        alibi_slopes,
        deterministic,
        attn_bias,
        return_attn_probs,
        return_ds,
    )