radna commited on
Commit
8067ae0
1 Parent(s): dcb1a0c

Update flash_attention.py

Browse files
Files changed (1) hide show
  1. flash_attention.py +110 -0
flash_attention.py CHANGED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+
5
+ try:
6
+ from .triton_flash_atn import _attention
7
+ from .triton_bert_pading import pad_input, unpad_input
8
+ except:
9
+ print("FlashAttention is not installed.")
10
+
11
+
12
+ class FlashAttention(nn.Module):
13
+ """Implement the scaled dot product attention with softmax.
14
+ Arguments
15
+ ---------
16
+ softmax_scale: The temperature to use for the softmax attention.
17
+ (default: 1/sqrt(d_keys) where d_keys is computed at
18
+ runtime)
19
+ attention_dropout: The dropout rate to apply to the attention
20
+ (default: 0.0)
21
+ """
22
+
23
+ def __init__(
24
+ self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
25
+ ):
26
+ super().__init__()
27
+ self.softmax_scale = softmax_scale
28
+ self.dropout_p = attention_dropout
29
+
30
+ def forward(
31
+ self,
32
+ qkv,
33
+ key_padding_mask=None,
34
+ causal=False,
35
+ cu_seqlens=None,
36
+ max_s=None,
37
+ need_weights=False,
38
+ ):
39
+ """Implements the multihead softmax attention.
40
+ Arguments
41
+ ---------
42
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
43
+ if unpadded: (nnz, 3, h, d)
44
+ key_padding_mask: a bool tensor of shape (B, S)
45
+ """
46
+ assert not need_weights
47
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
48
+ assert qkv.is_cuda
49
+
50
+ if cu_seqlens is None:
51
+ batch_size = qkv.shape[0]
52
+ seqlen = qkv.shape[1]
53
+ if key_padding_mask is None:
54
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
55
+ max_s = seqlen
56
+ cu_seqlens = torch.arange(
57
+ 0,
58
+ (batch_size + 1) * seqlen,
59
+ step=seqlen,
60
+ dtype=torch.int32,
61
+ device=qkv.device,
62
+ )
63
+ output = _attention.apply(
64
+ qkv,
65
+ cu_seqlens,
66
+ max_s,
67
+ self.dropout_p if self.training else 0.0,
68
+ self.softmax_scale,
69
+ causal,
70
+ )
71
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
72
+ else:
73
+ nheads = qkv.shape[-2]
74
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
75
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
76
+ x_unpad = rearrange(
77
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
78
+ )
79
+ output_unpad = _attention.apply(
80
+ x_unpad,
81
+ cu_seqlens,
82
+ max_s,
83
+ self.dropout_p if self.training else 0.0,
84
+ self.softmax_scale,
85
+ causal,
86
+ )
87
+ output = rearrange(
88
+ pad_input(
89
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
90
+ indices,
91
+ batch_size,
92
+ seqlen,
93
+ ),
94
+ "b s (h d) -> b s h d",
95
+ h=nheads,
96
+ )
97
+ else:
98
+ assert max_s is not None
99
+ output = _attention.apply(
100
+ qkv,
101
+ cu_seqlens,
102
+ max_s,
103
+ self.dropout_p if self.training else 0.0,
104
+ self.softmax_scale,
105
+ causal,
106
+ )
107
+
108
+ return output, None
109
+
110
+