larryvrh commited on
Commit
2a566c9
1 Parent(s): 3c083f2

Initial commit

Browse files
Files changed (5) hide show
  1. flash_attn_triton.py +861 -0
  2. llama_vocab_pruned_32k.json +0 -0
  3. modeling.py +255 -0
  4. tokenizers.py +244 -0
  5. webui.py +142 -0
flash_attn_triton.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
+ update imports to use 'triton_pre_mlir'
4
+
5
+ *Experimental* implementation of FlashAttention in Triton.
6
+ Tested with triton==2.0.0.dev20221202.
7
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
+ other than 64:
9
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
+ We'll update this implementation with the new Triton backend once this is fixed.
11
+
12
+ We use the FlashAttention implementation from Phil Tillet a starting point.
13
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
+
15
+ Changes:
16
+ - Implement both causal and non-causal attention.
17
+ - Implement both self-attention and cross-attention.
18
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
19
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
20
+ - Support attention bias.
21
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
22
+ - Make the backward for d=128 much faster by reducing register spilling.
23
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
+ small batch size * nheads.
25
+
26
+ Caution:
27
+ - This is an *experimental* implementation. The forward pass should be quite robust but
28
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
29
+ - This implementation has only been tested on A100.
30
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
31
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
32
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
+ that there are none left for other head dimensions.
35
+
36
+ Differences between this Triton version and the CUDA version:
37
+ - Triton version doesn't support dropout.
38
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
39
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
40
+ than CUDA forward + backward.
41
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
+ - Triton version supports attention bias, while CUDA version doesn't.
43
+ """
44
+
45
+ import math
46
+
47
+ import torch
48
+ import os
49
+
50
+ import triton_pre_mlir as triton
51
+ import triton_pre_mlir.compiler
52
+ import triton_pre_mlir.language as tl
53
+ import functools
54
+ import subprocess
55
+
56
+ if 'CONDA_PREFIX' in os.environ and 'CUDA_HOME' not in os.environ:
57
+ os.environ['CUDA_HOME'] = os.environ['CONDA_PREFIX']
58
+
59
+
60
+ @functools.lru_cache()
61
+ def libcuda_dirs():
62
+ libs = subprocess.check_output(["ldconfig", "-p"]).decode()
63
+ # each line looks like the following:
64
+ # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
65
+ locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so" in line]
66
+ dirs = [os.path.dirname(loc) for loc in locs]
67
+ msg = 'libcuda.so cannot found!\n'
68
+ if locs:
69
+ msg += 'Possible files are located at %s.' % str(locs)
70
+ msg += 'Please create a symlink of libcuda.so to any of the file.'
71
+ assert any(os.path.exists(os.path.join(path, 'libcuda.so')) for path in dirs), msg
72
+ return dirs
73
+
74
+
75
+ triton_pre_mlir.compiler.libcuda_dirs = libcuda_dirs
76
+
77
+
78
+ # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
79
+ # @triton.autotune(
80
+ # configs=[
81
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
82
+ # # This config has a race condition when EVEN_M == False, disabling it for now.
83
+ # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
84
+ # ],
85
+ # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
86
+ # )
87
+ @triton.heuristics(
88
+ {
89
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
90
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
91
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
92
+ }
93
+ )
94
+ @triton.jit
95
+ def _fwd_kernel(
96
+ Q, K, V, Bias, Out,
97
+ Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
98
+ softmax_scale,
99
+ stride_qb, stride_qh, stride_qm,
100
+ stride_kb, stride_kh, stride_kn,
101
+ stride_vb, stride_vh, stride_vn,
102
+ stride_bb, stride_bh, stride_bm,
103
+ stride_ob, stride_oh, stride_om,
104
+ nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
105
+ CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
106
+ BIAS_TYPE: tl.constexpr,
107
+ IS_CAUSAL: tl.constexpr,
108
+ BLOCK_HEADDIM: tl.constexpr,
109
+ EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
110
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
111
+ ):
112
+ start_m = tl.program_id(0)
113
+ off_hb = tl.program_id(1)
114
+ off_b = off_hb // nheads
115
+ off_h = off_hb % nheads
116
+ # off_b = tl.program_id(1)
117
+ # off_h = tl.program_id(2)
118
+ # off_hb = off_b * nheads + off_h
119
+ # initialize offsets
120
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
121
+ offs_n = tl.arange(0, BLOCK_N)
122
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
123
+ # Initialize pointers to Q, K, V
124
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
125
+ # https://github.com/openai/triton/issues/741
126
+ # I'm seeing a tiny bit of difference (5-7us)
127
+ q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
128
+ k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
129
+ v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
130
+ if BIAS_TYPE == 'vector':
131
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
132
+ elif BIAS_TYPE == 'matrix':
133
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
134
+ # initialize pointer to m and l
135
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
136
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
137
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
138
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
139
+ # load q: it will stay in SRAM throughout
140
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
141
+ # tl.load(q_ptrs), we get the wrong output!
142
+ if EVEN_M & EVEN_N:
143
+ if EVEN_HEADDIM:
144
+ q = tl.load(q_ptrs)
145
+ else:
146
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
147
+ else:
148
+ if EVEN_HEADDIM:
149
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
150
+ else:
151
+ q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
152
+ other=0.0)
153
+ # loop over k, v and update accumulator
154
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
155
+ for start_n in range(0, end_n, BLOCK_N):
156
+ start_n = tl.multiple_of(start_n, BLOCK_N)
157
+ # -- compute qk ----
158
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
159
+ if EVEN_HEADDIM:
160
+ k = tl.load(k_ptrs + start_n * stride_kn)
161
+ else:
162
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
163
+ else:
164
+ if EVEN_HEADDIM:
165
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
166
+ other=0.0)
167
+ else:
168
+ k = tl.load(k_ptrs + start_n * stride_kn,
169
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
170
+ other=0.0)
171
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
172
+ qk += tl.dot(q, k, trans_b=True)
173
+ # Trying to combine the two masks seem to make the result wrong
174
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
175
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
176
+ if IS_CAUSAL:
177
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
178
+ if BIAS_TYPE != 'none':
179
+ if BIAS_TYPE == 'vector':
180
+ if EVEN_N:
181
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
182
+ else:
183
+ bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
184
+ bias = bias[None, :]
185
+ elif BIAS_TYPE == 'matrix':
186
+ if EVEN_M & EVEN_N:
187
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
188
+ else:
189
+ bias = tl.load(b_ptrs + start_n,
190
+ mask=(offs_m[:, None] < seqlen_q)
191
+ & ((start_n + offs_n)[None, :] < seqlen_k),
192
+ other=0.0).to(tl.float32)
193
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
194
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
195
+ # to multiply with softmax_scale here.
196
+ qk = qk * softmax_scale + bias
197
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
198
+ p = tl.exp(qk - m_ij[:, None])
199
+ else:
200
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
201
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
202
+ l_ij = tl.sum(p, 1)
203
+
204
+ # scale acc_o
205
+ acc_o_scale = tl.exp(m_i - m_ij)
206
+
207
+ # # -- update output accumulator --
208
+ # BUG: have to store and immediately load
209
+ tl.store(t_ptrs, acc_o_scale)
210
+ acc_o_scale = tl.load(t_ptrs)
211
+ acc_o = acc_o * acc_o_scale[:, None]
212
+ # update acc_o
213
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
214
+ if EVEN_HEADDIM:
215
+ v = tl.load(v_ptrs + start_n * stride_vn)
216
+ else:
217
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
218
+ else:
219
+ if EVEN_HEADDIM:
220
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
221
+ other=0.0)
222
+ else:
223
+ v = tl.load(v_ptrs + start_n * stride_vn,
224
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
225
+ other=0.0)
226
+ p = p.to(v.dtype)
227
+ acc_o += tl.dot(p, v)
228
+
229
+ # -- update statistics
230
+ m_i = m_ij
231
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
232
+ lse_i = m_ij + tl.log(l_i_new)
233
+
234
+ o_scale = tl.exp(m_i - lse_i)
235
+ # BUG: have to store and immediately load
236
+ tl.store(t_ptrs, o_scale)
237
+ o_scale = tl.load(t_ptrs)
238
+ acc_o = acc_o * o_scale[:, None]
239
+ # rematerialize offsets to save registers
240
+ start_m = tl.program_id(0)
241
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
242
+ # write back l and m
243
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
244
+ tl.store(lse_ptrs, lse_i)
245
+ # initialize pointers to output
246
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
247
+ out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
248
+ if EVEN_M:
249
+ if EVEN_HEADDIM:
250
+ tl.store(out_ptrs, acc_o)
251
+ else:
252
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
253
+ else:
254
+ if EVEN_HEADDIM:
255
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
256
+ else:
257
+ tl.store(out_ptrs, acc_o,
258
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
259
+
260
+
261
+ @triton.jit
262
+ def _bwd_preprocess_do_o_dot(
263
+ Out, DO, Delta,
264
+ stride_ob, stride_oh, stride_om,
265
+ stride_dob, stride_doh, stride_dom,
266
+ nheads, seqlen_q, seqlen_q_rounded, headdim,
267
+ BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
268
+ ):
269
+ start_m = tl.program_id(0)
270
+ off_hb = tl.program_id(1)
271
+ off_b = off_hb // nheads
272
+ off_h = off_hb % nheads
273
+ # initialize offsets
274
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
275
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
276
+ # load
277
+ o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
278
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
279
+ do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
280
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
281
+ delta = tl.sum(o * do, axis=1)
282
+ # write-back
283
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
284
+
285
+
286
+ @triton.jit
287
+ def _bwd_store_dk_dv(
288
+ dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
289
+ EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
290
+ ):
291
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
292
+ # if we just call tl.store(dv_ptrs), there's a race condition
293
+ if EVEN_N & EVEN_M:
294
+ if EVEN_HEADDIM:
295
+ tl.store(dv_ptrs, dv)
296
+ tl.store(dk_ptrs, dk)
297
+ else:
298
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
299
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
300
+ else:
301
+ if EVEN_HEADDIM:
302
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
303
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
304
+ else:
305
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
306
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
307
+
308
+
309
+ @triton.jit
310
+ def _bwd_kernel_one_col_block(
311
+ start_n,
312
+ Q, K, V, Bias,
313
+ DO, DQ, DK, DV,
314
+ LSE, D,
315
+ softmax_scale,
316
+ stride_qm, stride_kn, stride_vn, stride_bm,
317
+ stride_dom, stride_dqm, stride_dkn, stride_dvn,
318
+ seqlen_q, seqlen_k, headdim,
319
+ ATOMIC_ADD: tl.constexpr,
320
+ BIAS_TYPE: tl.constexpr,
321
+ IS_CAUSAL: tl.constexpr,
322
+ BLOCK_HEADDIM: tl.constexpr,
323
+ EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
324
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
325
+ ):
326
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
327
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
328
+ # initialize row/col offsets
329
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
330
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
331
+ offs_m = tl.arange(0, BLOCK_M)
332
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
333
+ # initialize pointers to value-like data
334
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
335
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
336
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
337
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
338
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
339
+ if BIAS_TYPE == 'vector':
340
+ b_ptrs = Bias + offs_n
341
+ elif BIAS_TYPE == 'matrix':
342
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
343
+ # initialize dv and dk
344
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
345
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
346
+ # There seems to be some problem with Triton pipelining that makes results wrong for
347
+ # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
348
+ # may have zero step, and pipelining with the bias matrix could screw it up.
349
+ # So we just exit early.
350
+ if begin_m >= seqlen_q:
351
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
352
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
353
+ _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
354
+ EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
355
+ return
356
+ # k and v stay in SRAM throughout
357
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
358
+ # if we just call tl.load(k_ptrs), we get the wrong output!
359
+ if EVEN_N & EVEN_M:
360
+ if EVEN_HEADDIM:
361
+ k = tl.load(k_ptrs)
362
+ v = tl.load(v_ptrs)
363
+ else:
364
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
365
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
366
+ else:
367
+ if EVEN_HEADDIM:
368
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
369
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
370
+ else:
371
+ k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
372
+ other=0.0)
373
+ v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
374
+ other=0.0)
375
+ # loop over rows
376
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
377
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
378
+ start_m = tl.multiple_of(start_m, BLOCK_M)
379
+ offs_m_curr = start_m + offs_m
380
+ # load q, k, v, do on-chip
381
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
382
+ if EVEN_M & EVEN_HEADDIM:
383
+ q = tl.load(q_ptrs)
384
+ else:
385
+ if EVEN_HEADDIM:
386
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
387
+ else:
388
+ q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
389
+ & (offs_d[None, :] < headdim), other=0.0)
390
+ # recompute p = softmax(qk, dim=-1).T
391
+ qk = tl.dot(q, k, trans_b=True)
392
+ # Trying to combine the two masks seem to make the result wrong
393
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
394
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
395
+ if IS_CAUSAL:
396
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
397
+ if BIAS_TYPE != 'none':
398
+ tl.debug_barrier() # Race condition otherwise
399
+ if BIAS_TYPE == 'vector':
400
+ if EVEN_N:
401
+ bias = tl.load(b_ptrs).to(tl.float32)
402
+ else:
403
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
404
+ bias = bias[None, :]
405
+ elif BIAS_TYPE == 'matrix':
406
+ if EVEN_M & EVEN_N:
407
+ bias = tl.load(b_ptrs).to(tl.float32)
408
+ else:
409
+ bias = tl.load(b_ptrs,
410
+ mask=(offs_m_curr[:, None] < seqlen_q)
411
+ & (offs_n[None, :] < seqlen_k),
412
+ other=0.0).to(tl.float32)
413
+ qk = qk * softmax_scale + bias
414
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
415
+ # Also wrong for headdim=64.
416
+ if not (EVEN_M & EVEN_HEADDIM):
417
+ tl.debug_barrier()
418
+ lse_i = tl.load(LSE + offs_m_curr)
419
+ if BIAS_TYPE == 'none':
420
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
421
+ else:
422
+ p = tl.exp(qk - lse_i[:, None])
423
+ # compute dv
424
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
425
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
426
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
427
+ # the output is correct.
428
+ if EVEN_M & EVEN_HEADDIM:
429
+ do = tl.load(do_ptrs)
430
+ else:
431
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
432
+ do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
433
+ & (offs_d[None, :] < headdim), other=0.0)
434
+ # if EVEN_M:
435
+ # if EVEN_HEADDIM:
436
+ # do = tl.load(do_ptrs)
437
+ # else:
438
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
439
+ # else:
440
+ # if EVEN_HEADDIM:
441
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
442
+ # else:
443
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
444
+ # & (offs_d[None, :] < headdim), other=0.0)
445
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
446
+ # compute dp = dot(v, do)
447
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
448
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
449
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
450
+ if not (EVEN_M & EVEN_HEADDIM):
451
+ tl.debug_barrier()
452
+ dp = tl.dot(do, v, trans_b=True)
453
+ # There's a race condition for headdim=48
454
+ if not EVEN_HEADDIM:
455
+ tl.debug_barrier()
456
+ # compute ds = p * (dp - delta[:, None])
457
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
458
+ Di = tl.load(D + offs_m_curr)
459
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
460
+ # for BLOCK_HEADDIM=128
461
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
462
+ # compute dk = dot(ds.T, q)
463
+ dk += tl.dot(ds, q, trans_a=True)
464
+ # compute dq
465
+ if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix'
466
+ tl.debug_barrier()
467
+ if not ATOMIC_ADD:
468
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
469
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
470
+ dq += tl.dot(ds, k)
471
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
472
+ else:
473
+ if EVEN_HEADDIM:
474
+ dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
475
+ eviction_policy="evict_last")
476
+ dq += tl.dot(ds, k)
477
+ tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
478
+ eviction_policy="evict_last")
479
+ else:
480
+ dq = tl.load(dq_ptrs,
481
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
482
+ other=0.0, eviction_policy="evict_last")
483
+ dq += tl.dot(ds, k)
484
+ tl.store(dq_ptrs, dq,
485
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
486
+ eviction_policy="evict_last")
487
+ else: # If we're parallelizing across the seqlen_k dimension
488
+ dq = tl.dot(ds, k)
489
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
490
+ tl.atomic_add(dq_ptrs, dq)
491
+ else:
492
+ if EVEN_HEADDIM:
493
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
494
+ else:
495
+ tl.atomic_add(dq_ptrs, dq,
496
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
497
+ # increment pointers
498
+ dq_ptrs += BLOCK_M * stride_dqm
499
+ q_ptrs += BLOCK_M * stride_qm
500
+ do_ptrs += BLOCK_M * stride_dom
501
+ if BIAS_TYPE == 'matrix':
502
+ b_ptrs += BLOCK_M * stride_bm
503
+ # write-back
504
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
505
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
506
+ _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
507
+ EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
508
+
509
+
510
+ def init_to_zero(name):
511
+ return lambda nargs: nargs[name].zero_()
512
+
513
+
514
+ # TODO: Change BLOCK_M and BLOCK_N according to your GPU and num_warps according to headdim
515
+ @triton.autotune(
516
+ configs=[
517
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
518
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
519
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
520
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
521
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
522
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
523
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
524
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
525
+ ],
526
+ key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
527
+ )
528
+ @triton.heuristics(
529
+ {
530
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
531
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
532
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
533
+ }
534
+ )
535
+ @triton.jit
536
+ def _bwd_kernel(
537
+ Q, K, V, Bias,
538
+ DO, DQ, DK, DV,
539
+ LSE, D,
540
+ softmax_scale,
541
+ stride_qb, stride_qh, stride_qm,
542
+ stride_kb, stride_kh, stride_kn,
543
+ stride_vb, stride_vh, stride_vn,
544
+ stride_bb, stride_bh, stride_bm,
545
+ stride_dob, stride_doh, stride_dom,
546
+ stride_dqb, stride_dqh, stride_dqm,
547
+ stride_dkb, stride_dkh, stride_dkn,
548
+ stride_dvb, stride_dvh, stride_dvn,
549
+ nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
550
+ CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
551
+ BIAS_TYPE: tl.constexpr,
552
+ IS_CAUSAL: tl.constexpr,
553
+ BLOCK_HEADDIM: tl.constexpr,
554
+ SEQUENCE_PARALLEL: tl.constexpr,
555
+ EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
556
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
557
+ ):
558
+ off_hb = tl.program_id(1)
559
+ off_b = off_hb // nheads
560
+ off_h = off_hb % nheads
561
+ # offset pointers for batch/head
562
+ Q += off_b * stride_qb + off_h * stride_qh
563
+ K += off_b * stride_kb + off_h * stride_kh
564
+ V += off_b * stride_vb + off_h * stride_vh
565
+ DO += off_b * stride_dob + off_h * stride_doh
566
+ DQ += off_b * stride_dqb + off_h * stride_dqh
567
+ DK += off_b * stride_dkb + off_h * stride_dkh
568
+ DV += off_b * stride_dvb + off_h * stride_dvh
569
+ if BIAS_TYPE != 'none':
570
+ Bias += off_b * stride_bb + off_h * stride_bh
571
+ # pointer to row-wise quantities in value-like data
572
+ D += off_hb * seqlen_q_rounded
573
+ LSE += off_hb * seqlen_q_rounded
574
+ if not SEQUENCE_PARALLEL:
575
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
576
+ for start_n in range(0, num_block_n):
577
+ _bwd_kernel_one_col_block(
578
+ start_n,
579
+ Q, K, V, Bias,
580
+ DO, DQ, DK, DV,
581
+ LSE, D,
582
+ softmax_scale,
583
+ stride_qm, stride_kn, stride_vn, stride_bm,
584
+ stride_dom, stride_dqm, stride_dkn, stride_dvn,
585
+ seqlen_q, seqlen_k, headdim,
586
+ ATOMIC_ADD=False,
587
+ BIAS_TYPE=BIAS_TYPE,
588
+ IS_CAUSAL=IS_CAUSAL,
589
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
590
+ EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
591
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
592
+ )
593
+ else:
594
+ start_n = tl.program_id(0)
595
+ _bwd_kernel_one_col_block(
596
+ start_n,
597
+ Q, K, V, Bias,
598
+ DO, DQ, DK, DV,
599
+ LSE, D,
600
+ softmax_scale,
601
+ stride_qm, stride_kn, stride_vn, stride_bm,
602
+ stride_dom, stride_dqm, stride_dkn, stride_dvn,
603
+ seqlen_q, seqlen_k, headdim,
604
+ ATOMIC_ADD=True,
605
+ BIAS_TYPE=BIAS_TYPE,
606
+ IS_CAUSAL=IS_CAUSAL,
607
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
608
+ EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
609
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
610
+ )
611
+
612
+
613
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
614
+ # shape constraints
615
+ batch, seqlen_q, nheads, d = q.shape
616
+ _, seqlen_k, _, _ = k.shape
617
+ assert k.shape == (batch, seqlen_k, nheads, d)
618
+ assert v.shape == (batch, seqlen_k, nheads, d)
619
+ assert d <= 128, 'FlashAttention only support head dimensions up to 128'
620
+ assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
621
+ assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
622
+ assert q.is_cuda and k.is_cuda and v.is_cuda
623
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
624
+
625
+ has_bias = bias is not None
626
+ bias_type = 'none'
627
+ if has_bias:
628
+ assert bias.dtype in [q.dtype, torch.float]
629
+ assert bias.is_cuda
630
+ assert bias.dim() == 4
631
+ if bias.stride(-1) != 1:
632
+ bias = bias.contiguous()
633
+ if bias.shape[2:] == (1, seqlen_k):
634
+ bias_type = 'vector'
635
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
636
+ bias_type = 'matrix'
637
+ else:
638
+ raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
639
+ ' or (seqlen_q, seqlen_k)')
640
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
641
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
642
+
643
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
644
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
645
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
646
+ o = torch.empty_like(q)
647
+
648
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
649
+ BLOCK = 128
650
+ num_warps = 4 if d <= 64 else 8
651
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
652
+ _fwd_kernel[grid](
653
+ q, k, v, bias, o,
654
+ lse, tmp,
655
+ softmax_scale,
656
+ q.stride(0), q.stride(2), q.stride(1),
657
+ k.stride(0), k.stride(2), k.stride(1),
658
+ v.stride(0), v.stride(2), v.stride(1),
659
+ *bias_strides,
660
+ o.stride(0), o.stride(2), o.stride(1),
661
+ nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
662
+ seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
663
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
664
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
665
+ bias_type, causal, BLOCK_HEADDIM,
666
+ BLOCK_M=BLOCK, BLOCK_N=BLOCK,
667
+ num_warps=num_warps,
668
+ num_stages=1,
669
+ )
670
+ return o, lse, softmax_scale # softmax_scale could have been updated
671
+
672
+
673
+ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
674
+ # Make sure that the last dimension is contiguous
675
+ if do.stride(-1) != 1:
676
+ do = do.contiguous()
677
+ batch, seqlen_q, nheads, d = q.shape
678
+ _, seqlen_k, _, _ = k.shape
679
+ # assert d in {16, 32, 64, 128}
680
+ assert d <= 128
681
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
682
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
683
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
684
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
685
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
686
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
687
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
688
+ delta = torch.empty_like(lse)
689
+ # delta = torch.zeros_like(lse)
690
+
691
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
692
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
693
+ _bwd_preprocess_do_o_dot[grid](
694
+ o, do, delta,
695
+ o.stride(0), o.stride(2), o.stride(1),
696
+ do.stride(0), do.stride(2), do.stride(1),
697
+ nheads, seqlen_q, seqlen_q_rounded, d,
698
+ BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
699
+ )
700
+
701
+ has_bias = bias is not None
702
+ bias_type = 'none'
703
+ if has_bias:
704
+ assert bias.dtype in [q.dtype, torch.float]
705
+ assert bias.is_cuda
706
+ assert bias.dim() == 4
707
+ assert bias.stride(-1) == 1
708
+ if bias.shape[2:] == (1, seqlen_k):
709
+ bias_type = 'vector'
710
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
711
+ bias_type = 'matrix'
712
+ else:
713
+ raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
714
+ ' or (seqlen_q, seqlen_k)')
715
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
716
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
717
+
718
+ # BLOCK_M = 128
719
+ # BLOCK_N = 64
720
+ # num_warps = 4
721
+ grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
722
+ batch * nheads)
723
+ _bwd_kernel[grid](
724
+ q, k, v, bias,
725
+ do, dq_accum, dk, dv,
726
+ lse, delta,
727
+ softmax_scale,
728
+ q.stride(0), q.stride(2), q.stride(1),
729
+ k.stride(0), k.stride(2), k.stride(1),
730
+ v.stride(0), v.stride(2), v.stride(1),
731
+ *bias_strides,
732
+ do.stride(0), do.stride(2), do.stride(1),
733
+ dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
734
+ dk.stride(0), dk.stride(2), dk.stride(1),
735
+ dv.stride(0), dv.stride(2), dv.stride(1),
736
+ nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
737
+ seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
738
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
739
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
740
+ bias_type, causal, BLOCK_HEADDIM,
741
+ # SEQUENCE_PARALLEL=False,
742
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
743
+ # num_warps=num_warps,
744
+ # num_stages=1,
745
+ )
746
+ dq.copy_(dq_accum)
747
+
748
+
749
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
750
+
751
+ @staticmethod
752
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
753
+ """
754
+ qkv: (batch, seqlen, 3, nheads, headdim)
755
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
756
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
757
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
758
+ """
759
+ # Make sure that the last dimension is contiguous
760
+ if qkv.stride(-1) != 1:
761
+ qkv = qkv.contiguous()
762
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
763
+ qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
764
+ softmax_scale=softmax_scale
765
+ )
766
+ ctx.save_for_backward(qkv, o, lse, bias)
767
+ ctx.causal = causal
768
+ return o
769
+
770
+ @staticmethod
771
+ def backward(ctx, do):
772
+ qkv, o, lse, bias = ctx.saved_tensors
773
+ assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
774
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
775
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
776
+ with torch.inference_mode():
777
+ dqkv = torch.empty_like(qkv)
778
+ _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
779
+ dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
780
+ bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
781
+ return dqkv, None, None, None
782
+
783
+
784
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
785
+
786
+
787
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
788
+
789
+ @staticmethod
790
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
791
+ """
792
+ q: (batch, seqlen_q, nheads, headdim)
793
+ kv: (batch, seqlen_k, 2, nheads, headdim)
794
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
795
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
796
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
797
+ """
798
+ # Make sure that the last dimension is contiguous
799
+ q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
800
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
801
+ q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
802
+ )
803
+ ctx.save_for_backward(q, kv, o, lse, bias)
804
+ ctx.causal = causal
805
+ return o
806
+
807
+ @staticmethod
808
+ def backward(ctx, do):
809
+ q, kv, o, lse, bias = ctx.saved_tensors
810
+ if len(ctx.needs_input_grad) >= 3:
811
+ assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
812
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
813
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
814
+ with torch.inference_mode():
815
+ dq = torch.empty_like(q)
816
+ dkv = torch.empty_like(kv)
817
+ _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse,
818
+ dq, dkv[:, :, 0], dkv[:, :, 1],
819
+ bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
820
+ return dq, dkv, None, None, None
821
+
822
+
823
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
824
+
825
+
826
+ class FlashAttnFunc(torch.autograd.Function):
827
+
828
+ @staticmethod
829
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
830
+ """
831
+ q: (batch_size, seqlen_q, nheads, headdim)
832
+ k, v: (batch_size, seqlen_k, nheads, headdim)
833
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
834
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
835
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
836
+ """
837
+ # Make sure that the last dimension is contiguous
838
+ q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
839
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
840
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
841
+ )
842
+ ctx.save_for_backward(q, k, v, o, lse, bias)
843
+ ctx.causal = causal
844
+ return o
845
+
846
+ @staticmethod
847
+ def backward(ctx, do):
848
+ q, k, v, o, lse, bias = ctx.saved_tensors
849
+ assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
850
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
851
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
852
+ with torch.inference_mode():
853
+ dq = torch.empty_like(q)
854
+ dk = torch.empty_like(k)
855
+ dv = torch.empty_like(v)
856
+ _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
857
+ bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
858
+ return dq, dk, dv, None, None, None
859
+
860
+
861
+ flash_attn_func = FlashAttnFunc.apply
llama_vocab_pruned_32k.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import *
6
+ from flash_attn import flash_attn_func
7
+ from flash_attn_triton import flash_attn_func as flash_attn_func_triton
8
+ from math import ceil
9
+
10
+
11
+ class AttentionBackend(Enum):
12
+ Naive = 0
13
+ FlashAttentionCuda = 1
14
+ FlashAttentionTriton = 2
15
+
16
+
17
+ global_config = {
18
+ 'attn_backend': AttentionBackend.Naive
19
+ }
20
+
21
+
22
+ @dataclass
23
+ class TransformerConfig:
24
+ vocab_size: int = -1,
25
+ num_layers: int = -1,
26
+ num_heads: int = -1,
27
+ hidden_size: int = -1,
28
+ max_seq_len: int = -1,
29
+ root_model: 'ToyTransformer' = None
30
+ device: torch.device = torch.device('cpu')
31
+ dtype: torch.dtype = torch.float32
32
+
33
+
34
+ def expand_attn_mask(custom_attn_mask: torch.Tensor):
35
+ B, T = custom_attn_mask.shape
36
+ mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
37
+ seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
38
+ return seq_index_mask & (torch.tril(mask) > 0)
39
+
40
+
41
+ # expand attn mask to cu_seqlens for flash attn
42
+ def expand_attn_mask_to_seq_lengths(attn_mask: torch.Tensor):
43
+ attn_mask = attn_mask.to('cpu')
44
+ seq_len = attn_mask.shape[0] * attn_mask.shape[1]
45
+ disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1)
46
+ return torch.cat([torch.nonzero(disjoint_point.view((-1,))), torch.tensor([[seq_len]])]).to(dtype=torch.int32)
47
+
48
+
49
+ # naive RoPE implementation following https://arxiv.org/pdf/2104.09864.pdf
50
+ def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
51
+ assert dim % 2 == 0
52
+ freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim)
53
+ freqs = torch.repeat_interleave(freqs, 2)
54
+ v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
55
+ v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
56
+ v2 = v2 * torch.tensor([1, -1] * (dim // 2))
57
+ indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)])
58
+ return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device)
59
+
60
+
61
+ def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None):
62
+ v1, v2, indices = rope_cache
63
+ seq_len, dim = x.shape[1:]
64
+ if positions is None:
65
+ v1 = v1[:seq_len, :]
66
+ v2 = v2[:seq_len, :]
67
+ else:
68
+ v1 = v1[positions, torch.arange(dim)].view((-1, dim))
69
+ v2 = v2[positions, torch.arange(dim)].view((-1, dim))
70
+ applied_x = x * v1 + (x * v2)[:, :, indices]
71
+ return applied_x
72
+
73
+
74
+ # Optimized RoPE implementation adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py
75
+ def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
76
+ freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)))
77
+ t = torch.arange(seq_len, device=freqs.device)
78
+ freqs = torch.outer(t, freqs).float()
79
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
80
+ return freqs_cis.to(device)
81
+
82
+
83
+ def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
84
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
85
+ if positions is None and x.shape[1] < rope_cache.shape[0]:
86
+ freqs_cis = rope_cache[:x.shape[1], :]
87
+ elif positions is not None:
88
+ freqs_cis = rope_cache[positions, :]
89
+ else:
90
+ freqs_cis = rope_cache
91
+ freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)])
92
+
93
+ applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2)
94
+ return applied_x.type_as(x)
95
+
96
+
97
+ # RMSNorm implementation following https://arxiv.org/pdf/1910.07467.pdf
98
+ class RMSNorm(nn.Module):
99
+ def __init__(self, hidden_size, dtype, eps=1e-6):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype))
102
+ self.eps = eps
103
+
104
+ def forward(self, x: torch.Tensor):
105
+ x_ = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
106
+ return self.weight * x_
107
+
108
+
109
+ class AttentionHead(nn.Module):
110
+ def __init__(self, config: TransformerConfig):
111
+ super().__init__()
112
+ self.config = config
113
+ self.head_size = config.hidden_size // config.num_heads
114
+ self.dtype = config.dtype
115
+ self.q_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
116
+ self.k_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
117
+ self.v_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
118
+
119
+ def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
120
+ kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
121
+ q = self.q_proj(x)
122
+ k = self.k_proj(x)
123
+ v = self.v_proj(x)
124
+
125
+ # if global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
126
+ # padding the position indices for alignment
127
+ # positions = torch.tensor([kv_cache[0].shape[1]] * q.shape[1]).to(q.device) if kv_cache is not None else torch.arange(0, x.shape[1], 1).to(q.device)
128
+
129
+ positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None
130
+ q = apply_rope_fast(q, self.config.root_model.rope_cache, positions)
131
+ k = apply_rope_fast(k, self.config.root_model.rope_cache, positions)
132
+
133
+ if kv_cache is not None:
134
+ k = torch.concat([kv_cache[0], k], dim=1)
135
+ v = torch.concat([kv_cache[1], v], dim=1)
136
+
137
+ if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
138
+ q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
139
+ attn_result = flash_attn_func(q, k, v, causal=True)
140
+ q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
141
+ elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
142
+ q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
143
+ attn_result = flash_attn_func_triton(q, k, v, attn_masked_bias.unsqueeze(1) if attn_masked_bias is not None else None,
144
+ True if kv_cache is None else False)
145
+ q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
146
+ else:
147
+ attn_score = (q @ k.permute(0, 2, 1) / (self.head_size ** 0.5)) + attn_masked_bias
148
+ attn_result = torch.softmax(attn_score, dim=2) @ v
149
+
150
+ return attn_result, [k, v]
151
+
152
+
153
+ class MultiHeadAttention(nn.Module):
154
+ def __init__(self, config: TransformerConfig):
155
+ super().__init__()
156
+ self.config = config
157
+ self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
158
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype)
159
+
160
+ def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
161
+ kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
162
+ head_outputs = [head(x, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) for idx, head in
163
+ enumerate(self.attn_heads)]
164
+ return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]
165
+
166
+
167
+ class DecoderLayer(nn.Module):
168
+ def __init__(self, config: TransformerConfig):
169
+ super().__init__()
170
+ self.config = config
171
+ self.mha = MultiHeadAttention(config)
172
+ self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype)
173
+ self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype)
174
+ self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
175
+ self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
176
+ self.act = nn.GELU()
177
+
178
+ def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
179
+ kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
180
+ mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_masked_bias, kv_cache)
181
+ mha_output = x + mha_output
182
+ ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
183
+ return mha_output + ffn_output, new_kv_cache
184
+
185
+
186
+ class ToyTransformer(nn.Module):
187
+ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int,
188
+ device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32):
189
+ super().__init__()
190
+ self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device,
191
+ dtype)
192
+
193
+ self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)
194
+
195
+ self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype)
196
+
197
+ self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)])
198
+ self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)
199
+ self.to(device)
200
+
201
+ def forward(self, seq: torch.Tensor,
202
+ attn_mask: Optional[torch.Tensor] = None,
203
+ kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:
204
+ # sanity checks
205
+ assert attn_mask is None or kv_cache is None # No support for attn_mask and kv_cache both enabled
206
+ if kv_cache is not None:
207
+ assert seq.shape[0] == 1, 'kv_cache is not supported for batch inference'
208
+ # handle flash-attn triton alignment requirement (actually only needed for backward)
209
+ seq_length = seq.shape[1]
210
+ if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
211
+ if attn_mask is None: # forcibly enable attn_mask due to padding
212
+ attn_mask = torch.ones(seq.shape, device=self.device)
213
+ pad_length = (ceil(seq_length / 128) * 128) - seq_length
214
+ seq = nn.functional.pad(seq, (0, pad_length))
215
+ attn_mask = nn.functional.pad(attn_mask, (0, pad_length))
216
+
217
+ # handle attn_bias
218
+ if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
219
+ assert attn_mask is None, 'FlashAttn-Cuda does not support custom attn_mask'
220
+ attn_masked_bias = None
221
+ elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and attn_mask is None:
222
+ attn_masked_bias = None
223
+ elif attn_mask is not None:
224
+ attn_masked_bias = expand_attn_mask(attn_mask)
225
+ elif attn_mask is None and kv_cache is None:
226
+ attn_masked_bias = expand_attn_mask(torch.ones(seq.shape, device=self.device))
227
+ elif kv_cache is not None:
228
+ attn_masked_bias = torch.ones((1, seq.shape[1], seq.shape[1]), dtype=torch.bool, device=self.device)
229
+ else:
230
+ attn_masked_bias = None
231
+
232
+ if attn_masked_bias is not None:
233
+ mask_zero = torch.tensor(0, dtype=self.config.dtype)
234
+ mask_val = torch.tensor(torch.finfo(self.config.dtype).min / 2, dtype=self.config.dtype)
235
+ attn_masked_bias = torch.where(attn_masked_bias, mask_zero, mask_val).to(self.device)
236
+
237
+ hidden = self.sem_embed(seq)
238
+
239
+ new_kv_cache = []
240
+ for idx, decoder in enumerate(self.decoder_layers):
241
+ hidden, layer_kv_cache = decoder(hidden, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None)
242
+ new_kv_cache.append(layer_kv_cache)
243
+
244
+ logits = self.lm_head(hidden)
245
+
246
+ # remove padding for flash-attn triton
247
+ if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
248
+ logits = logits[:, :seq_length, :]
249
+ new_kv_cache = [[[cache[:, :seq_length, :] for cache in head] for head in layer] for layer in new_kv_cache]
250
+
251
+ return logits, new_kv_cache
252
+
253
+ @property
254
+ def device(self):
255
+ return next(self.parameters()).device
tokenizers.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import *
3
+ import re
4
+ import json
5
+ import numba
6
+
7
+
8
+ def sample_vocab(tokens: Iterable[str], vocab_size: Optional[int] = None,
9
+ vocab_coverage: Optional[float] = None) -> List[str]:
10
+ assert (vocab_size is not None and vocab_coverage is None) or \
11
+ (vocab_size is None and vocab_coverage is not None), "vocab_size [or] vocab_coverage need specified"
12
+
13
+ token_count = {}
14
+ for c in tokens:
15
+ token_count[c] = token_count.get(c, 0) + 1
16
+
17
+ if vocab_size is not None:
18
+ token_count = list(token_count.items())
19
+ token_count.sort(key=lambda i: i[1], reverse=True)
20
+ vocab = [c[0] for c in token_count[:vocab_size]]
21
+ else:
22
+ total_count = sum(token_count.values())
23
+ token_freq = [(c, i / total_count) for c, i in token_count.items()]
24
+ token_freq.sort(key=lambda i: i[1], reverse=True)
25
+ freq_sum = 0.0
26
+ split = 0
27
+ for split in range(len(token_freq)):
28
+ freq_sum += token_freq[split][1]
29
+ if freq_sum >= vocab_coverage:
30
+ break
31
+ vocab = [c[0] for c in token_freq[:split + 1]]
32
+ return vocab
33
+
34
+
35
+ class CharTokenizer:
36
+ def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None,
37
+ reserved_vocab: Optional[List[str]] = None, unk_literal: str = '<unk>'):
38
+ if reserved_vocab is not None:
39
+ assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab'
40
+ assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab'
41
+ else:
42
+ reserved_vocab = []
43
+ vocab = reserved_vocab.copy() if reserved_vocab is not None else []
44
+ vocab += sample_vocab(corpus, vocab_size - len(vocab) - 1, vocab_coverage)
45
+ self.s2i = {s: i + 1 for i, s in enumerate(vocab)}
46
+ self.s2i[unk_literal] = 0
47
+ self.i2s = {i: s for s, i in self.s2i.items()}
48
+ self.special_vocab = set(reserved_vocab + [unk_literal])
49
+ self.unk_literal = unk_literal
50
+
51
+ def encode(self, text: str) -> List[int]:
52
+ cursor, ids = 0, []
53
+ while cursor < len(text):
54
+ for s in self.special_vocab:
55
+ if text[cursor:].startswith(s):
56
+ ids.append(self.s2i[s])
57
+ cursor += len(s)
58
+ break
59
+ else:
60
+ ids.append(self.s2i.get(text[cursor], self.s2i.get(self.unk_literal)))
61
+ cursor += 1
62
+ return ids
63
+
64
+ def decode(self, ids: List[int]) -> str:
65
+ return ''.join(self.i2s[i] for i in ids)
66
+
67
+ def get_vocab_mapping(self):
68
+ return self.s2i
69
+
70
+
71
+ class WordTokenizer:
72
+ def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None,
73
+ reserved_vocab: Optional[List[str]] = None, unk_literal: str = '<unk>'):
74
+ if reserved_vocab is not None:
75
+ assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab'
76
+ assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab'
77
+ else:
78
+ reserved_vocab = []
79
+ vocab = reserved_vocab.copy() if reserved_vocab is not None else []
80
+
81
+ tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(r'(\w+)|(\W)', corpus))
82
+ vocab += sample_vocab(tokens, vocab_size - len(vocab) - 1, vocab_coverage)
83
+
84
+ self.s2i = {s: i + 1 for i, s in enumerate(vocab)}
85
+ self.s2i[unk_literal] = 0
86
+ self.i2s = {i: s for s, i in self.s2i.items()}
87
+ self.special_vocab = set(reserved_vocab + [unk_literal])
88
+ self.unk_literal = unk_literal
89
+
90
+ def encode(self, text: str) -> List[int]:
91
+ specials = '|'.join(f'{i}' for i in self.special_vocab)
92
+ tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(rf'({specials}|\w+)|(\W)', text))
93
+ return [self.s2i.get(t, self.s2i[self.unk_literal]) for t in tokens]
94
+
95
+ def decode(self, ids: List[int]) -> str:
96
+ return ''.join(self.i2s[i] for i in ids)
97
+
98
+ def get_vocab_mapping(self):
99
+ return self.s2i
100
+
101
+ def get_vocab_size(self):
102
+ return len(self.s2i)
103
+
104
+ def eval_vocab_coverage(self, corpus: str):
105
+ encoded = self.encode(corpus)
106
+ return 1 - (len([i for i in encoded if i == 0]) / len(encoded))
107
+
108
+
109
+ class TRIETokenizer:
110
+ @staticmethod
111
+ def split_bytes(data: bytes):
112
+ return [b'%c' % i for i in data]
113
+
114
+ def __init__(self, vocab_file: str):
115
+ self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children
116
+ with open(vocab_file, 'r') as file:
117
+ vocabs = json.load(file)
118
+ vocabs.sort(key=lambda i: len(i['bytes']))
119
+ for entry in vocabs:
120
+ self.add_vocab(bytes(entry['bytes']), entry['id'])
121
+
122
+ self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs}
123
+
124
+ def add_vocab(self, vocab_bytes: bytes, vocab_id: int):
125
+ cur_node_idx = 0
126
+ for i, b in enumerate(vocab_bytes):
127
+ cur_node = self.nodes[cur_node_idx]
128
+ if cur_node[3][b] != -1:
129
+ cur_node_idx = cur_node[3][b]
130
+ else:
131
+ new_node_idx = len(self.nodes)
132
+ self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1,
133
+ [-1 for _ in range(256)]))
134
+ cur_node[3][b] = new_node_idx
135
+ cur_node_idx = new_node_idx
136
+
137
+ def attempt_match(self, match_bytes: bytes):
138
+ match_length, match_token_id = -1, -1
139
+ cur_node_idx, depth = 0, 0
140
+ for i, b in enumerate(match_bytes):
141
+ match_node_idx = self.nodes[cur_node_idx][3][b]
142
+ if match_node_idx == -1:
143
+ break
144
+ cur_node = self.nodes[match_node_idx]
145
+ if cur_node[2] != -1:
146
+ match_length = depth
147
+ match_token_id = cur_node[2]
148
+ cur_node_idx = match_node_idx
149
+ depth += 1
150
+ return match_length, match_token_id
151
+
152
+ def encode(self, text: str):
153
+ text_bytes = text.encode('utf-8')
154
+ tokens, length = [], 0
155
+ while length < len(text_bytes):
156
+ offset, token_id = self.attempt_match(text_bytes[length:])
157
+ assert offset >= 0
158
+ tokens.append(token_id)
159
+ length += offset + 1
160
+ return tokens
161
+
162
+ def decode(self, token_ids: List[int]):
163
+ return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace')
164
+
165
+ def get_vocab_size(self):
166
+ return len(self.id_to_bytes)
167
+
168
+
169
+ @numba.njit
170
+ def trie_attempt_match_jit(trie_nodes, match_bytes: bytes):
171
+ match_length, match_token_id = -1, -1
172
+ cur_node_idx, depth = 0, 0
173
+ for i, b in enumerate(match_bytes):
174
+ match_node_idx = trie_nodes[cur_node_idx][3][int(b)]
175
+ if match_node_idx == -1:
176
+ break
177
+ cur_node = trie_nodes[match_node_idx]
178
+ if cur_node[2] != -1:
179
+ match_length = depth
180
+ match_token_id = cur_node[2]
181
+ cur_node_idx = match_node_idx
182
+ depth += 1
183
+ return match_length, match_token_id
184
+
185
+
186
+ @numba.njit
187
+ def trie_encode_jit(trie_nodes, text_bytes: bytes):
188
+ tokens, length = [], 0
189
+ while length < len(text_bytes):
190
+ offset, token_id = trie_attempt_match_jit(trie_nodes, text_bytes[length:])
191
+ assert offset >= 0
192
+ tokens.append(token_id)
193
+ length += offset + 1
194
+ return tokens
195
+
196
+
197
+ class TRIETokenizerFast:
198
+ def __init__(self, vocab_file: str):
199
+ self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children
200
+ with open(vocab_file, 'r') as file:
201
+ vocabs = json.load(file)
202
+ vocabs.sort(key=lambda i: len(i['bytes']))
203
+ for entry in vocabs:
204
+ self.add_vocab(bytes(entry['bytes']), entry['id'])
205
+
206
+ self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs}
207
+
208
+ self.nodesJit = numba.typed.List(self.nodes)
209
+
210
+ def add_vocab(self, vocab_bytes: bytes, vocab_id: int):
211
+ cur_node_idx = 0
212
+ for i, b in enumerate(vocab_bytes):
213
+ cur_node = self.nodes[cur_node_idx]
214
+ if cur_node[3][b] != -1:
215
+ cur_node_idx = cur_node[3][b]
216
+ else:
217
+ new_node_idx = len(self.nodes)
218
+ self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1,
219
+ [-1 for _ in range(256)]))
220
+ cur_node[3][b] = new_node_idx
221
+ cur_node_idx = new_node_idx
222
+
223
+ def encode(self, text: str):
224
+ return trie_encode_jit(self.nodesJit, text.encode('utf-8'))
225
+
226
+ def decode(self, token_ids: List[int]):
227
+ return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace')
228
+
229
+ def get_vocab_size(self):
230
+ return len(self.id_to_bytes)
231
+
232
+ # if __name__ == '__main__':
233
+ # tokenizer = TRIETokenizerFast('llama_vocab_pruned_20k.json')
234
+ # with open('corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file:
235
+ # text = file.read()[:10240]
236
+ #
237
+ # total_tokens = 0
238
+ # s = time.time()
239
+ # for i in range(1000):
240
+ # encoded = tokenizer.encode(text)
241
+ # total_tokens += len(encoded)
242
+ # print(len(encoded))
243
+ # e = time.time()
244
+ # print(f'{e - s:.3f} secs, {total_tokens / (e - s):.3f} tps')
webui.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modeling import global_config, ToyTransformer, AttentionBackend
3
+ import torch
4
+ from tokenizers import TRIETokenizer
5
+ from threading import Thread
6
+ import bisect
7
+
8
+ if torch.cuda.is_available():
9
+ g_device = torch.device('cpu')
10
+ else:
11
+ g_device = torch.device('cpu')
12
+ global_config['attn_backend'] = AttentionBackend.Naive
13
+
14
+ g_SEQ_LEN = 1024
15
+ g_HIDDEN_SIZE = 768
16
+ g_NUM_HEADS = 12
17
+ g_NUM_LAYERS = 12
18
+ g_DTYPE = torch.float32
19
+
20
+ g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json')
21
+ g_model = ToyTransformer(g_tokenizer.get_vocab_size(), g_NUM_LAYERS, g_NUM_HEADS, g_HIDDEN_SIZE, g_SEQ_LEN, g_device, g_DTYPE)
22
+
23
+ g_model.load_state_dict(torch.load('model.pt', map_location='cpu'))
24
+
25
+
26
+ def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
27
+ max_new_tokens=20, total_tokens=None,
28
+ end_tokens=None,
29
+ enable_kv_cache=True):
30
+ model.eval()
31
+
32
+ feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
33
+
34
+ all_tokens = feed_tokens.copy()
35
+ if total_tokens is not None:
36
+ max_new_tokens = max(0, total_tokens - len(feed_tokens))
37
+
38
+ with torch.no_grad():
39
+ kv_cache = None
40
+ for _ in range(max_new_tokens):
41
+ logits, kv_cache = model.forward(
42
+ torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
43
+ kv_cache=kv_cache)
44
+ logits = logits[0][-1].cpu()
45
+ if not enable_kv_cache:
46
+ kv_cache = None
47
+
48
+ # apply repetition penalty
49
+ logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
50
+ logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
51
+ logits.scatter_(0, torch.tensor(all_tokens), logits_rep)
52
+
53
+ # apply temperature
54
+ logits /= max(temperature, 1e-6)
55
+
56
+ probs = torch.softmax(logits, dim=0)
57
+
58
+ # apply top-p
59
+ ordered_probs, ordered_indices = torch.sort(probs, descending=True)
60
+ cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
61
+ top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
62
+ ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
63
+ sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()
64
+
65
+ all_tokens.append(sampled_index)
66
+ feed_tokens = [sampled_index]
67
+
68
+ if end_tokens is not None and sampled_index in end_tokens:
69
+ break
70
+
71
+ yield feed_tokens
72
+ return
73
+
74
+
75
+ def predict(user_input, history, max_length, top_p, temperature, rep_penalty, retry):
76
+ if retry and len(history) == 0:
77
+ yield []
78
+ return
79
+ elif retry:
80
+ user_input = history[-1][0]
81
+ history = history[:-1]
82
+
83
+ history.append((user_input, ""))
84
+
85
+ encoded_inputs = [(g_tokenizer.encode('User:' + h[0]), g_tokenizer.encode('Assistant:' + h[1])) for h in history]
86
+ taken_rounds, taken_rounds_length = [], 0
87
+ while len(taken_rounds) < len(encoded_inputs):
88
+ round_pair = encoded_inputs[len(encoded_inputs) - 1 - len(taken_rounds)]
89
+ if len(round_pair[0]) + len(round_pair[1]) + taken_rounds_length >= g_SEQ_LEN - max_length:
90
+ break
91
+ taken_rounds.append(round_pair)
92
+ taken_rounds_length += len(round_pair[0]) + len(round_pair[1])
93
+ taken_rounds = taken_rounds[::-1]
94
+
95
+ input_tokens = g_tokenizer.encode('<s>A chat between User and Assistant.')
96
+ for round_pair in taken_rounds:
97
+ input_tokens += g_tokenizer.encode('\n') + round_pair[0] + g_tokenizer.encode('\n') + round_pair[1]
98
+ # print(taken_rounds, g_tokenizer.decode(input_tokens))
99
+ for response in generate(g_model, g_tokenizer, input_tokens, temperature, top_p, rep_penalty, max_length, end_tokens=g_tokenizer.encode('</s>')):
100
+ history[-1] = (history[-1][0], history[-1][1] + g_tokenizer.decode(response))
101
+ yield history
102
+
103
+
104
+ def main():
105
+ css = '''
106
+ .contain {max-width:50}
107
+
108
+ #chatbot {min-height:500px}
109
+ '''
110
+
111
+ with gr.Blocks(css=css) as demo:
112
+ gr.HTML('<h1 align="center">ToyTransformer</h1>')
113
+
114
+ chatbot = gr.Chatbot(elem_id='chatbot')
115
+ with gr.Column():
116
+ user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1, container=False)
117
+ with gr.Row():
118
+ submitBtn = gr.Button("Send", variant="primary")
119
+ retryBtn = gr.Button("Retry")
120
+ cancelBtn = gr.Button('Undo')
121
+ emptyBtn = gr.Button("Clear")
122
+ with gr.Row():
123
+ max_length = gr.Slider(0, 512, value=200, step=1, label="Max Response Tokens", interactive=True)
124
+ top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
125
+ temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
126
+ rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)
127
+
128
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
129
+ [chatbot], show_progress=False)
130
+ submitBtn.click(lambda: '', [], [user_input], show_progress=False)
131
+
132
+ retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
133
+ [chatbot], show_progress=False)
134
+
135
+ cancelBtn.click(lambda m: m[:-1], [chatbot], [chatbot], show_progress=False)
136
+
137
+ emptyBtn.click(lambda: [], outputs=[chatbot], show_progress=False)
138
+
139
+ demo.queue().launch(share=False, inbrowser=True)
140
+
141
+
142
+ main()