Jackmin108 commited on
Commit
622abd4
1 Parent(s): 33026dc

Allow flash attn (#1)

Browse files

- old flash attn (4f24e0fef7d21fdb772aa2981062937a55ab26ab)
- set flash attn as option in config (6f3de154e580bce3f90396153114f28ddc5d672b)
- use tri daos version with backprop (c3b584d86f66bedb3802fc3309fff5d1f04246ca)

Files changed (3) hide show
  1. configuration_bert.py +4 -0
  2. flash_attn_triton.py +1160 -0
  3. modeling_bert.py +23 -1
configuration_bert.py CHANGED
@@ -127,6 +127,8 @@ class JinaBertConfig(PretrainedConfig):
127
  emb_pooler (`str`, *optional*, defaults to `None`):
128
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
129
  Should be one of `None`, `"mean"`.
 
 
130
 
131
  Examples:
132
 
@@ -164,6 +166,7 @@ class JinaBertConfig(PretrainedConfig):
164
  classifier_dropout=None,
165
  feed_forward_type="original",
166
  emb_pooler=None,
 
167
  **kwargs,
168
  ):
169
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -185,6 +188,7 @@ class JinaBertConfig(PretrainedConfig):
185
  self.classifier_dropout = classifier_dropout
186
  self.feed_forward_type = feed_forward_type
187
  self.emb_pooler = emb_pooler
 
188
 
189
 
190
  class JinaBertOnnxConfig(OnnxConfig):
 
127
  emb_pooler (`str`, *optional*, defaults to `None`):
128
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
129
  Should be one of `None`, `"mean"`.
130
+ with_flash (`bool`, *optional*, defaults to `False`):
131
+ Whether to use flash attention. Only works for `triton==2.0.0.dev20230208`
132
 
133
  Examples:
134
 
 
166
  classifier_dropout=None,
167
  feed_forward_type="original",
168
  emb_pooler=None,
169
+ with_flash=False,
170
  **kwargs,
171
  ):
172
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
188
  self.classifier_dropout = classifier_dropout
189
  self.feed_forward_type = feed_forward_type
190
  self.emb_pooler = emb_pooler
191
+ self.with_flash = with_flash
192
 
193
 
194
  class JinaBertOnnxConfig(OnnxConfig):
