OpenNLPLab commited on
Commit
c7fb4c5
1 Parent(s): e03779e

Upgrade to lightning att2

Browse files
Files changed (2) hide show
  1. lightning_attention2.py +540 -0
  2. modeling_transnormer.py +4 -3
lightning_attention2.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 OpenNLPLab
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # coding=utf-8
16
+ import torch
17
+ import triton
18
+ import triton.language as tl
19
+
20
+
21
+ @triton.jit
22
+ def _fwd_kernel(
23
+ Q,
24
+ K,
25
+ V,
26
+ Out,
27
+ S,
28
+ stride_qz,
29
+ stride_qh,
30
+ stride_qm,
31
+ stride_qk,
32
+ stride_kz,
33
+ stride_kh,
34
+ stride_kn,
35
+ stride_kk,
36
+ stride_vz,
37
+ stride_vh,
38
+ stride_vn,
39
+ stride_ve,
40
+ stride_oz,
41
+ stride_oh,
42
+ stride_om,
43
+ stride_oe,
44
+ stride_sh,
45
+ Z,
46
+ H,
47
+ N_CTX,
48
+ BLOCK_M: tl.constexpr,
49
+ BLOCK_DMODEL_QK: tl.constexpr,
50
+ BLOCK_N: tl.constexpr,
51
+ BLOCK_DMODEL_V: tl.constexpr,
52
+ IS_CAUSAL: tl.constexpr,
53
+ USE_DECAY: tl.constexpr,
54
+ ):
55
+ start_m = tl.program_id(0)
56
+ off_hz = tl.program_id(1)
57
+ off_h = off_hz % H
58
+ # initialize offsets
59
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
60
+ offs_n = tl.arange(0, BLOCK_N)
61
+ offs_k = tl.arange(0, BLOCK_DMODEL_QK)
62
+ offs_e = tl.arange(0, BLOCK_DMODEL_V)
63
+ # get current offset of q k v
64
+ off_q = (off_hz * stride_qh + offs_m[:, None] * stride_qm
65
+ + offs_k[None, :] * stride_qk)
66
+ off_k = (off_hz * stride_kh + offs_n[:, None] * stride_kn
67
+ + offs_k[None, :] * stride_kk)
68
+ off_v = (off_hz * stride_vh + offs_n[:, None] * stride_vn
69
+ + offs_e[None, :] * stride_ve)
70
+ off_o = (off_hz * stride_oh + offs_m[:, None] * stride_om
71
+ + offs_e[None, :] * stride_oe)
72
+
73
+ # Initialize pointers to Q, K, V
74
+ q_ptrs = Q + off_q
75
+ k_ptrs = K + off_k
76
+ v_ptrs = V + off_v
77
+
78
+ # initialize pointer to m and l
79
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32)
80
+ # load q: it will stay in SRAM throughout
81
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
82
+ # loop over k, v and update accumulator
83
+ lo = 0
84
+ # print(start_m)
85
+ hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
86
+ for start_n in range(lo, hi, BLOCK_N):
87
+ # -- load k, v --
88
+ k = tl.load(
89
+ k_ptrs + start_n * stride_kn,
90
+ mask=(start_n + offs_n)[:, None] < N_CTX,
91
+ other=0.0,
92
+ )
93
+ v = tl.load(
94
+ v_ptrs + start_n * stride_vn,
95
+ mask=(start_n + offs_n)[:, None] < N_CTX,
96
+ other=0.0,
97
+ )
98
+ # -- compute qk ---
99
+ # qk = tl.dot(q, k)
100
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
101
+ # qk += tl.dot(q, k, trans_b=True)
102
+ qk += tl.dot(q, tl.trans(k))
103
+ if IS_CAUSAL:
104
+ index = offs_m[:, None] - (start_n + offs_n[None, :])
105
+ if USE_DECAY:
106
+ S_block_ptr = S + off_h * stride_sh
107
+ s = tl.load(S_block_ptr)
108
+ s_index = s * index
109
+ s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
110
+ qk = tl.exp(s_index) * qk
111
+ else:
112
+ qk = tl.where(index >= 0, qk, 0)
113
+ acc += tl.dot(qk, v.to(qk.dtype))
114
+
115
+ out_ptrs = Out + off_o
116
+ tl.store(out_ptrs, acc.to(q.dtype), mask=offs_m[:, None] < N_CTX)
117
+
118
+
119
+ @triton.jit
120
+ def _bwd_kernel_kv(
121
+ Q,
122
+ K,
123
+ V,
124
+ S,
125
+ DO,
126
+ DQ,
127
+ DK,
128
+ DV,
129
+ stride_qz,
130
+ stride_qh,
131
+ stride_qm,
132
+ stride_qk,
133
+ stride_kz,
134
+ stride_kh,
135
+ stride_kn,
136
+ stride_kk,
137
+ stride_vz,
138
+ stride_vh,
139
+ stride_vn,
140
+ stride_ve,
141
+ stride_oz,
142
+ stride_oh,
143
+ stride_om,
144
+ stride_oe,
145
+ stride_sh,
146
+ Z,
147
+ H,
148
+ N_CTX,
149
+ num_block,
150
+ BLOCK_M: tl.constexpr,
151
+ BLOCK_DMODEL_QK: tl.constexpr,
152
+ BLOCK_N: tl.constexpr,
153
+ BLOCK_DMODEL_V: tl.constexpr,
154
+ CAUSAL: tl.constexpr,
155
+ USE_DECAY: tl.constexpr,
156
+ ):
157
+ start_n = tl.program_id(0)
158
+ off_hz = tl.program_id(1)
159
+
160
+ off_z = off_hz // H
161
+ off_h = off_hz % H
162
+ # offset pointers for batch/head
163
+ Q += off_z * stride_qz + off_h * stride_qh
164
+ K += off_z * stride_kz + off_h * stride_kh
165
+ V += off_z * stride_vz + off_h * stride_vh
166
+ DO += off_z * stride_oz + off_h * stride_oh
167
+ DQ += off_z * stride_qz + off_h * stride_qh
168
+ DK += off_z * stride_kz + off_h * stride_kh
169
+ DV += off_z * stride_vz + off_h * stride_vh
170
+
171
+ # start of q
172
+ if CAUSAL:
173
+ lo = start_n * BLOCK_M
174
+ else:
175
+ lo = 0
176
+ # initialize row/col offsets
177
+ # seqlence offset
178
+ offs_qm = lo + tl.arange(0, BLOCK_M)
179
+ offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
180
+ # feature offset
181
+ offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
182
+ offs_ve = tl.arange(0, BLOCK_DMODEL_V)
183
+ # row block index
184
+ offs_m = tl.arange(0, BLOCK_M)
185
+ # initialize pointers to value-like data
186
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_qkk[None, :] * stride_qk)
187
+ k_ptrs = K + (offs_kvn[:, None] * stride_kn
188
+ + offs_qkk[None, :] * stride_kk)
189
+ v_ptrs = V + (offs_kvn[:, None] * stride_vn + offs_ve[None, :] * stride_ve)
190
+ do_ptrs = DO + (offs_qm[:, None] * stride_om
191
+ + offs_ve[None, :] * stride_oe)
192
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
193
+ + offs_qkk[None, :] * stride_qk)
194
+ # initialize dv amd dk
195
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32)
196
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32)
197
+ # k and v stay in SRAM throughout
198
+ k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
199
+ v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
200
+ # loop over rows
201
+ for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
202
+ offs_m_curr = start_m + offs_m
203
+ # load q, k, v, do on-chip
204
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
205
+ qk = tl.dot(q, tl.trans(k))
206
+ # qk = tl.dot(q, k, trans_b=True)
207
+ if CAUSAL:
208
+ index = offs_m_curr[:, None] - offs_kvn[None, :]
209
+ if USE_DECAY:
210
+ S_block_ptr = S + off_h * stride_sh
211
+ s = tl.load(S_block_ptr)
212
+ s_index = s * index
213
+ s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
214
+ s = tl.exp(s_index)
215
+ qk = qk * s
216
+ else:
217
+ qk = tl.where(index >= 0, qk, 0)
218
+
219
+ p = qk
220
+ # compute dv
221
+ do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
222
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do)
223
+ dp = tl.dot(do, tl.trans(v).to(do.dtype))
224
+ if CAUSAL:
225
+ if USE_DECAY:
226
+ dp = dp * s
227
+ else:
228
+ dp = tl.where(index >= 0, dp, 0)
229
+
230
+ dk += tl.dot(tl.trans(dp.to(q.dtype)), q).to(tl.float32)
231
+
232
+ # increment pointers
233
+ q_ptrs += BLOCK_M * stride_qm
234
+ do_ptrs += BLOCK_M * stride_om
235
+ # write-back
236
+ dv_ptrs = DV + (offs_kvn[:, None] * stride_vn
237
+ + offs_ve[None, :] * stride_ve)
238
+ dk_ptrs = DK + (offs_kvn[:, None] * stride_kn
239
+ + offs_qkk[None, :] * stride_kk)
240
+ tl.store(dv_ptrs, dv, mask=offs_kvn[:, None] < N_CTX)
241
+ tl.store(dk_ptrs, dk, mask=offs_kvn[:, None] < N_CTX)
242
+
243
+
244
+ @triton.jit
245
+ def _bwd_kernel_q(
246
+ Q,
247
+ K,
248
+ V,
249
+ S,
250
+ DO,
251
+ DQ,
252
+ DK,
253
+ DV,
254
+ stride_qz,
255
+ stride_qh,
256
+ stride_qm,
257
+ stride_qk,
258
+ stride_kz,
259
+ stride_kh,
260
+ stride_kn,
261
+ stride_kk,
262
+ stride_vz,
263
+ stride_vh,
264
+ stride_vn,
265
+ stride_ve,
266
+ stride_oz,
267
+ stride_oh,
268
+ stride_om,
269
+ stride_oe,
270
+ stride_sh,
271
+ Z,
272
+ H,
273
+ N_CTX,
274
+ num_block,
275
+ BLOCK_M: tl.constexpr,
276
+ BLOCK_DMODEL_QK: tl.constexpr,
277
+ BLOCK_N: tl.constexpr,
278
+ BLOCK_DMODEL_V: tl.constexpr,
279
+ CAUSAL: tl.constexpr,
280
+ USE_DECAY: tl.constexpr,
281
+ ):
282
+ start_m = tl.program_id(0)
283
+ off_hz = tl.program_id(1)
284
+ off_z = off_hz // H
285
+ off_h = off_hz % H
286
+ # offset pointers for batch/head
287
+ K += off_z * stride_kz + off_h * stride_kh
288
+ V += off_z * stride_vz + off_h * stride_vh
289
+ DO += off_z * stride_oz + off_h * stride_oh
290
+ DQ += off_z * stride_qz + off_h * stride_qh
291
+ # feature offset
292
+ offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
293
+ offs_ve = tl.arange(0, BLOCK_DMODEL_V)
294
+ # row block index
295
+ offs_m = tl.arange(0, BLOCK_M)
296
+ # row block index
297
+ offs_qm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
298
+ # do
299
+ do_ptrs = DO + (offs_qm[:, None] * stride_om
300
+ + offs_ve[None, :] * stride_oe)
301
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
302
+ + offs_qkk[None, :] * stride_qk)
303
+
304
+ do = tl.load(do_ptrs, mask=offs_qm[:, None] < N_CTX, other=0.0)
305
+
306
+ dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32)
307
+ lo = 0
308
+ hi = (start_m + 1) * BLOCK_M if CAUSAL else N_CTX
309
+
310
+ offs_m_curr = start_m * BLOCK_M + offs_m
311
+
312
+ for start_n in range(0, num_block):
313
+ offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
314
+ k_ptrs = K + (offs_kvn[:, None] * stride_kn
315
+ + offs_qkk[None, :] * stride_kk)
316
+ v_ptrs = V + (offs_kvn[:, None] * stride_vn
317
+ + offs_ve[None, :] * stride_ve)
318
+ # k and v stay in SRAM throughout
319
+ k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
320
+ v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
321
+ # dp = do vT
322
+ dp = tl.dot(do, tl.trans(v).to(do.dtype))
323
+ if CAUSAL:
324
+ index = offs_m_curr[:, None] - offs_kvn[None, :]
325
+ if USE_DECAY:
326
+ S_block_ptr = S + off_h * stride_sh
327
+ s = tl.load(S_block_ptr)
328
+ s_index = s * index
329
+ s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
330
+ s = tl.exp(s_index)
331
+ dp = dp * s
332
+ else:
333
+ dp = tl.where(index >= 0, dp, 0)
334
+ # dq = dq + dp k
335
+ dq += tl.dot(dp.to(k.dtype), k)
336
+
337
+ tl.store(dq_ptrs, dq, mask=offs_qm[:, None] < N_CTX)
338
+
339
+
340
+ class _attention(torch.autograd.Function):
341
+
342
+ @staticmethod
343
+ def forward(ctx, q, k, v, causal, s):
344
+ q = q.contiguous()
345
+ k = k.contiguous()
346
+ v = v.contiguous()
347
+ s = s.contiguous()
348
+ # only support for Ampere now
349
+ capability = torch.cuda.get_device_capability()
350
+ if capability[0] < 8:
351
+ raise RuntimeError(
352
+ "Lightning attention currently only supported for compute capability >= 80"
353
+ )
354
+ # shape constraints
355
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
356
+ # right
357
+ o = torch.empty(
358
+ (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]),
359
+ dtype=q.dtype,
360
+ device=q.device,
361
+ )
362
+
363
+ BLOCK_M = 128
364
+ BLOCK_N = 64
365
+ num_warps = 4 if Lk <= 64 else 8
366
+ num_stages = 1
367
+
368
+ grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
369
+ use_decay = s.shape[0] > 0
370
+ _fwd_kernel[grid](
371
+ q,
372
+ k,
373
+ v,
374
+ o,
375
+ s,
376
+ q.stride(0),
377
+ q.stride(1),
378
+ q.stride(2),
379
+ q.stride(3),
380
+ k.stride(0),
381
+ k.stride(1),
382
+ k.stride(2),
383
+ k.stride(3),
384
+ v.stride(0),
385
+ v.stride(1),
386
+ v.stride(2),
387
+ v.stride(3),
388
+ o.stride(0),
389
+ o.stride(1),
390
+ o.stride(2),
391
+ o.stride(3),
392
+ s.stride(0),
393
+ q.shape[0],
394
+ q.shape[1],
395
+ q.shape[2],
396
+ BLOCK_M=BLOCK_M,
397
+ BLOCK_DMODEL_QK=Lk,
398
+ BLOCK_N=BLOCK_N,
399
+ BLOCK_DMODEL_V=Lv,
400
+ IS_CAUSAL=causal,
401
+ USE_DECAY=use_decay,
402
+ num_warps=num_warps,
403
+ num_stages=num_stages,
404
+ )
405
+
406
+ ctx.save_for_backward(q, k, v, s)
407
+ ctx.grid = grid
408
+ ctx.BLOCK_M = BLOCK_M
409
+ ctx.BLOCK_DMODEL_QK = Lk
410
+ ctx.BLOCK_N = BLOCK_N
411
+ ctx.BLOCK_DMODEL_V = Lv
412
+ ctx.causal = causal
413
+ ctx.use_decay = use_decay
414
+ return o
415
+
416
+ @staticmethod
417
+ def backward(ctx, do):
418
+ q, k, v, s = ctx.saved_tensors
419
+ BLOCK_M = 32
420
+ BLOCK_N = 32
421
+ num_warps = 4
422
+ num_stages = 1
423
+
424
+ do = do.contiguous()
425
+ dq = torch.zeros_like(q, dtype=torch.float32)
426
+ dk = torch.empty_like(k)
427
+ dv = torch.empty_like(v)
428
+
429
+ grid_kv = (triton.cdiv(k.shape[2],
430
+ BLOCK_N), k.shape[0] * k.shape[1], 1)
431
+ _bwd_kernel_kv[grid_kv](
432
+ q,
433
+ k,
434
+ v,
435
+ s,
436
+ do,
437
+ dq,
438
+ dk,
439
+ dv,
440
+ q.stride(0),
441
+ q.stride(1),
442
+ q.stride(2),
443
+ q.stride(3),
444
+ k.stride(0),
445
+ k.stride(1),
446
+ k.stride(2),
447
+ k.stride(3),
448
+ v.stride(0),
449
+ v.stride(1),
450
+ v.stride(2),
451
+ v.stride(3),
452
+ do.stride(0),
453
+ do.stride(1),
454
+ do.stride(2),
455
+ do.stride(3),
456
+ s.stride(0),
457
+ q.shape[0],
458
+ q.shape[1],
459
+ q.shape[2],
460
+ grid_kv[0],
461
+ BLOCK_M=BLOCK_M,
462
+ BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
463
+ BLOCK_N=BLOCK_N,
464
+ BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
465
+ CAUSAL=ctx.causal,
466
+ USE_DECAY=ctx.use_decay,
467
+ num_warps=num_warps,
468
+ num_stages=num_stages,
469
+ )
470
+
471
+ grid_q = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
472
+
473
+ _bwd_kernel_q[grid_q](
474
+ q,
475
+ k,
476
+ v,
477
+ s,
478
+ do,
479
+ dq,
480
+ dk,
481
+ dv,
482
+ q.stride(0),
483
+ q.stride(1),
484
+ q.stride(2),
485
+ q.stride(3),
486
+ k.stride(0),
487
+ k.stride(1),
488
+ k.stride(2),
489
+ k.stride(3),
490
+ v.stride(0),
491
+ v.stride(1),
492
+ v.stride(2),
493
+ v.stride(3),
494
+ do.stride(0),
495
+ do.stride(1),
496
+ do.stride(2),
497
+ do.stride(3),
498
+ s.stride(0),
499
+ q.shape[0],
500
+ q.shape[1],
501
+ q.shape[2],
502
+ grid_q[0],
503
+ BLOCK_M=BLOCK_M,
504
+ BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
505
+ BLOCK_N=BLOCK_N,
506
+ BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
507
+ CAUSAL=ctx.causal,
508
+ USE_DECAY=ctx.use_decay,
509
+ num_warps=num_warps,
510
+ num_stages=num_stages,
511
+ )
512
+
513
+ return dq.to(q.dtype), dk, dv, None, None
514
+
515
+
516
+ attention = _attention.apply
517
+
518
+
519
+ def lightning_attention(q, k, v, causal, ed):
520
+ d = q.shape[-1]
521
+ e = v.shape[-1]
522
+ # arr = f(d)
523
+ if d >= 128:
524
+ m = 128
525
+ else:
526
+ m = 64
527
+ arr = [m * i for i in range(d // m + 1)]
528
+ if arr[-1] != d:
529
+ arr.append(d)
530
+ n = len(arr)
531
+ output = 0
532
+ for i in range(n - 1):
533
+ s = arr[i]
534
+ e = arr[i + 1]
535
+ q1 = q[..., s:e]
536
+ k1 = k[..., s:e]
537
+ o = attention(q1, k1, v, causal, ed)
538
+ output = output + o
539
+
540
+ return output
modeling_transnormer.py CHANGED
@@ -63,7 +63,7 @@ BLOCK = 256
63
 
64
  if use_triton:
65
  try:
66
- from .lightning_attention import lightning_attention
67
 
68
  has_lightning_attention = True
69
  except (ImportError, ModuleNotFoundError):
@@ -345,8 +345,9 @@ class NormLinearAttention(nn.Module):
345
  k[:, :, i:i + 1],
346
  v[:, :, i:i + 1],
347
  )
348
- qkv = torch.einsum("... n e, ... e d -> ... n d",
349
- q[:, :, i:i + 1], kv)
 
350
  output.append(qkv)
351
  output = torch.concat(output, dim=-2)
352
 
 
63
 
64
  if use_triton:
65
  try:
66
+ from .lightning_attention2 import lightning_attention
67
 
68
  has_lightning_attention = True
69
  except (ImportError, ModuleNotFoundError):
 
345
  k[:, :, i:i + 1],
346
  v[:, :, i:i + 1],
347
  )
348
+ qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :,
349
+ i:i + 1],
350
+ kv.to(q.dtype))
351
  output.append(qkv)
352
  output = torch.concat(output, dim=-2)
353