flash_attn_triton.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ *Experimental* implementation of FlashAttention in Triton.
3
+ Tested with triton==2.0.0.dev20221202.
4
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
5
+ other than 64:
6
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
7
+ We'll update this implementation with the new Triton backend once this is fixed.
8
+
9
+ We use the FlashAttention implementation from Phil Tillet a starting point.
10
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
11
+
12
+ Changes:
13
+ - Implement both causal and non-causal attention.
14
+ - Implement both self-attention and cross-attention.
15
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
16
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
17
+ - Support attention bias.
18
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
19
+ - Make the backward for d=128 much faster by reducing register spilling.
20
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
21
+ small batch size * nheads.
22
+
23
+ Caution:
24
+ - This is an *experimental* implementation. The forward pass should be quite robust but
25
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
26
+ - This implementation has only been tested on A100.
27
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
28
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
29
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
30
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
31
+ that there are none left for other head dimensions.
32
+
33
+ Differences between this Triton version and the CUDA version:
34
+ - Triton version doesn't support dropout.
35
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
36
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
37
+ than CUDA forward + backward.
38
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
39
+ - Triton version supports attention bias, while CUDA version doesn't.
40
+ """
41
+
42
+ import math
43
+
44
+ import torch
45
+ import triton
46
+ import triton.language as tl
47
+
48
+
49
+ # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
50
+ # @triton.autotune(
51
+ # configs=[
52
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
53
+ # # This config has a race condition when EVEN_M == False, disabling it for now.
54
+ # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
55
+ # ],
56
+ # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
57
+ # )
58
+ @triton.heuristics(
59
+ {
60
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
61
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
62
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
63
+ }
64
+ )
65
+ @triton.jit
66
+ def _fwd_kernel(
67
+ Q,
68
+ K,
69
+ V,
70
+ Bias,
71
+ Out,
72
+ Lse,
73
+ TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
74
+ softmax_scale,
75
+ stride_qb,
76
+ stride_qh,
77
+ stride_qm,
78
+ stride_kb,
79
+ stride_kh,
80
+ stride_kn,
81
+ stride_vb,
82
+ stride_vh,
83
+ stride_vn,
84
+ stride_bb,
85
+ stride_bh,
86
+ stride_bm,
87
+ stride_ob,
88
+ stride_oh,
89
+ stride_om,
90
+ nheads,
91
+ seqlen_q,
92
+ seqlen_k,
93
+ seqlen_q_rounded,
94
+ headdim,
95
+ CACHE_KEY_SEQLEN_Q,
96
+ CACHE_KEY_SEQLEN_K,
97
+ BIAS_TYPE: tl.constexpr,
98
+ IS_CAUSAL: tl.constexpr,
99
+ BLOCK_HEADDIM: tl.constexpr,
100
+ EVEN_M: tl.constexpr,
101
+ EVEN_N: tl.constexpr,
102
+ EVEN_HEADDIM: tl.constexpr,
103
+ BLOCK_M: tl.constexpr,
104
+ BLOCK_N: tl.constexpr,
105
+ ):
106
+ start_m = tl.program_id(0)
107
+ off_hb = tl.program_id(1)
108
+ off_b = off_hb // nheads
109
+ off_h = off_hb % nheads
110
+ # off_b = tl.program_id(1)
111
+ # off_h = tl.program_id(2)
112
+ # off_hb = off_b * nheads + off_h
113
+ # initialize offsets
114
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
115
+ offs_n = tl.arange(0, BLOCK_N)
116
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
117
+ # Initialize pointers to Q, K, V
118
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
119
+ # https://github.com/openai/triton/issues/741
120
+ # I'm seeing a tiny bit of difference (5-7us)
121
+ q_ptrs = (
122
+ Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
123
+ )
124
+ k_ptrs = (
125
+ K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
126
+ )
127
+ v_ptrs = (
128
+ V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
129
+ )
130
+ if BIAS_TYPE == "vector":
131
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
132
+ elif BIAS_TYPE == "matrix":
133
+ b_ptrs = (
134
+ Bias
135
+ + off_b * stride_bb
136
+ + off_h * stride_bh
137
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
138
+ )
139
+ # initialize pointer to m and l
140
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
141
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
142
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
143
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
144
+ # load q: it will stay in SRAM throughout
145
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
146
+ # tl.load(q_ptrs), we get the wrong output!
147
+ if EVEN_M & EVEN_N:
148
+ if EVEN_HEADDIM:
149
+ q = tl.load(q_ptrs)
150
+ else:
151
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
152
+ else:
153
+ if EVEN_HEADDIM:
154
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
155
+ else:
156
+ q = tl.load(
157
+ q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
158
+ )
159
+ # loop over k, v and update accumulator
160
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
161
+ for start_n in range(0, end_n, BLOCK_N):
162
+ start_n = tl.multiple_of(start_n, BLOCK_N)
163
+ # -- compute qk ----
164
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
165
+ if EVEN_HEADDIM:
166
+ k = tl.load(k_ptrs + start_n * stride_kn)
167
+ else:
168
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
169
+ else:
170
+ if EVEN_HEADDIM:
171
+ k = tl.load(
172
+ k_ptrs + start_n * stride_kn,
173
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
174
+ other=0.0,
175
+ )
176
+ else:
177
+ k = tl.load(
178
+ k_ptrs + start_n * stride_kn,
179
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
180
+ other=0.0,
181
+ )
182
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
183
+ qk += tl.dot(q, k, trans_b=True)
184
+ # Trying to combine the two masks seem to make the result wrong
185
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
186
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
187
+ if IS_CAUSAL:
188
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
189
+ if BIAS_TYPE != "none":
190
+ if BIAS_TYPE == "vector":
191
+ if EVEN_N:
192
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
193
+ else:
194
+ bias = tl.load(
195
+ b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
196
+ ).to(tl.float32)
197
+ bias = bias[None, :]
198
+ elif BIAS_TYPE == "matrix":
199
+ if EVEN_M & EVEN_N:
200
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
201
+ else:
202
+ bias = tl.load(
203
+ b_ptrs + start_n,
204
+ mask=(offs_m[:, None] < seqlen_q)
205
+ & ((start_n + offs_n)[None, :] < seqlen_k),
206
+ other=0.0,
207
+ ).to(tl.float32)
208
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
209
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
210
+ # to multiply with softmax_scale here.
211
+ qk = qk * softmax_scale + bias
212
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
213
+ p = tl.exp(qk - m_ij[:, None])
214
+ else:
215
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
216
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
217
+ l_ij = tl.sum(p, 1)
218
+
219
+ # scale acc_o
220
+ acc_o_scale = tl.exp(m_i - m_ij)
221
+
222
+ # # -- update output accumulator --
223
+ # BUG: have to store and immediately load
224
+ tl.store(t_ptrs, acc_o_scale)
225
+ acc_o_scale = tl.load(t_ptrs)
226
+ acc_o = acc_o * acc_o_scale[:, None]
227
+ # update acc_o
228
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
229
+ if EVEN_HEADDIM:
230
+ v = tl.load(v_ptrs + start_n * stride_vn)
231
+ else:
232
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
233
+ else:
234
+ if EVEN_HEADDIM:
235
+ v = tl.load(
236
+ v_ptrs + start_n * stride_vn,
237
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
238
+ other=0.0,
239
+ )
240
+ else:
241
+ v = tl.load(
242
+ v_ptrs + start_n * stride_vn,
243
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
244
+ other=0.0,
245
+ )
246
+ p = p.to(v.dtype)
247
+ acc_o += tl.dot(p, v)
248
+
249
+ # -- update statistics
250
+ m_i = m_ij
251
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
252
+ lse_i = m_ij + tl.log(l_i_new)
253
+
254
+ o_scale = tl.exp(m_i - lse_i)
255
+ # BUG: have to store and immediately load
256
+ tl.store(t_ptrs, o_scale)
257
+ o_scale = tl.load(t_ptrs)
258
+ acc_o = acc_o * o_scale[:, None]
259
+ # rematerialize offsets to save registers
260
+ start_m = tl.program_id(0)
261
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
262
+ # write back l and m
263
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
264
+ tl.store(lse_ptrs, lse_i)
265
+ # initialize pointers to output
266
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
267
+ out_ptrs = (
268
+ Out
269
+ + off_b * stride_ob
270
+ + off_h * stride_oh
271
+ + (offs_m[:, None] * stride_om + offs_d[None, :])
272
+ )
273
+ if EVEN_M:
274
+ if EVEN_HEADDIM:
275
+ tl.store(out_ptrs, acc_o)
276
+ else:
277
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
278
+ else:
279
+ if EVEN_HEADDIM:
280
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
281
+ else:
282
+ tl.store(
283
+ out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
284
+ )
285
+
286
+
287
+ @triton.jit
288
+ def _bwd_preprocess_do_o_dot(
289
+ Out,
290
+ DO,
291
+ Delta,
292
+ stride_ob,
293
+ stride_oh,
294
+ stride_om,
295
+ stride_dob,
296
+ stride_doh,
297
+ stride_dom,
298
+ nheads,
299
+ seqlen_q,
300
+ seqlen_q_rounded,
301
+ headdim,
302
+ BLOCK_M: tl.constexpr,
303
+ BLOCK_HEADDIM: tl.constexpr,
304
+ ):
305
+ start_m = tl.program_id(0)
306
+ off_hb = tl.program_id(1)
307
+ off_b = off_hb // nheads
308
+ off_h = off_hb % nheads
309
+ # initialize offsets
310
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
311
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
312
+ # load
313
+ o = tl.load(
314
+ Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
315
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
316
+ other=0.0,
317
+ ).to(tl.float32)
318
+ do = tl.load(
319
+ DO
320
+ + off_b * stride_dob
321
+ + off_h * stride_doh
322
+ + offs_m[:, None] * stride_dom
323
+ + offs_d[None, :],
324
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
325
+ other=0.0,
326
+ ).to(tl.float32)
327
+ delta = tl.sum(o * do, axis=1)
328
+ # write-back
329
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
330
+
331
+
332
+ @triton.jit
333
+ def _bwd_store_dk_dv(
334
+ dk_ptrs,
335
+ dv_ptrs,
336
+ dk,
337
+ dv,
338
+ offs_n,
339
+ offs_d,
340
+ seqlen_k,
341
+ headdim,
342
+ EVEN_M: tl.constexpr,
343
+ EVEN_N: tl.constexpr,
344
+ EVEN_HEADDIM: tl.constexpr,
345
+ ):
346
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
347
+ # if we just call tl.store(dv_ptrs), there's a race condition
348
+ if EVEN_N & EVEN_M:
349
+ if EVEN_HEADDIM:
350
+ tl.store(dv_ptrs, dv)
351
+ tl.store(dk_ptrs, dk)
352
+ else:
353
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
354
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
355
+ else:
356
+ if EVEN_HEADDIM:
357
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
358
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
359
+ else:
360
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
361
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
362
+
363
+
364
+ @triton.jit
365
+ def _bwd_kernel_one_col_block(
366
+ start_n,
367
+ Q,
368
+ K,
369
+ V,
370
+ Bias,
371
+ DO,
372
+ DQ,
373
+ DK,
374
+ DV,
375
+ LSE,
376
+ D,
377
+ softmax_scale,
378
+ stride_qm,
379
+ stride_kn,
380
+ stride_vn,
381
+ stride_bm,
382
+ stride_dom,
383
+ stride_dqm,
384
+ stride_dkn,
385
+ stride_dvn,
386
+ seqlen_q,
387
+ seqlen_k,
388
+ headdim,
389
+ ATOMIC_ADD: tl.constexpr,
390
+ BIAS_TYPE: tl.constexpr,
391
+ IS_CAUSAL: tl.constexpr,
392
+ BLOCK_HEADDIM: tl.constexpr,
393
+ EVEN_M: tl.constexpr,
394
+ EVEN_N: tl.constexpr,
395
+ EVEN_HEADDIM: tl.constexpr,
396
+ BLOCK_M: tl.constexpr,
397
+ BLOCK_N: tl.constexpr,
398
+ ):
399
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
400
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
401
+ # initialize row/col offsets
402
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
403
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
404
+ offs_m = tl.arange(0, BLOCK_M)
405
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
406
+ # initialize pointers to value-like data
407
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
408
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
409
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
410
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
411
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
412
+ if BIAS_TYPE == "vector":
413
+ b_ptrs = Bias + offs_n
414
+ elif BIAS_TYPE == "matrix":
415
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
416
+ # initialize dv and dk
417
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
418
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
419
+ # There seems to be some problem with Triton pipelining that makes results wrong for
420
+ # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
421
+ # may have zero step, and pipelining with the bias matrix could screw it up.
422
+ # So we just exit early.
423
+ if begin_m >= seqlen_q:
424
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
425
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
426
+ _bwd_store_dk_dv(
427
+ dk_ptrs,
428
+ dv_ptrs,
429
+ dk,
430
+ dv,
431
+ offs_n,
432
+ offs_d,
433
+ seqlen_k,
434
+ headdim,
435
+ EVEN_M=EVEN_M,
436
+ EVEN_N=EVEN_N,
437
+ EVEN_HEADDIM=EVEN_HEADDIM,
438
+ )
439
+ return
440
+ # k and v stay in SRAM throughout
441
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
442
+ # if we just call tl.load(k_ptrs), we get the wrong output!
443
+ if EVEN_N & EVEN_M:
444
+ if EVEN_HEADDIM:
445
+ k = tl.load(k_ptrs)
446
+ v = tl.load(v_ptrs)
447
+ else:
448
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
449
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
450
+ else:
451
+ if EVEN_HEADDIM:
452
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
453
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
454
+ else:
455
+ k = tl.load(
456
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
457
+ )
458
+ v = tl.load(
459
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
460
+ )
461
+ # loop over rows
462
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
463
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
464
+ start_m = tl.multiple_of(start_m, BLOCK_M)
465
+ offs_m_curr = start_m + offs_m
466
+ # load q, k, v, do on-chip
467
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
468
+ if EVEN_M & EVEN_HEADDIM:
469
+ q = tl.load(q_ptrs)
470
+ else:
471
+ if EVEN_HEADDIM:
472
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
473
+ else:
474
+ q = tl.load(
475
+ q_ptrs,
476
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
477
+ other=0.0,
478
+ )
479
+ # recompute p = softmax(qk, dim=-1).T
480
+ qk = tl.dot(q, k, trans_b=True)
481
+ # Trying to combine the two masks seem to make the result wrong
482
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
483
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
484
+ if IS_CAUSAL:
485
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
486
+ if BIAS_TYPE != "none":
487
+ tl.debug_barrier() # Race condition otherwise
488
+ if BIAS_TYPE == "vector":
489
+ if EVEN_N:
490
+ bias = tl.load(b_ptrs).to(tl.float32)
491
+ else:
492
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
493
+ bias = bias[None, :]
494
+ elif BIAS_TYPE == "matrix":
495
+ if EVEN_M & EVEN_N:
496
+ bias = tl.load(b_ptrs).to(tl.float32)
497
+ else:
498
+ bias = tl.load(
499
+ b_ptrs,
500
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
501
+ other=0.0,
502
+ ).to(tl.float32)
503
+ qk = qk * softmax_scale + bias
504
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
505
+ # Also wrong for headdim=64.
506
+ if not (EVEN_M & EVEN_HEADDIM):
507
+ tl.debug_barrier()
508
+ lse_i = tl.load(LSE + offs_m_curr)
509
+ if BIAS_TYPE == "none":
510
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
511
+ else:
512
+ p = tl.exp(qk - lse_i[:, None])
513
+ # compute dv
514
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
515
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
516
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
517
+ # the output is correct.
518
+ if EVEN_M & EVEN_HEADDIM:
519
+ do = tl.load(do_ptrs)
520
+ else:
521
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
522
+ do = tl.load(
523
+ do_ptrs,
524
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
525
+ other=0.0,
526
+ )
527
+ # if EVEN_M:
528
+ # if EVEN_HEADDIM:
529
+ # do = tl.load(do_ptrs)
530
+ # else:
531
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
532
+ # else:
533
+ # if EVEN_HEADDIM:
534
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
535
+ # else:
536
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
537
+ # & (offs_d[None, :] < headdim), other=0.0)
538
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
539
+ # compute dp = dot(v, do)
540
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
541
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
542
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
543
+ if not (EVEN_M & EVEN_HEADDIM):
544
+ tl.debug_barrier()
545
+ dp = tl.dot(do, v, trans_b=True)
546
+ # There's a race condition for headdim=48
547
+ if not EVEN_HEADDIM:
548
+ tl.debug_barrier()
549
+ # compute ds = p * (dp - delta[:, None])
550
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
551
+ Di = tl.load(D + offs_m_curr)
552
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
553
+ # for BLOCK_HEADDIM=128
554
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
555
+ # compute dk = dot(ds.T, q)
556
+ dk += tl.dot(ds, q, trans_a=True)
557
+ # compute dq
558
+ if not (
559
+ EVEN_M & EVEN_HEADDIM
560
+ ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
561
+ tl.debug_barrier()
562
+ if not ATOMIC_ADD:
563
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
564
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
565
+ dq += tl.dot(ds, k)
566
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
567
+ else:
568
+ if EVEN_HEADDIM:
569
+ dq = tl.load(
570
+ dq_ptrs,
571
+ mask=offs_m_curr[:, None] < seqlen_q,
572
+ other=0.0,
573
+ eviction_policy="evict_last",
574
+ )
575
+ dq += tl.dot(ds, k)
576
+ tl.store(
577
+ dq_ptrs,
578
+ dq,
579
+ mask=offs_m_curr[:, None] < seqlen_q,
580
+ eviction_policy="evict_last",
581
+ )
582
+ else:
583
+ dq = tl.load(
584
+ dq_ptrs,
585
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
586
+ other=0.0,
587
+ eviction_policy="evict_last",
588
+ )
589
+ dq += tl.dot(ds, k)
590
+ tl.store(
591
+ dq_ptrs,
592
+ dq,
593
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
594
+ eviction_policy="evict_last",
595
+ )
596
+ else: # If we're parallelizing across the seqlen_k dimension
597
+ dq = tl.dot(ds, k)
598
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
599
+ tl.atomic_add(dq_ptrs, dq)
600
+ else:
601
+ if EVEN_HEADDIM:
602
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
603
+ else:
604
+ tl.atomic_add(
605
+ dq_ptrs,
606
+ dq,
607
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
608
+ )
609
+ # increment pointers
610
+ dq_ptrs += BLOCK_M * stride_dqm
611
+ q_ptrs += BLOCK_M * stride_qm
612
+ do_ptrs += BLOCK_M * stride_dom
613
+ if BIAS_TYPE == "matrix":
614
+ b_ptrs += BLOCK_M * stride_bm
615
+ # write-back
616
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
617
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
618
+ _bwd_store_dk_dv(
619
+ dk_ptrs,
620
+ dv_ptrs,
621
+ dk,
622
+ dv,
623
+ offs_n,
624
+ offs_d,
625
+ seqlen_k,
626
+ headdim,
627
+ EVEN_M=EVEN_M,
628
+ EVEN_N=EVEN_N,
629
+ EVEN_HEADDIM=EVEN_HEADDIM,
630
+ )
631
+
632
+
633
+ def init_to_zero(name):
634
+ return lambda nargs: nargs[name].zero_()
635
+
636
+
637
+ @triton.autotune(
638
+ configs=[
639
+ triton.Config(
640
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
641
+ num_warps=8,
642
+ num_stages=1,
643
+ pre_hook=init_to_zero("DQ"),
644
+ ),
645
+ triton.Config(
646
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
647
+ num_warps=8,
648
+ num_stages=1,
649
+ pre_hook=init_to_zero("DQ"),
650
+ ),
651
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
652
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
653
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
654
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
655
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
656
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
657
+ ],
658
+ key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
659
+ )
660
+ @triton.heuristics(
661
+ {
662
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
663
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
664
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
665
+ }
666
+ )
667
+ @triton.jit
668
+ def _bwd_kernel(
669
+ Q,
670
+ K,
671
+ V,
672
+ Bias,
673
+ DO,
674
+ DQ,
675
+ DK,
676
+ DV,
677
+ LSE,
678
+ D,
679
+ softmax_scale,
680
+ stride_qb,
681
+ stride_qh,
682
+ stride_qm,
683
+ stride_kb,
684
+ stride_kh,
685
+ stride_kn,
686
+ stride_vb,
687
+ stride_vh,
688
+ stride_vn,
689
+ stride_bb,
690
+ stride_bh,
691
+ stride_bm,
692
+ stride_dob,
693
+ stride_doh,
694
+ stride_dom,
695
+ stride_dqb,
696
+ stride_dqh,
697
+ stride_dqm,
698
+ stride_dkb,
699
+ stride_dkh,
700
+ stride_dkn,
701
+ stride_dvb,
702
+ stride_dvh,
703
+ stride_dvn,
704
+ nheads,
705
+ seqlen_q,
706
+ seqlen_k,
707
+ seqlen_q_rounded,
708
+ headdim,
709
+ CACHE_KEY_SEQLEN_Q,
710
+ CACHE_KEY_SEQLEN_K,
711
+ BIAS_TYPE: tl.constexpr,
712
+ IS_CAUSAL: tl.constexpr,
713
+ BLOCK_HEADDIM: tl.constexpr,
714
+ SEQUENCE_PARALLEL: tl.constexpr,
715
+ EVEN_M: tl.constexpr,
716
+ EVEN_N: tl.constexpr,
717
+ EVEN_HEADDIM: tl.constexpr,
718
+ BLOCK_M: tl.constexpr,
719
+ BLOCK_N: tl.constexpr,
720
+ ):
721
+ off_hb = tl.program_id(1)
722
+ off_b = off_hb // nheads
723
+ off_h = off_hb % nheads
724
+ # offset pointers for batch/head
725
+ Q += off_b * stride_qb + off_h * stride_qh
726
+ K += off_b * stride_kb + off_h * stride_kh
727
+ V += off_b * stride_vb + off_h * stride_vh
728
+ DO += off_b * stride_dob + off_h * stride_doh
729
+ DQ += off_b * stride_dqb + off_h * stride_dqh
730
+ DK += off_b * stride_dkb + off_h * stride_dkh
731
+ DV += off_b * stride_dvb + off_h * stride_dvh
732
+ if BIAS_TYPE != "none":
733
+ Bias += off_b * stride_bb + off_h * stride_bh
734
+ # pointer to row-wise quantities in value-like data
735
+ D += off_hb * seqlen_q_rounded
736
+ LSE += off_hb * seqlen_q_rounded
737
+ if not SEQUENCE_PARALLEL:
738
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
739
+ for start_n in range(0, num_block_n):
740
+ _bwd_kernel_one_col_block(
741
+ start_n,
742
+ Q,
743
+ K,
744
+ V,
745
+ Bias,
746
+ DO,
747
+ DQ,
748
+ DK,
749
+ DV,
750
+ LSE,
751
+ D,
752
+ softmax_scale,
753
+ stride_qm,
754
+ stride_kn,
755
+ stride_vn,
756
+ stride_bm,
757
+ stride_dom,
758
+ stride_dqm,
759
+ stride_dkn,
760
+ stride_dvn,
761
+ seqlen_q,
762
+ seqlen_k,
763
+ headdim,
764
+ ATOMIC_ADD=False,
765
+ BIAS_TYPE=BIAS_TYPE,
766
+ IS_CAUSAL=IS_CAUSAL,
767
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
768
+ EVEN_M=EVEN_M,
769
+ EVEN_N=EVEN_N,
770
+ EVEN_HEADDIM=EVEN_HEADDIM,
771
+ BLOCK_M=BLOCK_M,
772
+ BLOCK_N=BLOCK_N,
773
+ )
774
+ else:
775
+ start_n = tl.program_id(0)
776
+ _bwd_kernel_one_col_block(
777
+ start_n,
778
+ Q,
779
+ K,
780
+ V,
781
+ Bias,
782
+ DO,
783
+ DQ,
784
+ DK,
785
+ DV,
786
+ LSE,
787
+ D,
788
+ softmax_scale,
789
+ stride_qm,
790
+ stride_kn,
791
+ stride_vn,
792
+ stride_bm,
793
+ stride_dom,
794
+ stride_dqm,
795
+ stride_dkn,
796
+ stride_dvn,
797
+ seqlen_q,
798
+ seqlen_k,
799
+ headdim,
800
+ ATOMIC_ADD=True,
801
+ BIAS_TYPE=BIAS_TYPE,
802
+ IS_CAUSAL=IS_CAUSAL,
803
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
804
+ EVEN_M=EVEN_M,
805
+ EVEN_N=EVEN_N,
806
+ EVEN_HEADDIM=EVEN_HEADDIM,
807
+ BLOCK_M=BLOCK_M,
808
+ BLOCK_N=BLOCK_N,
809
+ )
810
+
811
+
812
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
813
+ # shape constraints
814
+ batch, seqlen_q, nheads, d = q.shape
815
+ _, seqlen_k, _, _ = k.shape
816
+ assert k.shape == (batch, seqlen_k, nheads, d)
817
+ assert v.shape == (batch, seqlen_k, nheads, d)
818
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
819
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
820
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
821
+ assert q.is_cuda and k.is_cuda and v.is_cuda
822
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
823
+
824
+ has_bias = bias is not None
825
+ bias_type = "none"
826
+ if has_bias:
827
+ assert bias.dtype in [q.dtype, torch.float]
828
+ assert bias.is_cuda
829
+ assert bias.dim() == 4
830
+ if bias.stride(-1) != 1:
831
+ bias = bias.contiguous()
832
+ if bias.shape[2:] == (1, seqlen_k):
833
+ bias_type = "vector"
834
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
835
+ bias_type = "matrix"
836
+ else:
837
+ raise RuntimeError(
838
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
839
+ )
840
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
841
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
842
+
843
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
844
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
845
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
846
+ o = torch.empty_like(q)
847
+
848
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
849
+ BLOCK = 128
850
+ num_warps = 4 if d <= 64 else 8
851
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
852
+ _fwd_kernel[grid](
853
+ q,
854
+ k,
855
+ v,
856
+ bias,
857
+ o,
858
+ lse,
859
+ tmp,
860
+ softmax_scale,
861
+ q.stride(0),
862
+ q.stride(2),
863
+ q.stride(1),
864
+ k.stride(0),
865
+ k.stride(2),
866
+ k.stride(1),
867
+ v.stride(0),
868
+ v.stride(2),
869
+ v.stride(1),
870
+ *bias_strides,
871
+ o.stride(0),
872
+ o.stride(2),
873
+ o.stride(1),
874
+ nheads,
875
+ seqlen_q,
876
+ seqlen_k,
877
+ seqlen_q_rounded,
878
+ d,
879
+ seqlen_q // 32,
880
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
881
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
882
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
883
+ bias_type,
884
+ causal,
885
+ BLOCK_HEADDIM,
886
+ BLOCK_M=BLOCK,
887
+ BLOCK_N=BLOCK,
888
+ num_warps=num_warps,
889
+ num_stages=1,
890
+ )
891
+ return o, lse, softmax_scale # softmax_scale could have been updated
892
+
893
+
894
+ def _flash_attn_backward(
895
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
896
+ ):
897
+ # Make sure that the last dimension is contiguous
898
+ if do.stride(-1) != 1:
899
+ do = do.contiguous()
900
+ batch, seqlen_q, nheads, d = q.shape
901
+ _, seqlen_k, _, _ = k.shape
902
+ # assert d in {16, 32, 64, 128}
903
+ assert d <= 128
904
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
905
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
906
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
907
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
908
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
909
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
910
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
911
+ delta = torch.empty_like(lse)
912
+ # delta = torch.zeros_like(lse)
913
+
914
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
915
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
916
+ _bwd_preprocess_do_o_dot[grid](
917
+ o,
918
+ do,
919
+ delta,
920
+ o.stride(0),
921
+ o.stride(2),
922
+ o.stride(1),
923
+ do.stride(0),
924
+ do.stride(2),
925
+ do.stride(1),
926
+ nheads,
927
+ seqlen_q,
928
+ seqlen_q_rounded,
929
+ d,
930
+ BLOCK_M=128,
931
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
932
+ )
933
+
934
+ has_bias = bias is not None
935
+ bias_type = "none"
936
+ if has_bias:
937
+ assert bias.dtype in [q.dtype, torch.float]
938
+ assert bias.is_cuda
939
+ assert bias.dim() == 4
940
+ assert bias.stride(-1) == 1
941
+ if bias.shape[2:] == (1, seqlen_k):
942
+ bias_type = "vector"
943
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
944
+ bias_type = "matrix"
945
+ else:
946
+ raise RuntimeError(
947
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
948
+ )
949
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
950
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
951
+
952
+ # BLOCK_M = 128
953
+ # BLOCK_N = 64
954
+ # num_warps = 4
955
+ grid = lambda META: (
956
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
957
+ batch * nheads,
958
+ )
959
+ _bwd_kernel[grid](
960
+ q,
961
+ k,
962
+ v,
963
+ bias,
964
+ do,
965
+ dq_accum,
966
+ dk,
967
+ dv,
968
+ lse,
969
+ delta,
970
+ softmax_scale,
971
+ q.stride(0),
972
+ q.stride(2),
973
+ q.stride(1),
974
+ k.stride(0),
975
+ k.stride(2),
976
+ k.stride(1),
977
+ v.stride(0),
978
+ v.stride(2),
979
+ v.stride(1),
980
+ *bias_strides,
981
+ do.stride(0),
982
+ do.stride(2),
983
+ do.stride(1),
984
+ dq_accum.stride(0),
985
+ dq_accum.stride(2),
986
+ dq_accum.stride(1),
987
+ dk.stride(0),
988
+ dk.stride(2),
989
+ dk.stride(1),
990
+ dv.stride(0),
991
+ dv.stride(2),
992
+ dv.stride(1),
993
+ nheads,
994
+ seqlen_q,
995
+ seqlen_k,
996
+ seqlen_q_rounded,
997
+ d,
998
+ seqlen_q // 32,
999
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
1000
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
1001
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1002
+ bias_type,
1003
+ causal,
1004
+ BLOCK_HEADDIM,
1005
+ # SEQUENCE_PARALLEL=False,
1006
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1007
+ # num_warps=num_warps,
1008
+ # num_stages=1,
1009
+ )
1010
+ dq.copy_(dq_accum)
1011
+
1012
+
1013
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
1014
+ @staticmethod
1015
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
1016
+ """
1017
+ qkv: (batch, seqlen, 3, nheads, headdim)
1018
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
1019
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
1020
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
1021
+ """
1022
+ # Make sure that the last dimension is contiguous
1023
+ if qkv.stride(-1) != 1:
1024
+ qkv = qkv.contiguous()
1025
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1026
+ qkv[:, :, 0],
1027
+ qkv[:, :, 1],
1028
+ qkv[:, :, 2],
1029
+ bias=bias,
1030
+ causal=causal,
1031
+ softmax_scale=softmax_scale,
1032
+ )
1033
+ ctx.save_for_backward(qkv, o, lse, bias)
1034
+ ctx.causal = causal
1035
+ return o
1036
+
1037
+ @staticmethod
1038
+ def backward(ctx, do):
1039
+ qkv, o, lse, bias = ctx.saved_tensors
1040
+ assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
1041
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1042
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1043
+ with torch.inference_mode():
1044
+ dqkv = torch.empty_like(qkv)
1045
+ _flash_attn_backward(
1046
+ do,
1047
+ qkv[:, :, 0],
1048
+ qkv[:, :, 1],
1049
+ qkv[:, :, 2],
1050
+ o,
1051
+ lse,
1052
+ dqkv[:, :, 0],
1053
+ dqkv[:, :, 1],
1054
+ dqkv[:, :, 2],
1055
+ bias=bias,
1056
+ causal=ctx.causal,
1057
+ softmax_scale=ctx.softmax_scale,
1058
+ )
1059
+ return dqkv, None, None, None
1060
+
1061
+
1062
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
1063
+
1064
+
1065
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
1066
+ @staticmethod
1067
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
1068
+ """
1069
+ q: (batch, seqlen_q, nheads, headdim)
1070
+ kv: (batch, seqlen_k, 2, nheads, headdim)
1071
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1072
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1073
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1074
+ """
1075
+ # Make sure that the last dimension is contiguous
1076
+ q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
1077
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1078
+ q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
1079
+ )
1080
+ ctx.save_for_backward(q, kv, o, lse, bias)
1081
+ ctx.causal = causal
1082
+ return o
1083
+
1084
+ @staticmethod
1085
+ def backward(ctx, do):
1086
+ q, kv, o, lse, bias = ctx.saved_tensors
1087
+ if len(ctx.needs_input_grad) >= 3:
1088
+ assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
1089
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1090
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1091
+ with torch.inference_mode():
1092
+ dq = torch.empty_like(q)
1093
+ dkv = torch.empty_like(kv)
1094
+ _flash_attn_backward(
1095
+ do,
1096
+ q,
1097
+ kv[:, :, 0],
1098
+ kv[:, :, 1],
1099
+ o,
1100
+ lse,
1101
+ dq,
1102
+ dkv[:, :, 0],
1103
+ dkv[:, :, 1],
1104
+ bias=bias,
1105
+ causal=ctx.causal,
1106
+ softmax_scale=ctx.softmax_scale,
1107
+ )
1108
+ return dq, dkv, None, None, None
1109
+
1110
+
1111
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
1112
+
1113
+
1114
+ class FlashAttnFunc(torch.autograd.Function):
1115
+ @staticmethod
1116
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1117
+ """
1118
+ q: (batch_size, seqlen_q, nheads, headdim)
1119
+ k, v: (batch_size, seqlen_k, nheads, headdim)
1120
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1121
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1122
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1123
+ """
1124
+ # Make sure that the last dimension is contiguous
1125
+ q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1126
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1127
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1128
+ )
1129
+ ctx.save_for_backward(q, k, v, o, lse, bias)
1130
+ ctx.causal = causal
1131
+ return o
1132
+
1133
+ @staticmethod
1134
+ def backward(ctx, do):
1135
+ q, k, v, o, lse, bias = ctx.saved_tensors
1136
+ assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
1137
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1138
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1139
+ with torch.inference_mode():
1140
+ dq = torch.empty_like(q)
1141
+ dk = torch.empty_like(k)
1142
+ dv = torch.empty_like(v)
1143
+ _flash_attn_backward(
1144
+ do,
1145
+ q,
1146
+ k,
1147
+ v,
1148
+ o,
1149
+ lse,
1150
+ dq,
1151
+ dk,
1152
+ dv,
1153
+ bias=bias,
1154
+ causal=ctx.causal,
1155
+ softmax_scale=ctx.softmax_scale,
1156
+ )
1157
+ return dq, dk, dv, None, None, None
1158
+
1159
+
1160
+ flash_attn_func = FlashAttnFunc.apply
modeling_bert.py CHANGED
@@ -55,6 +55,10 @@ from transformers.utils import (
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
 
 
 
 
58
 
59
  try:
60
  from tqdm.autonotebook import trange
@@ -281,7 +285,7 @@ class JinaBertEmbeddings(nn.Module):
281
 
282
 
283
  class JinaBertSelfAttention(nn.Module):
284
- def __init__(self, config, position_embedding_type=None):
285
  super().__init__()
286
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
287
  config, "embedding_size"
@@ -290,6 +294,13 @@ class JinaBertSelfAttention(nn.Module):
290
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
291
  f"heads ({config.num_attention_heads})"
292
  )
 
 
 
 
 
 
 
293
 
294
  self.num_attention_heads = config.num_attention_heads
295
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -333,6 +344,17 @@ class JinaBertSelfAttention(nn.Module):
333
  output_attentions: Optional[bool] = False,
334
  bias: Optional[torch.FloatTensor] = None,
335
  ) -> Tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
336
  mixed_query_layer = self.query(hidden_states)
337
 
338
  # If this is instantiated as a cross-attention module, the keys
 
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
58
+ try:
59
+ from .flash_attn_triton import flash_attn_func
60
+ except Exception:
61
+ flash_attn_func = None
62
 
63
  try:
64
  from tqdm.autonotebook import trange
 
285
 
286
 
287
  class JinaBertSelfAttention(nn.Module):
288
+ def __init__(self, config: JinaBertConfig, position_embedding_type=None):
289
  super().__init__()
290
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
291
  config, "embedding_size"
 
294
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
295
  f"heads ({config.num_attention_heads})"
296
  )
297
+
298
+ self.with_flash = config.with_flash
299
+ if self.with_flash:
300
+ if flash_attn_func is None:
301
+ raise ValueError(
302
+ f"flash_attn_func is None, please install flash_attn_triton"
303
+ )
304
 
305
  self.num_attention_heads = config.num_attention_heads
306
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
 
344
  output_attentions: Optional[bool] = False,
345
  bias: Optional[torch.FloatTensor] = None,
346
  ) -> Tuple[torch.Tensor]:
347
+ if self.with_flash:
348
+ b, s, h = hidden_states.shape
349
+ q = self.query(hidden_states)
350
+ k = self.key(hidden_states)
351
+ v = self.value(hidden_states)
352
+ # B x S x hidden_dim -> B x S x num_heads x head_dim
353
+ q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
354
+ k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
355
+ v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
356
+ attn = flash_attn_func(q, k, v, bias)
357
+ return (attn.view(b, s, h),)
358
  mixed_query_layer = self.query(hidden_states)
359
 
360
  # If this is instantiated as a cross-attention module, the keys